|
from transformers import GPTNeoForCausalLM, GPT2Tokenizer |
|
import gradio as gr |
|
|
|
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") |
|
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M") |
|
|
|
prompt = """This is a discussion between a person and Hassan Kane, an entrepreneur. |
|
|
|
person: What are you working on? |
|
Hassan: This new AI community building the future of Africa |
|
person: Where are you? |
|
Hassan: In Lagos for a week, then Paris or London. |
|
person: How's it going? |
|
Hassan: Not bad.. Just trying to hit EV (escape velocity) with my startup |
|
person: """ |
|
|
|
def my_split(s, seps): |
|
res = [s] |
|
for sep in seps: |
|
s, res = res, [] |
|
for seq in s: |
|
res += seq.split(sep) |
|
return res |
|
|
|
|
|
def chat_base(input): |
|
p = prompt + input |
|
input_ids = tokenizer(p, return_tensors="pt").input_ids |
|
gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.7, max_length=150,) |
|
gen_text = tokenizer.batch_decode(gen_tokens)[0] |
|
|
|
result = gen_text[len(p):] |
|
|
|
result = my_split(result, [']', '\n'])[1] |
|
|
|
if "Hassan: " in result: |
|
result = result.split("Hassan: ")[-1] |
|
|
|
return result |
|
|
|
import gradio as gr |
|
|
|
def chat(message): |
|
history = gr.get_state() or [] |
|
print(history) |
|
response = chat_base(message) |
|
history.append((message, response)) |
|
gr.set_state(history) |
|
html = "<div class='chatbot'>" |
|
for user_msg, resp_msg in history: |
|
html += f"<div class='user_msg'>{user_msg}</div>" |
|
html += f"<div class='resp_msg'>{resp_msg}</div>" |
|
html += "</div>" |
|
return response |
|
|
|
iface = gr.Interface(chat_base, gr.inputs.Textbox(label="Ask Hassan a Question"), "text", allow_screenshot=False, allow_flagging=False) |
|
iface.launch() |
|
|