Skip to content

Commit a008055

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in GradientShap (#1394)
Summary: Initial work on fixing Pyre errors in GradientSHAP Differential Revision: D64677343
1 parent 7b507fc commit a008055

File tree

1 file changed

+25
-57
lines changed

1 file changed

+25
-57
lines changed

captum/attr/_core/gradient_shap.py

Lines changed: 25 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, Tuple, Union
5+
from typing import Any, Callable, Literal, Tuple, Union
66

77
import numpy as np
88
import torch
99
from captum._utils.common import _is_tuple
1010
from captum._utils.typing import (
1111
BaselineType,
12-
Literal,
1312
TargetType,
1413
Tensor,
1514
TensorOrTupleOfTensorsGeneric,
@@ -57,8 +56,9 @@ class GradientShap(GradientAttribution):
5756
samples and compute the expectation (smoothgrad).
5857
"""
5958

60-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
61-
def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None:
59+
def __init__(
60+
self, forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True
61+
) -> None:
6262
r"""
6363
Args:
6464
@@ -82,8 +82,6 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N
8282
self._multiply_by_inputs = multiply_by_inputs
8383

8484
@typing.overload
85-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
86-
# arguments of overload defined on line `84`.
8785
def attribute(
8886
self,
8987
inputs: TensorOrTupleOfTensorsGeneric,
@@ -93,17 +91,12 @@ def attribute(
9391
n_samples: int = 5,
9492
stdevs: Union[float, Tuple[float, ...]] = 0.0,
9593
target: TargetType = None,
96-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
97-
additional_forward_args: Any = None,
94+
additional_forward_args: object = None,
9895
*,
99-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
100-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
10196
return_convergence_delta: Literal[True],
10297
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
10398

10499
@typing.overload
105-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
106-
# arguments of overload defined on line `99`.
107100
def attribute(
108101
self,
109102
inputs: TensorOrTupleOfTensorsGeneric,
@@ -113,10 +106,7 @@ def attribute(
113106
n_samples: int = 5,
114107
stdevs: Union[float, Tuple[float, ...]] = 0.0,
115108
target: TargetType = None,
116-
additional_forward_args: Any = None,
117-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
118-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
119-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
109+
additional_forward_args: object = None,
120110
return_convergence_delta: Literal[False] = False,
121111
) -> TensorOrTupleOfTensorsGeneric: ...
122112

@@ -132,7 +122,7 @@ def attribute(
132122
n_samples: int = 5,
133123
stdevs: Union[float, Tuple[float, ...]] = 0.0,
134124
target: TargetType = None,
135-
additional_forward_args: Any = None,
125+
additional_forward_args: object = None,
136126
return_convergence_delta: bool = False,
137127
) -> Union[
138128
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
@@ -265,20 +255,10 @@ def attribute(
265255
"""
266256
# since `baselines` is a distribution, we can generate it using a function
267257
# rather than passing it as an input argument
268-
# pyre-fixme[9]: baselines has type `Union[typing.Callable[...,
269-
# Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor,
270-
# ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
271-
# typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
272-
baselines = _format_callable_baseline(baselines, inputs)
273-
# pyre-fixme[16]: Item `Callable` of `Union[(...) ->
274-
# TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no
275-
# attribute `__getitem__`.
276-
assert isinstance(baselines[0], torch.Tensor), (
258+
formatted_baselines = _format_callable_baseline(baselines, inputs)
259+
assert isinstance(formatted_baselines[0], torch.Tensor), (
277260
"Baselines distribution has to be provided in a form "
278-
# pyre-fixme[16]: Item `Callable` of `Union[(...) ->
279-
# TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no
280-
# attribute `__getitem__`.
281-
"of a torch.Tensor {}.".format(baselines[0])
261+
"of a torch.Tensor {}.".format(formatted_baselines[0])
282262
)
283263

284264
input_min_baseline_x_grad = InputBaselineXGradient(
@@ -296,7 +276,7 @@ def attribute(
296276
nt_samples=n_samples,
297277
stdevs=stdevs,
298278
draw_baseline_from_distrib=True,
299-
baselines=baselines,
279+
baselines=formatted_baselines,
300280
target=target,
301281
additional_forward_args=additional_forward_args,
302282
return_convergence_delta=return_convergence_delta,
@@ -322,8 +302,11 @@ def multiplies_by_inputs(self) -> bool:
322302

323303

324304
class InputBaselineXGradient(GradientAttribution):
325-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
326-
def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None:
305+
_multiply_by_inputs: bool
306+
307+
def __init__(
308+
self, forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True
309+
) -> None:
327310
r"""
328311
Args:
329312
@@ -345,37 +328,26 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N
345328
346329
"""
347330
GradientAttribution.__init__(self, forward_func)
348-
# pyre-fixme[4]: Attribute must be annotated.
349331
self._multiply_by_inputs = multiply_by_inputs
350332

351333
@typing.overload
352-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
353-
# arguments of overload defined on line `318`.
354334
def attribute(
355335
self,
356336
inputs: TensorOrTupleOfTensorsGeneric,
357337
baselines: BaselineType = None,
358338
target: TargetType = None,
359-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
360-
additional_forward_args: Any = None,
339+
additional_forward_args: object = None,
361340
*,
362-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
363-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
364341
return_convergence_delta: Literal[True],
365342
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
366343

367344
@typing.overload
368-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
369-
# arguments of overload defined on line `329`.
370345
def attribute(
371346
self,
372347
inputs: TensorOrTupleOfTensorsGeneric,
373348
baselines: BaselineType = None,
374349
target: TargetType = None,
375-
additional_forward_args: Any = None,
376-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
377-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
378-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
350+
additional_forward_args: object = None,
379351
return_convergence_delta: Literal[False] = False,
380352
) -> TensorOrTupleOfTensorsGeneric: ...
381353

@@ -385,37 +357,33 @@ def attribute( # type: ignore
385357
inputs: TensorOrTupleOfTensorsGeneric,
386358
baselines: BaselineType = None,
387359
target: TargetType = None,
388-
additional_forward_args: Any = None,
360+
additional_forward_args: object = None,
389361
return_convergence_delta: bool = False,
390362
) -> Union[
391363
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
392364
]:
393365
# Keeps track whether original input is a tuple or not before
394366
# converting it into a tuple.
395-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
396-
# `TensorOrTupleOfTensorsGeneric`.
397367
is_inputs_tuple = _is_tuple(inputs)
398-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
399-
# `Tuple[Tensor, ...]`.
400-
inputs, baselines = _format_input_baseline(inputs, baselines)
368+
inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
401369

402370
rand_coefficient = torch.tensor(
403-
np.random.uniform(0.0, 1.0, inputs[0].shape[0]),
404-
device=inputs[0].device,
405-
dtype=inputs[0].dtype,
371+
np.random.uniform(0.0, 1.0, inputs_tuple[0].shape[0]),
372+
device=inputs_tuple[0].device,
373+
dtype=inputs_tuple[0].dtype,
406374
)
407375

408376
input_baseline_scaled = tuple(
409377
_scale_input(input, baseline, rand_coefficient)
410-
for input, baseline in zip(inputs, baselines)
378+
for input, baseline in zip(inputs_tuple, baselines)
411379
)
412380
grads = self.gradient_func(
413381
self.forward_func, input_baseline_scaled, target, additional_forward_args
414382
)
415383

416384
if self.multiplies_by_inputs:
417385
input_baseline_diffs = tuple(
418-
input - baseline for input, baseline in zip(inputs, baselines)
386+
input - baseline for input, baseline in zip(inputs_tuple, baselines)
419387
)
420388
attributions = tuple(
421389
input_baseline_diff * grad

0 commit comments

Comments
 (0)