File size: 5,189 Bytes
ddf7ac7 dd560bf ddf7ac7 9045dc1 ddf7ac7 9045dc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from transformers import AutoTokenizer ,AutoModelForCausalLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
import textwrap
from textwrap3 import wrap
# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
def remove_duplicates(S):
S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
S = S.split()
result = ""
for subst in S:
if subst not in result:
result += subst+" "
return result.rstrip()
def removeSigns(S):
last_index = max(S.rfind("."), S.rfind("!"))
if last_index >= 0:
S = S[:last_index+1]
return S
def prepare_punct():
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
with open('latest_silero_models.yml', 'r') as yaml_file:
models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
model_conf = models.get('te_models').get('latest')
# Prepare punctuation fix
model_url = model_conf.get('package')
model_dir = "downloaded_model"
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, os.path.basename(model_url))
if not os.path.isfile(model_path):
torch.hub.download_url_to_file(model_url,
model_path,
progress=True)
imp = package.PackageImporter(model_path)
model_punct = imp.load_pickle("te_model", "model")
return model_punct
def initialize():
""" Loading the model """
torch.backends.quantized.engine = 'qnnpack' # Just for the specific machine architecture
fit_checkpoint = "WarBot"
tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
model_punсt = prepare_punct()
return (model,tokenizer,model_punсt)
def split_string(string,n=256):
return [string[i:i+n] for i in range(0, len(string), n)]
def get_response(quote:str,model,tokenizer,model_punct):
# encode the input, add the eos_token and return a tensor in Pytorch
user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
+ quote + tokenizer.eos_token, return_tensors="pt")
chat_history_ids = user_inpit_ids # To be changed
tokens_count = len(tokenizer.encode(quote))
if tokens_count < 15:
no_repeat_ngram_size = 2
else:
no_repeat_ngram_size = 1
output_id = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=200, #512
no_repeat_ngram_size=no_repeat_ngram_size, #3
do_sample=True, #True
top_k=50,#50
top_p=0.9, #0.9
temperature = 0.4, # was 0.6, 0 for greedy
#mask_token_id=tokenizer.mask_token_id,
eos_token_id=tokenizer.eos_token_id,
#unk_token_id=tokenizer.unk_token_id,
pad_token_id=tokenizer.pad_token_id,
#pad_token_id=tokenizer.eos_token_id,
#device='cpu'
)
response = tokenizer.decode(output_id[0], skip_special_tokens=True)
response = removeSigns(response)
response = response.split(quote)[-1] # Remove the Quote
response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
response) # Clear the response, remains only alpha-numerical values
response = remove_duplicates(re.sub(r"\d{4,}", "", response)) # Remove the consequent numbers with 4 or more digits
response = re.sub(r'\.\.+', '', response) # Remove the "....." thing
maxLen = 170
try:
if len(response)>maxLen: # We shall play with it
resps = wrap(response,maxLen)
for i in range(len(resps)):
resps[i] = model_punct.enhance_text(resps[i], lan='ru')
response = ''.join(resps)
else:
response = model_punct.enhance_text(response, lan='ru')
except:
pass # sometimes the string is getting too long
response = re.sub(r'[UNK]', '', response) # Remove the [UNK] thing
return response
#if __name__ == '__main__':
#model,tokenizer,model_punct = initialize()
#quote = "Это хорошо, но глядя на ролик, когда ефиопские толпы в Израиле громят машины и нападают на улице на израильтян - задумаешься, куда все движется"
#print('please wait...')
#response = wrap(get_response(quote,model,tokenizer,model_punct),60)
#for phrase in response:
# print(phrase)
|