22
33# pyre-strict
44import typing
5- from typing import Any , Callable , Tuple , Union
5+ from typing import Any , Callable , Literal , Tuple , Union
66
77import numpy as np
88import torch
99from captum ._utils .common import _is_tuple
1010from captum ._utils .typing import (
1111 BaselineType ,
12- Literal ,
1312 TargetType ,
1413 Tensor ,
1514 TensorOrTupleOfTensorsGeneric ,
@@ -57,8 +56,9 @@ class GradientShap(GradientAttribution):
5756 samples and compute the expectation (smoothgrad).
5857 """
5958
60- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
61- def __init__ (self , forward_func : Callable , multiply_by_inputs : bool = True ) -> None :
59+ def __init__ (
60+ self , forward_func : Callable [..., Tensor ], multiply_by_inputs : bool = True
61+ ) -> None :
6262 r"""
6363 Args:
6464
@@ -82,8 +82,6 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N
8282 self ._multiply_by_inputs = multiply_by_inputs
8383
8484 @typing .overload
85- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
86- # arguments of overload defined on line `84`.
8785 def attribute (
8886 self ,
8987 inputs : TensorOrTupleOfTensorsGeneric ,
@@ -93,17 +91,12 @@ def attribute(
9391 n_samples : int = 5 ,
9492 stdevs : Union [float , Tuple [float , ...]] = 0.0 ,
9593 target : TargetType = None ,
96- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
97- additional_forward_args : Any = None ,
94+ additional_forward_args : object = None ,
9895 * ,
99- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
100- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
10196 return_convergence_delta : Literal [True ],
10297 ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]: ...
10398
10499 @typing .overload
105- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
106- # arguments of overload defined on line `99`.
107100 def attribute (
108101 self ,
109102 inputs : TensorOrTupleOfTensorsGeneric ,
@@ -113,10 +106,7 @@ def attribute(
113106 n_samples : int = 5 ,
114107 stdevs : Union [float , Tuple [float , ...]] = 0.0 ,
115108 target : TargetType = None ,
116- additional_forward_args : Any = None ,
117- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
118- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
119- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
109+ additional_forward_args : object = None ,
120110 return_convergence_delta : Literal [False ] = False ,
121111 ) -> TensorOrTupleOfTensorsGeneric : ...
122112
@@ -132,7 +122,7 @@ def attribute(
132122 n_samples : int = 5 ,
133123 stdevs : Union [float , Tuple [float , ...]] = 0.0 ,
134124 target : TargetType = None ,
135- additional_forward_args : Any = None ,
125+ additional_forward_args : object = None ,
136126 return_convergence_delta : bool = False ,
137127 ) -> Union [
138128 TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
@@ -265,20 +255,10 @@ def attribute(
265255 """
266256 # since `baselines` is a distribution, we can generate it using a function
267257 # rather than passing it as an input argument
268- # pyre-fixme[9]: baselines has type `Union[typing.Callable[...,
269- # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor,
270- # ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
271- # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
272- baselines = _format_callable_baseline (baselines , inputs )
273- # pyre-fixme[16]: Item `Callable` of `Union[(...) ->
274- # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no
275- # attribute `__getitem__`.
276- assert isinstance (baselines [0 ], torch .Tensor ), (
258+ formatted_baselines = _format_callable_baseline (baselines , inputs )
259+ assert isinstance (formatted_baselines [0 ], torch .Tensor ), (
277260 "Baselines distribution has to be provided in a form "
278- # pyre-fixme[16]: Item `Callable` of `Union[(...) ->
279- # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no
280- # attribute `__getitem__`.
281- "of a torch.Tensor {}." .format (baselines [0 ])
261+ "of a torch.Tensor {}." .format (formatted_baselines [0 ])
282262 )
283263
284264 input_min_baseline_x_grad = InputBaselineXGradient (
@@ -296,7 +276,7 @@ def attribute(
296276 nt_samples = n_samples ,
297277 stdevs = stdevs ,
298278 draw_baseline_from_distrib = True ,
299- baselines = baselines ,
279+ baselines = formatted_baselines ,
300280 target = target ,
301281 additional_forward_args = additional_forward_args ,
302282 return_convergence_delta = return_convergence_delta ,
@@ -322,8 +302,11 @@ def multiplies_by_inputs(self) -> bool:
322302
323303
324304class InputBaselineXGradient (GradientAttribution ):
325- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
326- def __init__ (self , forward_func : Callable , multiply_by_inputs : bool = True ) -> None :
305+ _multiply_by_inputs : bool
306+
307+ def __init__ (
308+ self , forward_func : Callable [..., Tensor ], multiply_by_inputs : bool = True
309+ ) -> None :
327310 r"""
328311 Args:
329312
@@ -345,37 +328,26 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N
345328
346329 """
347330 GradientAttribution .__init__ (self , forward_func )
348- # pyre-fixme[4]: Attribute must be annotated.
349331 self ._multiply_by_inputs = multiply_by_inputs
350332
351333 @typing .overload
352- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
353- # arguments of overload defined on line `318`.
354334 def attribute (
355335 self ,
356336 inputs : TensorOrTupleOfTensorsGeneric ,
357337 baselines : BaselineType = None ,
358338 target : TargetType = None ,
359- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
360- additional_forward_args : Any = None ,
339+ additional_forward_args : object = None ,
361340 * ,
362- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
363- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
364341 return_convergence_delta : Literal [True ],
365342 ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]: ...
366343
367344 @typing .overload
368- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
369- # arguments of overload defined on line `329`.
370345 def attribute (
371346 self ,
372347 inputs : TensorOrTupleOfTensorsGeneric ,
373348 baselines : BaselineType = None ,
374349 target : TargetType = None ,
375- additional_forward_args : Any = None ,
376- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
377- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
378- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
350+ additional_forward_args : object = None ,
379351 return_convergence_delta : Literal [False ] = False ,
380352 ) -> TensorOrTupleOfTensorsGeneric : ...
381353
@@ -385,37 +357,33 @@ def attribute( # type: ignore
385357 inputs : TensorOrTupleOfTensorsGeneric ,
386358 baselines : BaselineType = None ,
387359 target : TargetType = None ,
388- additional_forward_args : Any = None ,
360+ additional_forward_args : object = None ,
389361 return_convergence_delta : bool = False ,
390362 ) -> Union [
391363 TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
392364 ]:
393365 # Keeps track whether original input is a tuple or not before
394366 # converting it into a tuple.
395- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
396- # `TensorOrTupleOfTensorsGeneric`.
397367 is_inputs_tuple = _is_tuple (inputs )
398- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
399- # `Tuple[Tensor, ...]`.
400- inputs , baselines = _format_input_baseline (inputs , baselines )
368+ inputs_tuple , baselines = _format_input_baseline (inputs , baselines )
401369
402370 rand_coefficient = torch .tensor (
403- np .random .uniform (0.0 , 1.0 , inputs [0 ].shape [0 ]),
404- device = inputs [0 ].device ,
405- dtype = inputs [0 ].dtype ,
371+ np .random .uniform (0.0 , 1.0 , inputs_tuple [0 ].shape [0 ]),
372+ device = inputs_tuple [0 ].device ,
373+ dtype = inputs_tuple [0 ].dtype ,
406374 )
407375
408376 input_baseline_scaled = tuple (
409377 _scale_input (input , baseline , rand_coefficient )
410- for input , baseline in zip (inputs , baselines )
378+ for input , baseline in zip (inputs_tuple , baselines )
411379 )
412380 grads = self .gradient_func (
413381 self .forward_func , input_baseline_scaled , target , additional_forward_args
414382 )
415383
416384 if self .multiplies_by_inputs :
417385 input_baseline_diffs = tuple (
418- input - baseline for input , baseline in zip (inputs , baselines )
386+ input - baseline for input , baseline in zip (inputs_tuple , baselines )
419387 )
420388 attributions = tuple (
421389 input_baseline_diff * grad
0 commit comments