Skip to content

Commit a6d767f

Browse files
cicichen01facebook-github-bot
authored andcommitted
Separate the data generation logic
Summary: As titled. This is part of get_random_model_and_data() method simplification. Differential Revision: D55224026
1 parent 325e114 commit a6d767f

File tree

1 file changed

+158
-71
lines changed

1 file changed

+158
-71
lines changed

tests/influence/_utils/common.py

Lines changed: 158 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,20 @@ def _wrap_model_in_dataparallel(net):
3636
return torch.nn.DataParallel(net, device_ids=alt_device_ids)
3737

3838

39-
def _move_sample_to_cuda(samples):
39+
def _move_sample_list_to_cuda(samples: List[Tensor]) -> List[Tensor]:
4040
return [s.cuda() for s in samples]
4141

4242

4343
class ExplicitDataset(Dataset):
44-
def __init__(self, samples, labels, use_gpu=False) -> None:
44+
def __init__(
45+
self,
46+
samples: Tensor,
47+
labels: Tensor,
48+
use_gpu=False,
49+
) -> None:
4550
self.samples, self.labels = samples, labels
4651
if use_gpu:
47-
self.samples = (
48-
_move_sample_to_cuda(self.samples)
49-
if isinstance(self.samples, list)
50-
else self.samples.cuda()
51-
)
52+
self.samples = self.samples.cuda()
5253
self.labels = self.labels.cuda()
5354

5455
def __len__(self) -> int:
@@ -59,14 +60,15 @@ def __getitem__(self, idx):
5960

6061

6162
class UnpackDataset(Dataset):
62-
def __init__(self, samples, labels, use_gpu=False) -> None:
63+
def __init__(
64+
self,
65+
samples: List[Tensor],
66+
labels: Tensor,
67+
use_gpu=False,
68+
) -> None:
6369
self.samples, self.labels = samples, labels
6470
if use_gpu:
65-
self.samples = (
66-
_move_sample_to_cuda(self.samples)
67-
if isinstance(self.samples, list)
68-
else self.samples.cuda()
69-
)
71+
self.samples = _move_sample_list_to_cuda(self.samples)
7072
self.labels = self.labels.cuda()
7173

7274
def __len__(self) -> int:
@@ -225,6 +227,73 @@ def forward(self, *inputs: Tensor) -> Tensor:
225227
return self.linear(torch.cat(inputs, dim=1))
226228

227229

230+
def get_random_data(
231+
in_features: int,
232+
out_features: int,
233+
num_examples: int,
234+
use_gpu: bool,
235+
unpack_inputs: bool,
236+
) -> Tuple[Dataset, Dataset, Dataset]:
237+
"""
238+
returns train_dataset, test_dataset and hessian_dataset constructed from
239+
random labels and random features, with features having shape
240+
[num_examples x num_features] and labels having shape [num_examples].
241+
242+
Note: the random labels and features for different dataset needs to be
243+
generated together.
244+
Otherwise, the test_tracin_self_influence_15_none_ArnoldiInfluenceFunction_linear1_unpack_inputs
245+
will fail (https://www.internalfb.com/intern/testinfra/testconsole/testrun/7599824580546642/)
246+
"""
247+
248+
num_train = 32
249+
num_hessian = 22 # this needs to be high to prevent numerical issues
250+
num_inputs = 2 if unpack_inputs else 1
251+
252+
labels = torch.normal(1, 2, (num_examples, out_features)).double()
253+
labels = labels.cuda() if use_gpu else labels
254+
255+
all_samples = [
256+
torch.normal(0, 1, (num_examples, in_features)).double()
257+
for _ in range(num_inputs)
258+
]
259+
if use_gpu:
260+
all_samples = _move_sample_list_to_cuda(all_samples)
261+
262+
# TODO: no need to pass use_gpu since the data format has already been moved to cuda
263+
train_dataset = (
264+
UnpackDataset(
265+
[samples[:num_train] for samples in all_samples],
266+
labels[:num_train],
267+
use_gpu,
268+
)
269+
if unpack_inputs
270+
else ExplicitDataset(all_samples[0][:num_train], labels[:num_train], use_gpu)
271+
)
272+
273+
hessian_dataset = (
274+
UnpackDataset(
275+
[samples[:num_hessian] for samples in all_samples],
276+
labels[:num_hessian],
277+
use_gpu,
278+
)
279+
if unpack_inputs
280+
else ExplicitDataset(
281+
all_samples[0][:num_hessian], labels[:num_hessian], use_gpu
282+
)
283+
)
284+
285+
test_dataset = (
286+
UnpackDataset(
287+
[samples[num_train:] for samples in all_samples],
288+
labels[num_train:],
289+
use_gpu,
290+
)
291+
if unpack_inputs
292+
else ExplicitDataset(all_samples[0][num_train:], labels[num_train:], use_gpu)
293+
)
294+
return (train_dataset, hessian_dataset, test_dataset)
295+
296+
228297
def get_random_model_and_data(
229298
tmpdir,
230299
unpack_inputs,
@@ -283,38 +352,65 @@ def get_random_model_and_data(
283352
out_features = 3
284353

285354
num_samples = 50
286-
num_train = 32
287-
num_hessian = 22 # this needs to be high to prevent numerical issues
288-
all_labels = torch.normal(1, 2, (num_samples, out_features)).double()
289-
all_labels = all_labels.cuda() if use_gpu else all_labels
290-
train_labels = all_labels[:num_train]
291-
test_labels = all_labels[num_train:]
292-
hessian_labels = all_labels[:num_hessian]
293-
294-
if unpack_inputs:
295-
all_samples = [
296-
torch.normal(0, 1, (num_samples, in_features)).double()
297-
for _ in range(num_inputs)
298-
]
299-
if use_gpu:
300-
all_samples = _move_sample_to_cuda(all_samples)
301-
302-
train_samples = [ts[:num_train] for ts in all_samples]
303-
test_samples = [ts[num_train:] for ts in all_samples]
304-
hessian_samples = [ts[:num_hessian] for ts in all_samples]
305-
else:
306-
all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
307-
308-
if use_gpu:
309-
all_samples = all_samples.cuda()
310-
train_samples = all_samples[:num_train]
311-
test_samples = all_samples[num_train:]
312-
hessian_samples = all_samples[:num_hessian]
313-
314-
dataset = (
315-
ExplicitDataset(train_samples, train_labels, use_gpu)
316-
if not unpack_inputs
317-
else UnpackDataset(train_samples, train_labels, use_gpu)
355+
# num_train = 32
356+
# num_hessian = 22 # this needs to be high to prevent numerical issues
357+
358+
# all_labels = torch.normal(1, 2, (num_samples, out_features)).double()
359+
# all_labels = all_labels.cuda() if use_gpu else all_labels
360+
# train_labels = all_labels[:num_train]
361+
# test_labels = all_labels[num_train:]
362+
# hessian_labels = all_labels[:num_hessian]
363+
364+
# if unpack_inputs:
365+
# all_samples = [
366+
# torch.normal(0, 1, (num_samples, in_features)).double()
367+
# for _ in range(num_inputs)
368+
# ]
369+
# if use_gpu:
370+
# all_samples = _move_sample_to_cuda(all_samples)
371+
372+
# train_samples = [ts[:num_train] for ts in all_samples]
373+
# test_samples = [ts[num_train:] for ts in all_samples]
374+
# hessian_samples = [ts[:num_hessian] for ts in all_samples]
375+
# else:
376+
# all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
377+
378+
# if use_gpu:
379+
# all_samples = all_samples.cuda()
380+
# train_samples = all_samples[:num_train]
381+
# test_samples = all_samples[num_train:]
382+
# hessian_samples = all_samples[:num_hessian]
383+
384+
# train_dataset = get_random_data(
385+
# in_features, out_features, num_train, use_gpu, unpack_inputs
386+
# )
387+
# test_dataset = get_random_data(
388+
# in_features, out_features, num_samples - num_train, use_gpu, unpack_inputs
389+
# )
390+
# # hessian_dataset = get_random_data(
391+
# # in_features, out_features, num_hessian, use_gpu, unpack_inputs
392+
# # )
393+
# hessian_dataset = (
394+
# ExplicitDataset(
395+
# train_dataset.samples[:num_hessian],
396+
# train_dataset.labels[:num_hessian],
397+
# use_gpu,
398+
# )
399+
# if not unpack_inputs
400+
# else UnpackDataset(
401+
# [samples[:num_hessian] for samples in train_dataset.samples],
402+
# [labels[:num_hessian] for labels in train_dataset.labels],
403+
# use_gpu,
404+
# )
405+
# )
406+
407+
# (
408+
# ExplicitDataset(train_samples, train_labels, use_gpu)
409+
# if not unpack_inputs
410+
# else UnpackDataset(train_samples, train_labels, use_gpu)
411+
# )
412+
train_dataset, hessian_dataset, test_dataset = get_random_data(
413+
in_features, out_features, num_samples, use_gpu, unpack_inputs
318414
)
319415

320416
if model_type == "random":
@@ -358,15 +454,19 @@ def get_random_model_and_data(
358454

359455
# turn input into a single tensor for use by least squares
360456
tensor_hessian_samples = (
361-
hessian_samples if not unpack_inputs else torch.cat(hessian_samples, dim=1)
457+
hessian_dataset.samples
458+
if not unpack_inputs
459+
else torch.cat(hessian_dataset.samples, dim=1)
362460
)
363461
version = _parse_version(torch.__version__)
364462
if version < (1, 9):
365-
theta = torch.lstsq(tensor_hessian_samples, hessian_labels).solution[0:1]
463+
theta = torch.lstsq(
464+
tensor_hessian_samples, hessian_dataset.labels
465+
).solution[0:1]
366466
else:
367467
# run least squares to get optimal trained parameters
368468
theta = torch.linalg.lstsq(
369-
hessian_labels,
469+
hessian_dataset.labels,
370470
tensor_hessian_samples,
371471
).solution
372472
# the first `n` rows of `theta` contains the least squares solution, where
@@ -399,12 +499,12 @@ def get_random_model_and_data(
399499
# train model using several optimization steps on Hessian data
400500

401501
# create entire Hessian data as a batch
402-
hessian_dataset = (
403-
ExplicitDataset(hessian_samples, hessian_labels, use_gpu)
404-
if not unpack_inputs
405-
else UnpackDataset(hessian_samples, hessian_labels, use_gpu)
406-
)
407-
batch = next(iter(DataLoader(hessian_dataset, batch_size=num_hessian)))
502+
# hessian_dataset = (
503+
# ExplicitDataset(hessian_samples, hessian_labels, use_gpu)
504+
# if not unpack_inputs
505+
# else UnpackDataset(hessian_samples, hessian_labels, use_gpu)
506+
# )
507+
batch = next(iter(DataLoader(hessian_dataset, batch_size=len(hessian_dataset))))
408508

409509
optimizer = torch.optim.Adam(net.parameters())
410510
num_steps = 200
@@ -425,26 +525,13 @@ def get_random_model_and_data(
425525

426526
training_data = (
427527
net_adjusted,
428-
dataset,
528+
train_dataset,
429529
)
430530

431-
hessian_data = (
432-
(
433-
_move_sample_to_cuda(hessian_samples)
434-
if isinstance(hessian_samples, list) and use_gpu
435-
else (hessian_samples.cuda() if use_gpu else hessian_samples)
436-
),
437-
hessian_labels.cuda() if use_gpu else hessian_labels,
438-
)
531+
hessian_data = (hessian_dataset.samples, hessian_dataset.labels)
532+
533+
test_data = (test_dataset.samples, test_dataset.labels)
439534

440-
test_data = (
441-
(
442-
_move_sample_to_cuda(test_samples)
443-
if isinstance(test_samples, list) and use_gpu
444-
else (test_samples.cuda() if use_gpu else test_samples)
445-
),
446-
test_labels.cuda() if use_gpu else test_labels,
447-
)
448535
if return_test_data:
449536
if not return_hessian_data:
450537
return (*training_data, *test_data)

0 commit comments

Comments
 (0)