import cv2 import gradio as gr from huggingface_hub import hf_hub_download from vision.ssd.mobilenet_v2_ssd_lite import ( create_mobilenetv2_ssd_lite, create_mobilenetv2_ssd_lite_predictor, ) MODEL_REPO = "fa0311/oita-ken-strawberries-mobilenet" MODEL_FILENAME = "20250129_053504/mb2-ssd-lite-Epoch-55-Loss-1.508891262114048.pth" LABELS_FILENAME = "20250129_053504/labels.txt" model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) label_path = hf_hub_download(repo_id=MODEL_REPO, filename=LABELS_FILENAME) with open(label_path, "r") as f: class_names = [name.strip() for name in f.readlines()] net = create_mobilenetv2_ssd_lite(len(class_names), is_test=True) net.load(model_path) predictor = create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200) def detect_objects(image, threshold): image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) boxes, labels, probs = predictor.predict(image, 10, threshold) for i in range(boxes.size(0)): box = list(map(int, boxes[i, :])) cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 255, 0), 4) label = f"{class_names[labels[i]]}: {probs[i]:.2f}" cv2.putText( image, label, (box[0] + 10, box[1] + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 255), 2, ) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) iface = gr.Interface( fn=detect_objects, inputs=[ gr.Image(type="numpy"), gr.Slider(0.1, 1.0, value=0.7, label="Detection Threshold"), ], outputs=gr.Image(type="numpy"), title="SSD Object Detection - Strawberry quality classification", description="Upload an image of strawberries to detect objects using MobileNetV2-SSD-Lite.", ) if __name__ == "__main__": iface.launch()