Skip to content

Commit 2b9f4ae

Browse files
jsawrukfacebook-github-bot
authored andcommitted
Clean up pyre issues in tests/helpers/influence/common.py (#1455)
Summary: Pull Request resolved: #1455 Fix some pyre issues in pytorch/captum/tests/helpers/influence/common.py Reviewed By: cyrjano Differential Revision: D66902046 fbshipit-source-id: 8b74f638a2060330e8665ff6103a6f255e1a205e
1 parent 2a2b41d commit 2b9f4ae

File tree

1 file changed

+18
-31
lines changed

1 file changed

+18
-31
lines changed

tests/helpers/influence/common.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,16 @@
2323
from torch.utils.data import DataLoader, Dataset
2424

2525

26-
# pyre-fixme[3]: Return type must be annotated.
2726
# pyre-fixme[2]: Parameter must be annotated.
28-
def _isSorted(x, key=lambda x: x, descending=True):
27+
def _isSorted(x, key=lambda x: x, descending=True) -> bool:
2928
if descending:
30-
return all([key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1)])
29+
return all(key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1))
3130
else:
32-
return all([key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1)])
31+
return all(key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1))
3332

3433

35-
# pyre-fixme[3]: Return type must be annotated.
3634
# pyre-fixme[2]: Parameter must be annotated.
37-
def _wrap_model_in_dataparallel(net):
35+
def _wrap_model_in_dataparallel(net) -> Module:
3836
alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)]
3937
net = net.cuda()
4038
return torch.nn.DataParallel(net, device_ids=alt_device_ids)
@@ -60,9 +58,7 @@ def __init__(
6058
def __len__(self) -> int:
6159
return len(self.samples)
6260

63-
# pyre-fixme[3]: Return type must be annotated.
64-
# pyre-fixme[2]: Parameter must be annotated.
65-
def __getitem__(self, idx):
61+
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
6662
return (self.samples[idx], self.labels[idx])
6763

6864

@@ -83,8 +79,7 @@ def __len__(self) -> int:
8379
return len(self.samples[0])
8480

8581
# pyre-fixme[3]: Return type must be annotated.
86-
# pyre-fixme[2]: Parameter must be annotated.
87-
def __getitem__(self, idx):
82+
def __getitem__(self, idx: int):
8883
"""
8984
The signature of the returning item is: List[List], where the contents
9085
are: [sample_0, sample_1, ...] + [labels] (two lists concacenated).
@@ -98,10 +93,8 @@ def __init__(
9893
num_features: int,
9994
use_gpu: bool = False,
10095
) -> None:
101-
# pyre-fixme[4]: Attribute must be annotated.
102-
self.samples = torch.diag(torch.ones(num_features))
103-
# pyre-fixme[4]: Attribute must be annotated.
104-
self.labels = torch.zeros(num_features).unsqueeze(1)
96+
self.samples: Tensor = torch.diag(torch.ones(num_features))
97+
self.labels: Tensor = torch.zeros(num_features).unsqueeze(1)
10598
if use_gpu:
10699
self.samples = self.samples.cuda()
107100
self.labels = self.labels.cuda()
@@ -115,23 +108,22 @@ def __init__(
115108
num_features: int,
116109
use_gpu: bool = False,
117110
) -> None:
118-
# pyre-fixme[4]: Attribute must be annotated.
119-
self.samples = (
111+
self.samples: Tensor = (
120112
torch.arange(start=low, end=high, dtype=torch.float)
121113
.repeat(num_features, 1)
122114
.transpose(1, 0)
123115
)
124-
# pyre-fixme[4]: Attribute must be annotated.
125-
self.labels = torch.arange(start=low, end=high, dtype=torch.float).unsqueeze(1)
116+
self.labels: Tensor = torch.arange(
117+
start=low, end=high, dtype=torch.float
118+
).unsqueeze(1)
126119
if use_gpu:
127120
self.samples = self.samples.cuda()
128121
self.labels = self.labels.cuda()
129122

130123

131124
class BinaryDataset(ExplicitDataset):
132125
def __init__(self, use_gpu: bool = False) -> None:
133-
# pyre-fixme[4]: Attribute must be annotated.
134-
self.samples = F.normalize(
126+
self.samples: Tensor = F.normalize(
135127
torch.stack(
136128
(
137129
torch.Tensor([1, 1]),
@@ -161,8 +153,7 @@ def __init__(self, use_gpu: bool = False) -> None:
161153
)
162154
)
163155
)
164-
# pyre-fixme[4]: Attribute must be annotated.
165-
self.labels = torch.cat(
156+
self.labels: Tensor = torch.cat(
166157
(
167158
torch.Tensor([1]).repeat(12, 1),
168159
torch.Tensor([-1]).repeat(12, 1),
@@ -350,13 +341,10 @@ def get_random_model_and_data(
350341
tmpdir,
351342
# pyre-fixme[2]: Parameter must be annotated.
352343
unpack_inputs,
353-
# pyre-fixme[2]: Parameter must be annotated.
354-
return_test_data=True,
344+
return_test_data: bool = True,
355345
gpu_setting: Optional[str] = None,
356-
# pyre-fixme[2]: Parameter must be annotated.
357-
return_hessian_data=False,
358-
# pyre-fixme[2]: Parameter must be annotated.
359-
model_type="random",
346+
return_hessian_data: bool = False,
347+
model_type: str = "random",
360348
):
361349
"""
362350
returns a model, training data, and optionally data for computing the hessian
@@ -534,10 +522,9 @@ def generate_symmetric_matrix_given_eigenvalues(
534522
return torch.matmul(Q, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q.T))
535523

536524

537-
# pyre-fixme[3]: Return type must be annotated.
538525
def generate_assymetric_matrix_given_eigenvalues(
539526
eigenvalues: Union[Tensor, List[float]]
540-
):
527+
) -> Tensor:
541528
"""
542529
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501
543530
generate assymetric random matrix with specified eigenvalues. this is used in

0 commit comments

Comments
 (0)