Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ def fit_pipeline(
names=[eval_metric] if eval_metric is not None else None,
all_supported_metrics=False).pop()

pipeline_options = self.pipeline_options.copy().update(pipeline_options) if pipeline_options is not None \
pipeline_options = {**self.pipeline_options, **pipeline_options} if pipeline_options is not None \
else self.pipeline_options.copy()

assert pipeline_options is not None
Expand Down
57 changes: 57 additions & 0 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,63 @@ def test_pipeline_fit(openml_id,
assert not os.path.exists(cv_model_path)


@pytest.mark.parametrize('openml_id', (40984,))
@pytest.mark.parametrize("budget", [1])
def test_pipeline_fit_pass_pipeline_options(
openml_id,
backend,
budget,
n_samples
):
# Get the data and check that contents of data-manager make sense
X, y = sklearn.datasets.fetch_openml(
data_id=int(openml_id),
return_X_y=True, as_frame=True
)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X[:n_samples], y[:n_samples], random_state=1)

# Search for a good configuration
estimator = TabularClassificationTask(
backend=backend,
ensemble_size=0
)

dataset = estimator.get_dataset(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test)

configuration = estimator.get_search_space(dataset).get_default_configuration()
pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset,
configuration=configuration,
run_time_limit_secs=50,
budget_type='epochs',
budget=budget,
pipeline_options={'early_stopping': 100}
)
assert isinstance(dataset, BaseDataset)
assert isinstance(run_info, RunInfo)
assert isinstance(run_info.config, Configuration)

assert isinstance(run_value, RunValue)
assert 'SUCCESS' in str(run_value.status)

# Make sure that the pipeline can be pickled
dump_file = os.path.join(tempfile.gettempdir(), 'automl.dump.pkl')
with open(dump_file, 'wb') as f:
pickle.dump(pipeline, f)

num_run_dir = estimator._backend.get_numrun_directory(
run_info.seed, run_value.additional_info['num_run'], budget=float(budget))
model_path = os.path.join(num_run_dir, estimator._backend.get_model_filename(
run_info.seed, run_value.additional_info['num_run'], budget=float(budget)))

# We expect the model path always
# And the cv model only on 'cv'
assert os.path.exists(model_path)


@pytest.mark.parametrize('openml_id', (40984,))
@pytest.mark.parametrize('resampling_strategy,resampling_strategy_args',
((HoldoutValTypes.holdout_validation, {'val_share': 0.8}),
Expand Down