Skip to content

Commit 9a3170c

Browse files
committed
Merge remote-tracking branch 'origin/main' into tswast-remote-function-local-testing
2 parents 0475c39 + ac1a188 commit 9a3170c

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

tests/system/large/ml/test_ensemble.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def test_xgbregressor_default_params(penguins_df_default_index, dataset_id):
4646
reloaded_model = model.to_gbq(
4747
f"{dataset_id}.temp_configured_xgbregressor_model", replace=True
4848
)
49-
# TODO(b/340888645): fix type error
49+
assert reloaded_model._bqml_model is not None
5050
assert (
5151
f"{dataset_id}.temp_configured_xgbregressor_model"
52-
in reloaded_model._bqml_model.model_name # type: ignore
52+
in reloaded_model._bqml_model.model_name
5353
)
5454

5555

@@ -98,10 +98,10 @@ def test_xgbregressor_dart_booster_multiple_params(
9898
reloaded_model = model.to_gbq(
9999
f"{dataset_id}.temp_configured_xgbregressor_model", replace=True
100100
)
101-
# TODO(b/340888645): fix type error
101+
assert reloaded_model._bqml_model is not None
102102
assert (
103103
f"{dataset_id}.temp_configured_xgbregressor_model"
104-
in reloaded_model._bqml_model.model_name # type: ignore
104+
in reloaded_model._bqml_model.model_name
105105
)
106106
assert reloaded_model.booster == "DART"
107107
assert reloaded_model.dart_normalized_type == "TREE"
@@ -148,10 +148,10 @@ def test_xgbclassifier_default_params(penguins_df_default_index, dataset_id):
148148
reloaded_model = model.to_gbq(
149149
f"{dataset_id}.temp_configured_xgbclassifier_model", replace=True
150150
)
151-
# TODO(b/340888645): fix type error
151+
assert reloaded_model._bqml_model is not None
152152
assert (
153153
f"{dataset_id}.temp_configured_xgbclassifier_model"
154-
in reloaded_model._bqml_model.model_name # type: ignore
154+
in reloaded_model._bqml_model.model_name
155155
)
156156

157157

@@ -199,10 +199,10 @@ def test_xgbclassifier_dart_booster_multiple_params(
199199
reloaded_model = model.to_gbq(
200200
f"{dataset_id}.temp_configured_xgbclassifier_model", replace=True
201201
)
202-
# TODO(b/340888645): fix type error
202+
assert reloaded_model._bqml_model is not None
203203
assert (
204204
f"{dataset_id}.temp_configured_xgbclassifier_model"
205-
in reloaded_model._bqml_model.model_name # type: ignore
205+
in reloaded_model._bqml_model.model_name
206206
)
207207
assert reloaded_model.booster == "DART"
208208
assert reloaded_model.dart_normalized_type == "TREE"
@@ -250,10 +250,10 @@ def test_randomforestregressor_default_params(penguins_df_default_index, dataset
250250
reloaded_model = model.to_gbq(
251251
f"{dataset_id}.temp_configured_randomforestregressor_model", replace=True
252252
)
253-
# TODO(b/340888645): fix type error
253+
assert reloaded_model._bqml_model is not None
254254
assert (
255255
f"{dataset_id}.temp_configured_randomforestregressor_model"
256-
in reloaded_model._bqml_model.model_name # type: ignore
256+
in reloaded_model._bqml_model.model_name
257257
)
258258

259259

@@ -297,10 +297,10 @@ def test_randomforestregressor_multiple_params(penguins_df_default_index, datase
297297
reloaded_model = model.to_gbq(
298298
f"{dataset_id}.temp_configured_randomforestregressor_model", replace=True
299299
)
300-
# TODO(b/340888645): fix type error
300+
assert reloaded_model._bqml_model is not None
301301
assert (
302302
f"{dataset_id}.temp_configured_randomforestregressor_model"
303-
in reloaded_model._bqml_model.model_name # type: ignore
303+
in reloaded_model._bqml_model.model_name
304304
)
305305
assert reloaded_model.tree_method == "AUTO"
306306
assert reloaded_model.colsample_bytree == 0.95
@@ -344,18 +344,17 @@ def test_randomforestclassifier_default_params(penguins_df_default_index, datase
344344
reloaded_model = model.to_gbq(
345345
f"{dataset_id}.temp_configured_randomforestclassifier_model", replace=True
346346
)
347-
# TODO(b/340888645): fix type error
347+
assert reloaded_model._bqml_model is not None
348348
assert (
349349
f"{dataset_id}.temp_configured_randomforestclassifier_model"
350-
in reloaded_model._bqml_model.model_name # type: ignore
350+
in reloaded_model._bqml_model.model_name
351351
)
352352

353353

354354
@pytest.mark.flaky(retries=2)
355355
def test_randomforestclassifier_multiple_params(penguins_df_default_index, dataset_id):
356-
# TODO(b/340888645): fix type error
357356
model = bigframes.ml.ensemble.RandomForestClassifier(
358-
tree_method="AUTO", # type: ignore
357+
tree_method="auto",
359358
min_tree_child_weight=2,
360359
colsample_bytree=0.95,
361360
colsample_bylevel=0.95,
@@ -391,12 +390,12 @@ def test_randomforestclassifier_multiple_params(penguins_df_default_index, datas
391390
reloaded_model = model.to_gbq(
392391
f"{dataset_id}.temp_configured_randomforestclassifier_model", replace=True
393392
)
394-
# TODO(b/340888645): fix type error
393+
assert reloaded_model._bqml_model is not None
395394
assert (
396395
f"{dataset_id}.temp_configured_randomforestclassifier_model"
397-
in reloaded_model._bqml_model.model_name # type: ignore
396+
in reloaded_model._bqml_model.model_name
398397
)
399-
assert reloaded_model.tree_method == "AUTO"
398+
assert reloaded_model.tree_method == "auto"
400399
assert reloaded_model.colsample_bytree == 0.95
401400
assert reloaded_model.colsample_bylevel == 0.95
402401
assert reloaded_model.colsample_bynode == 0.95

0 commit comments

Comments
 (0)