@@ -36,19 +36,20 @@ def _wrap_model_in_dataparallel(net):
3636 return torch .nn .DataParallel (net , device_ids = alt_device_ids )
3737
3838
39- def _move_sample_to_cuda (samples ) :
39+ def _move_sample_list_to_cuda (samples : List [ Tensor ]) -> List [ Tensor ] :
4040 return [s .cuda () for s in samples ]
4141
4242
4343class ExplicitDataset (Dataset ):
44- def __init__ (self , samples , labels , use_gpu = False ) -> None :
44+ def __init__ (
45+ self ,
46+ samples : Tensor ,
47+ labels : Tensor ,
48+ use_gpu = False ,
49+ ) -> None :
4550 self .samples , self .labels = samples , labels
4651 if use_gpu :
47- self .samples = (
48- _move_sample_to_cuda (self .samples )
49- if isinstance (self .samples , list )
50- else self .samples .cuda ()
51- )
52+ self .samples = self .samples .cuda ()
5253 self .labels = self .labels .cuda ()
5354
5455 def __len__ (self ) -> int :
@@ -59,14 +60,15 @@ def __getitem__(self, idx):
5960
6061
6162class UnpackDataset (Dataset ):
62- def __init__ (self , samples , labels , use_gpu = False ) -> None :
63+ def __init__ (
64+ self ,
65+ samples : List [Tensor ],
66+ labels : Tensor ,
67+ use_gpu = False ,
68+ ) -> None :
6369 self .samples , self .labels = samples , labels
6470 if use_gpu :
65- self .samples = (
66- _move_sample_to_cuda (self .samples )
67- if isinstance (self .samples , list )
68- else self .samples .cuda ()
69- )
71+ self .samples = _move_sample_list_to_cuda (self .samples )
7072 self .labels = self .labels .cuda ()
7173
7274 def __len__ (self ) -> int :
@@ -225,6 +227,73 @@ def forward(self, *inputs: Tensor) -> Tensor:
225227 return self .linear (torch .cat (inputs , dim = 1 ))
226228
227229
230+ def get_random_data (
231+ in_features : int ,
232+ out_features : int ,
233+ num_examples : int ,
234+ use_gpu : bool ,
235+ unpack_inputs : bool ,
236+ ) -> Tuple [Dataset , Dataset , Dataset ]:
237+ """
238+ returns train_dataset, test_dataset and hessian_dataset constructed from
239+ random labels and random features, with features having shape
240+ [num_examples x num_features] and labels having shape [num_examples].
241+
242+ Note: the random labels and features for different dataset needs to be
243+ generated together.
244+ Otherwise, the test_tracin_self_influence_15_none_ArnoldiInfluenceFunction_linear1_unpack_inputs
245+ will fail (https://www.internalfb.com/intern/testinfra/testconsole/testrun/7599824580546642/)
246+ """
247+
248+ num_train = 32
249+ num_hessian = 22 # this needs to be high to prevent numerical issues
250+ num_inputs = 2 if unpack_inputs else 1
251+
252+ labels = torch .normal (1 , 2 , (num_examples , out_features )).double ()
253+ labels = labels .cuda () if use_gpu else labels
254+
255+ all_samples = [
256+ torch .normal (0 , 1 , (num_examples , in_features )).double ()
257+ for _ in range (num_inputs )
258+ ]
259+ if use_gpu :
260+ all_samples = _move_sample_list_to_cuda (all_samples )
261+
262+ # TODO: no need to pass use_gpu since the data format has already been moved to cuda
263+ train_dataset = (
264+ UnpackDataset (
265+ [samples [:num_train ] for samples in all_samples ],
266+ labels [:num_train ],
267+ use_gpu ,
268+ )
269+ if unpack_inputs
270+ else ExplicitDataset (all_samples [0 ][:num_train ], labels [:num_train ], use_gpu )
271+ )
272+
273+ hessian_dataset = (
274+ UnpackDataset (
275+ [samples [:num_hessian ] for samples in all_samples ],
276+ labels [:num_hessian ],
277+ use_gpu ,
278+ )
279+ if unpack_inputs
280+ else ExplicitDataset (
281+ all_samples [0 ][:num_hessian ], labels [:num_hessian ], use_gpu
282+ )
283+ )
284+
285+ test_dataset = (
286+ UnpackDataset (
287+ [samples [num_train :] for samples in all_samples ],
288+ labels [num_train :],
289+ use_gpu ,
290+ )
291+ if unpack_inputs
292+ else ExplicitDataset (all_samples [0 ][num_train :], labels [num_train :], use_gpu )
293+ )
294+ return (train_dataset , hessian_dataset , test_dataset )
295+
296+
228297def get_random_model_and_data (
229298 tmpdir ,
230299 unpack_inputs ,
@@ -283,38 +352,65 @@ def get_random_model_and_data(
283352 out_features = 3
284353
285354 num_samples = 50
286- num_train = 32
287- num_hessian = 22 # this needs to be high to prevent numerical issues
288- all_labels = torch .normal (1 , 2 , (num_samples , out_features )).double ()
289- all_labels = all_labels .cuda () if use_gpu else all_labels
290- train_labels = all_labels [:num_train ]
291- test_labels = all_labels [num_train :]
292- hessian_labels = all_labels [:num_hessian ]
293-
294- if unpack_inputs :
295- all_samples = [
296- torch .normal (0 , 1 , (num_samples , in_features )).double ()
297- for _ in range (num_inputs )
298- ]
299- if use_gpu :
300- all_samples = _move_sample_to_cuda (all_samples )
301-
302- train_samples = [ts [:num_train ] for ts in all_samples ]
303- test_samples = [ts [num_train :] for ts in all_samples ]
304- hessian_samples = [ts [:num_hessian ] for ts in all_samples ]
305- else :
306- all_samples = torch .normal (0 , 1 , (num_samples , in_features )).double ()
307-
308- if use_gpu :
309- all_samples = all_samples .cuda ()
310- train_samples = all_samples [:num_train ]
311- test_samples = all_samples [num_train :]
312- hessian_samples = all_samples [:num_hessian ]
313-
314- dataset = (
315- ExplicitDataset (train_samples , train_labels , use_gpu )
316- if not unpack_inputs
317- else UnpackDataset (train_samples , train_labels , use_gpu )
355+ # num_train = 32
356+ # num_hessian = 22 # this needs to be high to prevent numerical issues
357+
358+ # all_labels = torch.normal(1, 2, (num_samples, out_features)).double()
359+ # all_labels = all_labels.cuda() if use_gpu else all_labels
360+ # train_labels = all_labels[:num_train]
361+ # test_labels = all_labels[num_train:]
362+ # hessian_labels = all_labels[:num_hessian]
363+
364+ # if unpack_inputs:
365+ # all_samples = [
366+ # torch.normal(0, 1, (num_samples, in_features)).double()
367+ # for _ in range(num_inputs)
368+ # ]
369+ # if use_gpu:
370+ # all_samples = _move_sample_to_cuda(all_samples)
371+
372+ # train_samples = [ts[:num_train] for ts in all_samples]
373+ # test_samples = [ts[num_train:] for ts in all_samples]
374+ # hessian_samples = [ts[:num_hessian] for ts in all_samples]
375+ # else:
376+ # all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
377+
378+ # if use_gpu:
379+ # all_samples = all_samples.cuda()
380+ # train_samples = all_samples[:num_train]
381+ # test_samples = all_samples[num_train:]
382+ # hessian_samples = all_samples[:num_hessian]
383+
384+ # train_dataset = get_random_data(
385+ # in_features, out_features, num_train, use_gpu, unpack_inputs
386+ # )
387+ # test_dataset = get_random_data(
388+ # in_features, out_features, num_samples - num_train, use_gpu, unpack_inputs
389+ # )
390+ # # hessian_dataset = get_random_data(
391+ # # in_features, out_features, num_hessian, use_gpu, unpack_inputs
392+ # # )
393+ # hessian_dataset = (
394+ # ExplicitDataset(
395+ # train_dataset.samples[:num_hessian],
396+ # train_dataset.labels[:num_hessian],
397+ # use_gpu,
398+ # )
399+ # if not unpack_inputs
400+ # else UnpackDataset(
401+ # [samples[:num_hessian] for samples in train_dataset.samples],
402+ # [labels[:num_hessian] for labels in train_dataset.labels],
403+ # use_gpu,
404+ # )
405+ # )
406+
407+ # (
408+ # ExplicitDataset(train_samples, train_labels, use_gpu)
409+ # if not unpack_inputs
410+ # else UnpackDataset(train_samples, train_labels, use_gpu)
411+ # )
412+ train_dataset , hessian_dataset , test_dataset = get_random_data (
413+ in_features , out_features , num_samples , use_gpu , unpack_inputs
318414 )
319415
320416 if model_type == "random" :
@@ -358,15 +454,19 @@ def get_random_model_and_data(
358454
359455 # turn input into a single tensor for use by least squares
360456 tensor_hessian_samples = (
361- hessian_samples if not unpack_inputs else torch .cat (hessian_samples , dim = 1 )
457+ hessian_dataset .samples
458+ if not unpack_inputs
459+ else torch .cat (hessian_dataset .samples , dim = 1 )
362460 )
363461 version = _parse_version (torch .__version__ )
364462 if version < (1 , 9 ):
365- theta = torch .lstsq (tensor_hessian_samples , hessian_labels ).solution [0 :1 ]
463+ theta = torch .lstsq (
464+ tensor_hessian_samples , hessian_dataset .labels
465+ ).solution [0 :1 ]
366466 else :
367467 # run least squares to get optimal trained parameters
368468 theta = torch .linalg .lstsq (
369- hessian_labels ,
469+ hessian_dataset . labels ,
370470 tensor_hessian_samples ,
371471 ).solution
372472 # the first `n` rows of `theta` contains the least squares solution, where
@@ -399,12 +499,12 @@ def get_random_model_and_data(
399499 # train model using several optimization steps on Hessian data
400500
401501 # create entire Hessian data as a batch
402- hessian_dataset = (
403- ExplicitDataset (hessian_samples , hessian_labels , use_gpu )
404- if not unpack_inputs
405- else UnpackDataset (hessian_samples , hessian_labels , use_gpu )
406- )
407- batch = next (iter (DataLoader (hessian_dataset , batch_size = num_hessian )))
502+ # hessian_dataset = (
503+ # ExplicitDataset(hessian_samples, hessian_labels, use_gpu)
504+ # if not unpack_inputs
505+ # else UnpackDataset(hessian_samples, hessian_labels, use_gpu)
506+ # )
507+ batch = next (iter (DataLoader (hessian_dataset , batch_size = len ( hessian_dataset ) )))
408508
409509 optimizer = torch .optim .Adam (net .parameters ())
410510 num_steps = 200
@@ -425,26 +525,13 @@ def get_random_model_and_data(
425525
426526 training_data = (
427527 net_adjusted ,
428- dataset ,
528+ train_dataset ,
429529 )
430530
431- hessian_data = (
432- (
433- _move_sample_to_cuda (hessian_samples )
434- if isinstance (hessian_samples , list ) and use_gpu
435- else (hessian_samples .cuda () if use_gpu else hessian_samples )
436- ),
437- hessian_labels .cuda () if use_gpu else hessian_labels ,
438- )
531+ hessian_data = (hessian_dataset .samples , hessian_dataset .labels )
532+
533+ test_data = (test_dataset .samples , test_dataset .labels )
439534
440- test_data = (
441- (
442- _move_sample_to_cuda (test_samples )
443- if isinstance (test_samples , list ) and use_gpu
444- else (test_samples .cuda () if use_gpu else test_samples )
445- ),
446- test_labels .cuda () if use_gpu else test_labels ,
447- )
448535 if return_test_data :
449536 if not return_hessian_data :
450537 return (* training_data , * test_data )
0 commit comments