|
import cv2 |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
from vision.ssd.mobilenet_v2_ssd_lite import ( |
|
create_mobilenetv2_ssd_lite, |
|
create_mobilenetv2_ssd_lite_predictor, |
|
) |
|
|
|
MODEL_REPO = "fa0311/oita-ken-strawberries-mobilenet" |
|
MODEL_FILENAME = "20250129_053504/mb2-ssd-lite-Epoch-55-Loss-1.508891262114048.pth" |
|
LABELS_FILENAME = "20250129_053504/labels.txt" |
|
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
label_path = hf_hub_download(repo_id=MODEL_REPO, filename=LABELS_FILENAME) |
|
|
|
with open(label_path, "r") as f: |
|
class_names = [name.strip() for name in f.readlines()] |
|
|
|
net = create_mobilenetv2_ssd_lite(len(class_names), is_test=True) |
|
net.load(model_path) |
|
predictor = create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200) |
|
|
|
|
|
def detect_objects(image, threshold): |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
boxes, labels, probs = predictor.predict(image, 10, threshold) |
|
|
|
for i in range(boxes.size(0)): |
|
box = list(map(int, boxes[i, :])) |
|
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 255, 0), 4) |
|
label = f"{class_names[labels[i]]}: {probs[i]:.2f}" |
|
cv2.putText( |
|
image, |
|
label, |
|
(box[0] + 10, box[1] + 25), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.8, |
|
(255, 0, 255), |
|
2, |
|
) |
|
|
|
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=detect_objects, |
|
inputs=[ |
|
gr.Image(type="numpy"), |
|
gr.Slider(0.1, 1.0, value=0.7, label="Detection Threshold"), |
|
], |
|
outputs=gr.Image(type="numpy"), |
|
title="SSD Object Detection - Strawberry quality classification", |
|
description="Upload an image of strawberries to detect objects using MobileNetV2-SSD-Lite.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|