Skip to content

Commit 0665ea5

Browse files
yucufacebook-github-bot
authored andcommitted
Final round of Pyre enablement (#1330)
Summary: Pull Request resolved: #1330 Primarily use infer to annotate lots of easy case. `pyre -n --target fbcode//pytorch/captum/captum/... infer -i pytorch/captum/captum/` `pyre -n --target fbcode//pytorch/captum/tests/... infer -i pytorch/captum/tests/` Reviewed By: vivekmig Differential Revision: D61053870 fbshipit-source-id: 37f96e7e5590dc57ddbcaa3c1da106149f5e09bf
1 parent 09aa048 commit 0665ea5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+452
-559
lines changed

captum/_utils/common.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,9 @@ def _expand_target(
363363
return target
364364

365365

366-
# pyre-fixme[3]: Return type must be annotated.
367366
def _expand_feature_mask(
368367
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
369-
):
368+
) -> Tuple[Tensor, ...]:
370369
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
371370
# typing.Tuple[Tensor, ...]]`.
372371
is_feature_mask_tuple = _is_tuple(feature_mask)
@@ -379,18 +378,17 @@ def _expand_feature_mask(
379378
)
380379
for feature_mask_elem in feature_mask
381380
)
382-
return _format_output(is_feature_mask_tuple, feature_mask_new)
381+
return _format_output(is_feature_mask_tuple, feature_mask_new) # type: ignore
383382

384383

385-
# pyre-fixme[3]: Return type must be annotated.
386384
def _expand_and_update_baselines(
387385
inputs: Tuple[Tensor, ...],
388386
n_samples: int,
389387
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
390388
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
391389
kwargs: dict,
392390
draw_baseline_from_distrib: bool = False,
393-
):
391+
) -> None:
394392
# pyre-fixme[3]: Return type must be annotated.
395393
# pyre-fixme[2]: Parameter must be annotated.
396394
def get_random_baseline_indices(bsz, baseline):
@@ -432,10 +430,9 @@ def get_random_baseline_indices(bsz, baseline):
432430
kwargs["baselines"] = baselines
433431

434432

435-
# pyre-fixme[3]: Return type must be annotated.
436433
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
437434
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
438-
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
435+
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict) -> None:
439436
if "additional_forward_args" not in kwargs:
440437
return
441438
additional_forward_args = kwargs["additional_forward_args"]
@@ -451,10 +448,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
451448
kwargs["additional_forward_args"] = additional_forward_args
452449

453450

454-
# pyre-fixme[3]: Return type must be annotated.
455451
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
456452
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
457-
def _expand_and_update_target(n_samples: int, kwargs: dict):
453+
def _expand_and_update_target(n_samples: int, kwargs: dict) -> None:
458454
if "target" not in kwargs:
459455
return
460456
target = kwargs["target"]
@@ -465,10 +461,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
465461
kwargs["target"] = target
466462

467463

468-
# pyre-fixme[3]: Return type must be annotated.
469464
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
470465
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
471-
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
466+
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
472467
if "feature_mask" not in kwargs:
473468
return
474469

@@ -573,10 +568,9 @@ def _format_outputs(
573568
# pyre-fixme[24] Callable requires 2 arguments
574569
def _construct_future_forward(original_forward: Callable) -> Callable:
575570
# pyre-fixme[3] return type not specified
576-
# pyre-ignore
577-
def future_forward(*args, **kwargs):
578-
# pyre-ignore
579-
fut = torch.futures.Future()
571+
def future_forward(*args: Any, **kwargs: Any):
572+
# pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
573+
fut: torch.futures.Future[Tensor] = torch.futures.Future()
580574
fut.set_result(original_forward(*args, **kwargs))
581575
return fut
582576

@@ -921,8 +915,7 @@ def input_tensor_hook(input_grad: Tensor):
921915
]
922916

923917

924-
# pyre-fixme[3]: Return type must be annotated.
925-
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]):
918+
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
926919
"""
927920
Returns the max feature mask index
928921
The feature mask should be formatted to tuple of tensors at first.

captum/_utils/gradient.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,11 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
986986
out = loss
987987

988988
sample_grad_wrapper.compute_param_sample_gradients(
989-
out, loss_mode=reduction_type
989+
out,
990+
# pyre-fixme[6]: In call `SampleGradientWrapper.
991+
# compute_param_sample_gradients`, for argument `loss_mode`,
992+
# expected `str` but got `Optional[str]`.
993+
loss_mode=reduction_type, # type: ignore
990994
)
991995
if layer_modules is not None:
992996
layer_parameters = _extract_parameters_from_layers(layer_modules)

captum/_utils/models/linear_model/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self, train_fn: Callable, **kwargs) -> None:
4141
# pyre-fixme[4]: Attribute must be annotated.
4242
self.construct_kwargs = kwargs
4343

44-
# pyre-fixme[3]: Return type must be annotated.
4544
def _construct_model_params(
4645
self,
4746
in_features: Optional[int] = None,
@@ -52,7 +51,7 @@ def _construct_model_params(
5251
weight_values: Optional[Tensor] = None,
5352
bias_value: Optional[Tensor] = None,
5453
classes: Optional[Tensor] = None,
55-
):
54+
) -> None:
5655
r"""
5756
Lazily initializes a linear model. This will be called for you in a
5857
train method.

captum/_utils/models/linear_model/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from torch.utils.data import DataLoader
1010

1111

12-
# pyre-fixme[3]: Return type must be annotated.
1312
# pyre-fixme[2]: Parameter must be annotated.
14-
def l2_loss(x1, x2, weights=None):
13+
def l2_loss(x1, x2, weights=None) -> torch.Tensor:
1514
if weights is None:
1615
return torch.mean((x1 - x2) ** 2) / 2.0
1716
else:
@@ -236,7 +235,7 @@ def get_point(datapoint):
236235

237236
class NormLayer(nn.Module):
238237
# pyre-fixme[2]: Parameter must be annotated.
239-
def __init__(self, mean, std, n=None, eps=1e-8) -> None:
238+
def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None:
240239
super().__init__()
241240
# pyre-fixme[4]: Attribute must be annotated.
242241
self.mean = mean
@@ -251,7 +250,6 @@ def forward(self, x):
251250
return (x - self.mean) / (self.std + self.eps)
252251

253252

254-
# pyre-fixme[3]: Return type must be annotated.
255253
def sklearn_train_linear_model(
256254
model: LinearModel,
257255
dataloader: DataLoader,
@@ -260,7 +258,7 @@ def sklearn_train_linear_model(
260258
norm_input: bool = False,
261259
# pyre-fixme[2]: Parameter must be annotated.
262260
**fit_kwargs,
263-
):
261+
) -> Dict[str, float]:
264262
r"""
265263
Alternative method to train with sklearn. This does introduce some slight
266264
overhead as we convert the tensors to numpy and then convert the resulting

captum/_utils/progress.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import warnings
77
from time import time
8-
from typing import cast, Iterable, Optional, Sized, TextIO
8+
from typing import Any, cast, Iterable, Optional, Sized, TextIO
99

1010
from captum._utils.typing import Literal
1111

@@ -61,15 +61,17 @@ class NullProgress:
6161
progress bars.
6262
"""
6363

64-
# pyre-fixme[3]: Return type must be annotated.
65-
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
66-
# pyre-fixme[2]: Parameter must be annotated.
67-
def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs):
64+
def __init__(
65+
self,
66+
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
67+
iterable: Optional[Iterable] = None,
68+
*args: Any,
69+
**kwargs: Any,
70+
) -> None:
6871
del args, kwargs
6972
self.iterable = iterable
7073

71-
# pyre-fixme[3]: Return type must be annotated.
72-
def __enter__(self):
74+
def __enter__(self) -> "NullProgress":
7375
return self
7476

7577
# pyre-fixme[2]: Parameter must be annotated.
@@ -87,12 +89,10 @@ def __iter__(self):
8789
for it in self.iterable:
8890
yield it
8991

90-
# pyre-fixme[3]: Return type must be annotated.
91-
def update(self, amount: int = 1):
92+
def update(self, amount: int = 1) -> None:
9293
pass
9394

94-
# pyre-fixme[3]: Return type must be annotated.
95-
def close(self):
95+
def close(self) -> None:
9696
pass
9797

9898

@@ -133,8 +133,7 @@ def __init__(
133133
self.closed = False
134134
self._is_parent = False
135135

136-
# pyre-fixme[3]: Return type must be annotated.
137-
def __enter__(self):
136+
def __enter__(self) -> "SimpleProgress":
138137
self._is_parent = True
139138
self._refresh()
140139
return self
@@ -158,8 +157,7 @@ def __iter__(self):
158157
self.update()
159158
self.close()
160159

161-
# pyre-fixme[3]: Return type must be annotated.
162-
def _refresh(self):
160+
def _refresh(self) -> None:
163161
progress_str = self.desc + ": " if self.desc else ""
164162
if self.total:
165163
# e.g., progress: 60% 3/5
@@ -172,8 +170,7 @@ def _refresh(self):
172170
end = "\n" if self._is_parent else ""
173171
print("\r" + progress_str, end=end, file=self.file)
174172

175-
# pyre-fixme[3]: Return type must be annotated.
176-
def update(self, amount: int = 1):
173+
def update(self, amount: int = 1) -> None:
177174
if self.closed:
178175
return
179176
self.cur += amount
@@ -183,8 +180,7 @@ def update(self, amount: int = 1):
183180
self._refresh()
184181
self.last_print_t = cur_t
185182

186-
# pyre-fixme[3]: Return type must be annotated.
187-
def close(self):
183+
def close(self) -> None:
188184
if not self.closed and not self._is_parent:
189185
self._refresh()
190186
print(file=self.file) # end with new line
@@ -197,8 +193,7 @@ def progress(
197193
iterable: Optional[Iterable] = None,
198194
desc: Optional[str] = None,
199195
total: Optional[int] = None,
200-
# pyre-fixme[2]: Parameter must be annotated.
201-
use_tqdm=True,
196+
use_tqdm: bool = True,
202197
file: Optional[TextIO] = None,
203198
mininterval: float = 0.5,
204199
# pyre-fixme[2]: Parameter must be annotated.

captum/_utils/sample_gradient.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class SampleGradientWrapper:
103103
"""
104104

105105
# pyre-fixme[2]: Parameter must be annotated.
106-
def __init__(self, model, layer_modules=None) -> None:
106+
def __init__(self, model, layer_modules: Optional[List[Module]] = None) -> None:
107107
# pyre-fixme[4]: Attribute must be annotated.
108108
self.model = model
109109
self.hooks_added = False
@@ -162,8 +162,9 @@ def _reset(self) -> None:
162162
self.activation_dict = defaultdict(list)
163163
self.gradient_dict = defaultdict(list)
164164

165-
# pyre-fixme[2]: Parameter must be annotated.
166-
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None:
165+
def compute_param_sample_gradients(
166+
self, loss_blob: Tensor, loss_mode: str = "mean"
167+
) -> None:
167168
assert (
168169
loss_mode.upper() in LossMode.__members__
169170
), f"Provided loss mode {loss_mode} is not valid"

captum/attr/_core/dataloader_attr.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ class InputRole:
3030

3131

3232
# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
33-
# pyre-fixme[3]: Return type must be annotated.
3433
# pyre-fixme[2]: Parameter must be annotated.
35-
def _concat_tensors(accum, cur_output, _):
34+
def _concat_tensors(accum, cur_output, _) -> Tensor:
3635
return cur_output if accum is None else torch.cat([accum, cur_output])
3736

3837

@@ -185,7 +184,6 @@ def __init__(self, attr_method: Attribution) -> None:
185184

186185
self.attr_method.forward_func = self._forward_with_dataloader
187186

188-
# pyre-fixme[3]: Return type must be annotated.
189187
def _forward_with_dataloader(
190188
self,
191189
batched_perturbed_feature_indices: Tensor,
@@ -199,7 +197,7 @@ def _forward_with_dataloader(
199197
to_metric: Optional[Callable],
200198
show_progress: bool,
201199
feature_idx_to_mask_idx: Dict[int, List[int]],
202-
):
200+
) -> Tensor:
203201
"""
204202
Wrapper of the original given forward_func to be used in the attribution method
205203
It iterates over the dataloader with the given forward_func
@@ -468,10 +466,8 @@ def attribute(
468466

469467
return _format_output(is_inputs_tuple, attr)
470468

471-
# pyre-fixme[3]: Return type must be annotated.
472-
def attribute_future(
473-
self,
474-
):
469+
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
470+
def attribute_future(self) -> Callable:
475471
r"""
476472
This method is not implemented for DataLoaderAttribution.
477473
"""

0 commit comments

Comments
 (0)