@@ -45,7 +45,7 @@ def __init__(
4545 self ,
4646 samples : Tensor ,
4747 labels : Tensor ,
48- use_gpu = False ,
48+ use_gpu : bool = False ,
4949 ) -> None :
5050 self .samples , self .labels = samples , labels
5151 if use_gpu :
@@ -64,7 +64,7 @@ def __init__(
6464 self ,
6565 samples : List [Tensor ],
6666 labels : Tensor ,
67- use_gpu = False ,
67+ use_gpu : bool = False ,
6868 ) -> None :
6969 self .samples , self .labels = samples , labels
7070 if use_gpu :
@@ -83,7 +83,11 @@ def __getitem__(self, idx):
8383
8484
8585class IdentityDataset (ExplicitDataset ):
86- def __init__ (self , num_features , use_gpu = False ) -> None :
86+ def __init__ (
87+ self ,
88+ num_features : int ,
89+ use_gpu : bool = False ,
90+ ) -> None :
8791 self .samples = torch .diag (torch .ones (num_features ))
8892 self .labels = torch .zeros (num_features ).unsqueeze (1 )
8993 if use_gpu :
@@ -92,7 +96,13 @@ def __init__(self, num_features, use_gpu=False) -> None:
9296
9397
9498class RangeDataset (ExplicitDataset ):
95- def __init__ (self , low , high , num_features , use_gpu = False ) -> None :
99+ def __init__ (
100+ self ,
101+ low : int ,
102+ high : int ,
103+ num_features : int ,
104+ use_gpu : bool = False ,
105+ ) -> None :
96106 self .samples = (
97107 torch .arange (start = low , end = high , dtype = torch .float )
98108 .repeat (num_features , 1 )
@@ -105,7 +115,7 @@ def __init__(self, low, high, num_features, use_gpu=False) -> None:
105115
106116
107117class BinaryDataset (ExplicitDataset ):
108- def __init__ (self , use_gpu = False ) -> None :
118+ def __init__ (self , use_gpu : bool = False ) -> None :
109119 self .samples = F .normalize (
110120 torch .stack (
111121 (
@@ -249,16 +259,11 @@ def get_random_data(
249259 num_inputs = 2 if unpack_inputs else 1
250260
251261 labels = torch .normal (1 , 2 , (num_examples , out_features )).double ()
252- labels = labels .cuda () if use_gpu else labels
253-
254262 all_samples = [
255263 torch .normal (0 , 1 , (num_examples , in_features )).double ()
256264 for _ in range (num_inputs )
257265 ]
258- if use_gpu :
259- all_samples = _move_sample_list_to_cuda (all_samples )
260266
261- # TODO: no need to pass use_gpu since the data format has already been moved to cuda
262267 train_dataset = (
263268 UnpackDataset (
264269 [samples [:num_train ] for samples in all_samples ],
0 commit comments