detection-repoter / model.py
xingqiang's picture
UPDATE MODES google/paligemma-3b-ft-coco35l-224
e7d85ce
raw
history blame contribute delete
818 Bytes
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
import torch
from config import MODEL_NAME
class RadarDetectionModel:
def __init__(self):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
"google/paligemma-3b-ft-coco35l-224")
self.model = AutoModelForObjectDetection.from_pretrained(
"google/paligemma-3b-ft-coco35l-224")
self.model.eval()
@torch.no_grad()
def detect(self, image):
inputs = self.feature_extractor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = self.feature_extractor.post_process_object_detection(
outputs, threshold=0.5, target_sizes=target_sizes)[0]
return results