3333)
3434from captum .log import log_usage
3535from torch import Tensor
36+ from torch .nn import Module
3637from torch .nn .parallel .scatter_gather import scatter
3738
3839
@@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
5859
5960 def __init__ (
6061 self ,
61- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
62- forward_func : Callable ,
62+ forward_func : Callable [..., Tensor ],
6363 layer : ModuleOrModuleList ,
6464 device_ids : Union [None , List [int ]] = None ,
6565 multiply_by_inputs : bool = True ,
@@ -128,8 +128,7 @@ def _make_gradient_func(
128128 ) -> Callable [..., Tuple [Tensor , ...]]:
129129
130130 def _gradient_func (
131- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
132- forward_fn : Callable ,
131+ forward_fn : Callable [..., Tensor ],
133132 inputs : Union [Tensor , Tuple [Tensor , ...]],
134133 target_ind : TargetType = None ,
135134 additional_forward_args : Optional [object ] = None ,
@@ -146,28 +145,21 @@ def _gradient_func(
146145 target_gpus = self .device_ids ,
147146 )
148147
149- scattered_inputs_dict = {
148+ scattered_inputs_dict : Dict [
149+ torch .device , Union [Tensor , Tuple [Tensor , ...]]
150+ ] = {
150151 scattered_input [0 ].device : scattered_input
151152 for scattered_input in scattered_inputs
152153 }
153154
154155 with torch .autograd .set_grad_enabled (True ):
155156
156- # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
157- # annotated.
158- # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
159- # annotated.
160- # pyre-fixme[3]: Return type must be annotated.
161157 def layer_forward_hook (
162- # pyre-fixme[2]: Parameter must be annotated.
163- module ,
164- # pyre-fixme[2]: Parameter must be annotated.
165- hook_inputs ,
166- # pyre-fixme[2]: Parameter must be annotated.
167- hook_outputs = None ,
168- # pyre-fixme[2]: Parameter must be annotated.
169- layer_idx = 0 ,
170- ):
158+ module : Module ,
159+ hook_inputs : Union [Tensor , Tuple [Tensor , ...]],
160+ hook_outputs : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
161+ layer_idx : int = 0 ,
162+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
171163 device = _extract_device (module , hook_inputs , hook_outputs )
172164 is_layer_tuple = (
173165 isinstance (hook_outputs , tuple )
@@ -177,11 +169,14 @@ def layer_forward_hook(
177169 )
178170
179171 if is_layer_tuple :
180- return scattered_inputs_dict [device ][
181- num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
182- layer_idx + 1
183- ]
184- ]
172+ return cast (
173+ Union [Tensor , Tuple [Tensor , ...]],
174+ scattered_inputs_dict [device ][
175+ num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
176+ layer_idx + 1
177+ ]
178+ ],
179+ )
185180
186181 return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
187182
@@ -502,11 +497,22 @@ def attribute(
502497 additional_forward_args
503498 )
504499
505- # pyre-fixme[3]: Return type must be annotated.
506- # pyre-fixme[2]: Parameter must be annotated.
507- def flatten_tuple (tup ):
500+ def flatten_tuple (tup : List [Tuple [Tensor , ...]]) -> Tuple [Tensor , ...]:
508501 return tuple (
509- sum ((list (x ) if isinstance (x , (tuple , list )) else [x ] for x in tup ), [])
502+ cast (
503+ List [Tensor ],
504+ sum (
505+ (
506+ (
507+ list (x )
508+ if isinstance (x , (tuple , list ))
509+ else cast (List [Tensor ], [x ])
510+ )
511+ for x in tup
512+ ),
513+ [],
514+ ),
515+ )
510516 )
511517
512518 if self .device_ids is None :
@@ -520,16 +526,18 @@ def flatten_tuple(tup):
520526 additional_forward_args = additional_forward_args ,
521527 attribute_to_layer_input = attribute_to_layer_input ,
522528 )
523-
529+ input_layer_list : List [ Tuple [ Tensor , ...]]
524530 # if we have one output
525531 if not isinstance (self .layer , list ):
526- inputs_layer = (inputs_layer ,)
532+ input_layer_list = [cast (Tuple [Tensor , ...], inputs_layer )]
533+ else :
534+ input_layer_list = inputs_layer
527535
528- num_outputs = [1 if isinstance (x , Tensor ) else len (x ) for x in inputs_layer ]
536+ num_outputs = [1 if isinstance (x , Tensor ) else len (x ) for x in input_layer_list ]
529537 num_outputs_cumsum = torch .cumsum (
530538 torch .IntTensor ([0 ] + num_outputs ), dim = 0 # type: ignore
531539 )
532- inputs_layer = flatten_tuple (inputs_layer )
540+ inputs_layer = flatten_tuple (input_layer_list )
533541
534542 baselines_layer = _forward_layer_eval (
535543 self .forward_func ,
0 commit comments