33# pyre-strict
44from collections import defaultdict
55from copy import copy
6- from typing import Any , Callable , cast , Dict , Iterable , List , Optional , Tuple , Union
6+ from typing import Callable , cast , Dict , Iterable , List , Optional , Tuple , Union
77
88import torch
99from captum ._utils .common import (
@@ -31,7 +31,7 @@ class InputRole:
3131
3232# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
3333# pyre-fixme[2]: Parameter must be annotated.
34- def _concat_tensors (accum , cur_output , _ ) -> Tensor :
34+ def _concat_tensors (accum : Optional [ Tensor ] , cur_output : Tensor , _ ) -> Tensor :
3535 return cur_output if accum is None else torch .cat ([accum , cur_output ])
3636
3737
@@ -61,14 +61,12 @@ def _create_perturbation_mask(
6161 return perturbation_mask
6262
6363
64- # pyre-fixme[3]: Return annotation cannot contain `Any`.
6564def _perturb_inputs (
66- # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
67- inputs : Iterable [Any ],
65+ inputs : Iterable [object ],
6866 input_roles : Tuple [int ],
6967 baselines : Tuple [Union [int , float , Tensor ], ...],
7068 perturbation_mask : Tuple [Union [Tensor , None ], ...],
71- ) -> Tuple [Any , ...]:
69+ ) -> Tuple [object , ...]:
7270 """
7371 Perturb inputs based on perturbation mask and baselines
7472 """
@@ -164,6 +162,8 @@ class DataLoaderAttribution(Attribution):
164162 e.g., Precision & Recall.
165163 """
166164
165+ attr_method : Attribution
166+
167167 def __init__ (self , attr_method : Attribution ) -> None :
168168 r"""
169169 Args:
@@ -179,7 +179,6 @@ def __init__(self, attr_method: Attribution) -> None:
179179 super ().__init__ (attr_method .forward_func )
180180
181181 # shallow copy is enough to avoid modifying original instance
182- # pyre-fixme[4]: Attribute must be annotated.
183182 self .attr_method = copy (attr_method )
184183
185184 self .attr_method .forward_func = self ._forward_with_dataloader
@@ -352,27 +351,22 @@ def attribute(
352351 If return_input_shape is False, a single tensor is returned
353352 where each index of the last dimension represents a feature
354353 """
355- inputs = next (iter (dataloader ))
354+ inputs = cast ( Union [ Tensor , Tuple [ Tensor , ...]], next (iter (dataloader ) ))
356355 is_inputs_tuple = True
357356
357+ inputs_tuple : Tuple [Tensor , ...]
358358 if type (inputs ) is list :
359359 # support list as it is a common return type for dataloader in torch
360- inputs = tuple (inputs )
360+ inputs_tuple = tuple (inputs )
361361 elif type (inputs ) is not tuple :
362362 is_inputs_tuple = False
363- inputs = _format_tensor_into_tuples (inputs )
363+ inputs_tuple = _format_tensor_into_tuples (inputs )
364364
365365 if input_roles :
366- # pyre-fixme[6]: For 1st argument expected
367- # `pyre_extensions.ReadOnly[Sized]` but got
368- # `Optional[typing.Tuple[typing.Any, ...]]`.
369- assert len (input_roles ) == len (inputs ), (
366+ assert len (input_roles ) == len (inputs_tuple ), (
370367 "input_roles must have the same size as the return of the dataloader," ,
371368 f"length of input_roles is { len (input_roles )} " ,
372- # pyre-fixme[6]: For 1st argument expected
373- # `pyre_extensions.ReadOnly[Sized]` but got
374- # `Optional[typing.Tuple[typing.Any, ...]]`.
375- f"whereas the length of dataloader return is { len (inputs )} " ,
369+ f"whereas the length of dataloader return is { len (inputs_tuple )} " ,
376370 )
377371
378372 assert any (role == InputRole .need_attr for role in input_roles ), (
@@ -381,14 +375,11 @@ def attribute(
381375 )
382376 else :
383377 # by default, assume every element in the dataloader needs attribution
384- # pyre-fixme[16]: `Optional` has no attribute `__iter__`.
385- input_roles = tuple (InputRole .need_attr for _ in inputs )
378+ input_roles = tuple (InputRole .need_attr for _ in inputs_tuple )
386379
387380 attr_inputs = tuple (
388381 inp
389- # pyre-fixme[6]: For 2nd argument expected `Iterable[Variable[_T2]]` but
390- # got `Optional[typing.Tuple[typing.Any, ...]]`.
391- for role , inp in zip (input_roles , inputs )
382+ for role , inp in zip (input_roles , inputs_tuple )
392383 if role == InputRole .need_attr
393384 )
394385
@@ -398,10 +389,8 @@ def attribute(
398389 "Baselines must have the same size as the return of the dataloader " ,
399390 "that need attribution" ,
400391 f"length of baseline is { len (baselines )} " ,
401- # pyre-fixme[6]: For 1st argument expected
402- # `pyre_extensions.ReadOnly[Sized]` but got
403- # `Optional[typing.Tuple[typing.Any, ...]]`.
404- f'whereas the length of dataloader return with role "0" is { len (inputs )} ' ,
392+ 'whereas the length of dataloader return with role "0" is' ,
393+ f" { len (inputs_tuple )} " ,
405394 )
406395
407396 for i , baseline in enumerate (baselines ):
@@ -419,10 +408,8 @@ def attribute(
419408 "Feature mask must have the same size as the return of the dataloader " ,
420409 "that need attribution" ,
421410 f"length of feature_mask is { len (feature_mask )} " ,
422- # pyre-fixme[6]: For 1st argument expected
423- # `pyre_extensions.ReadOnly[Sized]` but got
424- # `Optional[typing.Tuple[typing.Any, ...]]`.
425- f'whereas the length of dataloader return with role "0" is { len (inputs )} ' ,
411+ 'whereas the length of dataloader return with role "0"' ,
412+ f" is { len (inputs_tuple )} " ,
426413 )
427414
428415 for i , each_mask in enumerate (feature_mask ):
0 commit comments