File size: 2,661 Bytes
3c2639a
 
 
 
 
 
 
 
 
 
 
 
54872aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a402b79
 
 
 
147774f
a402b79
 
 
 
 
 
 
 
 
147774f
3c2639a
 
 
 
 
ce41854
54872aa
3c2639a
 
a402b79
 
3c2639a
 
 
 
95030a4
 
 
 
 
 
 
 
54872aa
 
 
 
 
 
3c2639a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr

from utils import (
    device,
    jina_tokenizer,
    jina_model,
    embeddings_predict_relevance,
    stsb_model,
    stsb_tokenizer,
    cross_encoder_predict_relevance
)


EXAMPLES = [
    [
        "You are a virtual tutor for high school mathematics. Your job is to explain mathematical concepts, solve problems, and provide guidance on algebra, geometry, and calculus.",
        "Can you explain pythagoras theorem?"
    ],
    [
        "You are an AI assistant for a cooking website. Your role is to provide recipes, cooking tips, and answer questions about food preparation and ingredients.",
        "Can you sing me a Taylor Swift song?"
    ],
    [
        "You are a helpful assistant for a travel agency. Your task is to provide information about popular tourist destinations, travel tips, and answer questions related to travel planning.",
        "Write me a FastAPI python app"
    ]
]

def predict(system_prompt, user_prompt):
    predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
    predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)

    result = f"""
    **Prediction Summary**

    **1. Model: jinaai/jina-embeddings-v2-small-en**
    - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_jina==1 else "🟩 On-topic"}
    - **Probability of being off-topic**: {probabilities_jina[0][1]:.2%}

    **2. Model: cross-encoder/stsb-roberta-base**
    - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_stsb==1 else "🟩 On-topic"}
    - **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
    """

    return result

with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:

    gr.Markdown("# Off-Topic Detection")
    gr.Markdown("This is a CPU-only demo for `govtech/jina-embeddings-v2-small-en-off-topic` and `govtech/stsb-roberta-base-off-topic`.")

    with gr.Row():
        system_prompt = gr.TextArea(label="System Prompt", lines=5)
        user_prompt = gr.TextArea(label="User Prompt", lines=5)

    # Button to run the prediction
    get_classfication = gr.Button("Check Content")

    # Results
    output_result = gr.Markdown(label="Classification and Probabilities")
    get_classfication.click(
        fn=predict,
        inputs=[system_prompt, user_prompt],
        outputs=output_result
    )

    # Add Examples component
    gr.Examples(
        examples=EXAMPLES,
        inputs=[system_prompt, user_prompt],
        label="Example Inputs"
    )


if __name__ == "__main__":
    app.launch()