Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
from torchvision.models import mobilenet_v3_large | |
from torchvision.transforms import v2 | |
from PIL import Image | |
class TrashMobileNet(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, num_classes=6): | |
super(TrashMobileNet, self).__init__() | |
self.model = mobilenet_v3_large(weights="DEFAULT") | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
num_features = self.model.classifier[-1].in_features | |
self.model.classifier[-1] = nn.Linear(num_features, num_classes) | |
for param in self.model.classifier[-1].parameters(): | |
param.requires_grad = True | |
def forward(self, x): | |
x = self.model(x) | |
return x | |
model_name = "pradanaadn/trash-clasification" | |
model = TrashMobileNet.from_pretrained(model_name) | |
model.eval() | |
transform = v2.Compose([ | |
v2.Resize((224, 224)), | |
v2.ToImage(), | |
v2.ToDtype(torch.float32, scale=True), | |
]) | |
def predict(image): | |
labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
image_tensor = transform(image) | |
image_tensor = image_tensor.unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
probabilities = probabilities[0].tolist() | |
# Create dictionary of label-probability pairs | |
return {label: float(prob) for label, prob in zip(labels, probabilities)} | |
examples = [ | |
["examples/cardbox.jpeg", "A cardboard box"], | |
["examples/glass.jpeg", "A glass bottle"], | |
["examples/plastic.png", "Mixed trash"] | |
] | |
with gr.Blocks() as iface: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
elem_id="image_upload" | |
) | |
submit_btn = gr.Button("Classify", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label( | |
label="Classification Results", | |
num_top_classes=6 | |
) | |
gr.Markdown("### Example Images") | |
gr.Examples( | |
examples=examples, | |
inputs=input_image, | |
outputs=output_label, | |
fn=predict, | |
cache_examples=True | |
) | |
submit_btn.click( | |
fn=predict, | |
inputs=input_image, | |
outputs=output_label | |
) | |
# Launch the interface | |
iface.launch() |