panelforge commited on
Commit
0f8e37d
·
verified ·
1 Parent(s): fdfc2c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -96
app.py CHANGED
@@ -1,54 +1,69 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import spaces #[uncomment to use ZeroGPU]
5
  from diffusers import DiffusionPipeline, DPMSolverSDEScheduler
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from ultralytics import YOLO
9
- import cv2
10
  from PIL import Image
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Your diffusion model
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Load your main diffusion pipeline
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
17
  pipe.scheduler = DPMSolverSDEScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True)
18
  pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
- # Download the ADetailer YOLOv8 face detection model
24
- yolo_model_path = hf_hub_download(repo_id="Bingsu/adetailer", filename="face_yolov8n.pt")
25
- yolo_model = YOLO(yolo_model_path)
26
-
27
- def fix_eyes_with_adetailer(image):
28
- # Convert PIL image to OpenCV format for YOLO
29
  img = np.array(image)
30
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
31
 
32
- # Run the YOLO model on the image
33
  results = yolo_model(img)
34
 
35
- # Visualize and process the output
36
- pred = results[0].plot() # Draw bounding boxes and other detections
37
- pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Convert the processed image back to PIL format
40
- corrected_image = Image.fromarray(pred)
41
- return corrected_image
42
 
43
- @spaces.GPU #[uncomment to use ZeroGPU]
44
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
45
-
46
  if randomize_seed:
47
  seed = random.randint(0, MAX_SEED)
48
 
49
  generator = torch.Generator().manual_seed(seed)
50
 
51
- # Generate the initial image with the diffusion model
52
  image = pipe(
53
  prompt=prompt,
54
  negative_prompt=negative_prompt,
@@ -58,9 +73,9 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
58
  height=height,
59
  generator=generator
60
  ).images[0]
61
-
62
- # Apply ADetailer to fix the eyes after generating the image
63
- corrected_image = fix_eyes_with_adetailer(image)
64
 
65
  return corrected_image, seed
66
 
@@ -70,94 +85,41 @@ examples = [
70
  "A delicious ceviche cheesecake slice",
71
  ]
72
 
73
- css="""#col-container {margin: 0 auto; max-width: 640px;}"""
 
 
 
 
 
74
 
75
  with gr.Blocks(css=css) as demo:
76
-
77
  with gr.Column(elem_id="col-container"):
78
- gr.Markdown(f"""
79
- # Text-to-Image Gradio Template
80
- """)
81
 
82
  with gr.Row():
83
-
84
- prompt = gr.Text(
85
- label="Prompt",
86
- show_label=False,
87
- max_lines=1,
88
- placeholder="Enter your prompt",
89
- container=False,
90
- )
91
-
92
  run_button = gr.Button("Run", scale=0)
93
 
94
  result = gr.Image(label="Result", show_label=False)
95
 
96
  with gr.Accordion("Advanced Settings", open=False):
97
-
98
- negative_prompt = gr.Text(
99
- label="Negative prompt",
100
- max_lines=1,
101
- placeholder="Enter a negative prompt",
102
- visible=False,
103
- )
104
-
105
- seed = gr.Slider(
106
- label="Seed",
107
- minimum=0,
108
- maximum=MAX_SEED,
109
- step=1,
110
- value=0,
111
- )
112
-
113
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
114
 
115
  with gr.Row():
 
 
116
 
117
- width = gr.Slider(
118
- label="Width",
119
- minimum=256,
120
- maximum=MAX_IMAGE_SIZE,
121
- step=32,
122
- value=1024, #Replace with defaults that work for your model
123
- )
124
-
125
- height = gr.Slider(
126
- label="Height",
127
- minimum=256,
128
- maximum=MAX_IMAGE_SIZE,
129
- step=32,
130
- value=1024, #Replace with defaults that work for your model
131
- )
132
-
133
  with gr.Row():
134
-
135
- guidance_scale = gr.Slider(
136
- label="Guidance scale",
137
- minimum=0.0,
138
- maximum=10.0,
139
- step=0.1,
140
- value=0.0, #Replace with defaults that work for your model
141
- )
142
-
143
- num_inference_steps = gr.Slider(
144
- label="Number of inference steps",
145
- minimum=1,
146
- maximum=50,
147
- step=1,
148
- value=2, #Replace with defaults that work for your model
149
- )
150
 
151
- gr.Examples(
152
- examples=examples,
153
- inputs=[prompt]
154
- )
155
 
156
- gr.on(
157
- triggers=[run_button.click, prompt.submit],
158
- fn=infer,
159
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
160
- outputs=[result, seed]
161
- )
162
 
163
  demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces
5
  from diffusers import DiffusionPipeline, DPMSolverSDEScheduler
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from ultralytics import YOLO
 
9
  from PIL import Image
10
+ import cv2
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
14
+ adetailer_model_id = "Bingsu/adetailer" # Your ADetailer model
15
+
16
+ # Load the YOLO model for face detection
17
+ yolo_model_path = hf_hub_download(adetailer_model_id, "face_yolov8n.pt")
18
+ yolo_model = YOLO(yolo_model_path)
19
+
20
+ if torch.cuda.is_available():
21
+ torch_dtype = torch.float16
22
+ else:
23
+ torch_dtype = torch.float32
24
 
25
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
26
  pipe.scheduler = DPMSolverSDEScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True)
27
  pipe = pipe.to(device)
28
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
  MAX_IMAGE_SIZE = 1024
31
 
32
+ def correct_anime_face(image):
33
+ # Convert to OpenCV format
 
 
 
 
34
  img = np.array(image)
35
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
36
 
37
+ # Detect faces
38
  results = yolo_model(img)
39
 
40
+ for detection in results[0].boxes:
41
+ x1, y1, x2, y2 = map(int, detection.xyxy[0].tolist())
42
+
43
+ # Crop the face region
44
+ face = img[y1:y2, x1:x2]
45
+ face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
46
+
47
+ # Prompt for the correction model
48
+ prompt = "Enhance this anime character's face, fix eyes and make features more vivid."
49
+
50
+ # Process the face with the anime correction model
51
+ corrected_face = pipe(prompt=prompt, image=face_pil).images[0] # Replace with your correction model
52
+
53
+ # Place the corrected face back into the original image
54
+ img[y1:y2, x1:x2] = np.array(corrected_face)
55
 
56
+ # Convert back to PIL
57
+ final_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
58
+ return final_image
59
 
60
+ @spaces.GPU
61
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
 
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
 
65
  generator = torch.Generator().manual_seed(seed)
66
 
 
67
  image = pipe(
68
  prompt=prompt,
69
  negative_prompt=negative_prompt,
 
73
  height=height,
74
  generator=generator
75
  ).images[0]
76
+
77
+ # Correct anime face in the generated image
78
+ corrected_image = correct_anime_face(image)
79
 
80
  return corrected_image, seed
81
 
 
85
  "A delicious ceviche cheesecake slice",
86
  ]
87
 
88
+ css = """
89
+ #col-container {
90
+ margin: 0 auto;
91
+ max-width: 640px;
92
+ }
93
+ """
94
 
95
  with gr.Blocks(css=css) as demo:
 
96
  with gr.Column(elem_id="col-container"):
97
+ gr.Markdown("# Text-to-Image Gradio Template")
 
 
98
 
99
  with gr.Row():
100
+ prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
 
 
 
 
 
 
 
 
101
  run_button = gr.Button("Run", scale=0)
102
 
103
  result = gr.Image(label="Result", show_label=False)
104
 
105
  with gr.Accordion("Advanced Settings", open=False):
106
+ negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", visible=False)
107
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
109
 
110
  with gr.Row():
111
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
112
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  with gr.Row():
115
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
116
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ gr.Examples(examples=examples, inputs=[prompt])
 
 
 
119
 
120
+ gr.on(triggers=[run_button.click, prompt.submit],
121
+ fn=infer,
122
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
123
+ outputs=[result, seed])
 
 
124
 
125
  demo.queue().launch()