Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
# Load pre-trained model and tokenizer from Hugging Face | |
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Example, change to other open-source models if necessary | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# Define the intents | |
intents = {0: "database_query", 1: "product_description"} | |
# Function to classify query intent | |
def classify_intent(query): | |
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities).item() | |
return intents[predicted_class], probabilities[0][predicted_class].item() | |
# Example usage | |
query_1 = "Fetch all products with the keyword 'T-shirt' from the database." | |
query_2 = "Can you tell me about the description of this Shopify store?" | |
intent_1, confidence_1 = classify_intent(query_1) | |
intent_2, confidence_2 = classify_intent(query_2) | |
print(f"Query 1: '{query_1}'\nIntent: {intent_1} with confidence {confidence_1}\n") | |
print(f"Query 2: '{query_2}'\nIntent: {intent_2} with confidence {confidence_2}\n") | |
# Further routing based on classified intent | |
def handle_query(query): | |
intent, confidence = classify_intent(query) | |
if intent == "database_query": | |
# Call the natural language to SQL engine | |
return execute_database_query(query) | |
elif intent == "product_description": | |
# Call the RAG engine for product descriptionß | |
return execute_rag_query(query) | |
else: | |
return "Intent not recognized." | |
# Placeholder functions for database and RAG query handling | |
def execute_database_query(query): | |
# Integrate with SQL-based natural language query generator | |
return "Executing database query..." | |
def execute_rag_query(query): | |
# Integrate with RAG pipeline to retrieve product descriptions | |
return "Executing RAG product description query..." | |
# Test the function with different queries | |
print(handle_query(query_1)) | |
print(handle_query(query_2)) | |