diff --git a/routes/protected_prediction.py b/routes/protected_prediction.py index 179d16d..43e5552 100644 --- a/routes/protected_prediction.py +++ b/routes/protected_prediction.py @@ -18,7 +18,7 @@ def predict_auto(request: AutoPredictionRequest): series = df['amount'] - result = auto_arima_forecast(series, forecast_periods=1) + result = auto_arima_forecast(series, forecast_periods=request.future_steps) return AutoPredictionResponse( rmse=result["rmse"], @@ -47,14 +47,14 @@ def predict_manual(request: ManualPredictionRequest): p, d, q = request.arima_model - result = manual_arima_forecast(series, p=p, d=d, q=q, forecast_periods=1) + result = manual_arima_forecast(series, p=p, d=d, q=q, forecast_periods=request.future_steps) return ManualPredictionResponse( arima_order=tuple(result["arima_order"]), prediction=result["prediction"], lower=result["lower"], upper=result["upper"], - success=True + success=True, ) except ValueError as ve: diff --git a/schema/prediction.py b/schema/prediction.py index 496c917..7ecccd9 100644 --- a/schema/prediction.py +++ b/schema/prediction.py @@ -8,6 +8,7 @@ class BasePredictionRequest(BaseModel): value_column: str='sold_qty' date_column: str='date' date_regroup: bool=False + future_steps: int=1 class AutoPredictionRequest(BasePredictionRequest): diff --git a/utils/statistic/auto_arima.py b/utils/statistic/auto_arima.py index 94efb93..a06a44b 100644 --- a/utils/statistic/auto_arima.py +++ b/utils/statistic/auto_arima.py @@ -6,7 +6,7 @@ import pandas as pd import numpy as np warnings.filterwarnings("ignore", category=FutureWarning) -def auto_arima_forecast(series: pd.Series, train_ratio=0.8, forecast_periods: int = 1) -> dict: +def auto_arima_forecast(series: pd.Series, train_ratio=0.7, forecast_periods: int = 1) -> dict: if series is None or series.empty: raise ValueError("Data tidak valid atau kosong.")