File size: 7,274 Bytes
05d6778 8f30cba ddf7ac7 05d6778 ddf7ac7 8f30cba ddf7ac7 d1e095d 8f30cba ddf7ac7 8f30cba ddf7ac7 8f30cba ddf7ac7 05d6778 ddf7ac7 05d6778 8f30cba 05d6778 905f4ca ddf7ac7 05d6778 905f4ca ddf7ac7 05d6778 9045dc1 05d6778 ddf7ac7 05d6778 ddf7ac7 05d6778 5193efa 05d6778 ddf7ac7 05d6778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# Main library for WarBot
from transformers import AutoTokenizer ,AutoModelForCausalLM, AutoModelForSeq2SeqLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
#import textwrap
from textwrap3 import wrap
import replicate #imaging
# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
def remove_duplicates(S):
S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
S = S.split()
result = ""
for subst in S:
if subst not in result:
result += subst+" "
return result.rstrip()
def removeSigns(S):
last_index = max(S.rfind("."), S.rfind("!"))
if last_index >= 0:
S = S[:last_index+1]
return S
def prepare_punct():
# Prepare the Punctuation Model
# Important! Enable next line for Unix version (python related):
torch.backends.quantized.engine = 'qnnpack'
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
with open('latest_silero_models.yml', 'r') as yaml_file:
models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
model_conf = models.get('te_models').get('latest')
# Prepare punctuation fix
model_url = model_conf.get('package')
model_dir = "downloaded_model"
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, os.path.basename(model_url))
if not os.path.isfile(model_path):
torch.hub.download_url_to_file(model_url,
model_path,
progress=True)
imp = package.PackageImporter(model_path)
model_punct = imp.load_pickle("te_model", "model")
return model_punct
def initialize():
# Initializes all the settings
""" Loading the model """
fit_checkpoint = "WarBot"
tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
model_punсt = prepare_punct()
""" Initialize the translational model """
os.environ['REPLICATE_API_TOKEN'] = '2254e586b1380c49a948fd00d6802d45962492e4'
translation_model_name = "Helsinki-NLP/opus-mt-ru-en"
translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
""" Initialize the image model """
imageModel = replicate.models.get("stability-ai/stable-diffusion")
imgModel_version = imageModel.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478")
return (model, tokenizer, model_punсt, translation_model, translation_tokenizer, imgModel_version)
def translate(text:str,translation_model,translation_tokenizer):
# Translates from Russian to English
src = "ru" # source language
trg = "en" # target language
try:
batch = translation_tokenizer([text], return_tensors="pt")
generated_ids = translation_model.generate(**batch)
translated = translation_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
except:
translated = ""
return translated
def generate_image(prompt:str, imgModel_version):
# Generates an image from prompt and returns a url
prompt = prompt.replace("?","")
try:
output_url = imgModel_version.predict(prompt=prompt)[0]
except:
output_url = ""
return output_url
def split_string(string,n=256):
return [string[i:i+n] for i in range(0, len(string), n)]
def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
# encode the input, add the eos_token and return a tensor in Pytorch
try:
user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
+ quote + tokenizer.eos_token, return_tensors="pt")
# Better to force the lenparameter to be = {2}
except:
return "Exception in tokenization" # Exception in tokenization
chat_history_ids = user_inpit_ids # To be changed
tokens_count = len(tokenizer.encode(quote))
if tokens_count < 15:
no_repeat_ngram_size = 2
else:
no_repeat_ngram_size = 1
try:
output_id = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=200, #512
no_repeat_ngram_size=no_repeat_ngram_size, #3
do_sample=True, #True
top_k=50,#50
top_p=0.9, #0.9
temperature = temperature, # was 0.6, 0 for greedy
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
#device='cpu'
)
except:
return "Exception" # Exception in generation
response = tokenizer.decode(output_id[0], skip_special_tokens=True)
response = removeSigns(response)
response = response.split(quote)[-1] # Remove the Quote
response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
response) # Clear the response, remains only alpha-numerical values
response = remove_duplicates(re.sub(r"\d{4,}", "", response)) # Remove the consequent numbers with 4 or more digits
response = re.sub(r'\.\.+', '', response) # Remove the "....." thing
if len(response)>200:
resps = wrap(response,200)
for i in range(len(resps)):
try:
resps[i] = model_punct.enhance_text(resps[i], lan='ru')
response = ''.join(resps)
except:
return "" # Excepion in punctuation
else:
response = model_punct.enhance_text(response, lan='ru')
# Immanent postprocessing of the response
response = re.sub(r'[UNK]', '', response) # Remove the [UNK] thing
response = re.sub(r',+', ',', response) # Replace multi-commas with single one
response = re.sub(r'-+', ',', response) # Replace multi-dashes with single one
response = re.sub(r'\.\?', '?', response) # Fix the .? issue
response = re.sub(r'\,\?', '?', response) # Fix the ,? issue
response = re.sub(r'\.\!', '!', response) # Fix the .! issue
response = re.sub(r'\.\,', ',', response) # Fix the ,. issue
response = re.sub(r'\.\)', '.', response) # Fix the .) issue
response = response.replace('[]', '') # Fix the [] issue
return response
if __name__ == '__main__':
"""
quote = "Здравствуй, Жопа, Новый Год, выходи на ёлку!"
model, tokenizer, model_punct = initialize()
response = ""
while not response:
response = get_response(quote, model, tokenizer, model_punct,temperature=0.2)
print(response)
""" |