Skip to content

Commit b964dd8

Browse files
cicichen01facebook-github-bot
authored andcommitted
Specify use_cpu as bool for datasets (#1259)
Summary: As titled. use_cpu is supposed to be a bool for datasets. However, it was combined with gpu settings with more informations inside. We separate the variable to clarify the logic. This change also removes the duplicated data adjust for gpu. Differential Revision: D55167566
1 parent c49d2fb commit b964dd8

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tests/influence/_utils/common.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self,
4646
samples: Tensor,
4747
labels: Tensor,
48-
use_gpu=False,
48+
use_gpu: bool = False,
4949
) -> None:
5050
self.samples, self.labels = samples, labels
5151
if use_gpu:
@@ -64,7 +64,7 @@ def __init__(
6464
self,
6565
samples: List[Tensor],
6666
labels: Tensor,
67-
use_gpu=False,
67+
use_gpu: bool = False,
6868
) -> None:
6969
self.samples, self.labels = samples, labels
7070
if use_gpu:
@@ -83,7 +83,11 @@ def __getitem__(self, idx):
8383

8484

8585
class IdentityDataset(ExplicitDataset):
86-
def __init__(self, num_features, use_gpu=False) -> None:
86+
def __init__(
87+
self,
88+
num_features: int,
89+
use_gpu: bool = False,
90+
) -> None:
8791
self.samples = torch.diag(torch.ones(num_features))
8892
self.labels = torch.zeros(num_features).unsqueeze(1)
8993
if use_gpu:
@@ -92,7 +96,13 @@ def __init__(self, num_features, use_gpu=False) -> None:
9296

9397

9498
class RangeDataset(ExplicitDataset):
95-
def __init__(self, low, high, num_features, use_gpu=False) -> None:
99+
def __init__(
100+
self,
101+
low: int,
102+
high: int,
103+
num_features: int,
104+
use_gpu: bool = False,
105+
) -> None:
96106
self.samples = (
97107
torch.arange(start=low, end=high, dtype=torch.float)
98108
.repeat(num_features, 1)
@@ -105,7 +115,7 @@ def __init__(self, low, high, num_features, use_gpu=False) -> None:
105115

106116

107117
class BinaryDataset(ExplicitDataset):
108-
def __init__(self, use_gpu=False) -> None:
118+
def __init__(self, use_gpu: bool = False) -> None:
109119
self.samples = F.normalize(
110120
torch.stack(
111121
(
@@ -249,16 +259,11 @@ def get_random_data(
249259
num_inputs = 2 if unpack_inputs else 1
250260

251261
labels = torch.normal(1, 2, (num_examples, out_features)).double()
252-
labels = labels.cuda() if use_gpu else labels
253-
254262
all_samples = [
255263
torch.normal(0, 1, (num_examples, in_features)).double()
256264
for _ in range(num_inputs)
257265
]
258-
if use_gpu:
259-
all_samples = _move_sample_list_to_cuda(all_samples)
260266

261-
# TODO: no need to pass use_gpu since the data format has already been moved to cuda
262267
train_dataset = (
263268
UnpackDataset(
264269
[samples[:num_train] for samples in all_samples],

0 commit comments

Comments
 (0)