From b929e64b222cd57519513043b47b91ec01ca3a38 Mon Sep 17 00:00:00 2001 From: fhm Date: Wed, 9 Jul 2025 05:11:19 +0700 Subject: [PATCH] fix --- routes/protected_prediction.py | 6 +++--- schema/prediction.py | 1 + utils/statistic/auto_arima.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/routes/protected_prediction.py b/routes/protected_prediction.py index 179d16d..43e5552 100644 --- a/routes/protected_prediction.py +++ b/routes/protected_prediction.py @@ -18,7 +18,7 @@ def predict_auto(request: AutoPredictionRequest): series = df['amount'] - result = auto_arima_forecast(series, forecast_periods=1) + result = auto_arima_forecast(series, forecast_periods=request.future_steps) return AutoPredictionResponse( rmse=result["rmse"], @@ -47,14 +47,14 @@ def predict_manual(request: ManualPredictionRequest): p, d, q = request.arima_model - result = manual_arima_forecast(series, p=p, d=d, q=q, forecast_periods=1) + result = manual_arima_forecast(series, p=p, d=d, q=q, forecast_periods=request.future_steps) return ManualPredictionResponse( arima_order=tuple(result["arima_order"]), prediction=result["prediction"], lower=result["lower"], upper=result["upper"], - success=True + success=True, ) except ValueError as ve: diff --git a/schema/prediction.py b/schema/prediction.py index 496c917..7ecccd9 100644 --- a/schema/prediction.py +++ b/schema/prediction.py @@ -8,6 +8,7 @@ class BasePredictionRequest(BaseModel): value_column: str='sold_qty' date_column: str='date' date_regroup: bool=False + future_steps: int=1 class AutoPredictionRequest(BasePredictionRequest): diff --git a/utils/statistic/auto_arima.py b/utils/statistic/auto_arima.py index 94efb93..a06a44b 100644 --- a/utils/statistic/auto_arima.py +++ b/utils/statistic/auto_arima.py @@ -6,7 +6,7 @@ import pandas as pd import numpy as np warnings.filterwarnings("ignore", category=FutureWarning) -def auto_arima_forecast(series: pd.Series, train_ratio=0.8, forecast_periods: int = 1) -> dict: +def auto_arima_forecast(series: pd.Series, train_ratio=0.7, forecast_periods: int = 1) -> dict: if series is None or series.empty: raise ValueError("Data tidak valid atau kosong.")