1- from typing import Mapping
1+ import warnings
2+ from typing import List , Mapping
23
34import numpy as np
45
6+ import pandas as pd
7+
58import pytest
69
710from sklearn .datasets import fetch_openml
811
12+ from scipy .sparse import csr_matrix , spmatrix
13+
14+ from autoPyTorch .constants import (
15+ BINARY ,
16+ CLASSIFICATION_TASKS ,
17+ CONTINUOUS ,
18+ MULTICLASS ,
19+ MULTICLASSMULTIOUTPUT ,
20+ CONTINUOUSMULTIOUTPUT ,
21+ TABULAR_REGRESSION ,
22+ TABULAR_CLASSIFICATION ,
23+ )
924from autoPyTorch .data .utils import (
1025 default_dataset_compression_arg ,
1126 get_dataset_compression_mapping ,
1227 megabytes ,
1328 reduce_dataset_size_if_too_large ,
1429 reduce_precision ,
30+ subsample ,
1531 validate_dataset_compression_arg
1632)
1733from autoPyTorch .utils .common import subsampler
@@ -37,6 +53,102 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples):
3753 assert megabytes (X_converted ) < megabytes (X )
3854
3955
56+ @pytest .mark .parametrize ("X" , [np .asarray ([[1 , 1 , 1 ]] * 30 )])
57+ @pytest .mark .parametrize ("x_type" , [list , np .ndarray , csr_matrix , pd .DataFrame ])
58+ @pytest .mark .parametrize (
59+ "y, task, output" ,
60+ [
61+ (np .asarray ([0 ] * 15 + [1 ] * 15 ), TABULAR_CLASSIFICATION , BINARY ),
62+ (np .asarray ([0 ] * 10 + [1 ] * 10 + [2 ] * 10 ), TABULAR_CLASSIFICATION , MULTICLASS ),
63+ (np .asarray ([[1 , 0 , 1 ]] * 30 ), TABULAR_CLASSIFICATION , MULTICLASSMULTIOUTPUT ),
64+ (np .asarray ([1.0 ] * 30 ), TABULAR_REGRESSION , CONTINUOUS ),
65+ (np .asarray ([[1.0 , 1.0 , 1.0 ]] * 30 ), TABULAR_REGRESSION , CONTINUOUSMULTIOUTPUT ),
66+ ],
67+ )
68+ @pytest .mark .parametrize ("y_type" , [list , np .ndarray , pd .DataFrame , pd .Series ])
69+ @pytest .mark .parametrize ("random_state" , [0 ])
70+ @pytest .mark .parametrize ("sample_size" , [0.25 , 0.5 , 5 , 10 ])
71+ def test_subsample_validity (X , x_type , y , y_type , random_state , sample_size , task , output ):
72+ """Asserts the validity of the function with all valid types
73+ We want to make sure that `subsample` works correctly with all the types listed
74+ as x_type and y_type.
75+ We also want to make sure it works with all kinds of target types.
76+ The output should maintain the types, and subsample the correct amount.
77+ (test adapted from autosklearn)
78+ """
79+ assert len (X ) == len (y ) # Make sure our test data is correct
80+
81+ if y_type == pd .Series and output in [
82+ MULTICLASSMULTIOUTPUT ,
83+ CONTINUOUSMULTIOUTPUT ,
84+ ]:
85+ # We can't have a pd.Series with multiple values as it's 1 dimensional
86+ pytest .skip ("Can't have pd.Series as y when task is n-dimensional" )
87+
88+ # Convert our data to its given x_type or y_type
89+ def convert (arr , objtype ):
90+ if objtype == np .ndarray :
91+ return arr
92+ elif objtype == list :
93+ return arr .tolist ()
94+ else :
95+ return objtype (arr )
96+
97+ X = convert (X , x_type )
98+ y = convert (y , y_type )
99+
100+ # Subsample the data, ignoring any warnings
101+ with warnings .catch_warnings ():
102+ warnings .simplefilter ("ignore" )
103+ X_sampled , y_sampled = subsample (
104+ X ,
105+ y = y ,
106+ random_state = random_state ,
107+ sample_size = sample_size ,
108+ is_classification = task in CLASSIFICATION_TASKS ,
109+ )
110+
111+ # Function to get the type of an obj
112+ def dtype (obj ):
113+ if isinstance (obj , List ):
114+ if isinstance (obj [0 ], List ):
115+ return type (obj [0 ][0 ])
116+ else :
117+ return type (obj [0 ])
118+
119+ elif isinstance (obj , pd .DataFrame ):
120+ return obj .dtypes
121+
122+ else :
123+ return obj .dtype
124+
125+ # Check that the types of X remain the same after subsampling
126+ if isinstance (X , pd .DataFrame ):
127+ # Dataframe can have multiple types, one per column
128+ assert list (dtype (X_sampled )) == list (dtype (X ))
129+ else :
130+ assert dtype (X_sampled ) == dtype (X )
131+
132+ # Check that the types of y remain the same after subsampling
133+ if isinstance (y , pd .DataFrame ):
134+ assert list (dtype (y_sampled )) == list (dtype (y ))
135+ else :
136+ assert dtype (y_sampled ) == dtype (y )
137+
138+ # Function to get the size of an object
139+ def size (obj ):
140+ if isinstance (obj , spmatrix ): # spmatrix doesn't support __len__
141+ return obj .shape [0 ] if obj .shape [0 ] > 1 else obj .shape [1 ]
142+ else :
143+ return len (obj )
144+
145+ # check the right amount of samples were taken
146+ if sample_size < 1 :
147+ assert size (X_sampled ) == int (sample_size * size (X ))
148+ else :
149+ assert size (X_sampled ) == sample_size
150+
151+
40152def test_validate_dataset_compression_arg ():
41153
42154 data_compression_args = validate_dataset_compression_arg ({}, 10 )
0 commit comments