panelforge commited on
Commit
30f564a
·
verified ·
1 Parent(s): a125903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import spaces # [uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
 
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,13 +30,12 @@ def update_model_version(version, state_model_repo_id):
31
  print(f"Model switched to {model_repo_id}")
32
  state_model_repo_id.set(model_repo_id) # Update the state with the new model version
33
 
34
- @spaces.GPU # [uncomment to use ZeroGPU]
35
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
36
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
37
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
38
  selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
39
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
40
-
41
  if active_tab == "Prompt Input":
42
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
43
  else:
@@ -208,6 +206,8 @@ with gr.Blocks(css=css) as demo:
208
 
209
  with gr.Column(elem_id="right-column"):
210
  active_tab = gr.State("Prompt Input")
 
 
211
  with gr.Tabs() as tabs:
212
  with gr.TabItem("Prompt Input") as prompt_tab:
213
  prompt = gr.Textbox(
@@ -244,9 +244,9 @@ with gr.Blocks(css=css) as demo:
244
  link_button_v11 = gr.Button("V11 Model", variant="primary")
245
 
246
  # Set the model version based on the button clicked
247
- link_button_v7.click(update_model_version, inputs=["v7"], outputs=[])
248
- link_button_v8.click(update_model_version, inputs=["v8"], outputs=[])
249
- link_button_v11.click(update_model_version, inputs=["v11"], outputs=[])
250
 
251
  run_button.click(
252
  infer,
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
4
  import torch
5
+ from diffusers import DiffusionPipeline
6
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
30
  print(f"Model switched to {model_repo_id}")
31
  state_model_repo_id.set(model_repo_id) # Update the state with the new model version
32
 
33
+ @gradio.GPU # [uncomment to use ZeroGPU]
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
35
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
36
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
37
  selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
38
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
 
39
  if active_tab == "Prompt Input":
40
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
41
  else:
 
206
 
207
  with gr.Column(elem_id="right-column"):
208
  active_tab = gr.State("Prompt Input")
209
+ model_version_state = gr.State("v8") # Default model version is v8
210
+
211
  with gr.Tabs() as tabs:
212
  with gr.TabItem("Prompt Input") as prompt_tab:
213
  prompt = gr.Textbox(
 
244
  link_button_v11 = gr.Button("V11 Model", variant="primary")
245
 
246
  # Set the model version based on the button clicked
247
+ link_button_v7.click(update_model_version, inputs=[gr.Text("v7")], outputs=[model_version_state])
248
+ link_button_v8.click(update_model_version, inputs=[gr.Text("v8")], outputs=[model_version_state])
249
+ link_button_v11.click(update_model_version, inputs=[gr.Text("v11")], outputs=[model_version_state])
250
 
251
  run_button.click(
252
  infer,