|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import gradio as gr |
|
import traceback |
|
|
|
|
|
model_name = "lei-HuggingFace/Qwen2-7B-4it-Chat_Level_Measurement_Guide_07222024" |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
def load_model(): |
|
try: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
quantization_config=bnb_config, |
|
trust_remote_code=True |
|
) |
|
model.config.use_cache = True |
|
print(f"Model loaded successfully. Device: {model.device}") |
|
return model |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
traceback.print_exc() |
|
return None |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) |
|
print("Tokenizer loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {str(e)}") |
|
traceback.print_exc() |
|
tokenizer = None |
|
|
|
model = load_model() |
|
|
|
def generate_response(message, history): |
|
try: |
|
if model is None or tokenizer is None: |
|
return "Model or tokenizer failed to load. Please check the logs and try again." |
|
|
|
|
|
messages = [] |
|
for h in history: |
|
messages.append({"role": "user", "content": h[0]}) |
|
messages.append({"role": "assistant", "content": h[1]}) |
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) |
|
return response.strip() |
|
except Exception as e: |
|
error_message = f"Error generating response: {str(e)}" |
|
print(error_message) |
|
traceback.print_exc() |
|
return error_message |
|
|
|
|
|
iface = gr.ChatInterface( |
|
generate_response, |
|
chatbot=gr.Chatbot(height=300), |
|
textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7), |
|
title="Level Measurement Guide Chatbot (Optimized Quantized Model)", |
|
description="Chat with the optimized quantized fine-tuned Level Measurement Guide model.", |
|
theme="soft", |
|
examples=[ |
|
"What are the key considerations for level measurement in industrial settings?", |
|
"Can you explain the principle behind ultrasonic level sensors?", |
|
"What are the advantages of using radar level sensors?", |
|
], |
|
cache_examples=False, |
|
retry_btn=None, |
|
undo_btn="Delete Previous", |
|
clear_btn="Clear", |
|
) |
|
|
|
|
|
iface.launch() |