1212from itertools import count , repeat
1313from typing import Any , Dict , Hashable , Iterable , Optional , TypeVar , Union
1414
15- from botorch .utils .containers import BotorchContainer , DenseContainer , SliceContainer
15+ import torch
16+ from botorch .utils .containers import BotorchContainer , SliceContainer
1617from torch import long , ones , Tensor
1718
1819T = TypeVar ("T" )
19- ContainerLike = Union [BotorchContainer , Tensor ]
2020MaybeIterable = Union [T , Iterable [T ]]
2121
2222
@@ -25,9 +25,6 @@ class SupervisedDataset:
2525 and an optional `Yvar` that stipulates observations variances so
2626 that `Y[i] ~ N(f(X[i]), Yvar[i])`.
2727
28- This class object's `__init__` method converts Tensors `src` to
29- DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
30-
3128 Example:
3229
3330 .. code-block:: python
@@ -42,15 +39,11 @@ class SupervisedDataset:
4239 assert A == B
4340 """
4441
45- X : BotorchContainer
46- Y : BotorchContainer
47- Yvar : Optional [BotorchContainer ]
48-
4942 def __init__ (
5043 self ,
51- X : ContainerLike ,
52- Y : ContainerLike ,
53- Yvar : Optional [ ContainerLike ] = None ,
44+ X : Union [ BotorchContainer , Tensor ] ,
45+ Y : Union [ BotorchContainer , Tensor ] ,
46+ Yvar : Union [ BotorchContainer , Tensor , None ] = None ,
5447 validate_init : bool = True ,
5548 ) -> None :
5649 r"""Constructs a `SupervisedDataset`.
@@ -62,17 +55,41 @@ def __init__(
6255 the observation noise.
6356 validate_init: If `True`, validates the input shapes.
6457 """
65- self .X = _containerize ( X )
66- self .Y = _containerize ( Y )
67- self .Yvar = None if Yvar is None else _containerize ( Yvar )
58+ self ._X = X
59+ self ._Y = Y
60+ self ._Yvar = Yvar
6861 if validate_init :
6962 self ._validate ()
7063
64+ @property
65+ def X (self ) -> Tensor :
66+ if isinstance (self ._X , Tensor ):
67+ return self ._X
68+ return self ._X ()
69+
70+ @property
71+ def Y (self ) -> Tensor :
72+ if isinstance (self ._Y , Tensor ):
73+ return self ._Y
74+ return self ._Y ()
75+
76+ @property
77+ def Yvar (self ) -> Optional [Tensor ]:
78+ if self ._Yvar is None or isinstance (self ._Yvar , Tensor ):
79+ return self ._Yvar
80+ return self ._Yvar ()
81+
7182 def _validate (self ) -> None :
7283 shape_X = self .X .shape
73- shape_X = shape_X [: len (shape_X ) - len (self .X .event_shape )]
84+ if isinstance (self ._X , BotorchContainer ):
85+ shape_X = shape_X [: len (shape_X ) - len (self ._X .event_shape )]
86+ else :
87+ shape_X = shape_X [:- 1 ]
7488 shape_Y = self .Y .shape
75- shape_Y = shape_Y [: len (shape_Y ) - len (self .Y .event_shape )]
89+ if isinstance (self ._Y , BotorchContainer ):
90+ shape_Y = shape_Y [: len (shape_Y ) - len (self ._Y .event_shape )]
91+ else :
92+ shape_Y = shape_Y [:- 1 ]
7693 if shape_X != shape_Y :
7794 raise ValueError ("Batch dimensions of `X` and `Y` are incompatible." )
7895 if self .Yvar is not None and self .Yvar .shape != self .Y .shape :
@@ -81,9 +98,9 @@ def _validate(self) -> None:
8198 @classmethod
8299 def dict_from_iter (
83100 cls ,
84- X : MaybeIterable [ContainerLike ],
85- Y : MaybeIterable [ContainerLike ],
86- Yvar : Optional [MaybeIterable [ContainerLike ]] = None ,
101+ X : MaybeIterable [Union [ BotorchContainer , Tensor ] ],
102+ Y : MaybeIterable [Union [ BotorchContainer , Tensor ] ],
103+ Yvar : Optional [MaybeIterable [Union [ BotorchContainer , Tensor ] ]] = None ,
87104 * ,
88105 keys : Optional [Iterable [Hashable ]] = None ,
89106 ) -> Dict [Hashable , SupervisedDataset ]:
@@ -106,9 +123,13 @@ def dict_from_iter(
106123 def __eq__ (self , other : Any ) -> bool :
107124 return (
108125 type (other ) is type (self )
109- and self .X == other .X
110- and self .Y == other .Y
111- and self .Yvar == other .Yvar
126+ and torch .equal (self .X , other .X )
127+ and torch .equal (self .Y , other .Y )
128+ and (
129+ other .Yvar is None
130+ if self .Yvar is None
131+ else torch .equal (self .Yvar , other .Yvar )
132+ )
112133 )
113134
114135
@@ -121,9 +142,9 @@ class FixedNoiseDataset(SupervisedDataset):
121142
122143 def __init__ (
123144 self ,
124- X : ContainerLike ,
125- Y : ContainerLike ,
126- Yvar : ContainerLike ,
145+ X : Union [ BotorchContainer , Tensor ] ,
146+ Y : Union [ BotorchContainer , Tensor ] ,
147+ Yvar : Union [ BotorchContainer , Tensor ] ,
127148 validate_init : bool = True ,
128149 ) -> None :
129150 r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
@@ -159,11 +180,11 @@ class RankingDataset(SupervisedDataset):
159180 dataset = RankingDataset(X, Y)
160181 """
161182
162- X : SliceContainer
163- Y : BotorchContainer
164-
165183 def __init__ (
166- self , X : SliceContainer , Y : ContainerLike , validate_init : bool = True
184+ self ,
185+ X : SliceContainer ,
186+ Y : Union [BotorchContainer , Tensor ],
187+ validate_init : bool = True ,
167188 ) -> None :
168189 r"""Construct a `RankingDataset`.
169190
@@ -177,8 +198,8 @@ def __init__(
177198 def _validate (self ) -> None :
178199 super ()._validate ()
179200
180- Y = self .Y ()
181- arity = self .X .indices .shape [- 1 ]
201+ Y = self .Y
202+ arity = self ._X .indices .shape [- 1 ]
182203 if Y .min () < 0 or Y .max () >= arity :
183204 raise ValueError ("Invalid ranking(s): out-of-bounds ranks detected." )
184205
@@ -202,13 +223,3 @@ def _validate(self) -> None:
202223
203224 # Same as: torch.where(y_diff == 0, y_incr + 1, 1)
204225 y_incr = y_incr - y_diff + 1
205-
206-
207- def _containerize (value : ContainerLike ) -> BotorchContainer :
208- r"""Converts Tensor-valued arguments to DenseContainer under the assumption
209- that said arguments house collections of feature vectors.
210- """
211- if isinstance (value , Tensor ):
212- return DenseContainer (value , event_shape = value .shape [- 1 :])
213- else :
214- return value
0 commit comments