Skip to content

Commit f94d3ee

Browse files
cicichen01facebook-github-bot
authored andcommitted
Separate the data generation logic (#1258)
Summary: Pull Request resolved: #1258 As titled. This is part of get_random_model_and_data() method simplification. Reviewed By: cyrjano Differential Revision: D55224026 fbshipit-source-id: 27a45efb8a534d92ee6191bf4d25f292cca84d4f
1 parent f8dc1b7 commit f94d3ee

File tree

1 file changed

+95
-72
lines changed

1 file changed

+95
-72
lines changed

tests/influence/_utils/common.py

Lines changed: 95 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,20 @@ def _wrap_model_in_dataparallel(net):
3636
return torch.nn.DataParallel(net, device_ids=alt_device_ids)
3737

3838

39-
def _move_sample_to_cuda(samples):
39+
def _move_sample_list_to_cuda(samples: List[Tensor]) -> List[Tensor]:
4040
return [s.cuda() for s in samples]
4141

4242

4343
class ExplicitDataset(Dataset):
44-
def __init__(self, samples, labels, use_gpu=False) -> None:
44+
def __init__(
45+
self,
46+
samples: Tensor,
47+
labels: Tensor,
48+
use_gpu=False,
49+
) -> None:
4550
self.samples, self.labels = samples, labels
4651
if use_gpu:
47-
self.samples = (
48-
_move_sample_to_cuda(self.samples)
49-
if isinstance(self.samples, list)
50-
else self.samples.cuda()
51-
)
52+
self.samples = self.samples.cuda()
5253
self.labels = self.labels.cuda()
5354

5455
def __len__(self) -> int:
@@ -59,14 +60,15 @@ def __getitem__(self, idx):
5960

6061

6162
class UnpackDataset(Dataset):
62-
def __init__(self, samples, labels, use_gpu=False) -> None:
63+
def __init__(
64+
self,
65+
samples: List[Tensor],
66+
labels: Tensor,
67+
use_gpu=False,
68+
) -> None:
6369
self.samples, self.labels = samples, labels
6470
if use_gpu:
65-
self.samples = (
66-
_move_sample_to_cuda(self.samples)
67-
if isinstance(self.samples, list)
68-
else self.samples.cuda()
69-
)
71+
self.samples = _move_sample_list_to_cuda(self.samples)
7072
self.labels = self.labels.cuda()
7173

7274
def __len__(self) -> int:
@@ -225,6 +227,72 @@ def forward(self, *inputs: Tensor) -> Tensor:
225227
return self.linear(torch.cat(inputs, dim=1))
226228

227229

230+
def get_random_data(
231+
in_features: int,
232+
out_features: int,
233+
num_examples: int,
234+
use_gpu: bool,
235+
unpack_inputs: bool,
236+
) -> Tuple[Dataset, Dataset, Dataset]:
237+
"""
238+
returns train_dataset, test_dataset and hessian_dataset constructed from
239+
random labels and random features, with features having shape
240+
[num_examples x num_features] and labels having shape [num_examples].
241+
242+
Note: the random labels and features for different dataset needs to be
243+
generated together.
244+
Otherwise, some tests will fail (https://fburl.com/testinfra/737jnpip)
245+
"""
246+
247+
num_train = 32
248+
num_hessian = 22 # this needs to be high to prevent numerical issues
249+
num_inputs = 2 if unpack_inputs else 1
250+
251+
labels = torch.normal(1, 2, (num_examples, out_features)).double()
252+
labels = labels.cuda() if use_gpu else labels
253+
254+
all_samples = [
255+
torch.normal(0, 1, (num_examples, in_features)).double()
256+
for _ in range(num_inputs)
257+
]
258+
if use_gpu:
259+
all_samples = _move_sample_list_to_cuda(all_samples)
260+
261+
# TODO: no need to pass use_gpu since the data format has already been moved to cuda
262+
train_dataset = (
263+
UnpackDataset(
264+
[samples[:num_train] for samples in all_samples],
265+
labels[:num_train],
266+
use_gpu,
267+
)
268+
if unpack_inputs
269+
else ExplicitDataset(all_samples[0][:num_train], labels[:num_train], use_gpu)
270+
)
271+
272+
hessian_dataset = (
273+
UnpackDataset(
274+
[samples[:num_hessian] for samples in all_samples],
275+
labels[:num_hessian],
276+
use_gpu,
277+
)
278+
if unpack_inputs
279+
else ExplicitDataset(
280+
all_samples[0][:num_hessian], labels[:num_hessian], use_gpu
281+
)
282+
)
283+
284+
test_dataset = (
285+
UnpackDataset(
286+
[samples[num_train:] for samples in all_samples],
287+
labels[num_train:],
288+
use_gpu,
289+
)
290+
if unpack_inputs
291+
else ExplicitDataset(all_samples[0][num_train:], labels[num_train:], use_gpu)
292+
)
293+
return (train_dataset, hessian_dataset, test_dataset)
294+
295+
228296
def get_random_model_and_data(
229297
tmpdir,
230298
unpack_inputs,
@@ -283,38 +351,9 @@ def get_random_model_and_data(
283351
out_features = 3
284352

285353
num_samples = 50
286-
num_train = 32
287-
num_hessian = 22 # this needs to be high to prevent numerical issues
288-
all_labels = torch.normal(1, 2, (num_samples, out_features)).double()
289-
all_labels = all_labels.cuda() if use_gpu else all_labels
290-
train_labels = all_labels[:num_train]
291-
test_labels = all_labels[num_train:]
292-
hessian_labels = all_labels[:num_hessian]
293-
294-
if unpack_inputs:
295-
all_samples = [
296-
torch.normal(0, 1, (num_samples, in_features)).double()
297-
for _ in range(num_inputs)
298-
]
299-
if use_gpu:
300-
all_samples = _move_sample_to_cuda(all_samples)
301354

302-
train_samples = [ts[:num_train] for ts in all_samples]
303-
test_samples = [ts[num_train:] for ts in all_samples]
304-
hessian_samples = [ts[:num_hessian] for ts in all_samples]
305-
else:
306-
all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
307-
308-
if use_gpu:
309-
all_samples = all_samples.cuda()
310-
train_samples = all_samples[:num_train]
311-
test_samples = all_samples[num_train:]
312-
hessian_samples = all_samples[:num_hessian]
313-
314-
dataset = (
315-
ExplicitDataset(train_samples, train_labels, use_gpu)
316-
if not unpack_inputs
317-
else UnpackDataset(train_samples, train_labels, use_gpu)
355+
train_dataset, hessian_dataset, test_dataset = get_random_data(
356+
in_features, out_features, num_samples, use_gpu, unpack_inputs
318357
)
319358

320359
if model_type == "random":
@@ -358,15 +397,19 @@ def get_random_model_and_data(
358397

359398
# turn input into a single tensor for use by least squares
360399
tensor_hessian_samples = (
361-
hessian_samples if not unpack_inputs else torch.cat(hessian_samples, dim=1)
400+
hessian_dataset.samples
401+
if not unpack_inputs
402+
else torch.cat(hessian_dataset.samples, dim=1)
362403
)
363404
version = _parse_version(torch.__version__)
364405
if version < (1, 9):
365-
theta = torch.lstsq(tensor_hessian_samples, hessian_labels).solution[0:1]
406+
theta = torch.lstsq(
407+
tensor_hessian_samples, hessian_dataset.labels
408+
).solution[0:1]
366409
else:
367410
# run least squares to get optimal trained parameters
368411
theta = torch.linalg.lstsq(
369-
hessian_labels,
412+
hessian_dataset.labels,
370413
tensor_hessian_samples,
371414
).solution
372415
# the first `n` rows of `theta` contains the least squares solution, where
@@ -397,14 +440,7 @@ def get_random_model_and_data(
397440
)
398441

399442
# train model using several optimization steps on Hessian data
400-
401-
# create entire Hessian data as a batch
402-
hessian_dataset = (
403-
ExplicitDataset(hessian_samples, hessian_labels, use_gpu)
404-
if not unpack_inputs
405-
else UnpackDataset(hessian_samples, hessian_labels, use_gpu)
406-
)
407-
batch = next(iter(DataLoader(hessian_dataset, batch_size=num_hessian)))
443+
batch = next(iter(DataLoader(hessian_dataset, batch_size=len(hessian_dataset))))
408444

409445
optimizer = torch.optim.Adam(net.parameters())
410446
num_steps = 200
@@ -425,26 +461,13 @@ def get_random_model_and_data(
425461

426462
training_data = (
427463
net_adjusted,
428-
dataset,
464+
train_dataset,
429465
)
430466

431-
hessian_data = (
432-
(
433-
_move_sample_to_cuda(hessian_samples)
434-
if isinstance(hessian_samples, list) and use_gpu
435-
else (hessian_samples.cuda() if use_gpu else hessian_samples)
436-
),
437-
hessian_labels.cuda() if use_gpu else hessian_labels,
438-
)
467+
hessian_data = (hessian_dataset.samples, hessian_dataset.labels)
468+
469+
test_data = (test_dataset.samples, test_dataset.labels)
439470

440-
test_data = (
441-
(
442-
_move_sample_to_cuda(test_samples)
443-
if isinstance(test_samples, list) and use_gpu
444-
else (test_samples.cuda() if use_gpu else test_samples)
445-
),
446-
test_labels.cuda() if use_gpu else test_labels,
447-
)
448471
if return_test_data:
449472
if not return_hessian_data:
450473
return (*training_data, *test_data)

0 commit comments

Comments
 (0)