11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import Any , Callable , cast , List , Tuple , Type , Union
4+ from typing import Any , Callable , cast , Dict , List , Tuple , Type , Union
55
66import torch
77from captum ._utils .common import (
1818from captum .attr ._core .feature_permutation import FeaturePermutation
1919from captum .attr ._utils .attribution import LayerAttribution
2020from captum .log import log_usage
21- from torch import Tensor
21+ from torch import device , Tensor
2222from torch .nn import Module
2323from torch .nn .parallel .scatter_gather import scatter
2424
@@ -32,8 +32,7 @@ class LayerFeaturePermutation(LayerAttribution, FeaturePermutation):
3232
3333 def __init__ (
3434 self ,
35- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
36- forward_func : Callable ,
35+ forward_func : Callable [..., Tensor ],
3736 layer : Module ,
3837 device_ids : Union [None , List [int ]] = None ,
3938 ) -> None :
@@ -64,8 +63,7 @@ def attribute(
6463 self ,
6564 inputs : Union [Tensor , Tuple [Tensor , ...]],
6665 target : TargetType = None ,
67- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
68- additional_forward_args : Any = None ,
66+ additional_forward_args : object = None ,
6967 layer_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
7068 perturbations_per_eval : int = 1 ,
7169 ) -> Union [Tensor , Tuple [Tensor , ...]]:
@@ -159,28 +157,33 @@ def attribute(
159157 otherwise a single tensor is returned.
160158 """
161159
162- # pyre-fixme[2]: Parameter must be annotated.
163- def layer_forward_func (* args ) -> Tensor :
164- layer_length = args [- 1 ]
165- layer_input = args [:layer_length ]
166- original_inputs = args [layer_length :- 1 ]
160+ def layer_forward_func (* args : Any ) -> Tensor :
161+ r"""
162+ Args:
163+ args (Any): Tensors comprising the layer input and the original
164+ inputs, and an int representing the length of the layer input
165+ """
166+ layer_length : int = args [- 1 ]
167+ layer_input : Tuple [Tensor , ...] = args [:layer_length ]
168+ original_inputs : Tuple [Tensor , ...] = args [layer_length :- 1 ]
167169
168170 device_ids = self .device_ids
169171 if device_ids is None :
170172 device_ids = getattr (self .forward_func , "device_ids" , None )
171173
172- all_layer_inputs = {}
174+ all_layer_inputs : Dict [ device , Tuple [ Tensor , ...]] = {}
173175 if device_ids is not None :
174176 scattered_layer_input = scatter (layer_input , target_gpus = device_ids )
175177 for device_tensors in scattered_layer_input :
176178 all_layer_inputs [device_tensors [0 ].device ] = device_tensors
177179 else :
178180 all_layer_inputs [layer_input [0 ].device ] = layer_input
179181
180- # pyre-fixme[53]: Captured variable `all_layer_inputs` is not annotated.
181- # pyre-fixme[3]: Return type must be annotated.
182- # pyre-fixme[2]: Parameter must be annotated.
183- def forward_hook (module , inp , out = None ):
182+ def forward_hook (
183+ module : Module ,
184+ inp : Union [None , Tensor , Tuple [Tensor , ...]],
185+ out : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
186+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
184187 device = _extract_device (module , inp , out )
185188 is_layer_tuple = (
186189 isinstance (out , tuple )
0 commit comments