Skip to content

Commit 774c422

Browse files
yucufacebook-github-bot
authored andcommitted
Enable pyre for Captum open source part- 2/2 (#1319)
Summary: Pull Request resolved: #1319 Differential Revision: D60837583
1 parent 2434cf7 commit 774c422

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1577
-33
lines changed

captum/_utils/av.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
import glob
46
import os
57
import re
@@ -66,12 +68,14 @@ def __init__(
6668
which the activation vectors are computed
6769
"""
6870

71+
# pyre-fixme[4]: Attribute must be annotated.
6972
self.av_filesearch = AV._construct_file_search(
7073
path, model_id, identifier, layer, num_id
7174
)
7275

7376
files = glob.glob(self.av_filesearch)
7477

78+
# pyre-fixme[4]: Attribute must be annotated.
7579
self.files = AV.sort_files(files)
7680

7781
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]:
@@ -346,6 +350,7 @@ def _compute_and_save_activations(
346350
inputs: Union[Tensor, Tuple[Tensor, ...]],
347351
identifier: str,
348352
num_id: str,
353+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
349354
additional_forward_args: Any = None,
350355
load_from_disk: bool = True,
351356
) -> None:
@@ -395,6 +400,8 @@ def _compute_and_save_activations(
395400
AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id)
396401

397402
@staticmethod
403+
# pyre-fixme[3]: Return annotation cannot be `Any`.
404+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
398405
def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any:
399406
r"""
400407
Helper to extract input from labels when getting items from a Dataset. Assumes
@@ -490,6 +497,8 @@ def sort_files(files: List[str]) -> List[str]:
490497
lexigraphical sort.
491498
"""
492499

500+
# pyre-fixme[3]: Return type must be annotated.
501+
# pyre-fixme[2]: Parameter must be annotated.
493502
def split_alphanum(s):
494503
r"""
495504
Splits string into a list of strings and numbers

captum/_utils/common.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
3+
# pyre-strict
24
import typing
35
from enum import Enum
46
from functools import reduce
@@ -68,10 +70,18 @@ def safe_div(
6870

6971

7072
@typing.overload
73+
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
74+
# is incompatible with the return type of the implementation (`bool`).
75+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
76+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7177
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
7278

7379

7480
@typing.overload
81+
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
82+
# is incompatible with the return type of the implementation (`bool`).
83+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
84+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7585
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
7686

7787

@@ -230,6 +240,8 @@ def _format_tensor_into_tuples(
230240
return inputs
231241

232242

243+
# pyre-fixme[3]: Return annotation cannot be `Any`.
244+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
233245
def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
234246
return (
235247
inputs
@@ -257,16 +269,21 @@ def _format_additional_forward_args(additional_forward_args: None) -> None: ...
257269

258270
@overload
259271
def _format_additional_forward_args(
272+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
260273
additional_forward_args: Union[Tensor, Tuple]
274+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
261275
) -> Tuple: ...
262276

263277

264278
@overload
265279
def _format_additional_forward_args(
280+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
266281
additional_forward_args: Any,
282+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
267283
) -> Union[None, Tuple]: ...
268284

269285

286+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
270287
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
271288
if additional_forward_args is not None and not isinstance(
272289
additional_forward_args, tuple
@@ -276,9 +293,11 @@ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None,
276293

277294

278295
def _expand_additional_forward_args(
296+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
279297
additional_forward_args: Any,
280298
n_steps: int,
281299
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
300+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
282301
) -> Union[None, Tuple]:
283302
def _expand_tensor_forward_arg(
284303
additional_forward_arg: Tensor,
@@ -343,9 +362,12 @@ def _expand_target(
343362
return target
344363

345364

365+
# pyre-fixme[3]: Return type must be annotated.
346366
def _expand_feature_mask(
347367
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
348368
):
369+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
370+
# typing.Tuple[Tensor, ...]]`.
349371
is_feature_mask_tuple = _is_tuple(feature_mask)
350372
feature_mask = _format_tensor_into_tuples(feature_mask)
351373
feature_mask_new = tuple(
@@ -359,12 +381,17 @@ def _expand_feature_mask(
359381
return _format_output(is_feature_mask_tuple, feature_mask_new)
360382

361383

384+
# pyre-fixme[3]: Return type must be annotated.
362385
def _expand_and_update_baselines(
363386
inputs: Tuple[Tensor, ...],
364387
n_samples: int,
388+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
389+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
365390
kwargs: dict,
366391
draw_baseline_from_distrib: bool = False,
367392
):
393+
# pyre-fixme[3]: Return type must be annotated.
394+
# pyre-fixme[2]: Parameter must be annotated.
368395
def get_random_baseline_indices(bsz, baseline):
369396
num_ref_samples = baseline.shape[0]
370397
return np.random.choice(num_ref_samples, n_samples * bsz).tolist()
@@ -404,6 +431,9 @@ def get_random_baseline_indices(bsz, baseline):
404431
kwargs["baselines"] = baselines
405432

406433

434+
# pyre-fixme[3]: Return type must be annotated.
435+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
436+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
407437
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
408438
if "additional_forward_args" not in kwargs:
409439
return
@@ -420,6 +450,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
420450
kwargs["additional_forward_args"] = additional_forward_args
421451

422452

453+
# pyre-fixme[3]: Return type must be annotated.
454+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
455+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
423456
def _expand_and_update_target(n_samples: int, kwargs: dict):
424457
if "target" not in kwargs:
425458
return
@@ -431,6 +464,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
431464
kwargs["target"] = target
432465

433466

467+
# pyre-fixme[3]: Return type must be annotated.
468+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
469+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
434470
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
435471
if "feature_mask" not in kwargs:
436472
return
@@ -444,14 +480,24 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
444480

445481

446482
@typing.overload
483+
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
484+
# possible arguments of overload defined on line `449`.
447485
def _format_output(
448-
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
486+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
487+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
488+
is_inputs_tuple: Literal[True],
489+
output: Tuple[Tensor, ...],
449490
) -> Tuple[Tensor, ...]: ...
450491

451492

452493
@typing.overload
494+
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
495+
# possible arguments of overload defined on line `455`.
453496
def _format_output(
454-
is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...]
497+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
498+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
499+
is_inputs_tuple: Literal[False],
500+
output: Tuple[Tensor, ...],
455501
) -> Tensor: ...
456502

457503

@@ -474,18 +520,30 @@ def _format_output(
474520
"The input is a single tensor however the output isn't."
475521
"The number of output tensors is: {}".format(len(output))
476522
)
523+
# pyre-fixme[7]: Expected `Union[Tensor, typing.Tuple[Tensor, ...]]` but got
524+
# `Union[tuple[Tensor], Tensor]`.
477525
return output if is_inputs_tuple else output[0]
478526

479527

480528
@typing.overload
529+
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
530+
# possible arguments of overload defined on line `483`.
481531
def _format_outputs(
482-
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
532+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
533+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
534+
is_multiple_inputs: Literal[False],
535+
outputs: List[Tuple[Tensor, ...]],
483536
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
484537

485538

486539
@typing.overload
540+
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
541+
# possible arguments of overload defined on line `489`.
487542
def _format_outputs(
488-
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
543+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
544+
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
545+
is_multiple_inputs: Literal[True],
546+
outputs: List[Tuple[Tensor, ...]],
489547
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...
490548

491549

@@ -512,9 +570,12 @@ def _format_outputs(
512570

513571

514572
def _run_forward(
573+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
515574
forward_func: Callable,
575+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
516576
inputs: Any,
517577
target: TargetType = None,
578+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
518579
additional_forward_args: Any = None,
519580
) -> Union[Tensor, Future[Tensor]]:
520581
forward_func_args = signature(forward_func).parameters
@@ -529,6 +590,8 @@ def _run_forward(
529590

530591
output = forward_func(
531592
*(
593+
# pyre-fixme[60]: Concatenation not yet support for multiple variadic
594+
# tuples: `*inputs, *additional_forward_args`.
532595
(*inputs, *additional_forward_args)
533596
if additional_forward_args is not None
534597
else inputs
@@ -606,6 +669,8 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
606669
elif isinstance(target[0], tuple):
607670
return torch.stack(
608671
[
672+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
673+
# parameter.
609674
output[(i,) + cast(Tuple, targ_elem)]
610675
for i, targ_elem in enumerate(target)
611676
]
@@ -639,9 +704,11 @@ def _verify_select_column(
639704

640705
def _verify_select_neuron(
641706
layer_output: Tuple[Tensor, ...],
707+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
642708
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
643709
) -> Tensor:
644710
if callable(selector):
711+
# pyre-fixme[7]: Expected `Tensor` but got `object`.
645712
return selector(layer_output if len(layer_output) > 1 else layer_output[0])
646713

647714
assert len(layer_output) == 1, (
@@ -688,6 +755,9 @@ def _extract_device(
688755

689756
def _reduce_list(
690757
val_list: Sequence[TupleOrTensorOrBoolGeneric],
758+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
759+
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
760+
# `typing.List[<element type>]` to avoid runtime subscripting errors.
691761
red_func: Callable[[List], Any] = torch.cat,
692762
) -> TupleOrTensorOrBoolGeneric:
693763
"""
@@ -702,21 +772,28 @@ def _reduce_list(
702772
"""
703773
assert len(val_list) > 0, "Cannot reduce empty list!"
704774
if isinstance(val_list[0], torch.Tensor):
775+
# pyre-fixme[16]: `bool` has no attribute `device`.
705776
first_device = val_list[0].device
777+
# pyre-fixme[16]: `bool` has no attribute `to`.
706778
return red_func([elem.to(first_device) for elem in val_list])
707779
elif isinstance(val_list[0], bool):
780+
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
708781
return any(val_list)
709782
elif isinstance(val_list[0], tuple):
710783
final_out = []
784+
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
785+
# but got `TupleOrTensorOrBoolGeneric`.
711786
for i in range(len(val_list[0])):
712787
final_out.append(
788+
# pyre-fixme[16]: `bool` has no attribute `__getitem__`.
713789
_reduce_list([val_elem[i] for val_elem in val_list], red_func)
714790
)
715791
else:
716792
raise AssertionError(
717793
"Elements to be reduced can only be"
718794
"either Tensors or tuples containing Tensors."
719795
)
796+
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`.
720797
return tuple(final_out)
721798

722799

@@ -756,6 +833,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
756833
return torch.cat([single_inp.flatten() for single_inp in inp])
757834

758835

836+
# pyre-fixme[3]: Return annotation cannot be `Any`.
759837
def _get_module_from_name(model: Module, layer_name: str) -> Any:
760838
r"""
761839
Returns the module (layer) object, given its (string) name
@@ -772,7 +850,11 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any:
772850

773851

774852
def _register_backward_hook(
775-
module: Module, hook: Callable, attr_obj: Any
853+
module: Module,
854+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
855+
hook: Callable,
856+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
857+
attr_obj: Any,
776858
) -> List[torch.utils.hooks.RemovableHandle]:
777859
grad_out: Dict[device, Tensor] = {}
778860

@@ -784,6 +866,7 @@ def forward_hook(
784866
nonlocal grad_out
785867
grad_out = {}
786868

869+
# pyre-fixme[53]: Captured variable `grad_out` is not annotated.
787870
def output_tensor_hook(output_grad: Tensor) -> None:
788871
grad_out[output_grad.device] = output_grad
789872

@@ -795,7 +878,11 @@ def output_tensor_hook(output_grad: Tensor) -> None:
795878
else:
796879
out.register_hook(output_tensor_hook)
797880

881+
# pyre-fixme[3]: Return type must be annotated.
882+
# pyre-fixme[2]: Parameter must be annotated.
798883
def pre_hook(module, inp):
884+
# pyre-fixme[53]: Captured variable `module` is not annotated.
885+
# pyre-fixme[3]: Return type must be annotated.
799886
def input_tensor_hook(input_grad: Tensor):
800887
if len(grad_out) == 0:
801888
return
@@ -820,6 +907,7 @@ def input_tensor_hook(input_grad: Tensor):
820907
]
821908

822909

910+
# pyre-fixme[3]: Return type must be annotated.
823911
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]):
824912
"""
825913
Returns the max feature mask index

0 commit comments

Comments
 (0)