Skip to content

Commit b16740e

Browse files
rey-esptswast
andauthored
feat: support forecast_limit_lower_bound and forecast_limit_upper_bound in ARIMA_PLUS (and ARIMA_PLUS_XREG) models (#1305)
* feat: support forecast_limit_lower_bound and forecast_limit_upper_bound in ARIMA_PLUS (and ARIMA_PLUS_XREG) models * update doc string * Update test_forecasting.py - remove upper bound * add TODO * Apply suggestions from code review --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 9b777a0 commit b16740e

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,5 @@ coverage.xml
6060
system_tests/local_test_setup
6161

6262
# Make sure a generated file isn't accidentally committed.
63-
demo.ipynb
6463
pylintrc
6564
pylintrc.test

bigframes/ml/forecasting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
"holiday_region": "holidayRegion",
3737
"clean_spikes_and_dips": "cleanSpikesAndDips",
3838
"adjust_step_changes": "adjustStepChanges",
39+
"forecast_limit_upper_bound": "forecastLimitUpperBound",
40+
"forecast_limit_lower_bound": "forecastLimitLowerBound",
3941
"time_series_length_fraction": "timeSeriesLengthFraction",
4042
"min_time_series_length": "minTimeSeriesLength",
4143
"max_time_series_length": "maxTimeSeriesLength",
@@ -78,6 +80,17 @@ class ARIMAPlus(base.SupervisedTrainableWithIdColPredictor):
7880
adjust_step_changes (bool, default True):
7981
Determines whether or not to perform automatic step change detection and adjustment in the model training pipeline.
8082
83+
forecast_limit_upper_bound (float or None, default None):
84+
The upper bound of the forecasting values. When you specify the ``forecast_limit_upper_bound`` option, all of the forecast values must be less than the specified value.
85+
For example, if you set ``forecast_limit_upper_bound`` to 100, then all of the forecast values are less than 100.
86+
Also, all values greater than or equal to the ``forecast_limit_upper_bound`` value are excluded from modelling.
87+
The forecasting limit ensures that forecasts stay within limits.
88+
89+
forecast_limit_lower_bound (float or None, default None):
90+
The lower bound of the forecasting values where the minimum value allowed is 0. When you specify the ``forecast_limit_lower_bound`` option, all of the forecast values must be greater than the specified value.
91+
For example, if you set ``forecast_limit_lower_bound`` to 0, then all of the forecast values are larger than 0. Also, all values less than or equal to the ``forecast_limit_lower_bound`` value are excluded from modelling.
92+
The forecasting limit ensures that forecasts stay within limits.
93+
8194
time_series_length_fraction (float or None, default None):
8295
The fraction of the interpolated length of the time series that's used to model the time series trend component. All of the time points of the time series are used to model the non-trend component.
8396
@@ -106,6 +119,8 @@ def __init__(
106119
holiday_region: Optional[str] = None,
107120
clean_spikes_and_dips: bool = True,
108121
adjust_step_changes: bool = True,
122+
forecast_limit_lower_bound: Optional[float] = None,
123+
forecast_limit_upper_bound: Optional[float] = None,
109124
time_series_length_fraction: Optional[float] = None,
110125
min_time_series_length: Optional[int] = None,
111126
max_time_series_length: Optional[int] = None,
@@ -121,6 +136,8 @@ def __init__(
121136
self.holiday_region = holiday_region
122137
self.clean_spikes_and_dips = clean_spikes_and_dips
123138
self.adjust_step_changes = adjust_step_changes
139+
self.forecast_limit_upper_bound = forecast_limit_upper_bound
140+
self.forecast_limit_lower_bound = forecast_limit_lower_bound
124141
self.time_series_length_fraction = time_series_length_fraction
125142
self.min_time_series_length = min_time_series_length
126143
self.max_time_series_length = max_time_series_length
@@ -175,6 +192,10 @@ def _bqml_options(self) -> dict:
175192

176193
if self.include_drift:
177194
options["include_drift"] = True
195+
if self.forecast_limit_upper_bound is not None:
196+
options["forecast_limit_upper_bound"] = self.forecast_limit_upper_bound
197+
if self.forecast_limit_lower_bound is not None:
198+
options["forecast_limit_lower_bound"] = self.forecast_limit_lower_bound
178199

179200
return options
180201

tests/system/large/ml/test_forecasting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_arima_plus_model_fit_params(
154154
holiday_region="US",
155155
clean_spikes_and_dips=False,
156156
adjust_step_changes=False,
157+
forecast_limit_lower_bound=0.0,
157158
time_series_length_fraction=0.5,
158159
min_time_series_length=10,
159160
trend_smoothing_window_size=5,
@@ -183,6 +184,8 @@ def test_arima_plus_model_fit_params(
183184
assert reloaded_model.holiday_region == "US"
184185
assert reloaded_model.clean_spikes_and_dips is False
185186
assert reloaded_model.adjust_step_changes is False
187+
# TODO(b/391399223): API must return forecastLimitLowerBound for the following assertion
188+
# assert reloaded_model.forecast_limit_lower_bound == 0.0
186189
assert reloaded_model.time_series_length_fraction == 0.5
187190
assert reloaded_model.min_time_series_length == 10
188191
assert reloaded_model.trend_smoothing_window_size == 5

0 commit comments

Comments
 (0)