11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
55
66import torch
77from captum ._utils .common import (
@@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution):
4141
4242 def __init__ (
4343 self ,
44- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
45- forward_func : Callable ,
44+ forward_func : Callable [..., Tensor ],
4645 layer : Module ,
4746 device_ids : Union [None , List [int ]] = None ,
4847 ) -> None :
@@ -293,7 +292,7 @@ def _attribute(
293292 # Returns gradient of output with respect to hidden layer.
294293 layer_gradients , _ = compute_layer_gradients_and_eval (
295294 forward_fn = self .forward_func ,
296- layer = self .layer ,
295+ layer = cast ( Module , self .layer ) ,
297296 inputs = scaled_features_tpl ,
298297 target_ind = expanded_target ,
299298 additional_forward_args = input_additional_args ,
@@ -304,9 +303,7 @@ def _attribute(
304303 # flattening grads so that we can multiply it with step-size
305304 # calling contiguous to avoid `memory whole` problems
306305 scaled_grads = tuple (
307- # pyre-fixme[16]: `tuple` has no attribute `contiguous`.
308306 layer_grad .contiguous ().view (n_steps , - 1 )
309- # pyre-fixme[16]: `tuple` has no attribute `device`.
310307 * torch .tensor (step_sizes ).view (n_steps , 1 ).to (layer_grad .device )
311308 for layer_grad in layer_gradients
312309 )
@@ -317,8 +314,7 @@ def _attribute(
317314 scaled_grad ,
318315 n_steps ,
319316 inputs [0 ].shape [0 ],
320- # pyre-fixme[16]: `tuple` has no attribute `shape`.
321- layer_grad .shape [1 :],
317+ tuple (layer_grad .shape [1 :]),
322318 )
323319 for scaled_grad , layer_grad in zip (scaled_grads , layer_gradients )
324320 )
0 commit comments