Skip to content

Commit cd5c51e

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Make X, Y, Yvar into properties in datasets (#2004)
Summary: X-link: facebook/Ax#1817 Pull Request resolved: #2004 Since the primary use case of the datasets is to house a couple tensors, this simplifies the process by making dataset.X/Y/Yvar into a tensor rather than a callable. This also eliminates the need to make tensors into containers, eliminating a step in the process. Reviewed By: Balandat Differential Revision: D48926544 fbshipit-source-id: ca31fd0946c449b501bb5a9e5d3f8f5e38bab715
1 parent d6a1e12 commit cd5c51e

File tree

6 files changed

+124
-96
lines changed

6 files changed

+124
-96
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ def _field_is_shared(
141141
obj = getattr(dataset, fieldname)
142142
if base is None:
143143
base = obj
144-
elif base != obj:
144+
elif isinstance(base, Tensor):
145+
if not torch.equal(base, obj):
146+
return False
147+
elif base != obj: # pragma: no cover
145148
return False
146149

147150
return True
@@ -414,7 +417,7 @@ def construct_inputs_noisy_ei(
414417
X = _get_dataset_field(training_data, "X", first_only=True, assert_shared=True)
415418
return {
416419
"model": model,
417-
"X_observed": X(),
420+
"X_observed": X,
418421
"num_fantasies": num_fantasies,
419422
"maximize": maximize,
420423
}
@@ -624,7 +627,6 @@ def construct_inputs_qNEI(
624627
X_baseline = _get_dataset_field(
625628
training_data,
626629
fieldname="X",
627-
transform=lambda field: field(),
628630
assert_shared=True,
629631
first_only=True,
630632
)
@@ -845,7 +847,6 @@ def construct_inputs_EHVI(
845847
X = _get_dataset_field(
846848
training_data,
847849
fieldname="X",
848-
transform=lambda field: field(),
849850
first_only=True,
850851
assert_shared=True,
851852
)
@@ -908,7 +909,6 @@ def construct_inputs_qEHVI(
908909
X = _get_dataset_field(
909910
training_data,
910911
fieldname="X",
911-
transform=lambda field: field(),
912912
first_only=True,
913913
assert_shared=True,
914914
)
@@ -974,7 +974,6 @@ def construct_inputs_qNEHVI(
974974
X_baseline = _get_dataset_field(
975975
training_data,
976976
fieldname="X",
977-
transform=lambda field: field(),
978977
first_only=True,
979978
assert_shared=True,
980979
)
@@ -1245,7 +1244,6 @@ def get_best_f_analytic(
12451244
Y = _get_dataset_field(
12461245
training_data,
12471246
fieldname="Y",
1248-
transform=lambda field: field(),
12491247
join_rule=lambda field_tensors: torch.cat(field_tensors, dim=-1),
12501248
)
12511249

@@ -1274,15 +1272,13 @@ def get_best_f_mc(
12741272
X_baseline = _get_dataset_field(
12751273
training_data,
12761274
fieldname="X",
1277-
transform=lambda field: field(),
12781275
assert_shared=True,
12791276
first_only=True,
12801277
)
12811278

12821279
Y = _get_dataset_field(
12831280
training_data,
12841281
fieldname="Y",
1285-
transform=lambda field: field(),
12861282
join_rule=lambda field_tensors: torch.cat(field_tensors, dim=-1),
12871283
) # batch_shape x n x d
12881284

botorch/models/utils/parse_training_data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,21 @@ def parse_training_data(
5151
def _parse_model_supervised(
5252
consumer: Model, dataset: SupervisedDataset, **ignore: Any
5353
) -> Dict[str, Tensor]:
54-
parsed_data = {"train_X": dataset.X(), "train_Y": dataset.Y()}
54+
parsed_data = {"train_X": dataset.X, "train_Y": dataset.Y}
5555
if dataset.Yvar is not None:
56-
parsed_data["train_Yvar"] = dataset.Yvar()
56+
parsed_data["train_Yvar"] = dataset.Yvar
5757
return parsed_data
5858

5959

6060
@dispatcher.register(PairwiseGP, RankingDataset)
6161
def _parse_pairwiseGP_ranking(
6262
consumer: PairwiseGP, dataset: RankingDataset, **ignore: Any
6363
) -> Dict[str, Tensor]:
64-
datapoints = dataset.X.values
65-
comparisons = dataset.X.indices
66-
comp_order = dataset.Y()
64+
# TODO: [T163045056] Not sure what the point of the special container is if we have
65+
# to further process it here. We should move this logic into RankingDataset.
66+
datapoints = dataset._X.values
67+
comparisons = dataset._X.indices
68+
comp_order = dataset.Y
6769
comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order)
6870

6971
return {

botorch/utils/datasets.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from itertools import count, repeat
1313
from typing import Any, Dict, Hashable, Iterable, Optional, TypeVar, Union
1414

15-
from botorch.utils.containers import BotorchContainer, DenseContainer, SliceContainer
15+
import torch
16+
from botorch.utils.containers import BotorchContainer, SliceContainer
1617
from torch import long, ones, Tensor
1718

1819
T = TypeVar("T")
19-
ContainerLike = Union[BotorchContainer, Tensor]
2020
MaybeIterable = Union[T, Iterable[T]]
2121

2222

@@ -25,9 +25,6 @@ class SupervisedDataset:
2525
and an optional `Yvar` that stipulates observations variances so
2626
that `Y[i] ~ N(f(X[i]), Yvar[i])`.
2727
28-
This class object's `__init__` method converts Tensors `src` to
29-
DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
30-
3128
Example:
3229
3330
.. code-block:: python
@@ -42,15 +39,11 @@ class SupervisedDataset:
4239
assert A == B
4340
"""
4441

45-
X: BotorchContainer
46-
Y: BotorchContainer
47-
Yvar: Optional[BotorchContainer]
48-
4942
def __init__(
5043
self,
51-
X: ContainerLike,
52-
Y: ContainerLike,
53-
Yvar: Optional[ContainerLike] = None,
44+
X: Union[BotorchContainer, Tensor],
45+
Y: Union[BotorchContainer, Tensor],
46+
Yvar: Union[BotorchContainer, Tensor, None] = None,
5447
validate_init: bool = True,
5548
) -> None:
5649
r"""Constructs a `SupervisedDataset`.
@@ -62,17 +55,41 @@ def __init__(
6255
the observation noise.
6356
validate_init: If `True`, validates the input shapes.
6457
"""
65-
self.X = _containerize(X)
66-
self.Y = _containerize(Y)
67-
self.Yvar = None if Yvar is None else _containerize(Yvar)
58+
self._X = X
59+
self._Y = Y
60+
self._Yvar = Yvar
6861
if validate_init:
6962
self._validate()
7063

64+
@property
65+
def X(self) -> Tensor:
66+
if isinstance(self._X, Tensor):
67+
return self._X
68+
return self._X()
69+
70+
@property
71+
def Y(self) -> Tensor:
72+
if isinstance(self._Y, Tensor):
73+
return self._Y
74+
return self._Y()
75+
76+
@property
77+
def Yvar(self) -> Optional[Tensor]:
78+
if self._Yvar is None or isinstance(self._Yvar, Tensor):
79+
return self._Yvar
80+
return self._Yvar()
81+
7182
def _validate(self) -> None:
7283
shape_X = self.X.shape
73-
shape_X = shape_X[: len(shape_X) - len(self.X.event_shape)]
84+
if isinstance(self._X, BotorchContainer):
85+
shape_X = shape_X[: len(shape_X) - len(self._X.event_shape)]
86+
else:
87+
shape_X = shape_X[:-1]
7488
shape_Y = self.Y.shape
75-
shape_Y = shape_Y[: len(shape_Y) - len(self.Y.event_shape)]
89+
if isinstance(self._Y, BotorchContainer):
90+
shape_Y = shape_Y[: len(shape_Y) - len(self._Y.event_shape)]
91+
else:
92+
shape_Y = shape_Y[:-1]
7693
if shape_X != shape_Y:
7794
raise ValueError("Batch dimensions of `X` and `Y` are incompatible.")
7895
if self.Yvar is not None and self.Yvar.shape != self.Y.shape:
@@ -81,9 +98,9 @@ def _validate(self) -> None:
8198
@classmethod
8299
def dict_from_iter(
83100
cls,
84-
X: MaybeIterable[ContainerLike],
85-
Y: MaybeIterable[ContainerLike],
86-
Yvar: Optional[MaybeIterable[ContainerLike]] = None,
101+
X: MaybeIterable[Union[BotorchContainer, Tensor]],
102+
Y: MaybeIterable[Union[BotorchContainer, Tensor]],
103+
Yvar: Optional[MaybeIterable[Union[BotorchContainer, Tensor]]] = None,
87104
*,
88105
keys: Optional[Iterable[Hashable]] = None,
89106
) -> Dict[Hashable, SupervisedDataset]:
@@ -106,9 +123,13 @@ def dict_from_iter(
106123
def __eq__(self, other: Any) -> bool:
107124
return (
108125
type(other) is type(self)
109-
and self.X == other.X
110-
and self.Y == other.Y
111-
and self.Yvar == other.Yvar
126+
and torch.equal(self.X, other.X)
127+
and torch.equal(self.Y, other.Y)
128+
and (
129+
other.Yvar is None
130+
if self.Yvar is None
131+
else torch.equal(self.Yvar, other.Yvar)
132+
)
112133
)
113134

114135

@@ -121,9 +142,9 @@ class FixedNoiseDataset(SupervisedDataset):
121142

122143
def __init__(
123144
self,
124-
X: ContainerLike,
125-
Y: ContainerLike,
126-
Yvar: ContainerLike,
145+
X: Union[BotorchContainer, Tensor],
146+
Y: Union[BotorchContainer, Tensor],
147+
Yvar: Union[BotorchContainer, Tensor],
127148
validate_init: bool = True,
128149
) -> None:
129150
r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
@@ -159,11 +180,11 @@ class RankingDataset(SupervisedDataset):
159180
dataset = RankingDataset(X, Y)
160181
"""
161182

162-
X: SliceContainer
163-
Y: BotorchContainer
164-
165183
def __init__(
166-
self, X: SliceContainer, Y: ContainerLike, validate_init: bool = True
184+
self,
185+
X: SliceContainer,
186+
Y: Union[BotorchContainer, Tensor],
187+
validate_init: bool = True,
167188
) -> None:
168189
r"""Construct a `RankingDataset`.
169190
@@ -177,8 +198,8 @@ def __init__(
177198
def _validate(self) -> None:
178199
super()._validate()
179200

180-
Y = self.Y()
181-
arity = self.X.indices.shape[-1]
201+
Y = self.Y
202+
arity = self._X.indices.shape[-1]
182203
if Y.min() < 0 or Y.max() >= arity:
183204
raise ValueError("Invalid ranking(s): out-of-bounds ranks detected.")
184205

@@ -202,13 +223,3 @@ def _validate(self) -> None:
202223

203224
# Same as: torch.where(y_diff == 0, y_incr + 1, 1)
204225
y_incr = y_incr - y_diff + 1
205-
206-
207-
def _containerize(value: ContainerLike) -> BotorchContainer:
208-
r"""Converts Tensor-valued arguments to DenseContainer under the assumption
209-
that said arguments house collections of feature vectors.
210-
"""
211-
if isinstance(value, Tensor):
212-
return DenseContainer(value, event_shape=value.shape[-1:])
213-
else:
214-
return value

0 commit comments

Comments
 (0)