|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
import requests |
|
import redis |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
AutoModelForCausalLM, |
|
) |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.optim import AdamW |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import HTMLResponse |
|
import multiprocessing |
|
import time |
|
import uuid |
|
import random |
|
|
|
load_dotenv() |
|
|
|
REDIS_HOST = os.getenv('REDIS_HOST') |
|
REDIS_PORT = os.getenv('REDIS_PORT') |
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') |
|
|
|
app = FastAPI() |
|
|
|
default_language = "es" |
|
|
|
class ChatbotService: |
|
def __init__(self): |
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
self.model_name = "response_model" |
|
self.tokenizer_name = "response_tokenizer" |
|
|
|
def get_response(self, user_id, message, language=default_language): |
|
model = self.load_model_from_redis() |
|
tokenizer = self.load_tokenizer_from_redis() |
|
|
|
if model is None or tokenizer is None: |
|
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde." |
|
|
|
input_text = f"Usuario: {message} Asistente:" |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cpu") |
|
|
|
with torch.no_grad(): |
|
output = model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) |
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = response.replace(input_text, "").strip() |
|
|
|
return response |
|
|
|
def load_model_from_redis(self): |
|
model_data_bytes = self.redis_client.get(f"model:{self.model_name}") |
|
if model_data_bytes: |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
return model |
|
else: |
|
return None |
|
|
|
def load_tokenizer_from_redis(self): |
|
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}") |
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes)) |
|
return tokenizer |
|
else: |
|
return None |
|
|
|
chatbot_service = ChatbotService() |
|
|
|
class UnifiedModel(nn.Module): |
|
def __init__(self, models): |
|
super(UnifiedModel, self).__init__() |
|
self.models = nn.ModuleList(models) |
|
hidden_size = self.models[0].config.hidden_size |
|
self.projection = nn.Linear(len(models) * 3, 768) |
|
self.classifier = nn.Linear(hidden_size, 3) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
hidden_states = [] |
|
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask): |
|
outputs = model( |
|
input_ids=input_id, |
|
attention_mask=attn_mask |
|
) |
|
hidden_states.append(outputs.logits) |
|
|
|
concatenated_hidden_states = torch.cat(hidden_states, dim=1) |
|
projected_features = self.projection(concatenated_hidden_states) |
|
logits = self.classifier(projected_features) |
|
return logits |
|
|
|
@staticmethod |
|
def load_model_from_redis(redis_client): |
|
model_name = "unified_model" |
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
if model_data_bytes: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
else: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
return UnifiedModel([model, model]) |
|
|
|
class SyntheticDataset(Dataset): |
|
def __init__(self, tokenizers, data): |
|
self.tokenizers = tokenizers |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
text = item['text'] |
|
label = item['label'] |
|
tokenized = {} |
|
for name, tokenizer in self.tokenizers.items(): |
|
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128) |
|
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"]) |
|
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"]) |
|
tokenized["labels"] = torch.tensor(label) |
|
return tokenized |
|
|
|
conversation_history = {} |
|
|
|
@app.post("/process") |
|
async def process(request: Request): |
|
data = await request.json() |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
tokenizers = {} |
|
models = {} |
|
|
|
model_name = "unified_model" |
|
tokenizer_name = "unified_tokenizer" |
|
|
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
|
|
if model_data_bytes: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
else: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
models[model_name] = model |
|
|
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes)) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizers[tokenizer_name] = tokenizer |
|
|
|
unified_model = UnifiedModel(list(models.values())) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
if data.get("train"): |
|
user_data = data.get("user_data", []) |
|
if not user_data: |
|
user_data = [ |
|
{"text": "Hola", "label": 1}, |
|
{"text": "Necesito ayuda", "label": 2}, |
|
{"text": "No entiendo", "label": 0} |
|
] |
|
|
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": user_data |
|
})) |
|
|
|
return {"message": "Training data received. Model will be updated asynchronously."} |
|
|
|
elif data.get("message"): |
|
user_id = data.get("user_id") |
|
text = data['message'] |
|
language = data.get("language", default_language) |
|
|
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(text) |
|
|
|
contextualized_text = " ".join(conversation_history[user_id][-3:]) |
|
|
|
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()] |
|
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs] |
|
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs] |
|
|
|
with torch.no_grad(): |
|
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
|
|
response = chatbot_service.get_response(user_id, contextualized_text, language) |
|
|
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": [{"text": contextualized_text, "label": predicted_class}] |
|
})) |
|
|
|
return {"answer": response} |
|
|
|
else: |
|
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.") |
|
|
|
def get_chatbot_response(user_id, question, predicted_class, language): |
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(question) |
|
|
|
return chatbot_service.get_response(user_id, question, language) |
|
|
|
@app.get("/") |
|
async def get_home(): |
|
user_id = str(uuid.uuid4()) |
|
html_code = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>Chatbot</title> |
|
<style> |
|
body {{ |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f4f4f9; |
|
margin: 0; |
|
padding: 0; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
min-height: 100vh; |
|
}} |
|
|
|
.container {{ |
|
background-color: #fff; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); |
|
overflow: hidden; |
|
width: 400px; |
|
max-width: 90%; |
|
}} |
|
|
|
h1 {{ |
|
color: #333; |
|
text-align: center; |
|
padding: 20px; |
|
margin: 0; |
|
background-color: #f8f9fa; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
|
|
#chatbox {{ |
|
height: 400px; |
|
padding: 20px; |
|
overflow-y: auto; |
|
}} |
|
|
|
.message {{ |
|
margin-bottom: 15px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
max-width: 70%; |
|
animation: slide-in 0.3s ease-out; |
|
}} |
|
|
|
.user-message {{ |
|
text-align: right; |
|
background-color: #eee; |
|
margin-left: 30%; |
|
}} |
|
|
|
.bot-message {{ |
|
text-align: left; |
|
background-color: #ccf5ff; |
|
margin-right: 30%; |
|
}} |
|
|
|
#input-area {{ |
|
display: flex; |
|
padding: 10px; |
|
background-color: #f8f9fa; |
|
border-top: 1px solid #eee; |
|
}} |
|
|
|
#message-input {{ |
|
flex: 1; |
|
padding: 10px; |
|
border: 1px solid #ccc; |
|
border-radius: 5px; |
|
margin-right: 10px; |
|
}} |
|
|
|
#send-button {{ |
|
padding: 10px 15px; |
|
background-color: #28a745; |
|
color: white; |
|
border: none; |
|
cursor: pointer; |
|
border-radius: 5px; |
|
transition: background-color 0.3s ease; |
|
}} |
|
|
|
#send-button:hover {{ |
|
background-color: #218838; |
|
}} |
|
|
|
@keyframes slide-in {{ |
|
from {{ |
|
transform: translateX(-100%); |
|
opacity: 0; |
|
}} |
|
to {{ |
|
transform: translateX(0); |
|
opacity: 1; |
|
}} |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<div id="chatbox"></div> |
|
<div id="input-area"> |
|
<input type="hidden" id="user-id" value="{user_id}"> |
|
<input type="text" id="message-input" placeholder="Escribe tu mensaje..."> |
|
<button id="send-button">Enviar</button> |
|
</div> |
|
</div> |
|
<script> |
|
const chatbox = document.getElementById('chatbox'); |
|
const messageInput = document.getElementById('message-input'); |
|
const sendButton = document.getElementById('send-button'); |
|
const userId = document.getElementById('user-id').value; |
|
|
|
sendButton.addEventListener('click', sendMessage); |
|
|
|
function sendMessage() {{ |
|
const message = messageInput.value; |
|
if (message.trim() === '') return; |
|
|
|
appendMessage('user', message); |
|
messageInput.value = ''; |
|
|
|
fetch('/process', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json' |
|
}}, |
|
body: JSON.stringify({{ message: message, user_id: userId, language: 'es' }}) |
|
}}) |
|
.then(response => response.json()) |
|
.then(data => {{ |
|
appendMessage('bot', data.answer); |
|
}}); |
|
}} |
|
|
|
function appendMessage(sender, message) {{ |
|
const messageElement = document.createElement('div'); |
|
messageElement.classList.add('message', `${{sender}}-message`); |
|
messageElement.textContent = message; |
|
chatbox.appendChild(messageElement); |
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
}} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code) |
|
|
|
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name): |
|
for model_name, model in models.items(): |
|
torch.save(model.state_dict(), model_name) |
|
with open(model_name, "rb") as f: |
|
redis_client.set(f"model:{model_name}", f.read()) |
|
|
|
for tokenizer_name, tokenizer in tokenizers.items(): |
|
tokens = tokenizer.get_vocab() |
|
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens)) |
|
|
|
def continuous_training(): |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
while True: |
|
try: |
|
data = redis_client.lpop("training_queue") |
|
if data: |
|
data = json.loads(data) |
|
unified_model = UnifiedModel.load_model_from_redis(redis_client) |
|
unified_model.train() |
|
|
|
train_dataset = SyntheticDataset(data["tokenizers"], data["data"]) |
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
optimizer = AdamW(unified_model.parameters(), lr=5e-5) |
|
|
|
for epoch in range(10): |
|
for batch in train_loader: |
|
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()] |
|
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in data["tokenizers"].keys()] |
|
labels = batch["labels"].to("cpu") |
|
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = nn.CrossEntropyLoss()(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
print(f"Epoch {epoch}, Loss {loss.item()}") |
|
|
|
push_to_redis( |
|
{"response_model": unified_model}, |
|
{"response_tokenizer": tokenizer}, |
|
redis_client, |
|
"response_model", |
|
"response_tokenizer", |
|
) |
|
time.sleep(10) |
|
except Exception as e: |
|
print(f"Error in continuous training: {e}") |
|
time.sleep(5) |
|
|
|
if __name__ == "__main__": |
|
training_process = multiprocessing.Process(target=continuous_training) |
|
training_process.start() |
|
|
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |