supersolar commited on
Commit
1501500
·
verified ·
1 Parent(s): dba3ac4

Create florencegpu1.py

Browse files
Files changed (1) hide show
  1. utils/florencegpu1.py +58 -0
utils/florencegpu1.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
11
+ #FLORENCE_CHECKPOINT = "microsoft/Florence-2-large-ft"
12
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
13
+ FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
14
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
15
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
16
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
17
+
18
+
19
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
20
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
21
+ if not str(filename).endswith("/modeling_florence2.py"):
22
+ return get_imports(filename)
23
+ imports = get_imports(filename)
24
+ imports.remove("flash_attn")
25
+ return imports
26
+
27
+
28
+ def load_florence_model(
29
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
30
+ ) -> Tuple[Any, Any]:
31
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
34
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
35
+ return model, processor
36
+
37
+
38
+ def run_florence_inference(
39
+ model: Any,
40
+ processor: Any,
41
+ device: torch.device,
42
+ image: Image,
43
+ task: str,
44
+ text: str = ""
45
+ ) -> Tuple[str, Dict]:
46
+ prompt = task + text
47
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
48
+ generated_ids = model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=1024,
52
+ num_beams=3
53
+ )
54
+ generated_text = processor.batch_decode(
55
+ generated_ids, skip_special_tokens=False)[0]
56
+ response = processor.post_process_generation(
57
+ generated_text, task=task, image_size=image.size)
58
+ return generated_text, response