Skip to content

Commit 5f31375

Browse files
Merge pull request #217 from KhiopsML/216-coreapi-does-not-detect-unknown-parameters
Fix unknown arguments being silently ignored
2 parents fe00d30 + a4ccd56 commit 5f31375

File tree

4 files changed

+110
-57
lines changed

4 files changed

+110
-57
lines changed

khiops/core/api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ def _preprocess_task_arguments(task_args):
299299
)
300300
)
301301

302+
# Flatten kwargs
303+
if "kwargs" in task_args:
304+
task_args.update(task_args["kwargs"])
305+
del task_args["kwargs"]
306+
302307
return task_called_with_domain
303308

304309

@@ -336,10 +341,10 @@ def _preprocess_format_spec(detect_format, header_line, field_separator):
336341
def _clean_task_args(task_args):
337342
"""Cleans the task arguments
338343
339-
More precisely:
340-
- It removes command line arguments (they already are in another object).
341-
- It removes parameters removed from the API and warns about it.
342-
- It removes renamed API parameters and warns about it.
344+
More precisely it removes:
345+
- Command line arguments (they already are in another object).
346+
- Parameters removed from the API and warns about it.
347+
- Renamed API parameters and warns about it.
343348
"""
344349
# Remove non-task parameters
345350
command_line_arg_names = [
@@ -353,7 +358,6 @@ def _clean_task_args(task_args):
353358
"trace",
354359
"stdout_file_path",
355360
"stderr_file_path",
356-
"kwargs",
357361
]
358362
for arg_name in command_line_arg_names + other_arg_names:
359363
if arg_name in task_args:

khiops/sklearn/estimators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1464,9 +1464,11 @@ def _fit_prepare_training_function_inputs(self, dataset, computation_dir):
14641464
# Build the optional parameters from a copy of the estimator parameters
14651465
kwargs = self.get_params()
14661466

1467-
# Remove 'key' and 'output_dir'
1467+
# Remove non core.api params
14681468
del kwargs["key"]
14691469
del kwargs["output_dir"]
1470+
del kwargs["auto_sort"]
1471+
del kwargs["internal_sort"]
14701472

14711473
# Set the sampling percentage to a 100%
14721474
kwargs["sample_percentage"] = 100

khiops/sklearn/tables.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,27 @@ def __init__(self, X, y=None, categorical_target=True, key=None):
166166
y,
167167
categorical_target=categorical_target,
168168
)
169-
# A sparse matrix
169+
# A scipy.sparse.spmatrix
170170
elif isinstance(X, sp.spmatrix):
171171
self._init_tables_from_sparse_matrix(
172172
X, y, categorical_target=categorical_target
173173
)
174+
# Special rejection for scipy.sparse.sparray (to pass the sklearn tests)
175+
# Note: We don't use scipy.sparse.sparray because it is not implemented in scipy
176+
# 1.10 which is the latest supporting py3.8
177+
elif isinstance(
178+
X,
179+
(
180+
sp.bsr_array,
181+
sp.coo_array,
182+
sp.csc_array,
183+
sp.csr_array,
184+
sp.dia_array,
185+
sp.dok_array,
186+
sp.lil_array,
187+
),
188+
):
189+
check_array(X, accept_sparse=False)
174190
# A tuple spec
175191
elif isinstance(X, tuple):
176192
warnings.warn(

tests/test_core.py

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def test_analysis_results(self):
102102
with self.assertWarns(UserWarning):
103103
results = kh.read_analysis_results_file(ref_json_report)
104104
results.write_report_file(output_report)
105-
files_equal_or_fail(ref_report, output_report)
105+
assert_files_equal(self, ref_report, output_report)
106106
else:
107107
results = kh.read_analysis_results_file(ref_json_report)
108108
results.write_report_file(output_report)
109-
files_equal_or_fail(ref_report, output_report)
109+
assert_files_equal(self, ref_report, output_report)
110110

111111
def test_coclustering_results(self):
112112
"""Tests for the coclustering_results module"""
@@ -149,7 +149,7 @@ def test_coclustering_results(self):
149149
else:
150150
results = kh.read_coclustering_results_file(ref_json_report)
151151
results.write_report_file(output_report)
152-
files_equal_or_fail(ref_report, output_report)
152+
assert_files_equal(self, ref_report, output_report)
153153
for dimension in results.coclustering_report.dimensions:
154154
ref_hierarchy_report = os.path.join(
155155
ref_reports_dir, f"{report}_hierarchy_{dimension.name}.txt"
@@ -160,7 +160,9 @@ def test_coclustering_results(self):
160160
dimension.write_hierarchy_structure_report_file(
161161
output_hierarchy_report
162162
)
163-
files_equal_or_fail(ref_hierarchy_report, output_hierarchy_report)
163+
assert_files_equal(
164+
self, ref_hierarchy_report, output_hierarchy_report
165+
)
164166

165167
def test_binary_dictionary_domain(self):
166168
"""Test binary dictionary write"""
@@ -204,13 +206,13 @@ def test_binary_dictionary_domain(self):
204206
for domain in (domain_from_api, domain_from_json):
205207
# Dump domain object as kdic file and compare it to the reference
206208
domain.export_khiops_dictionary_file(output_kdic)
207-
files_equal_or_fail(ref_kdic, output_kdic)
209+
assert_files_equal(self, ref_kdic, output_kdic)
208210

209211
# Make a copy of the domain object, then dump it as kdic file and
210212
# compare it to the reference
211213
domain_copy = domain.copy()
212214
domain_copy.export_khiops_dictionary_file(copy_output_kdic)
213-
files_equal_or_fail(ref_kdic, copy_output_kdic)
215+
assert_files_equal(self, ref_kdic, copy_output_kdic)
214216

215217
def test_dictionary(self):
216218
"""Tests for the dictionary module"""
@@ -276,19 +278,13 @@ def test_dictionary(self):
276278
else:
277279
domain = kh.read_dictionary_file(ref_kdicj)
278280
domain.export_khiops_dictionary_file(output_kdic)
279-
files_equal_or_fail(ref_kdic, output_kdic)
281+
assert_files_equal(self, ref_kdic, output_kdic)
280282

281283
domain_copy = domain.copy()
282284
domain_copy.export_khiops_dictionary_file(copy_output_kdic)
283-
files_equal_or_fail(ref_kdic, copy_output_kdic)
284-
285-
def test_api_scenario_generation(self):
286-
"""Tests the scenarios generated by the API
285+
assert_files_equal(self, ref_kdic, copy_output_kdic)
287286

288-
These tests are not exhaustive, executed with the minimal parameters to trigger
289-
the more complex scenario generation code (lists, key-value sections) when they
290-
are present.
291-
"""
287+
def _build_mock_api_method_parameters(self):
292288
# Pseudo-mock data to test the creation of scenarios
293289
datasets = ["Adult", "SpliceJunction", "Customer"]
294290
additional_data_tables = {
@@ -571,6 +567,15 @@ def test_api_scenario_generation(self):
571567
},
572568
}
573569

570+
return method_test_args
571+
572+
def test_api_scenario_generation(self):
573+
"""Tests the scenarios generated by the API
574+
575+
These tests are not exhaustive, executed with the minimal parameters to trigger
576+
the more complex scenario generation code (lists, key-value sections) when they
577+
are present.
578+
"""
574579
# Set the root directory of these tests
575580
test_resources_dir = os.path.join(resources_dir(), "scenario_generation", "api")
576581

@@ -580,44 +585,69 @@ def test_api_scenario_generation(self):
580585
kh.set_runner(test_runner)
581586

582587
# Run test for all methods and all mock datasets parameters
588+
method_test_args = self._build_mock_api_method_parameters()
583589
for method_name, method_full_args in method_test_args.items():
584-
self._test_method_scenario_generation(
585-
test_runner,
586-
method_name,
587-
method_full_args,
588-
)
590+
# Set the runners test name
591+
test_runner.test_name = method_name
592+
593+
# Clean the directory for this method's tests
594+
cleanup_dir(test_runner.output_scenario_dir, "*/output/*._kh", verbose=True)
595+
596+
# Test for each dataset mock parameters
597+
for dataset, dataset_method_args in method_full_args.items():
598+
test_runner.subtest_name = dataset
599+
with self.subTest(dataset=dataset, method=method_name):
600+
# Execute the method
601+
method = getattr(kh, method_name)
602+
dataset_args = dataset_method_args["args"]
603+
dataset_kwargs = dataset_method_args["kwargs"]
604+
method(*dataset_args, **dataset_kwargs)
605+
606+
# Compare the reference with the output
607+
assert_files_equal(
608+
self,
609+
test_runner.ref_scenario_path,
610+
test_runner.output_scenario_path,
611+
line_comparator=scenario_line_comparator,
612+
)
589613

590614
# Restore the default runner
591615
kh.set_runner(default_runner)
592616

593-
def _test_method_scenario_generation(
594-
self,
595-
runner,
596-
method_name,
597-
method_full_args,
598-
):
599-
# Set the runners test name
600-
runner.test_name = method_name
601-
602-
# Clean the directory for this method's tests
603-
cleanup_dir(runner.output_scenario_dir, "*/output/*._kh", verbose=True)
617+
def test_unknown_argument_in_api_method(self):
618+
"""Tests if core.api raises ValueError when an unknown argument is passed"""
619+
# Obtain mock arguments for each API call
620+
method_test_args = self._build_mock_api_method_parameters()
604621

605622
# Test for each dataset mock parameters
606-
for dataset, dataset_method_args in method_full_args.items():
607-
runner.subtest_name = dataset
608-
with self.subTest(dataset=dataset, method=method_name):
609-
# Execute the method
610-
method = getattr(kh, method_name)
611-
dataset_args = dataset_method_args["args"]
612-
dataset_kwargs = dataset_method_args["kwargs"]
613-
method(*dataset_args, **dataset_kwargs)
614-
615-
# Compare the reference with the output
616-
files_equal_or_fail(
617-
runner.ref_scenario_path,
618-
runner.output_scenario_path,
619-
line_comparator=scenario_line_comparator,
620-
)
623+
for method_name, method_full_args in method_test_args.items():
624+
for dataset, dataset_method_args in method_full_args.items():
625+
# Test only for the Adult dataset
626+
if dataset != "Adult":
627+
continue
628+
629+
with self.subTest(method=method_name):
630+
# These methods do not have kwargs so they cannot have extra args
631+
if method_name in [
632+
"detect_data_table_format",
633+
"export_dictionary_as_json",
634+
]:
635+
continue
636+
637+
# Execute the method with an invalid parameter
638+
method = getattr(kh, method_name)
639+
dataset_args = dataset_method_args["args"]
640+
dataset_kwargs = dataset_method_args["kwargs"]
641+
dataset_kwargs["INVALID_PARAM"] = False
642+
643+
# Check that the call raised ValueError
644+
with self.assertRaises(ValueError) as context:
645+
method(*dataset_args, **dataset_kwargs)
646+
647+
# Check the message
648+
expected_msg = "Unknown argument 'INVALID_PARAM'"
649+
output_msg = str(context.exception)
650+
self.assertEqual(output_msg, expected_msg)
621651

622652
def test_general_options(self):
623653
"""Test that the general options are written to the scenario file"""
@@ -642,7 +672,8 @@ def test_general_options(self):
642672
kh.check_database("a.kdic", "dict_name", "data.txt")
643673

644674
# Compare the reference with the output
645-
files_equal_or_fail(
675+
assert_files_equal(
676+
self,
646677
test_runner.ref_scenario_path,
647678
test_runner.output_scenario_path,
648679
line_comparator=scenario_line_comparator,
@@ -2075,8 +2106,8 @@ def find_first_different_byte(ref_line, output_line):
20752106
return first_diff_pos, first_diff_ref_byte, first_diff_output_byte
20762107

20772108

2078-
def files_equal_or_fail(
2079-
ref_file_path, output_file_path, line_comparator=default_line_comparator
2109+
def assert_files_equal(
2110+
test_suite, ref_file_path, output_file_path, line_comparator=default_line_comparator
20802111
):
20812112
"""Portably tests if two files are equal by comparing line-by-line"""
20822113
# Read all lines from the files
@@ -2089,7 +2120,7 @@ def files_equal_or_fail(
20892120
ref_file_len = len(ref_file_lines)
20902121
output_file_len = len(output_file_lines)
20912122
if ref_file_len != output_file_len:
2092-
raise ValueError(
2123+
test_suite.fail(
20932124
"Files have different number of lines\n"
20942125
+ f"Ref file : {shorten_path(ref_file_path, 5)}\n"
20952126
+ f"Output file : {shorten_path(output_file_path, 5)}\n"

0 commit comments

Comments
 (0)