File size: 2,203 Bytes
d6ed2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load pre-trained model and tokenizer from Hugging Face
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"  # Example, change to other open-source models if necessary
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Define the intents
intents = {0: "database_query", 1: "product_description"}

# Function to classify query intent
def classify_intent(query):
    inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(probabilities).item()
    return intents[predicted_class], probabilities[0][predicted_class].item()

# Example usage
query_1 = "Fetch all products with the keyword 'T-shirt' from the database."
query_2 = "Can you tell me about the description of this Shopify store?"

intent_1, confidence_1 = classify_intent(query_1)
intent_2, confidence_2 = classify_intent(query_2)

print(f"Query 1: '{query_1}'\nIntent: {intent_1} with confidence {confidence_1}\n")
print(f"Query 2: '{query_2}'\nIntent: {intent_2} with confidence {confidence_2}\n")

# Further routing based on classified intent
def handle_query(query):
    intent, confidence = classify_intent(query)
    if intent == "database_query":
        # Call the natural language to SQL engine
        return execute_database_query(query)
    elif intent == "product_description":
        # Call the RAG engine for product descriptionß
        return execute_rag_query(query)
    else:
        return "Intent not recognized."

# Placeholder functions for database and RAG query handling
def execute_database_query(query):
    # Integrate with SQL-based natural language query generator
    return "Executing database query..."

def execute_rag_query(query):
    # Integrate with RAG pipeline to retrieve product descriptions
    return "Executing RAG product description query..."

# Test the function with different queries
print(handle_query(query_1))
print(handle_query(query_2))