panelforge commited on
Commit
ffa7df2
·
verified ·
1 Parent(s): 2de8388

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -191
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import torch
5
- import spaces
6
  from diffusers import DiffusionPipeline
7
- import importlib
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
@@ -15,227 +15,194 @@ else:
15
  torch_dtype = torch.float32
16
 
17
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
- # Function to load tags dynamically based on the selected tab
24
- def load_tags(active_tab):
25
- try:
26
- if active_tab == "Gay":
27
- return importlib.import_module('tags_gay') # dynamically import the tags_gay module
28
- elif active_tab == "Straight":
29
- return importlib.import_module('tags_straight') # dynamically import the tags_straight module
30
- elif active_tab == "Lesbian":
31
- return importlib.import_module('tags_lesbian') # dynamically import the tags_lesbian module
32
- else:
33
- raise ValueError(f"Unknown tab: {active_tab}")
34
- except Exception as e:
35
- print(f"Error loading tags for {active_tab}: {str(e)}")
36
- raise
37
-
38
- @spaces.GPU
39
- def infer(
40
- prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
41
- selected_participant_tags, selected_tribe_tags, selected_role_tags, selected_skin_tone_tags, selected_body_type_tags,
42
- selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags, selected_hair_style_tags,
43
- selected_position_tags, selected_fetish_tags, selected_location_tags, selected_camera_tags, selected_atmosphere_tags,
44
- active_tab, progress=gr.Progress(track_tqdm=True)
45
- ):
46
- # Dynamically load the correct tags module based on active tab
47
- tags_module = load_tags(active_tab)
48
-
49
- # Get the tag dictionaries from the loaded module
50
- participant_tags = tags_module.participant_tags
51
- tribe_tags = tags_module.tribe_tags
52
- role_tags = tags_module.role_tags
53
- skin_tone_tags = tags_module.skin_tone_tags
54
- body_type_tags = tags_module.body_type_tags
55
- tattoo_tags = tags_module.tattoo_tags
56
- piercing_tags = tags_module.piercing_tags
57
- expression_tags = tags_module.expression_tags
58
- eye_tags = tags_module.eye_tags
59
- hair_style_tags = tags_module.hair_style_tags
60
- position_tags = tags_module.position_tags
61
- fetish_tags = tags_module.fetish_tags
62
- location_tags = tags_module.location_tags
63
- camera_tags = tags_module.camera_tags
64
- atmosphere_tags = tags_module.atmosphere_tags
65
-
66
- # Build the tag list using selected tags from each group
67
- tag_list = []
68
-
69
- # Add selected participant tags
70
- for tag in selected_participant_tags:
71
- if tag in participant_tags:
72
- tag_list.append(participant_tags[tag])
73
-
74
- # Add selected tribe tags
75
- for tag in selected_tribe_tags:
76
- if tag in tribe_tags:
77
- tag_list.append(tribe_tags[tag])
78
-
79
- # Add selected role tags
80
- for tag in selected_role_tags:
81
- if tag in role_tags:
82
- tag_list.append(role_tags[tag])
83
-
84
- # Add selected skin tone tags
85
- for tag in selected_skin_tone_tags:
86
- if tag in skin_tone_tags:
87
- tag_list.append(skin_tone_tags[tag])
88
-
89
- # Add selected body type tags
90
- for tag in selected_body_type_tags:
91
- if tag in body_type_tags:
92
- tag_list.append(body_type_tags[tag])
93
-
94
- # Add selected tattoo tags
95
- for tag in selected_tattoo_tags:
96
- if tag in tattoo_tags:
97
- tag_list.append(tattoo_tags[tag])
98
-
99
- # Add selected piercing tags
100
- for tag in selected_piercing_tags:
101
- if tag in piercing_tags:
102
- tag_list.append(piercing_tags[tag])
103
-
104
- # Add selected expression tags
105
- for tag in selected_expression_tags:
106
- if tag in expression_tags:
107
- tag_list.append(expression_tags[tag])
108
-
109
- # Add selected eye tags
110
- for tag in selected_eye_tags:
111
- if tag in eye_tags:
112
- tag_list.append(eye_tags[tag])
113
-
114
- # Add selected hair style tags
115
- for tag in selected_hair_style_tags:
116
- if tag in hair_style_tags:
117
- tag_list.append(hair_style_tags[tag])
118
-
119
- # Add selected position tags
120
- for tag in selected_position_tags:
121
- if tag in position_tags:
122
- tag_list.append(position_tags[tag])
123
-
124
- # Add selected fetish tags
125
- for tag in selected_fetish_tags:
126
- if tag in fetish_tags:
127
- tag_list.append(fetish_tags[tag])
128
-
129
- # Add selected location tags
130
- for tag in selected_location_tags:
131
- if tag in location_tags:
132
- tag_list.append(location_tags[tag])
133
-
134
- # Add selected camera tags
135
- for tag in selected_camera_tags:
136
- if tag in camera_tags:
137
- tag_list.append(camera_tags[tag])
138
-
139
- # Add selected atmosphere tags
140
- for tag in selected_atmosphere_tags:
141
- if tag in atmosphere_tags:
142
- tag_list.append(atmosphere_tags[tag])
143
-
144
- # Construct final prompt
145
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"
146
-
147
- # Negative prompt
148
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
149
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
150
 
151
- # Handle random seed if needed
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
- generator = torch.Generator(device=device).manual_seed(seed)
155
-
156
- # Generate the image
157
- try:
158
- image = pipe(
159
- prompt=final_prompt,
160
- negative_prompt=full_negative_prompt,
161
- guidance_scale=guidance_scale,
162
- num_inference_steps=num_inference_steps,
163
- width=width,
164
- height=height,
165
- generator=generator
166
- ).images[0]
167
- except Exception as e:
168
- print(f"Error generating image: {str(e)}")
169
- raise
170
-
171
- return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"
172
-
173
- # Gradio UI setup
 
 
 
 
174
  css = """
175
  #col-container {
176
  margin: 0 auto;
177
  max-width: 640px;
178
  }
 
 
 
179
  """
180
 
181
  with gr.Blocks(css=css) as demo:
 
182
  with gr.Column(elem_id="col-container"):
183
- gr.Markdown("# Image Generator with Tags and Prompts")
184
 
 
185
  result = gr.Image(label="Result", show_label=False)
 
 
186
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
 
 
187
  active_tab = gr.State("Prompt Input")
188
 
 
189
  with gr.Tabs() as tabs:
190
- # Tab setup for different categories
191
- with gr.TabItem("Prompt Input"):
192
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
193
- tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
194
-
195
- with gr.TabItem("Straight"):
196
- tags_module = load_tags("Straight")
197
- selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
198
- selected_tribe_tags = gr.CheckboxGroup(choices=list(tags_module.tribe_tags.keys()), label="Tribe Tags")
199
- selected_role_tags = gr.CheckboxGroup(choices=list(tags_module.role_tags.keys()), label="Role Tags")
200
- selected_skin_tone_tags = gr.CheckboxGroup(choices=list(tags_module.skin_tone_tags.keys()), label="Skin Tone Tags")
201
- selected_body_type_tags = gr.CheckboxGroup(choices=list(tags_module.body_type_tags.keys()), label="Body Type Tags")
202
- selected_tattoo_tags = gr.CheckboxGroup(choices=list(tags_module.tattoo_tags.keys()), label="Tattoo Tags")
203
- selected_piercing_tags = gr.CheckboxGroup(choices=list(tags_module.piercing_tags.keys()), label="Piercing Tags")
204
- selected_expression_tags = gr.CheckboxGroup(choices=list(tags_module.expression_tags.keys()), label="Expression Tags")
205
- selected_eye_tags = gr.CheckboxGroup(choices=list(tags_module.eye_tags.keys()), label="Eye Tags")
206
- selected_hair_style_tags = gr.CheckboxGroup(choices=list(tags_module.hair_style_tags.keys()), label="Hair Style Tags")
207
- selected_position_tags = gr.CheckboxGroup(choices=list(tags_module.position_tags.keys()), label="Position Tags")
208
- selected_fetish_tags = gr.CheckboxGroup(choices=list(tags_module.fetish_tags.keys()), label="Fetish Tags")
209
- selected_location_tags = gr.CheckboxGroup(choices=list(tags_module.location_tags.keys()), label="Location Tags")
210
- selected_camera_tags = gr.CheckboxGroup(choices=list(tags_module.camera_tags.keys()), label="Camera Tags")
211
- selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
212
- tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)
213
-
214
- # Advanced settings
 
 
 
 
 
 
215
  with gr.Accordion("Advanced Settings", open=False):
216
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
217
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
218
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  with gr.Row():
221
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
222
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Row():
225
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7)
226
- num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=35)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- run_button = gr.Button("Run")
229
  run_button.click(
230
  infer,
231
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
232
- selected_participant_tags, selected_tribe_tags, selected_role_tags,
233
- selected_skin_tone_tags, selected_body_type_tags, selected_tattoo_tags,
234
- selected_piercing_tags, selected_expression_tags, selected_eye_tags,
235
- selected_hair_style_tags, selected_position_tags, selected_fetish_tags,
236
- selected_location_tags, selected_camera_tags, selected_atmosphere_tags,
237
- active_tab],
238
  outputs=[result, seed, prompt_info]
239
  )
240
 
241
- demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces # [uncomment to use ZeroGPU]
 
5
  from diffusers import DiffusionPipeline
6
+ import torch
7
+ from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
 
15
  torch_dtype = torch.float32
16
 
17
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
+ pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+ @spaces.GPU # [uncomment to use ZeroGPU]
24
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
+ selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
26
+ selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
27
+ selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
28
+ selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
29
+
30
+ if active_tab == "Prompt Input":
31
+ # Use the user-provided prompt
32
+ final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
33
+ else:
34
+ # Use tags from the "Tag Selection" tab
35
+ selected_tags = (
36
+ [participant_tags[tag] for tag in selected_participant_tags] +
37
+ [tribe_tags[tag] for tag in selected_tribe_tags] +
38
+ [skin_tone_tags[tag] for tag in selected_skin_tone_tags] +
39
+ [body_type_tags[tag] for tag in selected_body_type_tags] +
40
+ [tattoo_tags[tag] for tag in selected_tattoo_tags] +
41
+ [piercing_tags[tag] for tag in selected_piercing_tags] +
42
+ [expression_tags[tag] for tag in selected_expression_tags] +
43
+ [eye_tags[tag] for tag in selected_eye_tags] +
44
+ [hair_style_tags[tag] for tag in selected_hair_style_tags] +
45
+ [position_tags[tag] for tag in selected_position_tags] +
46
+ [fetish_tags[tag] for tag in selected_fetish_tags] +
47
+ [location_tags[tag] for tag in selected_location_tags] +
48
+ [camera_tags[tag] for tag in selected_camera_tags] +
49
+ [atmosphere_tags[tag] for tag in selected_atmosphere_tags]
50
+ )
51
+ tags_text = ', '.join(selected_tags)
52
+ final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
53
+
54
+ # Concatenate user-provided negative prompt with additional restrictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
56
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
57
 
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
+
61
+ generator = torch.Generator().manual_seed(seed)
62
+
63
+ # Generate the image with the final prompts
64
+ image = pipe(
65
+ prompt=final_prompt,
66
+ negative_prompt=full_negative_prompt,
67
+ guidance_scale=guidance_scale,
68
+ num_inference_steps=num_inference_steps,
69
+ width=width,
70
+ height=height,
71
+ generator=generator
72
+ ).images[0]
73
+
74
+ # Return image, seed, and the used prompts
75
+ return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
76
+
77
+
78
+ examples = [
79
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
80
+ "An astronaut riding a green horse",
81
+ "A delicious ceviche cheesecake slice",
82
+ ]
83
+
84
  css = """
85
  #col-container {
86
  margin: 0 auto;
87
  max-width: 640px;
88
  }
89
+ #run-button {
90
+ width: 100%;
91
+ }
92
  """
93
 
94
  with gr.Blocks(css=css) as demo:
95
+
96
  with gr.Column(elem_id="col-container"):
97
+ gr.Markdown("""# Text-to-Image Gradio Template""")
98
 
99
+ # Display result image at the top
100
  result = gr.Image(label="Result", show_label=False)
101
+
102
+ # Add a textbox to display the prompts used for generation
103
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
104
+
105
+ # State to track active tab
106
  active_tab = gr.State("Prompt Input")
107
 
108
+ # Tabbed interface to select either Prompt or Tags
109
  with gr.Tabs() as tabs:
110
+ with gr.TabItem("Prompt Input") as prompt_tab:
111
+ prompt = gr.Textbox(
112
+ label="Prompt",
113
+ show_label=False,
114
+ max_lines=1,
115
+ placeholder="Enter your prompt",
116
+ container=False,
117
+ )
118
+ prompt_tab.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
119
+
120
+ with gr.TabItem("Tag Selection") as tag_tab:
121
+ # Tag selection checkboxes for each tag group
122
+ selected_participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Participant Tags")
123
+ selected_tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Tribe Tags")
124
+ selected_skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Skin Tone Tags")
125
+ selected_body_type_tags = gr.CheckboxGroup(choices=list(body_type_tags.keys()), label="Body Type Tags")
126
+ selected_tattoo_tags = gr.CheckboxGroup(choices=list(tattoo_tags.keys()), label="Tattoo Tags")
127
+ selected_piercing_tags = gr.CheckboxGroup(choices=list(piercing_tags.keys()), label="Piercing Tags")
128
+ selected_expression_tags = gr.CheckboxGroup(choices=list(expression_tags.keys()), label="Expression Tags")
129
+ selected_eye_tags = gr.CheckboxGroup(choices=list(eye_tags.keys()), label="Eye Tags")
130
+ selected_hair_style_tags = gr.CheckboxGroup(choices=list(hair_style_tags.keys()), label="Hair Style Tags")
131
+ selected_position_tags = gr.CheckboxGroup(choices=list(position_tags.keys()), label="Position Tags")
132
+ selected_fetish_tags = gr.CheckboxGroup(choices=list(fetish_tags.keys()), label="Fetish Tags")
133
+ selected_location_tags = gr.CheckboxGroup(choices=list(location_tags.keys()), label="Location Tags")
134
+ selected_camera_tags = gr.CheckboxGroup(choices=list(camera_tags.keys()), label="Camera Tags")
135
+ selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
136
+ tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
137
+
138
+ # Full-width "Run" button
139
+ run_button = gr.Button("Run", scale=0, elem_id="run-button")
140
+
141
  with gr.Accordion("Advanced Settings", open=False):
142
+ negative_prompt = gr.Textbox(
143
+ label="Negative prompt",
144
+ max_lines=1,
145
+ placeholder="Enter a negative prompt",
146
+ visible=True,
147
+ )
148
+
149
+ seed = gr.Slider(
150
+ label="Seed",
151
+ minimum=0,
152
+ maximum=MAX_SEED,
153
+ step=1,
154
+ value=0,
155
+ )
156
+
157
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
158
 
159
  with gr.Row():
160
+ width = gr.Slider(
161
+ label="Width",
162
+ minimum=256,
163
+ maximum=MAX_IMAGE_SIZE,
164
+ step=32,
165
+ value=1024,
166
+ )
167
+
168
+ height = gr.Slider(
169
+ label="Height",
170
+ minimum=256,
171
+ maximum=MAX_IMAGE_SIZE,
172
+ step=32,
173
+ value=1024,
174
+ )
175
 
176
  with gr.Row():
177
+ guidance_scale = gr.Slider(
178
+ label="Guidance scale",
179
+ minimum=0.0,
180
+ maximum=10.0,
181
+ step=0.1,
182
+ value=7,
183
+ )
184
+
185
+ num_inference_steps = gr.Slider(
186
+ label="Number of inference steps",
187
+ minimum=1,
188
+ maximum=50,
189
+ step=1,
190
+ value=35,
191
+ )
192
+
193
+ gr.Examples(
194
+ examples=examples,
195
+ inputs=[prompt]
196
+ )
197
 
 
198
  run_button.click(
199
  infer,
200
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
201
+ selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
202
+ selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
203
+ selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
204
+ selected_camera_tags, selected_atmosphere_tags, active_tab],
 
 
205
  outputs=[result, seed, prompt_info]
206
  )
207
 
208
+ demo.queue().launch()