panelforge commited on
Commit
3e820bd
·
verified ·
1 Parent(s): 30f564a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import torch
5
  from diffusers import DiffusionPipeline
 
6
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # Default model version
11
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Default model V8
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
@@ -21,21 +20,13 @@ pipe = pipe.to(device)
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
23
 
24
- def update_model_version(version, state_model_repo_id):
25
- """Update the model version dynamically based on the selected version."""
26
- global model_repo_id, pipe
27
- model_repo_id = f"John6666/wai-ani-nsfw-ponyxl-{version}-sdxl"
28
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
29
- pipe = pipe.to(device)
30
- print(f"Model switched to {model_repo_id}")
31
- state_model_repo_id.set(model_repo_id) # Update the state with the new model version
32
-
33
- @gradio.GPU # [uncomment to use ZeroGPU]
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
35
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
36
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
37
  selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
38
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
 
39
  if active_tab == "Prompt Input":
40
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
41
  else:
@@ -66,7 +57,6 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
66
 
67
  generator = torch.Generator().manual_seed(seed)
68
 
69
- # Generate the image with the final prompts
70
  image = pipe(
71
  prompt=final_prompt,
72
  negative_prompt=full_negative_prompt,
@@ -79,7 +69,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
79
 
80
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
81
 
82
- # CSS for button styling and horizontal layout
83
  css = """
84
  #col-container {
85
  margin: 0 auto;
@@ -127,29 +117,17 @@ css = """
127
  #prompt {
128
  margin-bottom: 20px;
129
  }
130
-
131
- .button-group {
132
- display: flex;
133
- justify-content: space-between;
134
- }
135
-
136
- .button-group .gradio-button {
137
- flex: 1;
138
- margin: 0 10px;
139
- text-align: center;
140
- }
141
  """
142
 
143
- # Gradio interface setup
144
  with gr.Blocks(css=css) as demo:
145
 
146
  with gr.Row():
147
  with gr.Column(elem_id="left-column"):
148
  gr.Markdown("""# Rainbow Media X""")
 
149
  result = gr.Image(label="Result", show_label=False, elem_id="result")
150
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False, elem_id="prompt-info")
151
 
152
- # Advanced Settings and Run Button
153
  with gr.Accordion("Advanced Settings", open=False):
154
  negative_prompt = gr.Textbox(
155
  label="Negative prompt",
@@ -206,7 +184,6 @@ with gr.Blocks(css=css) as demo:
206
 
207
  with gr.Column(elem_id="right-column"):
208
  active_tab = gr.State("Prompt Input")
209
- model_version_state = gr.State("v8") # Default model version is v8
210
 
211
  with gr.Tabs() as tabs:
212
  with gr.TabItem("Prompt Input") as prompt_tab:
@@ -237,17 +214,6 @@ with gr.Blocks(css=css) as demo:
237
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
238
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
239
 
240
- # Horizontal buttons to switch models
241
- with gr.Row(elem_id="button-group"):
242
- link_button_v7 = gr.Button("V7 Model", variant="primary")
243
- link_button_v8 = gr.Button("V8 Model", variant="primary")
244
- link_button_v11 = gr.Button("V11 Model", variant="primary")
245
-
246
- # Set the model version based on the button clicked
247
- link_button_v7.click(update_model_version, inputs=[gr.Text("v7")], outputs=[model_version_state])
248
- link_button_v8.click(update_model_version, inputs=[gr.Text("v8")], outputs=[model_version_state])
249
- link_button_v11.click(update_model_version, inputs=[gr.Text("v11")], outputs=[model_version_state])
250
-
251
  run_button.click(
252
  infer,
253
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
@@ -258,4 +224,23 @@ with gr.Blocks(css=css) as demo:
258
  outputs=[result, seed, prompt_info]
259
  )
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces # Restored import for spaces
5
  from diffusers import DiffusionPipeline
6
+ import torch
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Default model version
 
 
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+ @spaces.GPU # Restored decorator to enable GPU use
 
 
 
 
 
 
 
 
 
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
26
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
27
  selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
28
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
29
+
30
  if active_tab == "Prompt Input":
31
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
32
  else:
 
57
 
58
  generator = torch.Generator().manual_seed(seed)
59
 
 
60
  image = pipe(
61
  prompt=final_prompt,
62
  negative_prompt=full_negative_prompt,
 
69
 
70
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
71
 
72
+
73
  css = """
74
  #col-container {
75
  margin: 0 auto;
 
117
  #prompt {
118
  margin-bottom: 20px;
119
  }
 
 
 
 
 
 
 
 
 
 
 
120
  """
121
 
 
122
  with gr.Blocks(css=css) as demo:
123
 
124
  with gr.Row():
125
  with gr.Column(elem_id="left-column"):
126
  gr.Markdown("""# Rainbow Media X""")
127
+
128
  result = gr.Image(label="Result", show_label=False, elem_id="result")
129
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False, elem_id="prompt-info")
130
 
 
131
  with gr.Accordion("Advanced Settings", open=False):
132
  negative_prompt = gr.Textbox(
133
  label="Negative prompt",
 
184
 
185
  with gr.Column(elem_id="right-column"):
186
  active_tab = gr.State("Prompt Input")
 
187
 
188
  with gr.Tabs() as tabs:
189
  with gr.TabItem("Prompt Input") as prompt_tab:
 
214
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
215
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
216
 
 
 
 
 
 
 
 
 
 
 
 
217
  run_button.click(
218
  infer,
219
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
 
224
  outputs=[result, seed, prompt_info]
225
  )
226
 
227
+ link_button_v7 = gr.Button("V7", elem_id="link-v7", size="sm")
228
+ link_button_v8 = gr.Button("V8", elem_id="link-v8", size="sm")
229
+ link_button_v11 = gr.Button("V11", elem_id="link-v11", size="sm")
230
+
231
+ def update_model_version(version):
232
+ global model_repo_id
233
+ if version == "v7":
234
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
235
+ elif version == "v8":
236
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
237
+ elif version == "v11":
238
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
239
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
240
+ pipe = pipe.to(device)
241
+
242
+ link_button_v7.click(update_model_version, inputs=["v7"], outputs=[])
243
+ link_button_v8.click(update_model_version, inputs=["v8"], outputs=[])
244
+ link_button_v11.click(update_model_version, inputs=["v11"], outputs=[])
245
+
246
  demo.queue().launch()