lawapi / main.py
chaithanyashaji's picture
Update main.py
9f38060 verified
import logging
import os
import warnings
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from langchain_community.document_loaders import DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
# ==========================
# Logging Configuration
# ==========================
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("LegalChatbot")
logger.debug("Initializing Legal Chatbot application...")
# ==========================
# Suppress Warnings
# ==========================
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False")
# ==========================
# Load Environment Variables
# ==========================
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
HF_HOME = os.getenv("HF_HOME", "./cache")
os.environ["HF_HOME"] = HF_HOME
# Ensure the HF_HOME directory exists
os.makedirs(HF_HOME, exist_ok=True)
# Validate required environment variables
if not TOGETHER_AI_API:
raise ValueError("The TOGETHER_AI_API environment variable is missing. Please set it in your .env file.")
# ==========================
# Initialize Embeddings
# ==========================
try:
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
)
logger.info("Embeddings successfully initialized.")
except Exception as e:
logger.error(f"Error initializing embeddings: {e}")
raise RuntimeError("Oops! Something went wrong while setting up embeddings. Please check the configuration and try again.")
# ==========================
# Load FAISS Vectorstore
# ==========================
try:
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3, "max-length": 512})
logger.info("Vectorstore successfully loaded.")
except Exception as e:
logger.error(f"Error loading FAISS vectorstore: {e}")
raise RuntimeError("We couldn't load the vector database. Please ensure the database file is available and try again.")
# ==========================
# Define Prompt Template
# ==========================
prompt_template = """<s>[INST]You are a legal chatbot specializing in the Indian Penal Code. Provide concise, context-aware answers in a conversational tone. Avoid presenting the response as a question-answer format unless explicitly required.
If the answer cannot be derived from the given context, respond with: "I'm sorry, I couldn't find relevant information for your query."
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
</s>[INST]"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
# ==========================
# Initialize Together API
# ==========================
try:
llm = Together(
model="mistralai/Mistral-7B-Instruct-v0.2",
temperature=0.5,
max_tokens=1024,
together_api_key=TOGETHER_AI_API,
)
logger.info("Together API successfully initialized.")
except Exception as e:
logger.error(f"Error initializing Together API: {e}")
raise RuntimeError("Something went wrong with the Together API setup. Please verify your API key and configuration.")
# ==========================
# Conversational Retrieval Chain
# ==========================
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=db_retriever,
combine_docs_chain_kwargs={"prompt": prompt},
)
logger.info("Conversational Retrieval Chain initialized.")
# ==========================
# FastAPI Backend
# ==========================
app = FastAPI()
class ChatRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
answer: str
@app.get("/")
async def root():
return {"message": "Hello! Welcome to the Legal Chatbot. I'm here to assist you with your legal queries related to the Indian Penal Code. How can I help you today?"}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
logger.debug(f"Received user question: {request.question}")
result = qa.invoke(input=request.question)
answer = result.get("answer")
if not answer or "The information is not available in the provided context" in answer:
answer = "I'm sorry, I couldn't find relevant information for your query. Please try rephrasing or providing more details."
return ChatResponse(answer=answer)
except Exception as e:
logger.error(f"Error during chat invocation: {e}")
raise HTTPException(status_code=500, detail="Oops! Something went wrong on our end. Please try again later.")
# ==========================
# Run Uvicorn Server
# ==========================
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860)