Skip to content

Commit 8351b3f

Browse files
authored
Merge branch 'master' into export-D64624875
2 parents e4e23f1 + b80e488 commit 8351b3f

19 files changed

+316
-621
lines changed

captum/_utils/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
8686
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8787

8888

89+
@typing.overload
90+
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
91+
92+
8993
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
9094
return isinstance(inputs, tuple)
9195

captum/_utils/typing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,11 @@
22

33
# pyre-strict
44

5-
from typing import (
6-
List,
7-
Optional,
8-
overload,
9-
Protocol,
10-
Tuple,
11-
TYPE_CHECKING,
12-
TypeVar,
13-
Union,
14-
)
5+
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
156

167
from torch import Tensor
178
from torch.nn import Module
189

19-
if TYPE_CHECKING:
20-
from typing import Literal
21-
else:
22-
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}
23-
2410
TensorOrTupleOfTensorsGeneric = TypeVar(
2511
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
2612
)

captum/attr/_core/dataloader_attr.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
from collections import defaultdict
55
from copy import copy
6-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -31,7 +31,7 @@ class InputRole:
3131

3232
# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
3333
# pyre-fixme[2]: Parameter must be annotated.
34-
def _concat_tensors(accum, cur_output, _) -> Tensor:
34+
def _concat_tensors(accum: Optional[Tensor], cur_output: Tensor, _) -> Tensor:
3535
return cur_output if accum is None else torch.cat([accum, cur_output])
3636

3737

@@ -61,14 +61,12 @@ def _create_perturbation_mask(
6161
return perturbation_mask
6262

6363

64-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
6564
def _perturb_inputs(
66-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
67-
inputs: Iterable[Any],
65+
inputs: Iterable[object],
6866
input_roles: Tuple[int],
6967
baselines: Tuple[Union[int, float, Tensor], ...],
7068
perturbation_mask: Tuple[Union[Tensor, None], ...],
71-
) -> Tuple[Any, ...]:
69+
) -> Tuple[object, ...]:
7270
"""
7371
Perturb inputs based on perturbation mask and baselines
7472
"""
@@ -164,6 +162,8 @@ class DataLoaderAttribution(Attribution):
164162
e.g., Precision & Recall.
165163
"""
166164

165+
attr_method: Attribution
166+
167167
def __init__(self, attr_method: Attribution) -> None:
168168
r"""
169169
Args:
@@ -179,7 +179,6 @@ def __init__(self, attr_method: Attribution) -> None:
179179
super().__init__(attr_method.forward_func)
180180

181181
# shallow copy is enough to avoid modifying original instance
182-
# pyre-fixme[4]: Attribute must be annotated.
183182
self.attr_method = copy(attr_method)
184183

185184
self.attr_method.forward_func = self._forward_with_dataloader
@@ -352,27 +351,22 @@ def attribute(
352351
If return_input_shape is False, a single tensor is returned
353352
where each index of the last dimension represents a feature
354353
"""
355-
inputs = next(iter(dataloader))
354+
inputs = cast(Union[Tensor, Tuple[Tensor, ...]], next(iter(dataloader)))
356355
is_inputs_tuple = True
357356

357+
inputs_tuple: Tuple[Tensor, ...]
358358
if type(inputs) is list:
359359
# support list as it is a common return type for dataloader in torch
360-
inputs = tuple(inputs)
360+
inputs_tuple = tuple(inputs)
361361
elif type(inputs) is not tuple:
362362
is_inputs_tuple = False
363-
inputs = _format_tensor_into_tuples(inputs)
363+
inputs_tuple = _format_tensor_into_tuples(inputs)
364364

365365
if input_roles:
366-
# pyre-fixme[6]: For 1st argument expected
367-
# `pyre_extensions.ReadOnly[Sized]` but got
368-
# `Optional[typing.Tuple[typing.Any, ...]]`.
369-
assert len(input_roles) == len(inputs), (
366+
assert len(input_roles) == len(inputs_tuple), (
370367
"input_roles must have the same size as the return of the dataloader,",
371368
f"length of input_roles is {len(input_roles)} ",
372-
# pyre-fixme[6]: For 1st argument expected
373-
# `pyre_extensions.ReadOnly[Sized]` but got
374-
# `Optional[typing.Tuple[typing.Any, ...]]`.
375-
f"whereas the length of dataloader return is {len(inputs)}",
369+
f"whereas the length of dataloader return is {len(inputs_tuple)}",
376370
)
377371

378372
assert any(role == InputRole.need_attr for role in input_roles), (
@@ -381,14 +375,11 @@ def attribute(
381375
)
382376
else:
383377
# by default, assume every element in the dataloader needs attribution
384-
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
385-
input_roles = tuple(InputRole.need_attr for _ in inputs)
378+
input_roles = tuple(InputRole.need_attr for _ in inputs_tuple)
386379

387380
attr_inputs = tuple(
388381
inp
389-
# pyre-fixme[6]: For 2nd argument expected `Iterable[Variable[_T2]]` but
390-
# got `Optional[typing.Tuple[typing.Any, ...]]`.
391-
for role, inp in zip(input_roles, inputs)
382+
for role, inp in zip(input_roles, inputs_tuple)
392383
if role == InputRole.need_attr
393384
)
394385

@@ -398,10 +389,8 @@ def attribute(
398389
"Baselines must have the same size as the return of the dataloader ",
399390
"that need attribution",
400391
f"length of baseline is {len(baselines)} ",
401-
# pyre-fixme[6]: For 1st argument expected
402-
# `pyre_extensions.ReadOnly[Sized]` but got
403-
# `Optional[typing.Tuple[typing.Any, ...]]`.
404-
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
392+
'whereas the length of dataloader return with role "0" is',
393+
f" {len(inputs_tuple)}",
405394
)
406395

407396
for i, baseline in enumerate(baselines):
@@ -419,10 +408,8 @@ def attribute(
419408
"Feature mask must have the same size as the return of the dataloader ",
420409
"that need attribution",
421410
f"length of feature_mask is {len(feature_mask)} ",
422-
# pyre-fixme[6]: For 1st argument expected
423-
# `pyre_extensions.ReadOnly[Sized]` but got
424-
# `Optional[typing.Tuple[typing.Any, ...]]`.
425-
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
411+
'whereas the length of dataloader return with role "0"',
412+
f" is {len(inputs_tuple)}",
426413
)
427414

428415
for i, each_mask in enumerate(feature_mask):

0 commit comments

Comments
 (0)