dejanseo commited on
Commit
efcea6b
·
verified ·
1 Parent(s): 0505e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -30
app.py CHANGED
@@ -1,19 +1,12 @@
1
  import streamlit as st
2
  import requests
3
  from bs4 import BeautifulSoup
4
- #from transformers import AutoTokenizer, AutoModelForCausalLM # Removed for GGUF
5
  import json
 
6
  import re
7
- from llama_cpp import Llama # For GGUF
8
- import os # For checking GGUF file
9
 
10
- # --- Constants ---
11
- GGUF_MODEL_PATH = "gemma-2-2b-it.Q4_K_M.gguf" # Or your GGUF file path
12
- if not os.path.exists(GGUF_MODEL_PATH):
13
- st.error(f"Error: GGUF model not found at: {GGUF_MODEL_PATH}. Please download it and place it next to the app.")
14
- st.stop()
15
-
16
- # --- Scraping Function (No changes needed) ---
17
  def scrape_url(url):
18
  try:
19
  headers = {
@@ -32,26 +25,18 @@ def scrape_url(url):
32
  return None
33
 
34
 
35
- # --- Load GGUF model with llama-cpp ---
36
  @st.cache_resource
37
  def load_model():
38
- try:
39
- llm = Llama(model_path=GGUF_MODEL_PATH, n_gpu_layers=20) # Change n_gpu_layers if needed
40
- return llm
41
- except Exception as e:
42
- st.error(f"Error loading model: {e}")
43
- return None
44
-
45
 
46
- model = load_model()
47
 
48
- if model is None:
49
- st.error("Failed to load the model. Please check the logs for errors.")
50
- st.stop()
51
-
52
-
53
-
54
- # --- Generate JSON Output with Llama-cpp ---
55
  def generate_json_output(text):
56
  prompt = f"""You are a web page text scanner. Your task is to carefully review text from a web page.
57
 
@@ -66,9 +51,10 @@ Answer the following questions:
66
  You should output your answers strictly in the following JSON format, but do NOT use markdown:
67
  {{\"brand\": \"<brand>\", \"intent\": \"<intent>\"}}
68
  """
69
-
70
- output = model(prompt, max_tokens=256, stop=["\n"], echo=False)
71
- response = output['choices'][0]['text']
 
72
 
73
  output_json = None
74
  try:
@@ -86,7 +72,7 @@ You should output your answers strictly in the following JSON format, but do NOT
86
 
87
  return response, output_json
88
 
89
- # --- Streamlit App ---
90
  def main():
91
  st.title("Google Brand and Intent Detection")
92
  st.write("Google's brand and intent detection reverse engineered from Chrome by [DEJAN AI](https://dejan.ai/).")
 
1
  import streamlit as st
2
  import requests
3
  from bs4 import BeautifulSoup
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import json
6
+ import torch
7
  import re
 
 
8
 
9
+ # Function to scrape a URL using Beautiful Soup
 
 
 
 
 
 
10
  def scrape_url(url):
11
  try:
12
  headers = {
 
25
  return None
26
 
27
 
 
28
  @st.cache_resource
29
  def load_model():
30
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", trust_remote_code=True)
31
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it",
32
+ torch_dtype=torch.bfloat16,
33
+ low_cpu_mem_usage=True,
34
+ device_map="auto",
35
+ trust_remote_code=True)
36
+ return tokenizer, model
37
 
38
+ tokenizer, model = load_model()
39
 
 
 
 
 
 
 
 
40
  def generate_json_output(text):
41
  prompt = f"""You are a web page text scanner. Your task is to carefully review text from a web page.
42
 
 
51
  You should output your answers strictly in the following JSON format, but do NOT use markdown:
52
  {{\"brand\": \"<brand>\", \"intent\": \"<intent>\"}}
53
  """
54
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
55
+ outputs = model.generate(**inputs, max_new_tokens=256, return_dict_in_generate=True)
56
+ generated_tokens = outputs.sequences[:, inputs.input_ids.shape[1]:]
57
+ response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
58
 
59
  output_json = None
60
  try:
 
72
 
73
  return response, output_json
74
 
75
+ # Streamlit app
76
  def main():
77
  st.title("Google Brand and Intent Detection")
78
  st.write("Google's brand and intent detection reverse engineered from Chrome by [DEJAN AI](https://dejan.ai/).")