Skip to content

Commit 9588325

Browse files
committed
add tests from ask
1 parent 7870aff commit 9588325

File tree

1 file changed

+113
-1
lines changed

1 file changed

+113
-1
lines changed

test/test_data/test_utils.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
1-
from typing import Mapping
1+
import warnings
2+
from typing import List, Mapping
23

34
import numpy as np
45

6+
import pandas as pd
7+
58
import pytest
69

710
from sklearn.datasets import fetch_openml
811

12+
from scipy.sparse import csr_matrix, spmatrix
13+
14+
from autoPyTorch.constants import (
15+
BINARY,
16+
CLASSIFICATION_TASKS,
17+
CONTINUOUS,
18+
MULTICLASS,
19+
MULTICLASSMULTIOUTPUT,
20+
CONTINUOUSMULTIOUTPUT,
21+
TABULAR_REGRESSION,
22+
TABULAR_CLASSIFICATION,
23+
)
924
from autoPyTorch.data.utils import (
1025
default_dataset_compression_arg,
1126
get_dataset_compression_mapping,
1227
megabytes,
1328
reduce_dataset_size_if_too_large,
1429
reduce_precision,
30+
subsample,
1531
validate_dataset_compression_arg
1632
)
1733
from autoPyTorch.utils.common import subsampler
@@ -37,6 +53,102 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples):
3753
assert megabytes(X_converted) < megabytes(X)
3854

3955

56+
@pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)])
57+
@pytest.mark.parametrize("x_type", [list, np.ndarray, csr_matrix, pd.DataFrame])
58+
@pytest.mark.parametrize(
59+
"y, task, output",
60+
[
61+
(np.asarray([0] * 15 + [1] * 15), TABULAR_CLASSIFICATION, BINARY),
62+
(np.asarray([0] * 10 + [1] * 10 + [2] * 10), TABULAR_CLASSIFICATION, MULTICLASS),
63+
(np.asarray([[1, 0, 1]] * 30), TABULAR_CLASSIFICATION, MULTICLASSMULTIOUTPUT),
64+
(np.asarray([1.0] * 30), TABULAR_REGRESSION, CONTINUOUS),
65+
(np.asarray([[1.0, 1.0, 1.0]] * 30), TABULAR_REGRESSION, CONTINUOUSMULTIOUTPUT),
66+
],
67+
)
68+
@pytest.mark.parametrize("y_type", [list, np.ndarray, pd.DataFrame, pd.Series])
69+
@pytest.mark.parametrize("random_state", [0])
70+
@pytest.mark.parametrize("sample_size", [0.25, 0.5, 5, 10])
71+
def test_subsample_validity(X, x_type, y, y_type, random_state, sample_size, task, output):
72+
"""Asserts the validity of the function with all valid types
73+
We want to make sure that `subsample` works correctly with all the types listed
74+
as x_type and y_type.
75+
We also want to make sure it works with all kinds of target types.
76+
The output should maintain the types, and subsample the correct amount.
77+
(test adapted from autosklearn)
78+
"""
79+
assert len(X) == len(y) # Make sure our test data is correct
80+
81+
if y_type == pd.Series and output in [
82+
MULTICLASSMULTIOUTPUT,
83+
CONTINUOUSMULTIOUTPUT,
84+
]:
85+
# We can't have a pd.Series with multiple values as it's 1 dimensional
86+
pytest.skip("Can't have pd.Series as y when task is n-dimensional")
87+
88+
# Convert our data to its given x_type or y_type
89+
def convert(arr, objtype):
90+
if objtype == np.ndarray:
91+
return arr
92+
elif objtype == list:
93+
return arr.tolist()
94+
else:
95+
return objtype(arr)
96+
97+
X = convert(X, x_type)
98+
y = convert(y, y_type)
99+
100+
# Subsample the data, ignoring any warnings
101+
with warnings.catch_warnings():
102+
warnings.simplefilter("ignore")
103+
X_sampled, y_sampled = subsample(
104+
X,
105+
y=y,
106+
random_state=random_state,
107+
sample_size=sample_size,
108+
is_classification=task in CLASSIFICATION_TASKS,
109+
)
110+
111+
# Function to get the type of an obj
112+
def dtype(obj):
113+
if isinstance(obj, List):
114+
if isinstance(obj[0], List):
115+
return type(obj[0][0])
116+
else:
117+
return type(obj[0])
118+
119+
elif isinstance(obj, pd.DataFrame):
120+
return obj.dtypes
121+
122+
else:
123+
return obj.dtype
124+
125+
# Check that the types of X remain the same after subsampling
126+
if isinstance(X, pd.DataFrame):
127+
# Dataframe can have multiple types, one per column
128+
assert list(dtype(X_sampled)) == list(dtype(X))
129+
else:
130+
assert dtype(X_sampled) == dtype(X)
131+
132+
# Check that the types of y remain the same after subsampling
133+
if isinstance(y, pd.DataFrame):
134+
assert list(dtype(y_sampled)) == list(dtype(y))
135+
else:
136+
assert dtype(y_sampled) == dtype(y)
137+
138+
# Function to get the size of an object
139+
def size(obj):
140+
if isinstance(obj, spmatrix): # spmatrix doesn't support __len__
141+
return obj.shape[0] if obj.shape[0] > 1 else obj.shape[1]
142+
else:
143+
return len(obj)
144+
145+
# check the right amount of samples were taken
146+
if sample_size < 1:
147+
assert size(X_sampled) == int(sample_size * size(X))
148+
else:
149+
assert size(X_sampled) == sample_size
150+
151+
40152
def test_validate_dataset_compression_arg():
41153

42154
data_compression_args = validate_dataset_compression_arg({}, 10)

0 commit comments

Comments
 (0)