11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
55
66from captum ._utils .common import (
77 _format_additional_forward_args ,
@@ -24,8 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution):
2424
2525 def __init__ (
2626 self ,
27- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
28- forward_func : Callable ,
27+ forward_func : Callable [..., Tensor ],
2928 layer : ModuleOrModuleList ,
3029 device_ids : Union [None , List [int ]] = None ,
3130 multiply_by_inputs : bool = True ,
@@ -186,11 +185,10 @@ def attribute(
186185 if isinstance (self .layer , Module ):
187186 return _format_output (
188187 len (layer_evals ) > 1 ,
189- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
190- # got `List[typing.Tuple[Tensor, ...]]`.
191- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but
192- # got `List[typing.Tuple[Tensor, ...]]`.
193- self .multiply_gradient_acts (layer_gradients , layer_evals ),
188+ self .multiply_gradient_acts (
189+ cast (Tuple [Tensor , ...], layer_gradients ),
190+ cast (Tuple [Tensor , ...], layer_evals ),
191+ ),
194192 )
195193 else :
196194 return [
0 commit comments