# # # from fastapi import FastAPI
# # # from pydantic import BaseModel
# # # import pickle
# # # import pandas as pd
# # # from preprocessing import preprocess_input

# # # app = FastAPI()

# # # # Load model
# # # with open("app/models/model.pkl", "rb") as f:
# # #     model = pickle.load(f)

# # # # Load encoders + scaler
# # # with open("app/models/encoders.pkl", "rb") as f:
# # #     encoders, scaler, mean_opv_doses, final_feature_order, df_temp_prep = pickle.load(f)

# # # # Mapping
# # # class_map = {0: "Positive", 1: "Negative", 2: "NPENT"}

# # # # Input schema
# # # class PredictRequest(BaseModel):
# # #     Province: str
# # #     District: str
# # #     Case_Contact: str
# # #     OPV_Doses: str
# # #     Specimen_Number: str
# # #     Specimen_condition: str
# # #     Body_Temp: str
# # #     Sore_Throat: str
# # #     Fatigue: str
# # #     Limb_Discomfort: str


# # # @app.post("/predict")
# # # def predict(req: PredictRequest):

# # #     raw = {
# # #         "Province": req.Province,
# # #         "District": req.District,
# # #         "Case/Contact": req.Case_Contact,
# # #         "OPV Doses": req.OPV_Doses,
# # #         "Specimen Number": req.Specimen_Number,
# # #         "Specimen condition": req.Specimen_condition,
# # #         "Body Temp": req.Body_Temp,
# # #         "Sore Throat": req.Sore_Throat,
# # #         "Fatigue": req.Fatigue,
# # #         "Limb Discomfort": req.Limb_Discomfort,
# # #     }

# # #     X = preprocess_input(raw, encoders, scaler, mean_opv_doses, final_feature_order, df_temp_prep)
# # #     pred_enc = model.predict(X)
# # #     pred_class = class_map[int(pred_enc.item())]

# # #     return {"prediction": pred_class}
# # # app/main.py
# # from fastapi import FastAPI, HTTPException
# # from fastapi.middleware.cors import CORSMiddleware
# # from pydantic import BaseModel
# # import pickle
# # import os

# # from app.preprocessing import preprocess_input
# # # , encoder_province, encoder_district

# # APP_ROOT = os.path.dirname(os.path.abspath(__file__))
# # MODELS_DIR = os.path.join(APP_ROOT, "models")

# # # Paths (adjust filenames if you used different names)
# # LABEL_ENCODERS_PATH = os.path.join(MODELS_DIR, "label_encoders.pkl")
# # SCALER_PATH = os.path.join(MODELS_DIR, "scaler.pkl")
# # MODEL_PATH = os.path.join(MODELS_DIR, "catboost_model.pkl")

# # # Load artifacts
# # with open(LABEL_ENCODERS_PATH, "rb") as f:
# #     label_encoders = pickle.load(f)

# # with open(SCALER_PATH, "rb") as f:
# #     scaler = pickle.load(f)

# # with open(MODEL_PATH, "rb") as f:
# #     model = pickle.load(f)

# # # Derive readable class names from your target encoder
# # le_target = label_encoders.get("Final_cell_culture_result")
# # class_map = {i: name for i, name in enumerate(le_target.classes_)}

# # # Choose a sensible mean OPV doses (if you have real mean, replace here)
# # mean_opv_doses = 2.0

# # # final_feature_order must match the training feature order:
# # # Replace or adjust if you used different order during training
# # final_feature_order = ['Province', 'District', 'Case/Contact', 'OPV Doses',
# #                        'Specimen Number', 'Specimen condition', 'Body Temp',
# #                        'Sore Throat', 'Fatigue', 'Limb Discomfort']

# # app = FastAPI(title="CatBoost Prediction API")

# # # Allow requests from localhost and your Flutter dev server (adjust origins as needed)
# # app.add_middleware(
# #     CORSMiddleware,
# #     allow_origins=["*"],   # in production lock this down to your domains
# #     allow_credentials=True,
# #     allow_methods=["*"],
# #     allow_headers=["*"],
# # )

# # # Pydantic request model (snake/camel mapping kept simple)
# # class PredictRequest(BaseModel):
# #     age_in_months: float
# #     province: str
# #     district: str
# #     case_or_contact: str
# #     specimen_cond: str
# #     date_of_onset: str
# #     date_of_collection: str
# #     date_received: str
# #     opv_doses_received: float
# #     sore_throat: str
# #     fatigue: str
# #     limb_discomfort: str

# # @app.post("/predict")
# # def predict(req: PredictRequest):
# #     # Build raw_input keys matching preprocess_input expected keys
# #     raw = {
# #         "Province": req.Province,
# #         "District": req.District,
# #         "Case/Contact": req.Case_Contact or "Case",
# #         "OPV Doses": req.OPV_Doses,
# #         "Specimen Number": req.Specimen_Number,
# #         "Specimen condition": req.Specimen_condition or "Good",
# #         "Body Temp": req.Body_Temp,
# #         "Sore Throat": req.Sore_Throat,
# #         "Fatigue": req.Fatigue,
# #         "Limb Discomfort": req.Limb_Discomfort,
# #     }

# #     try:
# #         X = preprocess_input(raw, label_encoders, scaler, mean_opv_doses, final_feature_order)
# #     except Exception as e:
# #         raise HTTPException(status_code=400, detail=f"Preprocessing failed: {e}")

# #     try:
# #         # model.predict should accept a DataFrame; if not adjust to numpy
# #         pred_enc = model.predict(X)
# #         pred_index = int(pred_enc.item()) if hasattr(pred_enc, "item") else int(pred_enc)
# #         pred_label = class_map.get(pred_index, str(pred_index))
# #     except Exception as e:
# #         raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")

# #     return {"prediction": pred_label}


# # main.py

# from fastapi import FastAPI
# from pydantic import BaseModel
# import pickle
# from app.preprocessing import preprocess_input

# app = FastAPI(title="Polio Prediction API")

# # Load trained ML model
# with open("app/models/model.pkl", "rb") as f:
#     model = pickle.load(f)


# # ---------- CLEAN API INPUT MODEL ----------
# class PredictionInput(BaseModel):
#     province: str
#     district: str
#     case_contact: str
#     opv_doses: str
#     specimen_number: str
#     specimen_condition: str
#     body_temp: str
#     sore_throat: str
#     fatigue: str
#     limb_discomfort: str


# # ---------- API ENDPOINT ----------
# @app.post("/predict")
# def predict(data: PredictionInput):
    
#     # Convert API field names → original model field names
#     raw_input = {
#         "Province": data.province,
#         "District": data.district,
#         "Case/Contact": data.case_contact,
#         "OPV Doses": data.opv_doses,
#         "Specimen Number": data.specimen_number,
#         "Specimen condition": data.specimen_condition,
#         "Body Temp": data.body_temp,
#         "Sore Throat": data.sore_throat,
#         "Fatigue": data.fatigue,
#         "Limb Discomfort": data.limb_discomfort
#     }

#     # Preprocess
#     processed_df = preprocess_input(raw_input)

#     # Predict
#     prediction = model.predict(processed_df)[0]
#     probability = model.predict_proba(processed_df)[0].max()

#     return {
#         "prediction": int(prediction),
#         "probability": float(probability)
#     }


# @app.get("/")
# def home():
#     return {"message": "Polio Prediction API running."}


from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.schemas import SampleInput
from app.load_artifacts import (
    loaded_catboost_model,
    loaded_label_encoders,
    loaded_scaler,
    class_names_readable
)
from app.preprocessing import preprocess_single_input

import pandas as pd

app = FastAPI(
    title="Poliovirus Detection API",
    description="Serve CatBoost model predictions via API",
    version="1.0.0",
)
# ✅ Add CORS middleware here
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],    # Allow all origins (Flutter mobile + web)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/predict")
def predict(input_data: SampleInput):

    if loaded_catboost_model is None or loaded_label_encoders is None or loaded_scaler is None:
        return {"error": "Model or preprocessing components not loaded"}

    raw_input_dict = input_data.dict()

    processed_df = preprocess_single_input(
        raw_input_dict=raw_input_dict
    )

    prediction_encoded = loaded_catboost_model.predict(processed_df)
    prediction_class = class_names_readable[int(prediction_encoded.item())]

    return {
        "prediction": prediction_class,
        "processed_input": processed_df.to_dict(orient="records")[0]
    }
