from fastapi import FastAPI, HTTPException import numpy as np import torch from pydantic import BaseModel import base64 import io import os import logging from pathlib import Path from inference import InferenceRecipe from fastapi.middleware.cors import CORSMiddleware logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model instance model = None def initialize_model(): """Initialize the model from mounted directory""" global model, INITIALIZATION_STATUS try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing model on device: {device}") model_path = os.getenv("MODEL_PATH", "/app/models") if not os.path.exists(model_path): raise RuntimeError(f"Model path {model_path} does not exist") model = InferenceRecipe(model_path, device=device) INITIALIZATION_STATUS["model_loaded"] = True logger.info("Model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "gpu_available": torch.cuda.is_available(), "initialization_status": INITIALIZATION_STATUS } if model is not None: status.update({ "device": str(model.device), "model_path": str(model.model_path), "mimi_loaded": model.mimi is not None, "tokenizer_loaded": model.text_tokenizer is not None, "lm_loaded": model.lm_gen is not None }) return status @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference on audio input""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: # Decode audio from base64 audio_bytes = base64.b64decode(request.audio_data) audio_array = np.load(io.BytesIO(audio_bytes)) # Run inference result = model.inference(audio_array, request.sample_rate) # Encode output audio buffer = io.BytesIO() np.save(buffer, result['audio']) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=result.get("text", "") ) except Exception as e: logger.error(f"Inference failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)