Skip to content

Commit e7c4de0

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in NoiseTunnel (#1402)
Summary: Initial work on fixing Pyre errors in Noise Tunnel Reviewed By: craymichael Differential Revision: D64677341
1 parent 6d9bba6 commit e7c4de0

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

captum/attr/_core/noise_tunnel.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
from enum import Enum
5-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -27,8 +27,7 @@ class NoiseTunnelType(Enum):
2727
vargrad = 3
2828

2929

30-
# pyre-fixme[5]: Global expression must be annotated.
31-
SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys())
30+
SUPPORTED_NOISE_TUNNEL_TYPES: List[str] = list(NoiseTunnelType.__members__.keys())
3231

3332

3433
class NoiseTunnel(Attribution):
@@ -58,6 +57,10 @@ class NoiseTunnel(Attribution):
5857
It is assumed that the batch size is the first dimension of input tensors.
5958
"""
6059

60+
is_delta_supported: bool
61+
_multiply_by_inputs: bool
62+
is_gradient_method: bool
63+
6164
def __init__(self, attribution_method: Attribution) -> None:
6265
r"""
6366
Args:
@@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None:
6669
Conductance or Saliency.
6770
"""
6871
self.attribution_method = attribution_method
69-
# pyre-fixme[4]: Attribute must be annotated.
7072
self.is_delta_supported = self.attribution_method.has_convergence_delta()
71-
# pyre-fixme[4]: Attribute must be annotated.
7273
self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs
73-
# pyre-fixme[4]: Attribute must be annotated.
7474
self.is_gradient_method = isinstance(
7575
self.attribution_method, GradientAttribution
7676
)
7777
Attribution.__init__(self, self.attribution_method.forward_func)
7878

7979
@property
80-
# pyre-fixme[3]: Return type must be annotated.
81-
def multiplies_by_inputs(self):
80+
def multiplies_by_inputs(self) -> bool:
8281
return self._multiply_by_inputs
8382

8483
@log_usage()
@@ -205,9 +204,10 @@ def attribute(
205204
nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib
206205
)
207206

208-
sum_attributions: List[Union[None, Tensor]] = []
209-
sum_attributions_sq: List[Union[None, Tensor]] = []
207+
sum_attributions: Sequence[Union[None, Tensor]] = []
208+
sum_attributions_sq: Sequence[Union[None, Tensor]] = []
210209
delta_partial_list: List[Tensor] = []
210+
is_attrib_tuple = is_inputs_tuple
211211

212212
for _ in range(nt_samples_partition):
213213
inputs_with_noise = self._add_noise_to_inputs(
@@ -225,11 +225,7 @@ def attribute(
225225
)
226226

227227
if len(sum_attributions) == 0:
228-
# pyre-fixme[9]: sum_attributions has type
229-
# `List[Optional[Tensor]]`; used as `List[None]`.
230228
sum_attributions = [None] * len(attributions_partial)
231-
# pyre-fixme[9]: sum_attributions_sq has type
232-
# `List[Optional[Tensor]]`; used as `List[None]`.
233229
sum_attributions_sq = [None] * len(attributions_partial)
234230

235231
self._update_partial_attribution_and_delta(
@@ -297,7 +293,6 @@ def attribute(
297293

298294
return self._apply_checks_and_return_attributions(
299295
attributions,
300-
# pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined.
301296
is_attrib_tuple,
302297
return_convergence_delta,
303298
delta,
@@ -348,9 +343,7 @@ def _add_noise_to_input(
348343
bsz = input.shape[0]
349344

350345
# expand input size by the number of drawn samples
351-
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
352-
# and `Size`.
353-
input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:]
346+
input_expanded_size = (bsz * nt_samples_partition,) + tuple(input.shape[1:])
354347

355348
# expand stdev for the shape of the input and number of drawn samples
356349
stdev_expanded = torch.tensor(stdev, device=input.device).repeat(
@@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq(
375368
bsz = attribution.shape[0] // nt_samples_batch_size_inter
376369
attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter))
377370
if len(attribution.shape) > 1:
378-
# pyre-fixme[22]: The cast is redundant.
379-
attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:]))
371+
attribution_shape += tuple(attribution.shape[1:])
380372

381373
attribution = attribution.view(attribution_shape)
382374
current_attribution_sum = attribution.sum(dim=1, keepdim=False)
383-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
384-
# `int`.
385-
current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False)
375+
current_attribution_sq = torch.sum(
376+
torch.pow(attribution, 2), dim=1, keepdim=False
377+
)
386378

387379
sum_attribution[i] = (
388380
current_attribution_sum
@@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq(
398390
def _compute_partial_attribution(
399391
self,
400392
inputs_with_noise_partition: Tuple[Tensor, ...],
401-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
402-
kwargs_partition: Any,
393+
kwargs_partition: object,
403394
is_inputs_tuple: bool,
404395
return_convergence_delta: bool,
405396
) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]:
@@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions(
505496
) -> Union[
506497
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
507498
]:
508-
# pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor,
509-
# ...]]`, expected a tuple.
510-
attributions = _format_output(is_attrib_tuple, attributions)
499+
attributions_tuple = _format_output(is_attrib_tuple, attributions)
511500

512501
ret = (
513-
(attributions, cast(Tensor, delta))
502+
(attributions_tuple, cast(Tensor, delta))
514503
if self.is_delta_supported and return_convergence_delta
515-
else attributions
504+
else attributions_tuple
516505
)
517506
ret = cast(
518507
# pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:

0 commit comments

Comments
 (0)