PePe / intent.py
nileshhanotia's picture
Create intent.py
d6ed2ba verified
raw
history blame
2.2 kB
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load pre-trained model and tokenizer from Hugging Face
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Example, change to other open-source models if necessary
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Define the intents
intents = {0: "database_query", 1: "product_description"}
# Function to classify query intent
def classify_intent(query):
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities).item()
return intents[predicted_class], probabilities[0][predicted_class].item()
# Example usage
query_1 = "Fetch all products with the keyword 'T-shirt' from the database."
query_2 = "Can you tell me about the description of this Shopify store?"
intent_1, confidence_1 = classify_intent(query_1)
intent_2, confidence_2 = classify_intent(query_2)
print(f"Query 1: '{query_1}'\nIntent: {intent_1} with confidence {confidence_1}\n")
print(f"Query 2: '{query_2}'\nIntent: {intent_2} with confidence {confidence_2}\n")
# Further routing based on classified intent
def handle_query(query):
intent, confidence = classify_intent(query)
if intent == "database_query":
# Call the natural language to SQL engine
return execute_database_query(query)
elif intent == "product_description":
# Call the RAG engine for product descriptionß
return execute_rag_query(query)
else:
return "Intent not recognized."
# Placeholder functions for database and RAG query handling
def execute_database_query(query):
# Integrate with SQL-based natural language query generator
return "Executing database query..."
def execute_rag_query(query):
# Integrate with RAG pipeline to retrieve product descriptions
return "Executing RAG product description query..."
# Test the function with different queries
print(handle_query(query_1))
print(handle_query(query_2))