@@ -36,19 +36,20 @@ def _wrap_model_in_dataparallel(net):
3636 return torch .nn .DataParallel (net , device_ids = alt_device_ids )
3737
3838
39- def _move_sample_to_cuda (samples ) :
39+ def _move_sample_list_to_cuda (samples : List [ Tensor ]) -> List [ Tensor ] :
4040 return [s .cuda () for s in samples ]
4141
4242
4343class ExplicitDataset (Dataset ):
44- def __init__ (self , samples , labels , use_gpu = False ) -> None :
44+ def __init__ (
45+ self ,
46+ samples : Tensor ,
47+ labels : Tensor ,
48+ use_gpu = False ,
49+ ) -> None :
4550 self .samples , self .labels = samples , labels
4651 if use_gpu :
47- self .samples = (
48- _move_sample_to_cuda (self .samples )
49- if isinstance (self .samples , list )
50- else self .samples .cuda ()
51- )
52+ self .samples = self .samples .cuda ()
5253 self .labels = self .labels .cuda ()
5354
5455 def __len__ (self ) -> int :
@@ -59,14 +60,15 @@ def __getitem__(self, idx):
5960
6061
6162class UnpackDataset (Dataset ):
62- def __init__ (self , samples , labels , use_gpu = False ) -> None :
63+ def __init__ (
64+ self ,
65+ samples : List [Tensor ],
66+ labels : Tensor ,
67+ use_gpu = False ,
68+ ) -> None :
6369 self .samples , self .labels = samples , labels
6470 if use_gpu :
65- self .samples = (
66- _move_sample_to_cuda (self .samples )
67- if isinstance (self .samples , list )
68- else self .samples .cuda ()
69- )
71+ self .samples = _move_sample_list_to_cuda (self .samples )
7072 self .labels = self .labels .cuda ()
7173
7274 def __len__ (self ) -> int :
@@ -225,6 +227,72 @@ def forward(self, *inputs: Tensor) -> Tensor:
225227 return self .linear (torch .cat (inputs , dim = 1 ))
226228
227229
230+ def get_random_data (
231+ in_features : int ,
232+ out_features : int ,
233+ num_examples : int ,
234+ use_gpu : bool ,
235+ unpack_inputs : bool ,
236+ ) -> Tuple [Dataset , Dataset , Dataset ]:
237+ """
238+ returns train_dataset, test_dataset and hessian_dataset constructed from
239+ random labels and random features, with features having shape
240+ [num_examples x num_features] and labels having shape [num_examples].
241+
242+ Note: the random labels and features for different dataset needs to be
243+ generated together.
244+ Otherwise, some tests will fail (https://fburl.com/testinfra/737jnpip)
245+ """
246+
247+ num_train = 32
248+ num_hessian = 22 # this needs to be high to prevent numerical issues
249+ num_inputs = 2 if unpack_inputs else 1
250+
251+ labels = torch .normal (1 , 2 , (num_examples , out_features )).double ()
252+ labels = labels .cuda () if use_gpu else labels
253+
254+ all_samples = [
255+ torch .normal (0 , 1 , (num_examples , in_features )).double ()
256+ for _ in range (num_inputs )
257+ ]
258+ if use_gpu :
259+ all_samples = _move_sample_list_to_cuda (all_samples )
260+
261+ # TODO: no need to pass use_gpu since the data format has already been moved to cuda
262+ train_dataset = (
263+ UnpackDataset (
264+ [samples [:num_train ] for samples in all_samples ],
265+ labels [:num_train ],
266+ use_gpu ,
267+ )
268+ if unpack_inputs
269+ else ExplicitDataset (all_samples [0 ][:num_train ], labels [:num_train ], use_gpu )
270+ )
271+
272+ hessian_dataset = (
273+ UnpackDataset (
274+ [samples [:num_hessian ] for samples in all_samples ],
275+ labels [:num_hessian ],
276+ use_gpu ,
277+ )
278+ if unpack_inputs
279+ else ExplicitDataset (
280+ all_samples [0 ][:num_hessian ], labels [:num_hessian ], use_gpu
281+ )
282+ )
283+
284+ test_dataset = (
285+ UnpackDataset (
286+ [samples [num_train :] for samples in all_samples ],
287+ labels [num_train :],
288+ use_gpu ,
289+ )
290+ if unpack_inputs
291+ else ExplicitDataset (all_samples [0 ][num_train :], labels [num_train :], use_gpu )
292+ )
293+ return (train_dataset , hessian_dataset , test_dataset )
294+
295+
228296def get_random_model_and_data (
229297 tmpdir ,
230298 unpack_inputs ,
@@ -283,38 +351,9 @@ def get_random_model_and_data(
283351 out_features = 3
284352
285353 num_samples = 50
286- num_train = 32
287- num_hessian = 22 # this needs to be high to prevent numerical issues
288- all_labels = torch .normal (1 , 2 , (num_samples , out_features )).double ()
289- all_labels = all_labels .cuda () if use_gpu else all_labels
290- train_labels = all_labels [:num_train ]
291- test_labels = all_labels [num_train :]
292- hessian_labels = all_labels [:num_hessian ]
293-
294- if unpack_inputs :
295- all_samples = [
296- torch .normal (0 , 1 , (num_samples , in_features )).double ()
297- for _ in range (num_inputs )
298- ]
299- if use_gpu :
300- all_samples = _move_sample_to_cuda (all_samples )
301354
302- train_samples = [ts [:num_train ] for ts in all_samples ]
303- test_samples = [ts [num_train :] for ts in all_samples ]
304- hessian_samples = [ts [:num_hessian ] for ts in all_samples ]
305- else :
306- all_samples = torch .normal (0 , 1 , (num_samples , in_features )).double ()
307-
308- if use_gpu :
309- all_samples = all_samples .cuda ()
310- train_samples = all_samples [:num_train ]
311- test_samples = all_samples [num_train :]
312- hessian_samples = all_samples [:num_hessian ]
313-
314- dataset = (
315- ExplicitDataset (train_samples , train_labels , use_gpu )
316- if not unpack_inputs
317- else UnpackDataset (train_samples , train_labels , use_gpu )
355+ train_dataset , hessian_dataset , test_dataset = get_random_data (
356+ in_features , out_features , num_samples , use_gpu , unpack_inputs
318357 )
319358
320359 if model_type == "random" :
@@ -358,15 +397,19 @@ def get_random_model_and_data(
358397
359398 # turn input into a single tensor for use by least squares
360399 tensor_hessian_samples = (
361- hessian_samples if not unpack_inputs else torch .cat (hessian_samples , dim = 1 )
400+ hessian_dataset .samples
401+ if not unpack_inputs
402+ else torch .cat (hessian_dataset .samples , dim = 1 )
362403 )
363404 version = _parse_version (torch .__version__ )
364405 if version < (1 , 9 ):
365- theta = torch .lstsq (tensor_hessian_samples , hessian_labels ).solution [0 :1 ]
406+ theta = torch .lstsq (
407+ tensor_hessian_samples , hessian_dataset .labels
408+ ).solution [0 :1 ]
366409 else :
367410 # run least squares to get optimal trained parameters
368411 theta = torch .linalg .lstsq (
369- hessian_labels ,
412+ hessian_dataset . labels ,
370413 tensor_hessian_samples ,
371414 ).solution
372415 # the first `n` rows of `theta` contains the least squares solution, where
@@ -397,14 +440,7 @@ def get_random_model_and_data(
397440 )
398441
399442 # train model using several optimization steps on Hessian data
400-
401- # create entire Hessian data as a batch
402- hessian_dataset = (
403- ExplicitDataset (hessian_samples , hessian_labels , use_gpu )
404- if not unpack_inputs
405- else UnpackDataset (hessian_samples , hessian_labels , use_gpu )
406- )
407- batch = next (iter (DataLoader (hessian_dataset , batch_size = num_hessian )))
443+ batch = next (iter (DataLoader (hessian_dataset , batch_size = len (hessian_dataset ))))
408444
409445 optimizer = torch .optim .Adam (net .parameters ())
410446 num_steps = 200
@@ -425,26 +461,13 @@ def get_random_model_and_data(
425461
426462 training_data = (
427463 net_adjusted ,
428- dataset ,
464+ train_dataset ,
429465 )
430466
431- hessian_data = (
432- (
433- _move_sample_to_cuda (hessian_samples )
434- if isinstance (hessian_samples , list ) and use_gpu
435- else (hessian_samples .cuda () if use_gpu else hessian_samples )
436- ),
437- hessian_labels .cuda () if use_gpu else hessian_labels ,
438- )
467+ hessian_data = (hessian_dataset .samples , hessian_dataset .labels )
468+
469+ test_data = (test_dataset .samples , test_dataset .labels )
439470
440- test_data = (
441- (
442- _move_sample_to_cuda (test_samples )
443- if isinstance (test_samples , list ) and use_gpu
444- else (test_samples .cuda () if use_gpu else test_samples )
445- ),
446- test_labels .cuda () if use_gpu else test_labels ,
447- )
448471 if return_test_data :
449472 if not return_hessian_data :
450473 return (* training_data , * test_data )
0 commit comments