- handler.py +1070 -0
- requirements.txt +5 -3
handler.py
ADDED
@@ -0,0 +1,1070 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings('ignore')
|
3 |
+
|
4 |
+
import subprocess, io, os, sys, time
|
5 |
+
|
6 |
+
is_production = True
|
7 |
+
os.environ['CUDA_HOME'] = '/usr/local/cuda-11.7/' if is_production else '/usr/local/cuda-12.1/'
|
8 |
+
|
9 |
+
run_gradio = False
|
10 |
+
|
11 |
+
if run_gradio:
|
12 |
+
os.system("pip install gradio==3.50.2")
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
from loguru import logger
|
17 |
+
|
18 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
19 |
+
|
20 |
+
if is_production:
|
21 |
+
os.chdir("/repository")
|
22 |
+
sys.path.insert(0, '/repository')
|
23 |
+
|
24 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
25 |
+
result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
|
26 |
+
print(f'pip install GroundingDINO = {result}')
|
27 |
+
|
28 |
+
# result = subprocess.run(['pip', 'list'], check=True)
|
29 |
+
# print(f'pip list = {result}')
|
30 |
+
|
31 |
+
sys.path.insert(0, '/repository/GroundingDINO' if is_production else "./GroundingDINO")
|
32 |
+
|
33 |
+
import argparse
|
34 |
+
import copy
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
import torch
|
38 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
39 |
+
|
40 |
+
# Grounding DINO
|
41 |
+
import GroundingDINO.groundingdino.datasets.transforms as T
|
42 |
+
from GroundingDINO.groundingdino.models import build_model
|
43 |
+
from GroundingDINO.groundingdino.util import box_ops
|
44 |
+
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
45 |
+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
46 |
+
|
47 |
+
import cv2
|
48 |
+
import numpy as np
|
49 |
+
import matplotlib
|
50 |
+
matplotlib.use('AGG')
|
51 |
+
plt = matplotlib.pyplot
|
52 |
+
# import matplotlib.pyplot as plt
|
53 |
+
|
54 |
+
groundingdino_enable = True
|
55 |
+
sam_enable = True
|
56 |
+
inpainting_enable = True
|
57 |
+
ram_enable = True
|
58 |
+
|
59 |
+
lama_cleaner_enable = True
|
60 |
+
|
61 |
+
kosmos_enable = False
|
62 |
+
|
63 |
+
# qwen_enable = True
|
64 |
+
# from qwen_utils import *
|
65 |
+
|
66 |
+
if os.environ.get('IS_MY_DEBUG') is not None:
|
67 |
+
sam_enable = False
|
68 |
+
ram_enable = False
|
69 |
+
inpainting_enable = False
|
70 |
+
kosmos_enable = False
|
71 |
+
|
72 |
+
if lama_cleaner_enable:
|
73 |
+
try:
|
74 |
+
from lama_cleaner.model_manager import ModelManager
|
75 |
+
from lama_cleaner.schema import Config as lama_Config
|
76 |
+
except Exception as e:
|
77 |
+
lama_cleaner_enable = False
|
78 |
+
|
79 |
+
# segment anything
|
80 |
+
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
81 |
+
|
82 |
+
# diffusers
|
83 |
+
import PIL
|
84 |
+
import requests
|
85 |
+
import torch
|
86 |
+
from io import BytesIO
|
87 |
+
from diffusers import StableDiffusionInpaintPipeline
|
88 |
+
from huggingface_hub import hf_hub_download
|
89 |
+
|
90 |
+
from util_computer import computer_info
|
91 |
+
|
92 |
+
# relate anything
|
93 |
+
from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
|
94 |
+
from ram_train_eval import RamModel, RamPredictor
|
95 |
+
from mmengine.config import Config as mmengine_Config
|
96 |
+
|
97 |
+
if lama_cleaner_enable:
|
98 |
+
from lama_cleaner.helper import (
|
99 |
+
load_img,
|
100 |
+
numpy_to_bytes,
|
101 |
+
resize_max_size,
|
102 |
+
)
|
103 |
+
|
104 |
+
# from transformers import AutoProcessor, AutoModelForVision2Seq
|
105 |
+
import ast
|
106 |
+
|
107 |
+
if kosmos_enable:
|
108 |
+
os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
|
109 |
+
# os.system("pip install transformers==4.32.0")
|
110 |
+
|
111 |
+
from kosmos_utils import *
|
112 |
+
|
113 |
+
from util_tencent import getTextTrans
|
114 |
+
|
115 |
+
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
116 |
+
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
117 |
+
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
118 |
+
sam_checkpoint = './sam_vit_h_4b8939.pth'
|
119 |
+
output_dir = "outputs"
|
120 |
+
|
121 |
+
device = 'cpu'
|
122 |
+
os.makedirs(output_dir, exist_ok=True)
|
123 |
+
groundingdino_model = None
|
124 |
+
sam_device = "cuda"
|
125 |
+
sam_model = None
|
126 |
+
|
127 |
+
|
128 |
+
def get_sam_vit_h_4b8939():
|
129 |
+
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
130 |
+
logger.info(f"get sam_vit_h_4b8939.pth...")
|
131 |
+
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
|
132 |
+
print(f'wget sam_vit_h_4b8939.pth result = {result}')
|
133 |
+
|
134 |
+
get_sam_vit_h_4b8939()
|
135 |
+
logger.info(f"initialize SAM model...")
|
136 |
+
sam_device = "cuda"
|
137 |
+
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
138 |
+
sam_predictor = SamPredictor(sam_model)
|
139 |
+
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
140 |
+
|
141 |
+
sam_mask_generator = None
|
142 |
+
sd_model = None
|
143 |
+
lama_cleaner_model= None
|
144 |
+
ram_model = None
|
145 |
+
kosmos_model = None
|
146 |
+
kosmos_processor = None
|
147 |
+
|
148 |
+
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
149 |
+
args = SLConfig.fromfile(model_config_path)
|
150 |
+
model = build_model(args)
|
151 |
+
args.device = device
|
152 |
+
|
153 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
|
154 |
+
checkpoint = torch.load(cache_file, map_location=device)
|
155 |
+
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
156 |
+
print("Model loaded from {} \n => {}".format(cache_file, log))
|
157 |
+
_ = model.eval()
|
158 |
+
return model
|
159 |
+
|
160 |
+
def plot_boxes_to_image(image_pil, tgt):
|
161 |
+
H, W = tgt["size"]
|
162 |
+
boxes = tgt["boxes"]
|
163 |
+
labels = tgt["labels"]
|
164 |
+
assert len(boxes) == len(labels), "boxes and labels must have same length"
|
165 |
+
|
166 |
+
draw = ImageDraw.Draw(image_pil)
|
167 |
+
mask = Image.new("L", image_pil.size, 0)
|
168 |
+
mask_draw = ImageDraw.Draw(mask)
|
169 |
+
|
170 |
+
# draw boxes and masks
|
171 |
+
for box, label in zip(boxes, labels):
|
172 |
+
# from 0..1 to 0..W, 0..H
|
173 |
+
box = box * torch.Tensor([W, H, W, H])
|
174 |
+
# from xywh to xyxy
|
175 |
+
box[:2] -= box[2:] / 2
|
176 |
+
box[2:] += box[:2]
|
177 |
+
# random color
|
178 |
+
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
179 |
+
# draw
|
180 |
+
x0, y0, x1, y1 = box
|
181 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
182 |
+
|
183 |
+
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
184 |
+
# draw.text((x0, y0), str(label), fill=color)
|
185 |
+
|
186 |
+
font = ImageFont.load_default()
|
187 |
+
if hasattr(font, "getbbox"):
|
188 |
+
bbox = draw.textbbox((x0, y0), str(label), font)
|
189 |
+
else:
|
190 |
+
w, h = draw.textsize(str(label), font)
|
191 |
+
bbox = (x0, y0, w + x0, y0 + h)
|
192 |
+
# bbox = draw.textbbox((x0, y0), str(label))
|
193 |
+
draw.rectangle(bbox, fill=color)
|
194 |
+
|
195 |
+
try:
|
196 |
+
font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
|
197 |
+
font_size = 36
|
198 |
+
new_font = ImageFont.truetype(font, font_size)
|
199 |
+
|
200 |
+
draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
|
201 |
+
except Exception as e:
|
202 |
+
pass
|
203 |
+
|
204 |
+
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
|
205 |
+
|
206 |
+
|
207 |
+
return image_pil, mask
|
208 |
+
|
209 |
+
def load_image(image_path):
|
210 |
+
# # load image
|
211 |
+
if isinstance(image_path, PIL.Image.Image):
|
212 |
+
image_pil = image_path
|
213 |
+
else:
|
214 |
+
image_pil = Image.open(image_path).convert("RGB") # load image
|
215 |
+
|
216 |
+
transform = T.Compose(
|
217 |
+
[
|
218 |
+
T.RandomResize([800], max_size=1333),
|
219 |
+
T.ToTensor(),
|
220 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
221 |
+
]
|
222 |
+
)
|
223 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
224 |
+
return image_pil, image
|
225 |
+
|
226 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
227 |
+
caption = caption.lower()
|
228 |
+
caption = caption.strip()
|
229 |
+
if not caption.endswith("."):
|
230 |
+
caption = caption + "."
|
231 |
+
model = model.to(device)
|
232 |
+
image = image.to(device)
|
233 |
+
with torch.no_grad():
|
234 |
+
outputs = model(image[None], captions=[caption])
|
235 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
236 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
237 |
+
logits.shape[0]
|
238 |
+
|
239 |
+
# filter output
|
240 |
+
logits_filt = logits.clone()
|
241 |
+
boxes_filt = boxes.clone()
|
242 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
243 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
244 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
245 |
+
logits_filt.shape[0]
|
246 |
+
|
247 |
+
# get phrase
|
248 |
+
tokenlizer = model.tokenizer
|
249 |
+
tokenized = tokenlizer(caption)
|
250 |
+
# build pred
|
251 |
+
pred_phrases = []
|
252 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
253 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
254 |
+
if with_logits:
|
255 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
256 |
+
else:
|
257 |
+
pred_phrases.append(pred_phrase)
|
258 |
+
|
259 |
+
return boxes_filt, pred_phrases
|
260 |
+
|
261 |
+
def show_mask(mask, ax, random_color=False):
|
262 |
+
if random_color:
|
263 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
264 |
+
else:
|
265 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
266 |
+
h, w = mask.shape[-2:]
|
267 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
268 |
+
ax.imshow(mask_image)
|
269 |
+
|
270 |
+
def show_box(box, ax, label):
|
271 |
+
x0, y0 = box[0], box[1]
|
272 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
273 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
274 |
+
ax.text(x0, y0, label)
|
275 |
+
|
276 |
+
def xywh_to_xyxy(box, sizeW, sizeH):
|
277 |
+
if isinstance(box, list):
|
278 |
+
box = torch.Tensor(box)
|
279 |
+
box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH])
|
280 |
+
box[:2] -= box[2:] / 2
|
281 |
+
box[2:] += box[:2]
|
282 |
+
box = box.numpy()
|
283 |
+
return box
|
284 |
+
|
285 |
+
def mask_extend(img, box, extend_pixels=10, useRectangle=True):
|
286 |
+
box[0] = int(box[0])
|
287 |
+
box[1] = int(box[1])
|
288 |
+
box[2] = int(box[2])
|
289 |
+
box[3] = int(box[3])
|
290 |
+
region = img.crop(tuple(box))
|
291 |
+
new_width = box[2] - box[0] + 2*extend_pixels
|
292 |
+
new_height = box[3] - box[1] + 2*extend_pixels
|
293 |
+
|
294 |
+
region_BILINEAR = region.resize((int(new_width), int(new_height)))
|
295 |
+
if useRectangle:
|
296 |
+
region_draw = ImageDraw.Draw(region_BILINEAR)
|
297 |
+
region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255))
|
298 |
+
img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels)))
|
299 |
+
return img
|
300 |
+
|
301 |
+
def mix_masks(imgs):
|
302 |
+
re_img = 1 - np.asarray(imgs[0].convert("1"))
|
303 |
+
for i in range(len(imgs)-1):
|
304 |
+
re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
|
305 |
+
re_img = 1 - re_img
|
306 |
+
return Image.fromarray(np.uint8(255*re_img))
|
307 |
+
|
308 |
+
def set_device():
|
309 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
310 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
311 |
+
else:
|
312 |
+
device = 'cpu'
|
313 |
+
print(f'device={device}')
|
314 |
+
return device
|
315 |
+
|
316 |
+
def load_groundingdino_model(device):
|
317 |
+
# initialize groundingdino model
|
318 |
+
logger.info(f"initialize groundingdino model...")
|
319 |
+
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
|
320 |
+
return groundingdino_model
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
def load_sam_model(device):
|
325 |
+
# initialize SAM
|
326 |
+
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
327 |
+
get_sam_vit_h_4b8939()
|
328 |
+
logger.info(f"initialize SAM model...")
|
329 |
+
sam_device = device
|
330 |
+
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
331 |
+
sam_predictor = SamPredictor(sam_model)
|
332 |
+
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
333 |
+
|
334 |
+
def load_sd_model(device):
|
335 |
+
# initialize stable-diffusion-inpainting
|
336 |
+
global sd_model
|
337 |
+
logger.info(f"initialize stable-diffusion-inpainting...")
|
338 |
+
sd_model = None
|
339 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
340 |
+
sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
341 |
+
"runwayml/stable-diffusion-inpainting",
|
342 |
+
revision="fp16",
|
343 |
+
# "stabilityai/stable-diffusion-2-inpainting",
|
344 |
+
torch_dtype=torch.float16,
|
345 |
+
)
|
346 |
+
sd_model = sd_model.to(device)
|
347 |
+
|
348 |
+
def load_lama_cleaner_model(device):
|
349 |
+
# initialize lama_cleaner
|
350 |
+
global lama_cleaner_model
|
351 |
+
logger.info(f"initialize lama_cleaner...")
|
352 |
+
|
353 |
+
lama_cleaner_model = ModelManager(
|
354 |
+
name='lama',
|
355 |
+
device=device,
|
356 |
+
)
|
357 |
+
|
358 |
+
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
359 |
+
try:
|
360 |
+
logger.info(f'_______lama_cleaner_process_______1____')
|
361 |
+
ori_image = image
|
362 |
+
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
363 |
+
# rotate image
|
364 |
+
logger.info(f'_______lama_cleaner_process_______2____')
|
365 |
+
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
366 |
+
logger.info(f'_______lama_cleaner_process_______3____')
|
367 |
+
image = ori_image
|
368 |
+
|
369 |
+
logger.info(f'_______lama_cleaner_process_______4____')
|
370 |
+
original_shape = ori_image.shape
|
371 |
+
logger.info(f'_______lama_cleaner_process_______5____')
|
372 |
+
interpolation = cv2.INTER_CUBIC
|
373 |
+
|
374 |
+
size_limit = cleaner_size_limit
|
375 |
+
if size_limit == -1:
|
376 |
+
logger.info(f'_______lama_cleaner_process_______6____')
|
377 |
+
size_limit = max(image.shape)
|
378 |
+
else:
|
379 |
+
logger.info(f'_______lama_cleaner_process_______7____')
|
380 |
+
size_limit = int(size_limit)
|
381 |
+
|
382 |
+
logger.info(f'_______lama_cleaner_process_______8____')
|
383 |
+
config = lama_Config(
|
384 |
+
ldm_steps=25,
|
385 |
+
ldm_sampler='plms',
|
386 |
+
zits_wireframe=True,
|
387 |
+
hd_strategy='Original',
|
388 |
+
hd_strategy_crop_margin=196,
|
389 |
+
hd_strategy_crop_trigger_size=1280,
|
390 |
+
hd_strategy_resize_limit=2048,
|
391 |
+
prompt='',
|
392 |
+
use_croper=False,
|
393 |
+
croper_x=0,
|
394 |
+
croper_y=0,
|
395 |
+
croper_height=512,
|
396 |
+
croper_width=512,
|
397 |
+
sd_mask_blur=5,
|
398 |
+
sd_strength=0.75,
|
399 |
+
sd_steps=50,
|
400 |
+
sd_guidance_scale=7.5,
|
401 |
+
sd_sampler='ddim',
|
402 |
+
sd_seed=42,
|
403 |
+
cv2_flag='INPAINT_NS',
|
404 |
+
cv2_radius=5,
|
405 |
+
)
|
406 |
+
|
407 |
+
logger.info(f'_______lama_cleaner_process_______9____')
|
408 |
+
if config.sd_seed == -1:
|
409 |
+
config.sd_seed = random.randint(1, 999999999)
|
410 |
+
|
411 |
+
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
412 |
+
logger.info(f'_______lama_cleaner_process_______10____')
|
413 |
+
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
414 |
+
# logger.info(f"Resized image shape_1_: {image.shape}")
|
415 |
+
|
416 |
+
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
417 |
+
logger.info(f'_______lama_cleaner_process_______11____')
|
418 |
+
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
419 |
+
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
420 |
+
|
421 |
+
logger.info(f'_______lama_cleaner_process_______12____')
|
422 |
+
res_np_img = lama_cleaner_model(image, mask, config)
|
423 |
+
logger.info(f'_______lama_cleaner_process_______13____')
|
424 |
+
torch.cuda.empty_cache()
|
425 |
+
|
426 |
+
logger.info(f'_______lama_cleaner_process_______14____')
|
427 |
+
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
428 |
+
logger.info(f'_______lama_cleaner_process_______15____')
|
429 |
+
except Exception as e:
|
430 |
+
logger.info(f'lama_cleaner_process[Error]:' + str(e))
|
431 |
+
image = None
|
432 |
+
return image
|
433 |
+
|
434 |
+
class Ram_Predictor(RamPredictor):
|
435 |
+
def __init__(self, config, device='cpu'):
|
436 |
+
self.config = config
|
437 |
+
self.device = torch.device(device)
|
438 |
+
self._build_model()
|
439 |
+
|
440 |
+
def _build_model(self):
|
441 |
+
self.model = RamModel(**self.config.model).to(self.device)
|
442 |
+
if self.config.load_from is not None:
|
443 |
+
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
444 |
+
self.model.train()
|
445 |
+
|
446 |
+
def load_ram_model(device):
|
447 |
+
# load ram model
|
448 |
+
global ram_model
|
449 |
+
if os.environ.get('IS_MY_DEBUG') is not None:
|
450 |
+
return
|
451 |
+
model_path = "./checkpoints/ram_epoch12.pth"
|
452 |
+
ram_config = dict(
|
453 |
+
model=dict(
|
454 |
+
pretrained_model_name_or_path='bert-base-uncased',
|
455 |
+
load_pretrained_weights=False,
|
456 |
+
num_transformer_layer=2,
|
457 |
+
input_feature_size=256,
|
458 |
+
output_feature_size=768,
|
459 |
+
cls_feature_size=512,
|
460 |
+
num_relation_classes=56,
|
461 |
+
pred_type='attention',
|
462 |
+
loss_type='multi_label_ce',
|
463 |
+
),
|
464 |
+
load_from=model_path,
|
465 |
+
)
|
466 |
+
ram_config = mmengine_Config(ram_config)
|
467 |
+
ram_model = Ram_Predictor(ram_config, device)
|
468 |
+
|
469 |
+
# visualization
|
470 |
+
def draw_selected_mask(mask, draw):
|
471 |
+
color = (255, 0, 0, 153)
|
472 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
473 |
+
for coord in nonzero_coords:
|
474 |
+
draw.point(coord[::-1], fill=color)
|
475 |
+
|
476 |
+
def draw_object_mask(mask, draw):
|
477 |
+
color = (0, 0, 255, 153)
|
478 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
479 |
+
for coord in nonzero_coords:
|
480 |
+
draw.point(coord[::-1], fill=color)
|
481 |
+
|
482 |
+
def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
|
483 |
+
# Define the colors to use for each word
|
484 |
+
color_red = (255, 0, 0)
|
485 |
+
color_black = (0, 0, 0)
|
486 |
+
color_blue = (0, 0, 255)
|
487 |
+
|
488 |
+
# Define the initial font size and spacing between words
|
489 |
+
font_size = 40
|
490 |
+
|
491 |
+
# Create a new image with the specified width and white background
|
492 |
+
image = Image.new('RGB', (width, 60), (255, 255, 255))
|
493 |
+
|
494 |
+
try:
|
495 |
+
# Load the specified font
|
496 |
+
font = ImageFont.truetype(font_path, font_size)
|
497 |
+
|
498 |
+
# Keep increasing the font size until all words fit within the desired width
|
499 |
+
while True:
|
500 |
+
# Create a draw object for the image
|
501 |
+
draw = ImageDraw.Draw(image)
|
502 |
+
|
503 |
+
word_spacing = font_size / 2
|
504 |
+
# Draw each word in the appropriate color
|
505 |
+
x_offset = word_spacing
|
506 |
+
draw.text((x_offset, 0), word1, color_red, font=font)
|
507 |
+
x_offset += font.getsize(word1)[0] + word_spacing
|
508 |
+
draw.text((x_offset, 0), word2, color_black, font=font)
|
509 |
+
x_offset += font.getsize(word2)[0] + word_spacing
|
510 |
+
draw.text((x_offset, 0), word3, color_blue, font=font)
|
511 |
+
|
512 |
+
word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
|
513 |
+
total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
|
514 |
+
|
515 |
+
# Stop increasing font size if the image is within the desired width
|
516 |
+
if total_width <= width:
|
517 |
+
break
|
518 |
+
|
519 |
+
# Increase font size and reset the draw object
|
520 |
+
font_size -= 1
|
521 |
+
image = Image.new('RGB', (width, 50), (255, 255, 255))
|
522 |
+
font = ImageFont.truetype(font_path, font_size)
|
523 |
+
draw = None
|
524 |
+
except Exception as e:
|
525 |
+
pass
|
526 |
+
|
527 |
+
return image
|
528 |
+
|
529 |
+
def concatenate_images_vertical(image1, image2):
|
530 |
+
# Get the dimensions of the two images
|
531 |
+
width1, height1 = image1.size
|
532 |
+
width2, height2 = image2.size
|
533 |
+
|
534 |
+
# Create a new image with the combined height and the maximum width
|
535 |
+
new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
|
536 |
+
|
537 |
+
# Paste the first image at the top of the new image
|
538 |
+
new_image.paste(image1, (0, 0))
|
539 |
+
|
540 |
+
# Paste the second image below the first image
|
541 |
+
new_image.paste(image2, (0, height1))
|
542 |
+
|
543 |
+
return new_image
|
544 |
+
|
545 |
+
def relate_anything(input_image, k):
|
546 |
+
logger.info(f'relate_anything_1_{input_image.size}_')
|
547 |
+
w, h = input_image.size
|
548 |
+
max_edge = 1500
|
549 |
+
if w > max_edge or h > max_edge:
|
550 |
+
ratio = max(w, h) / max_edge
|
551 |
+
new_size = (int(w / ratio), int(h / ratio))
|
552 |
+
input_image.thumbnail(new_size)
|
553 |
+
|
554 |
+
logger.info(f'relate_anything_2_')
|
555 |
+
# load image
|
556 |
+
pil_image = input_image.convert('RGBA')
|
557 |
+
image = np.array(input_image)
|
558 |
+
sam_masks = sam_mask_generator.generate(image)
|
559 |
+
filtered_masks = sort_and_deduplicate(sam_masks)
|
560 |
+
|
561 |
+
logger.info(f'relate_anything_3_')
|
562 |
+
feat_list = []
|
563 |
+
for fm in filtered_masks:
|
564 |
+
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
565 |
+
feat_list.append(feat)
|
566 |
+
feat = torch.cat(feat_list, dim=1).to(device)
|
567 |
+
matrix_output, rel_triplets = ram_model.predict(feat)
|
568 |
+
|
569 |
+
logger.info(f'relate_anything_4_')
|
570 |
+
pil_image_list = []
|
571 |
+
for i, rel in enumerate(rel_triplets[:k]):
|
572 |
+
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
573 |
+
relation = relation_classes[r]
|
574 |
+
|
575 |
+
mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
|
576 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
577 |
+
|
578 |
+
draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
|
579 |
+
draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
|
580 |
+
|
581 |
+
current_pil_image = pil_image.copy()
|
582 |
+
current_pil_image.alpha_composite(mask_image)
|
583 |
+
|
584 |
+
title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
|
585 |
+
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
586 |
+
pil_image_list.append(concate_pil_image)
|
587 |
+
|
588 |
+
logger.info(f'relate_anything_5_{len(pil_image_list)}')
|
589 |
+
return pil_image_list
|
590 |
+
|
591 |
+
mask_source_draw = "draw a mask on input image"
|
592 |
+
mask_source_segment = "type what to detect below"
|
593 |
+
|
594 |
+
def get_time_cost(run_task_time, time_cost_str):
|
595 |
+
now_time = int(time.time()*1000)
|
596 |
+
if run_task_time == 0:
|
597 |
+
time_cost_str = 'start'
|
598 |
+
else:
|
599 |
+
if time_cost_str != '':
|
600 |
+
time_cost_str += f'-->'
|
601 |
+
time_cost_str += f'{now_time - run_task_time}'
|
602 |
+
run_task_time = now_time
|
603 |
+
return run_task_time, time_cost_str
|
604 |
+
|
605 |
+
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
606 |
+
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
607 |
+
|
608 |
+
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
609 |
+
inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
|
610 |
+
|
611 |
+
run_task_time = 0
|
612 |
+
time_cost_str = ''
|
613 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
614 |
+
|
615 |
+
text_prompt = text_prompt.strip()
|
616 |
+
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
617 |
+
if text_prompt == '':
|
618 |
+
return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
619 |
+
|
620 |
+
if input_image is None:
|
621 |
+
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
622 |
+
|
623 |
+
file_temp = int(time.time())
|
624 |
+
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
625 |
+
|
626 |
+
output_images = []
|
627 |
+
|
628 |
+
image_pil, image = load_image(input_image.convert("RGB"))
|
629 |
+
input_img = input_image
|
630 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
631 |
+
|
632 |
+
size = image_pil.size
|
633 |
+
H, W = size[1], size[0]
|
634 |
+
|
635 |
+
# run grounding dino model
|
636 |
+
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
637 |
+
pass
|
638 |
+
else:
|
639 |
+
groundingdino_device = 'cpu'
|
640 |
+
if device != 'cpu':
|
641 |
+
try:
|
642 |
+
from groundingdino import _C
|
643 |
+
groundingdino_device = 'cuda:0'
|
644 |
+
except:
|
645 |
+
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
|
646 |
+
|
647 |
+
boxes_filt, pred_phrases = get_grounding_output(
|
648 |
+
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
649 |
+
)
|
650 |
+
if boxes_filt.size(0) == 0:
|
651 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
652 |
+
return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
653 |
+
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
654 |
+
|
655 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
656 |
+
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
657 |
+
image = np.array(input_img)
|
658 |
+
if sam_predictor:
|
659 |
+
sam_predictor.set_image(image)
|
660 |
+
|
661 |
+
for i in range(boxes_filt.size(0)):
|
662 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
663 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
664 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
665 |
+
|
666 |
+
if sam_predictor:
|
667 |
+
boxes_filt = boxes_filt.to(sam_device)
|
668 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
669 |
+
|
670 |
+
masks, _, _, _ = sam_predictor.predict_torch(
|
671 |
+
point_coords = None,
|
672 |
+
point_labels = None,
|
673 |
+
boxes = transformed_boxes,
|
674 |
+
multimask_output = False,
|
675 |
+
)
|
676 |
+
# masks: [9, 1, 512, 512]
|
677 |
+
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
678 |
+
else:
|
679 |
+
masks = torch.zeros(len(boxes_filt), 1, H, W)
|
680 |
+
mask_count = 0
|
681 |
+
for box in boxes_filt:
|
682 |
+
masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
|
683 |
+
mask_count += 1
|
684 |
+
masks = torch.where(masks > 0, True, False)
|
685 |
+
run_mode = "rectangle"
|
686 |
+
|
687 |
+
# draw output image
|
688 |
+
plt.figure(figsize=(10, 10))
|
689 |
+
plt.imshow(image)
|
690 |
+
for mask in masks:
|
691 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
692 |
+
for box, label in zip(boxes_filt, pred_phrases):
|
693 |
+
show_box(box.cpu().numpy(), plt.gca(), label)
|
694 |
+
plt.axis('off')
|
695 |
+
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
696 |
+
plt.savefig(image_path, bbox_inches="tight")
|
697 |
+
plt.clf()
|
698 |
+
plt.close('all')
|
699 |
+
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
700 |
+
os.remove(image_path)
|
701 |
+
output_images.append(Image.fromarray(segment_image_result))
|
702 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
703 |
+
|
704 |
+
print(sam_predictor)
|
705 |
+
|
706 |
+
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
707 |
+
task_type = 'remove'
|
708 |
+
|
709 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
710 |
+
if mask_source_radio == mask_source_draw:
|
711 |
+
mask_pil = input_mask_pil
|
712 |
+
mask = input_mask
|
713 |
+
else:
|
714 |
+
masks_ori = copy.deepcopy(masks)
|
715 |
+
if inpaint_mode == 'merge':
|
716 |
+
masks = torch.sum(masks, dim=0).unsqueeze(0)
|
717 |
+
masks = torch.where(masks > 0, True, False)
|
718 |
+
mask = masks[0][0].cpu().numpy()
|
719 |
+
mask_pil = Image.fromarray(mask)
|
720 |
+
output_images.append(mask_pil.convert("RGB"))
|
721 |
+
return mask_pil
|
722 |
+
|
723 |
+
def change_radio_display(task_type, mask_source_radio):
|
724 |
+
text_prompt_visible = True
|
725 |
+
inpaint_prompt_visible = False
|
726 |
+
mask_source_radio_visible = False
|
727 |
+
num_relation_visible = False
|
728 |
+
|
729 |
+
image_gallery_visible = True
|
730 |
+
kosmos_input_visible = False
|
731 |
+
kosmos_output_visible = False
|
732 |
+
kosmos_text_output_visible = False
|
733 |
+
|
734 |
+
if task_type == "Kosmos-2":
|
735 |
+
if kosmos_enable:
|
736 |
+
text_prompt_visible = False
|
737 |
+
image_gallery_visible = False
|
738 |
+
kosmos_input_visible = True
|
739 |
+
kosmos_output_visible = True
|
740 |
+
kosmos_text_output_visible = True
|
741 |
+
|
742 |
+
if task_type == "inpainting":
|
743 |
+
inpaint_prompt_visible = True
|
744 |
+
if task_type == "inpainting" or task_type == "remove":
|
745 |
+
mask_source_radio_visible = True
|
746 |
+
if mask_source_radio == mask_source_draw:
|
747 |
+
text_prompt_visible = False
|
748 |
+
if task_type == "relate anything":
|
749 |
+
text_prompt_visible = False
|
750 |
+
num_relation_visible = True
|
751 |
+
|
752 |
+
return (gr.Textbox.update(visible=text_prompt_visible),
|
753 |
+
gr.Textbox.update(visible=inpaint_prompt_visible),
|
754 |
+
gr.Radio.update(visible=mask_source_radio_visible),
|
755 |
+
gr.Slider.update(visible=num_relation_visible),
|
756 |
+
gr.Gallery.update(visible=image_gallery_visible),
|
757 |
+
gr.Radio.update(visible=kosmos_input_visible),
|
758 |
+
gr.Image.update(visible=kosmos_output_visible),
|
759 |
+
gr.HighlightedText.update(visible=kosmos_text_output_visible))
|
760 |
+
|
761 |
+
def get_model_device(module):
|
762 |
+
try:
|
763 |
+
if module is None:
|
764 |
+
return 'None'
|
765 |
+
if isinstance(module, torch.nn.DataParallel):
|
766 |
+
module = module.module
|
767 |
+
for submodule in module.children():
|
768 |
+
if hasattr(submodule, "_parameters"):
|
769 |
+
parameters = submodule._parameters
|
770 |
+
if "weight" in parameters:
|
771 |
+
return parameters["weight"].device
|
772 |
+
return 'UnKnown'
|
773 |
+
except Exception as e:
|
774 |
+
return 'Error'
|
775 |
+
|
776 |
+
def main_gradio(args):
|
777 |
+
block = gr.Blocks().queue()
|
778 |
+
with block:
|
779 |
+
with gr.Row():
|
780 |
+
with gr.Column():
|
781 |
+
task_types = ["detection"]
|
782 |
+
if sam_enable:
|
783 |
+
task_types.append("segment")
|
784 |
+
if inpainting_enable:
|
785 |
+
task_types.append("inpainting")
|
786 |
+
if lama_cleaner_enable:
|
787 |
+
task_types.append("remove")
|
788 |
+
if ram_enable:
|
789 |
+
task_types.append("relate anything")
|
790 |
+
if kosmos_enable:
|
791 |
+
task_types.append("Kosmos-2")
|
792 |
+
|
793 |
+
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
|
794 |
+
height=512, brush_color='#00FFFF', mask_opacity=0.6)
|
795 |
+
task_type = gr.Radio(task_types, value="detection",
|
796 |
+
label='Task type', visible=True)
|
797 |
+
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
798 |
+
value=mask_source_segment, label="Mask from",
|
799 |
+
visible=False)
|
800 |
+
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
801 |
+
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
802 |
+
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
803 |
+
|
804 |
+
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
805 |
+
|
806 |
+
run_button = gr.Button(label="Run", visible=True)
|
807 |
+
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
808 |
+
box_threshold = gr.Slider(
|
809 |
+
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
810 |
+
)
|
811 |
+
text_threshold = gr.Slider(
|
812 |
+
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
813 |
+
)
|
814 |
+
iou_threshold = gr.Slider(
|
815 |
+
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
|
816 |
+
)
|
817 |
+
inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
|
818 |
+
with gr.Row():
|
819 |
+
with gr.Column(scale=1):
|
820 |
+
remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
821 |
+
with gr.Column(scale=1):
|
822 |
+
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
823 |
+
|
824 |
+
with gr.Column():
|
825 |
+
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
826 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
827 |
+
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
828 |
+
|
829 |
+
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
830 |
+
kosmos_text_output = gr.HighlightedText(
|
831 |
+
label="Generated Description",
|
832 |
+
combine_adjacent=False,
|
833 |
+
show_legend=True,
|
834 |
+
visible=False,
|
835 |
+
).style(color_map=color_map)
|
836 |
+
# record which text span (label) is selected
|
837 |
+
selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
|
838 |
+
|
839 |
+
# record the current `entities`
|
840 |
+
entity_output = gr.Textbox(visible=False)
|
841 |
+
|
842 |
+
# get the current selected span label
|
843 |
+
def get_text_span_label(evt: gr.SelectData):
|
844 |
+
if evt.value[-1] is None:
|
845 |
+
return -1
|
846 |
+
return int(evt.value[-1])
|
847 |
+
# and set this information to `selected`
|
848 |
+
kosmos_text_output.select(get_text_span_label, None, selected)
|
849 |
+
|
850 |
+
# update output image when we change the span (enity) selection
|
851 |
+
def update_output_image(img_input, image_output, entities, idx):
|
852 |
+
entities = ast.literal_eval(entities)
|
853 |
+
updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
|
854 |
+
return updated_image
|
855 |
+
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
856 |
+
|
857 |
+
run_button.click(fn=run_anything_task, inputs=[
|
858 |
+
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
859 |
+
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
860 |
+
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
861 |
+
|
862 |
+
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
863 |
+
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
864 |
+
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
865 |
+
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
|
866 |
+
image_gallery, kosmos_input, kosmos_output, kosmos_text_output
|
867 |
+
])
|
868 |
+
|
869 |
+
DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
870 |
+
if lama_cleaner_enable:
|
871 |
+
DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
|
872 |
+
if kosmos_enable:
|
873 |
+
DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2). <br>'
|
874 |
+
if ram_enable:
|
875 |
+
DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
|
876 |
+
DESCRIPTION += f'Thanks for their excellent work.'
|
877 |
+
DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
|
878 |
+
<a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
879 |
+
gr.Markdown(DESCRIPTION)
|
880 |
+
|
881 |
+
print(f'device = {device}')
|
882 |
+
print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
|
883 |
+
computer_info()
|
884 |
+
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
|
885 |
+
|
886 |
+
import signal
|
887 |
+
import json
|
888 |
+
from datetime import date, datetime, timedelta
|
889 |
+
from gevent import pywsgi
|
890 |
+
import base64
|
891 |
+
|
892 |
+
def imgFile_to_base64(image_file):
|
893 |
+
with open(image_file, "rb") as f:
|
894 |
+
im_bytes = f.read()
|
895 |
+
im_b64_encode = base64.b64encode(im_bytes)
|
896 |
+
im_b64 = im_b64_encode.decode("utf8")
|
897 |
+
return im_b64
|
898 |
+
|
899 |
+
def base64_to_bytes(im_b64):
|
900 |
+
im_b64_encode = im_b64.encode("utf-8")
|
901 |
+
im_bytes = base64.b64decode(im_b64_encode)
|
902 |
+
return im_bytes
|
903 |
+
|
904 |
+
def base64_to_PILImage(im_b64):
|
905 |
+
im_bytes = base64_to_bytes(im_b64)
|
906 |
+
pil_img = Image.open(io.BytesIO(im_bytes))
|
907 |
+
return pil_img
|
908 |
+
|
909 |
+
class API_Starter:
|
910 |
+
def __init__(self):
|
911 |
+
from flask import Flask, request, jsonify, make_response
|
912 |
+
from flask_cors import CORS, cross_origin
|
913 |
+
import logging
|
914 |
+
|
915 |
+
app = Flask(__name__)
|
916 |
+
app.logger.setLevel(logging.ERROR)
|
917 |
+
CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
|
918 |
+
|
919 |
+
@app.route('/imgCLeaner', methods=['GET', 'POST'])
|
920 |
+
@cross_origin()
|
921 |
+
def processAssist():
|
922 |
+
if request.method == 'GET':
|
923 |
+
ret_json = {'code': -1, 'reason':'no support to get'}
|
924 |
+
elif request.method == 'POST':
|
925 |
+
request_data = request.data.decode('utf-8')
|
926 |
+
data = json.loads(request_data)
|
927 |
+
result = self.handle_data(data)
|
928 |
+
if result is None:
|
929 |
+
ret_json = {'code': -2, 'reason':'handle error'}
|
930 |
+
else:
|
931 |
+
ret_json = {'code': 0, 'result':result}
|
932 |
+
return jsonify(ret_json)
|
933 |
+
|
934 |
+
self.app = app
|
935 |
+
now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
936 |
+
logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
|
937 |
+
signal.signal(signal.SIGINT, self.signal_handler)
|
938 |
+
|
939 |
+
def handle_data(self, data):
|
940 |
+
im_b64 = data['img']
|
941 |
+
img = base64_to_PILImage(im_b64)
|
942 |
+
remove_texts = data['remove_texts']
|
943 |
+
remove_mask_extend = data['mask_extend']
|
944 |
+
results = run_anything_task(input_image = img,
|
945 |
+
text_prompt = f"{remove_texts}",
|
946 |
+
task_type = 'remove',
|
947 |
+
inpaint_prompt = '',
|
948 |
+
box_threshold = 0.3,
|
949 |
+
text_threshold = 0.25,
|
950 |
+
iou_threshold = 0.8,
|
951 |
+
inpaint_mode = "merge",
|
952 |
+
mask_source_radio = "type what to detect below",
|
953 |
+
remove_mode = "rectangle", # ["segment", "rectangle"]
|
954 |
+
remove_mask_extend = f"{remove_mask_extend}",
|
955 |
+
num_relation = 5,
|
956 |
+
kosmos_input = None,
|
957 |
+
cleaner_size_limit = -1,
|
958 |
+
)
|
959 |
+
output_images = results[0]
|
960 |
+
if output_images is None:
|
961 |
+
return None
|
962 |
+
ret_json_images = []
|
963 |
+
file_temp = int(time.time())
|
964 |
+
count = 0
|
965 |
+
output_images = output_images[-1:]
|
966 |
+
for image_pil in output_images:
|
967 |
+
try:
|
968 |
+
img_format = image_pil.format.lower()
|
969 |
+
except Exception as e:
|
970 |
+
img_format = 'png'
|
971 |
+
image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
|
972 |
+
count += 1
|
973 |
+
try:
|
974 |
+
image_pil.save(image_path)
|
975 |
+
except Exception as e:
|
976 |
+
Image.fromarray(image_pil).save(image_path)
|
977 |
+
im_b64 = imgFile_to_base64(image_path)
|
978 |
+
ret_json_images.append(im_b64)
|
979 |
+
os.remove(image_path)
|
980 |
+
data = {
|
981 |
+
'imgs': ret_json_images,
|
982 |
+
}
|
983 |
+
return data
|
984 |
+
|
985 |
+
def signal_handler(self, signal, frame):
|
986 |
+
print('\nSignal Catched! You have just type Ctrl+C!')
|
987 |
+
sys.exit(0)
|
988 |
+
|
989 |
+
def run(self):
|
990 |
+
from gevent import pywsgi
|
991 |
+
logger.info(f'\nargs={args}\n')
|
992 |
+
computer_info()
|
993 |
+
print(f"Start a api server: http://0.0.0.0:{args.port}/imgCLeaner")
|
994 |
+
server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
|
995 |
+
server.serve_forever()
|
996 |
+
|
997 |
+
def main_api(args):
|
998 |
+
if args.port == 0:
|
999 |
+
print('Please give valid port!')
|
1000 |
+
else:
|
1001 |
+
api_starter = API_Starter()
|
1002 |
+
api_starter.run()
|
1003 |
+
|
1004 |
+
if __name__ == "__main__":
|
1005 |
+
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
1006 |
+
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
1007 |
+
parser.add_argument("--share", action="store_true", help="share the app")
|
1008 |
+
parser.add_argument("--port", "-p", type=int, default=7860, help="port")
|
1009 |
+
args, _ = parser.parse_known_args()
|
1010 |
+
print(f'args = {args}')
|
1011 |
+
|
1012 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
1013 |
+
os.system("pip list")
|
1014 |
+
|
1015 |
+
device = set_device()
|
1016 |
+
if device == 'cpu':
|
1017 |
+
kosmos_enable = False
|
1018 |
+
|
1019 |
+
if kosmos_enable:
|
1020 |
+
kosmos_model, kosmos_processor = load_kosmos_model(device)
|
1021 |
+
|
1022 |
+
if groundingdino_enable:
|
1023 |
+
groundingdino_model = load_groundingdino_model('cpu')
|
1024 |
+
|
1025 |
+
if sam_enable:
|
1026 |
+
load_sam_model(device)
|
1027 |
+
|
1028 |
+
if inpainting_enable:
|
1029 |
+
load_sd_model(device)
|
1030 |
+
|
1031 |
+
if lama_cleaner_enable:
|
1032 |
+
load_lama_cleaner_model(device)
|
1033 |
+
|
1034 |
+
if ram_enable:
|
1035 |
+
load_ram_model(device)
|
1036 |
+
|
1037 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
1038 |
+
os.system("pip list")
|
1039 |
+
|
1040 |
+
def just_fucking_get_sd_mask(input_pil, prompt):
|
1041 |
+
return run_anything_task(input_pil, prompt, "inpainting", "", 0.3, 0.25, 0.8, "merge", "type what to detect below", "segment", "10", 5, "Brief")
|
1042 |
+
|
1043 |
+
just_fucking_get_sd_mask(Image.open("chick.png"), "face . shoes").save("fucking.png")
|
1044 |
+
just_fucking_get_sd_mask(Image.open("chick.png"), "face . shoes").save("fucking2.png")
|
1045 |
+
|
1046 |
+
class EndpointHandler():
|
1047 |
+
def __init__(self, path=""):
|
1048 |
+
pass
|
1049 |
+
|
1050 |
+
def __call__(self, data):
|
1051 |
+
original_link = data.get("original_link")
|
1052 |
+
response = requests.get(original_link, verify=False)
|
1053 |
+
byte_arr = response.content
|
1054 |
+
original_image = Image.open(io.BytesIO(byte_arr))
|
1055 |
+
|
1056 |
+
mask_pil = just_fucking_get_sd_mask(original_image, "person")
|
1057 |
+
|
1058 |
+
img_byte_arr = io.BytesIO()
|
1059 |
+
mask_pil.save(img_byte_arr, format="PNG")
|
1060 |
+
img_byte_arr = img_byte_arr.getvalue()
|
1061 |
+
|
1062 |
+
# Upload to file.io
|
1063 |
+
response = requests.post('https://file.io', files={'file': img_byte_arr})
|
1064 |
+
link = response.json()['link']
|
1065 |
+
|
1066 |
+
return link
|
1067 |
+
|
1068 |
+
print(EndpointHandler()({
|
1069 |
+
"original_link": "https://cdn.karneval-megastore.de/images/rep_art/gra/310/6/310698/justice-league-wonder-woman-damenkostum-lizenzware-blau-gold-rot.jpg"
|
1070 |
+
}))
|
requirements.txt
CHANGED
@@ -15,8 +15,10 @@ setuptools
|
|
15 |
supervision
|
16 |
termcolor
|
17 |
timm
|
18 |
-
torch==2.0.0
|
19 |
-
torchvision==0.15.1
|
|
|
|
|
20 |
|
21 |
gevent
|
22 |
yapf
|
@@ -33,7 +35,7 @@ transformers==4.27.4
|
|
33 |
# lama-cleaner==1.2.4
|
34 |
lama-cleaner@git+https://github.com/yizhangliu/lama-cleaner.git@main
|
35 |
|
36 |
-
mmcv==
|
37 |
mmengine
|
38 |
openmim==0.3.9
|
39 |
|
|
|
15 |
supervision
|
16 |
termcolor
|
17 |
timm
|
18 |
+
# torch==2.0.0 # is production
|
19 |
+
# torchvision==0.15.1 # is production
|
20 |
+
# torch
|
21 |
+
# torchvision
|
22 |
|
23 |
gevent
|
24 |
yapf
|
|
|
35 |
# lama-cleaner==1.2.4
|
36 |
lama-cleaner@git+https://github.com/yizhangliu/lama-cleaner.git@main
|
37 |
|
38 |
+
mmcv==1.7.1
|
39 |
mmengine
|
40 |
openmim==0.3.9
|
41 |
|