adaface-neurips commited on
Commit
bef4321
·
1 Parent(s): b527a08

Rename functions and variables

Browse files
Files changed (2) hide show
  1. lib/pipline_ConsistentID.py +22 -23
  2. models/insightface +1 -0
lib/pipline_ConsistentID.py CHANGED
@@ -80,9 +80,9 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
80
  self.id_image_processor = CLIPImageProcessor()
81
  self.crop_size = 512
82
 
83
- # FaceID
84
- self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
85
- self.app.prepare(ctx_id=0, det_size=(640, 640))
86
 
87
  if not os.path.exists(consistentID_weight_path):
88
  ### Download pretrained models
@@ -172,8 +172,8 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
172
  # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
173
  # clip_encoder maps image parts to image-space diffusion prompts.
174
  # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
175
- def get_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
176
- facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
177
 
178
  hidden_states = []
179
  uncond_hidden_states = []
@@ -200,13 +200,13 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
200
 
201
  @torch.inference_mode()
202
  # Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings.
203
- def get_global_id_embeds(self, faceid_embeds, face_image, s_scale, shortcut=False):
 
 
 
 
204
 
205
- clip_image = self.clip_preprocessor(images=face_image, return_tensors="pt").pixel_values
206
- clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
207
- clip_image_embeds = self.clip_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
208
- uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
209
-
210
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
211
  # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
212
  # clip_image_embeds are used as queries to transform faceid_embeds.
@@ -222,9 +222,9 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
222
  attn_processor.scale = scale
223
 
224
  @torch.inference_mode()
225
- def extract_faceid(self, face_image):
226
- faceid_image = np.array(face_image)
227
- faces = self.app.get(faceid_image)
228
  if faces==[]:
229
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
230
  else:
@@ -377,8 +377,7 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
377
  self.vae = None
378
 
379
  # input_subj_image_obj: an Image object.
380
- def generate_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
381
- faceid_embeds = self.extract_faceid(face_image=input_subj_image_obj)
382
  face_caption = "The person has one nose, two eyes, two ears, and a mouth."
383
  key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
384
 
@@ -403,9 +402,9 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
403
 
404
  # 5. Prepare the input ID images
405
  # global_id_embeds: [1, 4, 768]
406
- # get_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
407
  global_id_embeds, uncond_global_id_embeds = \
408
- self.get_global_id_embeds(faceid_embeds, face_image=input_subj_image_obj, s_scale=1.0, shortcut=False)
409
 
410
  # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
411
  parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
@@ -423,13 +422,13 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
423
  # text_local_id_embeds: [1, 77, 768]
424
  # text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
425
  # to text_global_id_embeds on the rest 73 tokens.
426
- # get_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
427
  # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
428
  # parsed_image_parts2: [1, 5, 3, 224, 224]
429
  text_local_id_embeds, uncond_text_local_id_embeds = \
430
- self.get_local_facial_embeds(text_embeds, uncond_text_embeds, \
431
- parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
432
- calc_uncond=calc_uncond)
433
 
434
  # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
435
  text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
@@ -508,7 +507,7 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
508
 
509
  # 3. Encode input prompt
510
  coarse_prompt_embeds, fine_prompt_embeds = \
511
- self.generate_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
512
 
513
  # 7. Prepare timesteps
514
  self.scheduler.set_timesteps(num_inference_steps, device=device)
 
80
  self.id_image_processor = CLIPImageProcessor()
81
  self.crop_size = 512
82
 
83
+ # face_app: FaceAnalysis object
84
+ self.face_app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider'])
85
+ self.face_app.prepare(ctx_id=0, det_size=(640, 640))
86
 
87
  if not os.path.exists(consistentID_weight_path):
88
  ### Download pretrained models
 
172
  # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
173
  # clip_encoder maps image parts to image-space diffusion prompts.
174
  # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
175
+ def extract_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
176
+ facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
177
 
178
  hidden_states = []
179
  uncond_hidden_states = []
 
200
 
201
  @torch.inference_mode()
202
  # Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings.
203
+ def extract_global_id_embeds(self, face_image_obj, s_scale=1.0, shortcut=False):
204
+ clip_image_ts = self.clip_preprocessor(images=face_image_obj, return_tensors="pt").pixel_values
205
+ clip_image_ts = clip_image_ts.to(self.device, dtype=self.torch_dtype)
206
+ clip_image_embeds = self.clip_encoder(clip_image_ts, output_hidden_states=True).hidden_states[-2]
207
+ uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image_ts), output_hidden_states=True).hidden_states[-2]
208
 
209
+ faceid_embeds = self.extract_faceid(face_image_obj)
 
 
 
 
210
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
211
  # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
212
  # clip_image_embeds are used as queries to transform faceid_embeds.
 
222
  attn_processor.scale = scale
223
 
224
  @torch.inference_mode()
225
+ def extract_faceid(self, face_image_obj):
226
+ faceid_image = np.array(face_image_obj)
227
+ faces = self.face_app.get(faceid_image)
228
  if faces==[]:
229
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
230
  else:
 
377
  self.vae = None
378
 
379
  # input_subj_image_obj: an Image object.
380
+ def extract_double_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
 
381
  face_caption = "The person has one nose, two eyes, two ears, and a mouth."
382
  key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
383
 
 
402
 
403
  # 5. Prepare the input ID images
404
  # global_id_embeds: [1, 4, 768]
405
+ # extract_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
406
  global_id_embeds, uncond_global_id_embeds = \
407
+ self.extract_global_id_embeds(face_image_obj=input_subj_image_obj, s_scale=1.0, shortcut=False)
408
 
409
  # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
410
  parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
 
422
  # text_local_id_embeds: [1, 77, 768]
423
  # text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
424
  # to text_global_id_embeds on the rest 73 tokens.
425
+ # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
426
  # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
427
  # parsed_image_parts2: [1, 5, 3, 224, 224]
428
  text_local_id_embeds, uncond_text_local_id_embeds = \
429
+ self.extract_local_facial_embeds(text_embeds, uncond_text_embeds, \
430
+ parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
431
+ calc_uncond=calc_uncond)
432
 
433
  # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
434
  text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
 
507
 
508
  # 3. Encode input prompt
509
  coarse_prompt_embeds, fine_prompt_embeds = \
510
+ self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
511
 
512
  # 7. Prepare timesteps
513
  self.scheduler.set_timesteps(num_inference_steps, device=device)
models/insightface ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/lish/adaprompt/models/insightface