Skip to content

Commit c2f78f2

Browse files
committed
fix issue with missing classes
1 parent b2658e5 commit c2f78f2

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

autoPyTorch/api/base_task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import unittest.mock
1313
import warnings
1414
from abc import ABC, abstractmethod
15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1616

1717
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
1818

@@ -299,6 +299,7 @@ def _get_dataset_input_validator(
299299
resampling_strategy: Optional[ResamplingStrategies] = None,
300300
resampling_strategy_args: Optional[Dict[str, Any]] = None,
301301
dataset_name: Optional[str] = None,
302+
dataset_compression: Optional[Mapping[str, Any]] = None,
302303
) -> Tuple[BaseDataset, BaseInputValidator]:
303304
"""
304305
Returns an object of a child class of `BaseDataset` and
@@ -341,6 +342,7 @@ def get_dataset(
341342
resampling_strategy: Optional[ResamplingStrategies] = None,
342343
resampling_strategy_args: Optional[Dict[str, Any]] = None,
343344
dataset_name: Optional[str] = None,
345+
dataset_compression: Optional[Mapping[str, Any]] = None,
344346
) -> BaseDataset:
345347
"""
346348
Returns an object of a child class of `BaseDataset` according to the current task.
@@ -375,7 +377,8 @@ def get_dataset(
375377
y_test=y_test,
376378
resampling_strategy=resampling_strategy,
377379
resampling_strategy_args=resampling_strategy_args,
378-
dataset_name=dataset_name)
380+
dataset_name=dataset_name,
381+
dataset_compression=dataset_compression)
379382

380383
return dataset
381384

autoPyTorch/data/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def validate_dataset_compression_arg(
240240
f"\nmemory_allocation = {memory_allocation}"
241241
f"\ndataset_compression = {dataset_compression}"
242242
)
243-
# convert to int so we can directly use
244-
dataset_compression["memory_allocation"] = floor(memory_allocation * memory_limit)
243+
# convert to required memory so we can directly use
244+
dataset_compression["memory_allocation"] = memory_allocation * memory_limit
245245

246246
# "methods" must be non-empty sequence
247247
if (
@@ -464,7 +464,7 @@ def megabytes(arr: DatasetCompressionInputType) -> float:
464464

465465
def reduce_dataset_size_if_too_large(
466466
X: DatasetCompressionInputType,
467-
memory_allocation: int,
467+
memory_allocation: float,
468468
is_classification: bool,
469469
random_state: Union[int, np.random.RandomState],
470470
y: Optional[SupportedTargetTypes] = None,

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.optim.lr_scheduler import _LRScheduler
1919
from torch.utils.tensorboard.writer import SummaryWriter
2020

21-
from autoPyTorch.constants import STRING_TO_TASK_TYPES
21+
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
2222
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
2323
from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice
2424
from autoPyTorch.pipeline.components.base_component import (
@@ -257,6 +257,14 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
257257
if 'optimize_metric' in X and X['optimize_metric'] not in [m.name for m in metrics]:
258258
metrics.extend(get_metrics(dataset_properties=X['dataset_properties'], names=[X['optimize_metric']]))
259259
additional_losses = X['additional_losses'] if 'additional_losses' in X else None
260+
261+
# Ensure that the split is not missing any class.
262+
labels = X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
263+
if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
264+
unique_labels = len(np.unique(labels))
265+
if unique_labels < X['dataset_properties']['output_shape']:
266+
raise ValueError(f"Expected number of unique labels {unique_labels} in train split: {X['split_id']}"
267+
f" to be = num_classes {X['dataset_properties']['output_shape']}.")
260268
self.choice.prepare(
261269
model=X['network'],
262270
metrics=metrics,
@@ -268,7 +276,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
268276
metrics_during_training=X['metrics_during_training'],
269277
scheduler=X['lr_scheduler'],
270278
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
271-
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
279+
labels=labels,
272280
step_interval=X['step_interval']
273281
)
274282
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])

0 commit comments

Comments
 (0)