File size: 7,307 Bytes
0fbed26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from share import *
import config
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from icecream import ic
import matplotlib.pyplot as plt
import sys
import matplotlib
matplotlib.use('Agg')
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./farfetch_controlnet.ckpt', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
with torch.no_grad():
img = resize_image(HWC3(input_image), image_resolution)
H, W, C = img.shape
detected_map = np.zeros_like(img, dtype=np.uint8)
detected_map[np.min(img, axis=2) < 127] = 255
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
if config.save_memory:
model.low_vram_shift(is_diffusing=True)
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
ic((x_samples[0]))
ic(results)
return [255 - detected_map] + results
def segment_anything(input_image, model_type="vit_h", device="cuda"):
"""
处理图像,应用SAM模型,生成并保存处理后的图像。
参数:
- input_image: 输入图像的numpy数组。
- sam_checkpoint: SAM模型的路径。
- model_type: 模型类型,默认为"vit_h"。
- device: 运行设备,默认为"cuda"。
"""
for i in input_image:
ic(type(i))
ic(i)
sam_checkpoint="./sam_vit_h_4b8939.pth"
# 添加路径以便可以从相对目录导入SAM相关模块
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
# 确保输入图像为RGB格式
image_path=input_image[-1]['name']
image = cv2.imread(image_path)
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if input_image.shape[2] == 3:
image = input_image
else:
raise ValueError("Input image must be in RGB format.")
# 加载SAM模型
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# 预测器配置
predictor = SamPredictor(sam)
predictor.set_image(image)
# 输入点和标签
input_point = np.array([[280, 280], [220, 220]])
input_label = np.array([1, 1])
# 预测
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
# 生成并处理掩码
segmentation_mask = masks[0]
binary_mask = np.where(segmentation_mask > 0.5, 1, 0)
# 创建白色背景,并将掩码应用于图像
white_background = np.ones_like(image) * 255
binary_mask = cv2.GaussianBlur(binary_mask.astype(np.float32), (15, 15), 0)
new_image = white_background * (1 - binary_mask[..., np.newaxis]) + image * binary_mask[..., np.newaxis]
ic(new_image)
# plt.imshow(new_image.astype(np.uint8))
# plt.axis('off')
# plt.savefig('pic3.png')
new_image = new_image.clip(0, 255).astype(np.uint8)
# sam_list= {'data': 'https://5710d7c97de8b56005.gradio.live/file=/tmp/gradio/7c98a3c16d9ac06d68f6caac66b61705fc214b9a/image.png',
# 'is_file': True,
# 'name': '/tmp/gradio/7c98a3c16d9ac06d68f6caac66b61705fc214b9a/image.png'}
return [new_image]
# # 显示和保存图像
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Control Stable Diffusion with farfetch")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="numpy")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(label="Run")
sam_button=gr.Button("Sam")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
eta = gr.Number(label="eta (DDIM)", value=0.0)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
with gr.Row():
sam_output= gr.Gallery(label='sam_Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
sam_button.click(fn=segment_anything,inputs=[result_gallery],outputs=[sam_output])
block.launch(server_name='0.0.0.0',share=True)
|