Spaces:
Running
Running
import logging | |
import os | |
import warnings | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import uvicorn | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
from langchain_together import Together | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
# ========================== | |
# Logging Configuration | |
# ========================== | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger("LegalChatbot") | |
logger.debug("Initializing Legal Chatbot application...") | |
# ========================== | |
# Suppress Warnings | |
# ========================== | |
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False") | |
# ========================== | |
# Load Environment Variables | |
# ========================== | |
load_dotenv() | |
TOGETHER_AI_API = os.getenv("TOGETHER_AI") | |
HF_HOME = os.getenv("HF_HOME", "./cache") | |
os.environ["HF_HOME"] = HF_HOME | |
# Ensure the HF_HOME directory exists | |
os.makedirs(HF_HOME, exist_ok=True) | |
# Validate required environment variables | |
if not TOGETHER_AI_API: | |
raise ValueError("The TOGETHER_AI_API environment variable is missing. Please set it in your .env file.") | |
# ========================== | |
# Initialize Embeddings | |
# ========================== | |
try: | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"}, | |
) | |
logger.info("Embeddings successfully initialized.") | |
except Exception as e: | |
logger.error(f"Error initializing embeddings: {e}") | |
raise RuntimeError("Oops! Something went wrong while setting up embeddings. Please check the configuration and try again.") | |
# ========================== | |
# Load FAISS Vectorstore | |
# ========================== | |
try: | |
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True) | |
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3, "max-length": 512}) | |
logger.info("Vectorstore successfully loaded.") | |
except Exception as e: | |
logger.error(f"Error loading FAISS vectorstore: {e}") | |
raise RuntimeError("We couldn't load the vector database. Please ensure the database file is available and try again.") | |
# ========================== | |
# Define Prompt Template | |
# ========================== | |
prompt_template = """<s>[INST]You are a legal chatbot specializing in the Indian Penal Code. Provide concise, context-aware answers in a conversational tone. Avoid presenting the response as a question-answer format unless explicitly required. | |
If the answer cannot be derived from the given context, respond with: "I'm sorry, I couldn't find relevant information for your query." | |
CONTEXT: {context} | |
CHAT HISTORY: {chat_history} | |
QUESTION: {question} | |
ANSWER: | |
</s>[INST]""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"]) | |
# ========================== | |
# Initialize Together API | |
# ========================== | |
try: | |
llm = Together( | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
temperature=0.5, | |
max_tokens=1024, | |
together_api_key=TOGETHER_AI_API, | |
) | |
logger.info("Together API successfully initialized.") | |
except Exception as e: | |
logger.error(f"Error initializing Together API: {e}") | |
raise RuntimeError("Something went wrong with the Together API setup. Please verify your API key and configuration.") | |
# ========================== | |
# Conversational Retrieval Chain | |
# ========================== | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
memory=memory, | |
retriever=db_retriever, | |
combine_docs_chain_kwargs={"prompt": prompt}, | |
) | |
logger.info("Conversational Retrieval Chain initialized.") | |
# ========================== | |
# FastAPI Backend | |
# ========================== | |
app = FastAPI() | |
class ChatRequest(BaseModel): | |
question: str | |
class ChatResponse(BaseModel): | |
answer: str | |
async def root(): | |
return {"message": "Hello! Welcome to the Legal Chatbot. I'm here to assist you with your legal queries related to the Indian Penal Code. How can I help you today?"} | |
async def chat(request: ChatRequest): | |
try: | |
logger.debug(f"Received user question: {request.question}") | |
result = qa.invoke(input=request.question) | |
answer = result.get("answer") | |
if not answer or "The information is not available in the provided context" in answer: | |
answer = "I'm sorry, I couldn't find relevant information for your query. Please try rephrasing or providing more details." | |
return ChatResponse(answer=answer) | |
except Exception as e: | |
logger.error(f"Error during chat invocation: {e}") | |
raise HTTPException(status_code=500, detail="Oops! Something went wrong on our end. Please try again later.") | |
# ========================== | |
# Run Uvicorn Server | |
# ========================== | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="0.0.0.0", port=7860) | |