import streamlit as st from PIL import Image import os from fastapi import FastAPI, File, UploadFile from pydantic import BaseModel import uvicorn import threading from werkzeug.utils import secure_filename from utils import setup_and_predict from io import BytesIO # Configure upload folder and allowed extensions UPLOAD_FOLDER = 'uploads' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} if not os.path.exists(UPLOAD_FOLDER): os.makedirs(UPLOAD_FOLDER) def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def save_uploaded_file(uploaded_file): filename = secure_filename(uploaded_file.name) filepath = os.path.join(UPLOAD_FOLDER, filename) with open(filepath, "wb") as f: f.write(uploaded_file.getbuffer()) return filepath # FastAPI app for Postman requests api = FastAPI() class PredictionResponse(BaseModel): filename: str result: str @api.post("/predict", response_model=PredictionResponse) async def predict_api(file: UploadFile = File(...)): if allowed_file(file.filename): filename = secure_filename(file.filename) filepath = os.path.join(UPLOAD_FOLDER, filename) with open(filepath, "wb") as f: f.write(await file.read()) try: result = setup_and_predict(filepath) except Exception as e: result = f"Unable to process the request: {e}" return PredictionResponse(filename=filename, result=result) return {"error": "Invalid file type"} def run_api(): uvicorn.run(api, host="0.0.0.0", port=8000) # Streamlit app st.title("Upload an Image for Prediction") uploaded_file = st.file_uploader("Choose an image...", type=['png', 'jpg', 'jpeg', 'gif']) if uploaded_file is not None and allowed_file(uploaded_file.name): filepath = save_uploaded_file(uploaded_file) try: result = setup_and_predict(filepath) st.image(filepath, caption='Uploaded Image.', use_column_width=True) st.write("Prediction Result:") st.write(result) except Exception as e: st.error(f"Unable to process the request: {e}") # Option to download the image with open(filepath, "rb") as file: st.download_button(label="Download Image", data=file, file_name=uploaded_file.name) else: st.warning("Please upload a valid image file.") # Run FastAPI in a separate thread threading.Thread(target=run_api).start()