128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.responses import JSONResponse
|
|
from treatment import Treatment
|
|
import pandas as pd
|
|
import pickle
|
|
import subprocess
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
def predict(treatment: Treatment):
|
|
# load the model from pickle files
|
|
with open("pickles/ems_model.pkl", "rb") as f:
|
|
model = pickle.load(f)
|
|
|
|
with open("pickles/label_encoding.pkl", "rb") as f:
|
|
le = pickle.load(f)
|
|
|
|
with open("pickles/scaler_encoding.pkl", "rb") as f:
|
|
scaler = pickle.load(f)
|
|
|
|
# encode the species species
|
|
species = le.transform([treatment.species])
|
|
|
|
# normalize the input data
|
|
transformed_data = scaler.transform(
|
|
pd.DataFrame(
|
|
{
|
|
"soakDuration": [treatment.soakDuration],
|
|
"lowestTemp": [treatment.lowestTemp],
|
|
"highestTemp": [treatment.highestTemp],
|
|
}
|
|
)
|
|
)
|
|
|
|
data = pd.DataFrame(
|
|
{
|
|
"species": [species],
|
|
"emsConcentration": [treatment.emsConcentration],
|
|
"soakDuration": [transformed_data[0][0]],
|
|
"lowestTemp": [transformed_data[0][1]],
|
|
"highestTemp": [transformed_data[0][2]],
|
|
}
|
|
)
|
|
|
|
prediction_prob = model.predict_proba(data)[0]
|
|
prediction = model.predict(data)[0]
|
|
|
|
confidence_score = prediction_prob[prediction] * 100
|
|
|
|
result = {
|
|
"result": int(prediction),
|
|
"confidence_score": float(confidence_score),
|
|
"success_rate": float(prediction_prob[1] * 100),
|
|
}
|
|
|
|
return JSONResponse(content=jsonable_encoder(result))
|
|
|
|
|
|
@app.post("/process")
|
|
def process(treatment: Treatment):
|
|
return predict(treatment)
|
|
|
|
|
|
@app.get("/species")
|
|
def get_species():
|
|
df = pd.read_csv("csv/ems_data.csv")
|
|
unique_values = df.iloc[:, 0].drop_duplicates().tolist()
|
|
|
|
return JSONResponse(content=unique_values)
|
|
|
|
|
|
@app.get("/retrain-model")
|
|
def retrain_model():
|
|
try:
|
|
result = subprocess.run(
|
|
["python", "model.py"],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
output = result.stdout.strip()
|
|
return {"message": "Retraining succeeded!", "details": output}
|
|
except subprocess.CalledProcessError as e:
|
|
return {"message": "Retraining failed!", "error": str(e)}
|
|
|
|
|
|
@app.post("/upload-csv")
|
|
async def upload_csv(file: UploadFile = File(...)):
|
|
try:
|
|
# check if the uploaded file is a csv
|
|
if not file.filename.endswith(".csv"):
|
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed.")
|
|
|
|
# save the uploaded file to csv/ems_data.csv (overwrite)
|
|
file_path = "csv/ems_data.csv"
|
|
with open(file_path, "wb") as f:
|
|
content = await file.read()
|
|
f.write(content)
|
|
|
|
# verify if the file is valid (optional: try loading it as a dataframe)
|
|
try:
|
|
pd.read_csv(
|
|
file_path
|
|
)
|
|
except Exception:
|
|
os.remove(file_path)
|
|
raise HTTPException(
|
|
status_code=400, detail="Uploaded file is not a valid CSV."
|
|
)
|
|
|
|
return {"message": "File uploaded and overwritten successfully!"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/get-csv")
|
|
def get_csv():
|
|
file_path = "csv/ems_data.csv"
|
|
|
|
try:
|
|
df = pd.read_csv(file_path)
|
|
|
|
data = df.to_dict(orient="records")
|
|
return data
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error processing CSV: {str(e)}")
|