import gradio as gr import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from torchvision.models import mobilenet_v3_large from torchvision.transforms import v2 from PIL import Image class TrashMobileNet(nn.Module, PyTorchModelHubMixin): def __init__(self, num_classes=6): super(TrashMobileNet, self).__init__() self.model = mobilenet_v3_large(weights="DEFAULT") for param in self.model.parameters(): param.requires_grad = False num_features = self.model.classifier[-1].in_features self.model.classifier[-1] = nn.Linear(num_features, num_classes) for param in self.model.classifier[-1].parameters(): param.requires_grad = True def forward(self, x): x = self.model(x) return x model_name = "pradanaadn/trash-clasification" model = TrashMobileNet.from_pretrained(model_name) model.eval() transform = v2.Compose([ v2.Resize((224, 224)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) def predict(image): labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] if not isinstance(image, Image.Image): image = Image.fromarray(image) image_tensor = transform(image) image_tensor = image_tensor.unsqueeze(0) with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) probabilities = probabilities[0].tolist() # Create dictionary of label-probability pairs return {label: float(prob) for label, prob in zip(labels, probabilities)} examples = [ ["examples/cardbox.jpeg", "A cardboard box"], ["examples/glass.jpeg", "A glass bottle"], ["examples/plastic.png", "Mixed trash"] ] with gr.Blocks() as iface: with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Image", type="pil", elem_id="image_upload" ) submit_btn = gr.Button("Classify", variant="primary") with gr.Column(): output_label = gr.Label( label="Classification Results", num_top_classes=6 ) gr.Markdown("### Example Images") gr.Examples( examples=examples, inputs=input_image, outputs=output_label, fn=predict, cache_examples=True ) submit_btn.click( fn=predict, inputs=input_image, outputs=output_label ) # Launch the interface iface.launch()