pradanaadn's picture
Rename main.py to app.py
ef48eea verified
import gradio as gr
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from torchvision.models import mobilenet_v3_large
from torchvision.transforms import v2
from PIL import Image
class TrashMobileNet(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes=6):
super(TrashMobileNet, self).__init__()
self.model = mobilenet_v3_large(weights="DEFAULT")
for param in self.model.parameters():
param.requires_grad = False
num_features = self.model.classifier[-1].in_features
self.model.classifier[-1] = nn.Linear(num_features, num_classes)
for param in self.model.classifier[-1].parameters():
param.requires_grad = True
def forward(self, x):
x = self.model(x)
return x
model_name = "pradanaadn/trash-clasification"
model = TrashMobileNet.from_pretrained(model_name)
model.eval()
transform = v2.Compose([
v2.Resize((224, 224)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
def predict(image):
labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probabilities = probabilities[0].tolist()
# Create dictionary of label-probability pairs
return {label: float(prob) for label, prob in zip(labels, probabilities)}
examples = [
["examples/cardbox.jpeg", "A cardboard box"],
["examples/glass.jpeg", "A glass bottle"],
["examples/plastic.png", "Mixed trash"]
]
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
elem_id="image_upload"
)
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column():
output_label = gr.Label(
label="Classification Results",
num_top_classes=6
)
gr.Markdown("### Example Images")
gr.Examples(
examples=examples,
inputs=input_image,
outputs=output_label,
fn=predict,
cache_examples=True
)
submit_btn.click(
fn=predict,
inputs=input_image,
outputs=output_label
)
# Launch the interface
iface.launch()