Skip to content

Commit b80e488

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Dataloader Attr (#1390)
Summary: Pull Request resolved: #1390 Initial work on fixing Pyre errors in Dataloader Attr Reviewed By: jsawruk Differential Revision: D64677336 fbshipit-source-id: 3f2fe24cc21aecbadfe65c7152673e7aad6a3cd0
1 parent d09d90f commit b80e488

File tree

1 file changed

+18
-31
lines changed

1 file changed

+18
-31
lines changed

captum/attr/_core/dataloader_attr.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
from collections import defaultdict
55
from copy import copy
6-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -31,7 +31,7 @@ class InputRole:
3131

3232
# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
3333
# pyre-fixme[2]: Parameter must be annotated.
34-
def _concat_tensors(accum, cur_output, _) -> Tensor:
34+
def _concat_tensors(accum: Optional[Tensor], cur_output: Tensor, _) -> Tensor:
3535
return cur_output if accum is None else torch.cat([accum, cur_output])
3636

3737

@@ -61,14 +61,12 @@ def _create_perturbation_mask(
6161
return perturbation_mask
6262

6363

64-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
6564
def _perturb_inputs(
66-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
67-
inputs: Iterable[Any],
65+
inputs: Iterable[object],
6866
input_roles: Tuple[int],
6967
baselines: Tuple[Union[int, float, Tensor], ...],
7068
perturbation_mask: Tuple[Union[Tensor, None], ...],
71-
) -> Tuple[Any, ...]:
69+
) -> Tuple[object, ...]:
7270
"""
7371
Perturb inputs based on perturbation mask and baselines
7472
"""
@@ -164,6 +162,8 @@ class DataLoaderAttribution(Attribution):
164162
e.g., Precision & Recall.
165163
"""
166164

165+
attr_method: Attribution
166+
167167
def __init__(self, attr_method: Attribution) -> None:
168168
r"""
169169
Args:
@@ -179,7 +179,6 @@ def __init__(self, attr_method: Attribution) -> None:
179179
super().__init__(attr_method.forward_func)
180180

181181
# shallow copy is enough to avoid modifying original instance
182-
# pyre-fixme[4]: Attribute must be annotated.
183182
self.attr_method = copy(attr_method)
184183

185184
self.attr_method.forward_func = self._forward_with_dataloader
@@ -352,27 +351,22 @@ def attribute(
352351
If return_input_shape is False, a single tensor is returned
353352
where each index of the last dimension represents a feature
354353
"""
355-
inputs = next(iter(dataloader))
354+
inputs = cast(Union[Tensor, Tuple[Tensor, ...]], next(iter(dataloader)))
356355
is_inputs_tuple = True
357356

357+
inputs_tuple: Tuple[Tensor, ...]
358358
if type(inputs) is list:
359359
# support list as it is a common return type for dataloader in torch
360-
inputs = tuple(inputs)
360+
inputs_tuple = tuple(inputs)
361361
elif type(inputs) is not tuple:
362362
is_inputs_tuple = False
363-
inputs = _format_tensor_into_tuples(inputs)
363+
inputs_tuple = _format_tensor_into_tuples(inputs)
364364

365365
if input_roles:
366-
# pyre-fixme[6]: For 1st argument expected
367-
# `pyre_extensions.ReadOnly[Sized]` but got
368-
# `Optional[typing.Tuple[typing.Any, ...]]`.
369-
assert len(input_roles) == len(inputs), (
366+
assert len(input_roles) == len(inputs_tuple), (
370367
"input_roles must have the same size as the return of the dataloader,",
371368
f"length of input_roles is {len(input_roles)} ",
372-
# pyre-fixme[6]: For 1st argument expected
373-
# `pyre_extensions.ReadOnly[Sized]` but got
374-
# `Optional[typing.Tuple[typing.Any, ...]]`.
375-
f"whereas the length of dataloader return is {len(inputs)}",
369+
f"whereas the length of dataloader return is {len(inputs_tuple)}",
376370
)
377371

378372
assert any(role == InputRole.need_attr for role in input_roles), (
@@ -381,14 +375,11 @@ def attribute(
381375
)
382376
else:
383377
# by default, assume every element in the dataloader needs attribution
384-
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
385-
input_roles = tuple(InputRole.need_attr for _ in inputs)
378+
input_roles = tuple(InputRole.need_attr for _ in inputs_tuple)
386379

387380
attr_inputs = tuple(
388381
inp
389-
# pyre-fixme[6]: For 2nd argument expected `Iterable[Variable[_T2]]` but
390-
# got `Optional[typing.Tuple[typing.Any, ...]]`.
391-
for role, inp in zip(input_roles, inputs)
382+
for role, inp in zip(input_roles, inputs_tuple)
392383
if role == InputRole.need_attr
393384
)
394385

@@ -398,10 +389,8 @@ def attribute(
398389
"Baselines must have the same size as the return of the dataloader ",
399390
"that need attribution",
400391
f"length of baseline is {len(baselines)} ",
401-
# pyre-fixme[6]: For 1st argument expected
402-
# `pyre_extensions.ReadOnly[Sized]` but got
403-
# `Optional[typing.Tuple[typing.Any, ...]]`.
404-
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
392+
'whereas the length of dataloader return with role "0" is',
393+
f" {len(inputs_tuple)}",
405394
)
406395

407396
for i, baseline in enumerate(baselines):
@@ -419,10 +408,8 @@ def attribute(
419408
"Feature mask must have the same size as the return of the dataloader ",
420409
"that need attribution",
421410
f"length of feature_mask is {len(feature_mask)} ",
422-
# pyre-fixme[6]: For 1st argument expected
423-
# `pyre_extensions.ReadOnly[Sized]` but got
424-
# `Optional[typing.Tuple[typing.Any, ...]]`.
425-
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
411+
'whereas the length of dataloader return with role "0"',
412+
f" is {len(inputs_tuple)}",
426413
)
427414

428415
for i, each_mask in enumerate(feature_mask):

0 commit comments

Comments
 (0)