22
33# pyre-strict
44from enum import Enum
5- from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
5+ from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , Union
66
77import torch
88from captum ._utils .common import (
@@ -27,8 +27,7 @@ class NoiseTunnelType(Enum):
2727 vargrad = 3
2828
2929
30- # pyre-fixme[5]: Global expression must be annotated.
31- SUPPORTED_NOISE_TUNNEL_TYPES = list (NoiseTunnelType .__members__ .keys ())
30+ SUPPORTED_NOISE_TUNNEL_TYPES : List [str ] = list (NoiseTunnelType .__members__ .keys ())
3231
3332
3433class NoiseTunnel (Attribution ):
@@ -58,6 +57,10 @@ class NoiseTunnel(Attribution):
5857 It is assumed that the batch size is the first dimension of input tensors.
5958 """
6059
60+ is_delta_supported : bool
61+ _multiply_by_inputs : bool
62+ is_gradient_method : bool
63+
6164 def __init__ (self , attribution_method : Attribution ) -> None :
6265 r"""
6366 Args:
@@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None:
6669 Conductance or Saliency.
6770 """
6871 self .attribution_method = attribution_method
69- # pyre-fixme[4]: Attribute must be annotated.
7072 self .is_delta_supported = self .attribution_method .has_convergence_delta ()
71- # pyre-fixme[4]: Attribute must be annotated.
7273 self ._multiply_by_inputs = self .attribution_method .multiplies_by_inputs
73- # pyre-fixme[4]: Attribute must be annotated.
7474 self .is_gradient_method = isinstance (
7575 self .attribution_method , GradientAttribution
7676 )
7777 Attribution .__init__ (self , self .attribution_method .forward_func )
7878
7979 @property
80- # pyre-fixme[3]: Return type must be annotated.
81- def multiplies_by_inputs (self ):
80+ def multiplies_by_inputs (self ) -> bool :
8281 return self ._multiply_by_inputs
8382
8483 @log_usage ()
@@ -205,9 +204,10 @@ def attribute(
205204 nt_samples_batch_size , kwargs_copy , inputs , draw_baseline_from_distrib
206205 )
207206
208- sum_attributions : List [Union [None , Tensor ]] = []
209- sum_attributions_sq : List [Union [None , Tensor ]] = []
207+ sum_attributions : Sequence [Union [None , Tensor ]] = []
208+ sum_attributions_sq : Sequence [Union [None , Tensor ]] = []
210209 delta_partial_list : List [Tensor ] = []
210+ is_attrib_tuple = is_inputs_tuple
211211
212212 for _ in range (nt_samples_partition ):
213213 inputs_with_noise = self ._add_noise_to_inputs (
@@ -225,11 +225,7 @@ def attribute(
225225 )
226226
227227 if len (sum_attributions ) == 0 :
228- # pyre-fixme[9]: sum_attributions has type
229- # `List[Optional[Tensor]]`; used as `List[None]`.
230228 sum_attributions = [None ] * len (attributions_partial )
231- # pyre-fixme[9]: sum_attributions_sq has type
232- # `List[Optional[Tensor]]`; used as `List[None]`.
233229 sum_attributions_sq = [None ] * len (attributions_partial )
234230
235231 self ._update_partial_attribution_and_delta (
@@ -297,7 +293,6 @@ def attribute(
297293
298294 return self ._apply_checks_and_return_attributions (
299295 attributions ,
300- # pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined.
301296 is_attrib_tuple ,
302297 return_convergence_delta ,
303298 delta ,
@@ -348,9 +343,7 @@ def _add_noise_to_input(
348343 bsz = input .shape [0 ]
349344
350345 # expand input size by the number of drawn samples
351- # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
352- # and `Size`.
353- input_expanded_size = (bsz * nt_samples_partition ,) + input .shape [1 :]
346+ input_expanded_size = (bsz * nt_samples_partition ,) + tuple (input .shape [1 :])
354347
355348 # expand stdev for the shape of the input and number of drawn samples
356349 stdev_expanded = torch .tensor (stdev , device = input .device ).repeat (
@@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq(
375368 bsz = attribution .shape [0 ] // nt_samples_batch_size_inter
376369 attribution_shape = cast (Tuple [int , ...], (bsz , nt_samples_batch_size_inter ))
377370 if len (attribution .shape ) > 1 :
378- # pyre-fixme[22]: The cast is redundant.
379- attribution_shape += cast (Tuple [int , ...], tuple (attribution .shape [1 :]))
371+ attribution_shape += tuple (attribution .shape [1 :])
380372
381373 attribution = attribution .view (attribution_shape )
382374 current_attribution_sum = attribution .sum (dim = 1 , keepdim = False )
383- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
384- # `int`.
385- current_attribution_sq = torch . sum ( attribution ** 2 , dim = 1 , keepdim = False )
375+ current_attribution_sq = torch . sum (
376+ torch . pow ( attribution , 2 ), dim = 1 , keepdim = False
377+ )
386378
387379 sum_attribution [i ] = (
388380 current_attribution_sum
@@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq(
398390 def _compute_partial_attribution (
399391 self ,
400392 inputs_with_noise_partition : Tuple [Tensor , ...],
401- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
402- kwargs_partition : Any ,
393+ kwargs_partition : object ,
403394 is_inputs_tuple : bool ,
404395 return_convergence_delta : bool ,
405396 ) -> Tuple [Tuple [Tensor , ...], bool , Union [None , Tensor ]]:
@@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions(
505496 ) -> Union [
506497 TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
507498 ]:
508- # pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor,
509- # ...]]`, expected a tuple.
510- attributions = _format_output (is_attrib_tuple , attributions )
499+ attributions_tuple = _format_output (is_attrib_tuple , attributions )
511500
512501 ret = (
513- (attributions , cast (Tensor , delta ))
502+ (attributions_tuple , cast (Tensor , delta ))
514503 if self .is_delta_supported and return_convergence_delta
515- else attributions
504+ else attributions_tuple
516505 )
517506 ret = cast (
518507 # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
0 commit comments