panelforge commited on
Commit
02074a8
·
verified ·
1 Parent(s): 8849eb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -27
app.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags # Import tags here
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace to the model you would like to use
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
@@ -21,17 +21,31 @@ MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
  @spaces.GPU # [uncomment to use ZeroGPU]
24
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, tag_selection_1, tag_selection_2, tag_selection_3, tag_selection_4, use_tags, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
25
 
26
- # Determine final prompt
27
  if use_tags:
28
- participant_tags = [participant_tags[tag] for tag in participant_tags if tag in participant_tags]
29
- tribe_tags = [tribe_tags[tag] for tag in tribe_tags if tag in tribe_tags]
30
- skin_tone_tags = [skin_tone_tags[tag] for tag in skin_tone_tags if tag in skin_tone_tags]
31
- body_type_tags = [body_type_tags[tag] for tag in body_type_tags if tag in body_type_tags]
32
- tags_text = ', '.join(selected_tags_1 + selected_tags_2 + selected_tags_3 + selected_tags_4)
33
-
34
-
 
 
 
 
 
 
 
 
 
 
35
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
36
  else:
37
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
@@ -90,7 +104,7 @@ with gr.Blocks(css=css) as demo:
90
  # Tabbed interface to select either Prompt or Tags
91
  with gr.Tabs() as tabs:
92
  with gr.TabItem("Prompt Input"):
93
- prompt = gr.Text(
94
  label="Prompt",
95
  show_label=False,
96
  max_lines=1,
@@ -100,15 +114,21 @@ with gr.Blocks(css=css) as demo:
100
  use_tags = gr.State(False)
101
 
102
  with gr.TabItem("Tag Selection"):
103
- # Separate each tag section vertically
104
- with gr.Column():
105
- participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Select Tags (Style)")
106
- with gr.Column():
107
- tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Select Tags (Theme)")
108
- with gr.Column():
109
- skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Select Tags (Other)")
110
- with gr.Column():
111
- body_type_tags = gr.CheckboxGroup(choices=list(body_type_tags.keys()), label="Select Tags (Additional)")
 
 
 
 
 
 
112
 
113
  use_tags = gr.State(True)
114
 
@@ -116,7 +136,7 @@ with gr.Blocks(css=css) as demo:
116
  run_button = gr.Button("Run", scale=0, elem_id="run-button")
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
- negative_prompt = gr.Text(
120
  label="Negative prompt",
121
  max_lines=1,
122
  placeholder="Enter a negative prompt",
@@ -172,11 +192,14 @@ with gr.Blocks(css=css) as demo:
172
  inputs=[prompt]
173
  )
174
 
175
- gr.on(
176
- triggers=[run_button.click, prompt.submit],
177
- fn=infer,
178
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, tag_selection_1, tag_selection_2, tag_selection_3, tag_selection_4, use_tags],
179
- outputs=[result, seed, prompt_info] # Include prompt_info in the outputs
180
- )
 
 
 
181
 
182
  demo.queue().launch()
 
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags # Import tags here
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
 
21
  MAX_IMAGE_SIZE = 1024
22
 
23
  @spaces.GPU # [uncomment to use ZeroGPU]
24
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
+ selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
26
+ selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
27
+ selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
28
+ selected_camera_tags, selected_atmosphere_tags, use_tags, progress=gr.Progress(track_tqdm=True)):
29
 
30
+ # Construct final prompt using selected tags if `use_tags` is True
31
  if use_tags:
32
+ selected_tags = (
33
+ [participant_tags[tag] for tag in selected_participant_tags] +
34
+ [tribe_tags[tag] for tag in selected_tribe_tags] +
35
+ [skin_tone_tags[tag] for tag in selected_skin_tone_tags] +
36
+ [body_type_tags[tag] for tag in selected_body_type_tags] +
37
+ [tattoo_tags[tag] for tag in selected_tattoo_tags] +
38
+ [piercing_tags[tag] for tag in selected_piercing_tags] +
39
+ [expression_tags[tag] for tag in selected_expression_tags] +
40
+ [eye_tags[tag] for tag in selected_eye_tags] +
41
+ [hair_style_tags[tag] for tag in selected_hair_style_tags] +
42
+ [position_tags[tag] for tag in selected_position_tags] +
43
+ [fetish_tags[tag] for tag in selected_fetish_tags] +
44
+ [location_tags[tag] for tag in selected_location_tags] +
45
+ [camera_tags[tag] for tag in selected_camera_tags] +
46
+ [atmosphere_tags[tag] for tag in selected_atmosphere_tags]
47
+ )
48
+ tags_text = ', '.join(selected_tags)
49
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
50
  else:
51
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
 
104
  # Tabbed interface to select either Prompt or Tags
105
  with gr.Tabs() as tabs:
106
  with gr.TabItem("Prompt Input"):
107
+ prompt = gr.Textbox(
108
  label="Prompt",
109
  show_label=False,
110
  max_lines=1,
 
114
  use_tags = gr.State(False)
115
 
116
  with gr.TabItem("Tag Selection"):
117
+ # Tag selection checkboxes for each tag group
118
+ selected_participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Participant Tags")
119
+ selected_tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Tribe Tags")
120
+ selected_skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Skin Tone Tags")
121
+ selected_body_type_tags = gr.CheckboxGroup(choices=list(body_type_tags.keys()), label="Body Type Tags")
122
+ selected_tattoo_tags = gr.CheckboxGroup(choices=list(tattoo_tags.keys()), label="Tattoo Tags")
123
+ selected_piercing_tags = gr.CheckboxGroup(choices=list(piercing_tags.keys()), label="Piercing Tags")
124
+ selected_expression_tags = gr.CheckboxGroup(choices=list(expression_tags.keys()), label="Expression Tags")
125
+ selected_eye_tags = gr.CheckboxGroup(choices=list(eye_tags.keys()), label="Eye Tags")
126
+ selected_hair_style_tags = gr.CheckboxGroup(choices=list(hair_style_tags.keys()), label="Hair Style Tags")
127
+ selected_position_tags = gr.CheckboxGroup(choices=list(position_tags.keys()), label="Position Tags")
128
+ selected_fetish_tags = gr.CheckboxGroup(choices=list(fetish_tags.keys()), label="Fetish Tags")
129
+ selected_location_tags = gr.CheckboxGroup(choices=list(location_tags.keys()), label="Location Tags")
130
+ selected_camera_tags = gr.CheckboxGroup(choices=list(camera_tags.keys()), label="Camera Tags")
131
+ selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
132
 
133
  use_tags = gr.State(True)
134
 
 
136
  run_button = gr.Button("Run", scale=0, elem_id="run-button")
137
 
138
  with gr.Accordion("Advanced Settings", open=False):
139
+ negative_prompt = gr.Textbox(
140
  label="Negative prompt",
141
  max_lines=1,
142
  placeholder="Enter a negative prompt",
 
192
  inputs=[prompt]
193
  )
194
 
195
+ run_button.click(
196
+ infer,
197
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
198
+ selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
199
+ selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
200
+ selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
201
+ selected_camera_tags, selected_atmosphere_tags, use_tags],
202
+ outputs=[result, seed, prompt_info]
203
+ )
204
 
205
  demo.queue().launch()