22
33# pyre-strict
44import typing
5- from typing import Any , cast , List , Literal , Optional , Tuple , Union
5+ from typing import cast , Dict , List , Literal , Optional , Tuple , TypeVar , Union
6+
7+ import torch
68
79from captum ._utils .common import (
810 _format_tensor_into_tuples ,
2123)
2224from captum .attr ._core .lrp import LRP
2325from captum .attr ._utils .attribution import LayerAttribution
26+ from captum .attr ._utils .lrp_rules import PropagationRule
2427from torch import Tensor
2528from torch .nn import Module
29+ from torch .utils .hooks import RemovableHandle
30+
31+ T = TypeVar ("T" )
2632
2733
2834class LayerLRP (LRP , LayerAttribution ):
@@ -39,6 +45,13 @@ class LayerLRP(LRP, LayerAttribution):
3945 Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
4046 """
4147
48+ device_ids : List [int ]
49+ verbose : bool
50+ layers : List [Module ]
51+ attribute_to_layer_input : bool = False
52+ backward_handles : List [RemovableHandle ]
53+ forward_handles : List [RemovableHandle ]
54+
4255 def __init__ (self , model : Module , layer : ModuleOrModuleList ) -> None :
4356 """
4457 Args:
@@ -59,7 +72,6 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
5972 LayerAttribution .__init__ (self , model , layer )
6073 LRP .__init__ (self , model )
6174 if hasattr (self .model , "device_ids" ):
62- # pyre-fixme[4]: Attribute must be annotated.
6375 self .device_ids = cast (List [int ], self .model .device_ids )
6476
6577 @typing .overload # type: ignore
@@ -208,56 +220,45 @@ def attribute(
208220 >>> attribution = layer_lrp.attribute(input, target=5)
209221
210222 """
211- # pyre-fixme[16]: `LayerLRP` has no attribute `verbose`.
212223 self .verbose = verbose
213- # pyre-fixme[16]: `LayerLRP` has no attribute `_original_state_dict`.
214224 self ._original_state_dict = self .model .state_dict ()
215- # pyre-fixme[16]: `LayerLRP` has no attribute `layers`.
216225 self .layers = []
217226 self ._get_layers (self .model )
218227 self ._check_and_attach_rules ()
219- # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
220228 self .attribute_to_layer_input = attribute_to_layer_input
221- # pyre-fixme[16]: `LayerLRP` has no attribute `backward_handles`.
222229 self .backward_handles = []
223- # pyre-fixme[16]: `LayerLRP` has no attribute `forward_handles`.
224230 self .forward_handles = []
225231
226- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
227- # `Tuple[Tensor, ...]`.
228- inputs = _format_tensor_into_tuples (inputs )
229- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
230- # `TensorOrTupleOfTensorsGeneric`.
231- gradient_mask = apply_gradient_requirements (inputs )
232+ inputs_tuple = _format_tensor_into_tuples (inputs )
233+ gradient_mask = apply_gradient_requirements (inputs_tuple )
232234
233235 try :
234236 # 1. Forward pass
235237 output = self ._compute_output_and_change_weights (
236- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
237- # got `TensorOrTupleOfTensorsGeneric`.
238- inputs ,
238+ inputs_tuple ,
239239 target ,
240240 additional_forward_args ,
241241 )
242242 self ._register_forward_hooks ()
243243 # 2. Forward pass + backward pass
244244 _ = compute_gradients (
245- self ._forward_fn_wrapper , inputs , target , additional_forward_args
245+ self ._forward_fn_wrapper , inputs_tuple , target , additional_forward_args
246246 )
247247 relevances = self ._get_output_relevance (output )
248248 finally :
249249 self ._restore_model ()
250- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
251- # `TensorOrTupleOfTensorsGeneric`.
252- undo_gradient_requirements (inputs , gradient_mask )
250+ undo_gradient_requirements (inputs_tuple , gradient_mask )
253251
254252 if return_convergence_delta :
255253 delta : Union [Tensor , List [Tensor ]]
256254 if isinstance (self .layer , list ):
257255 delta = []
258256 for relevance_layer in relevances :
259257 delta .append (
260- self .compute_convergence_delta (relevance_layer , output )
258+ self .compute_convergence_delta (
259+ cast (Union [Tensor , Tuple [Tensor , ...]], relevance_layer ),
260+ output ,
261+ )
261262 )
262263 else :
263264 delta = self .compute_convergence_delta (
@@ -267,33 +268,35 @@ def attribute(
267268 else :
268269 return relevances # type: ignore
269270
270- # pyre-fixme[3]: Return type must be annotated.
271- # pyre-fixme[2]: Parameter must be annotated.
272- def _get_single_output_relevance (self , layer , output ):
273- # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
271+ def _get_single_output_relevance (
272+ self , layer : Module , output : Tensor
273+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
274274 if self .attribute_to_layer_input :
275- normalized_relevances = layer .rule .relevance_input
275+ normalized_relevances = cast (
276+ Dict [torch .device , Tensor ],
277+ cast (PropagationRule , layer .rule ).relevance_input ,
278+ )
276279 else :
277- normalized_relevances = layer .rule .relevance_output
280+ normalized_relevances = cast ( PropagationRule , layer .rule ) .relevance_output
278281 key_list = _sort_key_list (list (normalized_relevances .keys ()), self .device_ids )
279- normalized_relevances = _reduce_list (
282+ normalized_relevances_reduced = _reduce_list (
280283 [normalized_relevances [device_id ] for device_id in key_list ]
281284 )
282285
283- if isinstance (normalized_relevances , tuple ):
286+ if isinstance (normalized_relevances_reduced , tuple ):
284287 return tuple (
285288 normalized_relevance
286289 * output .reshape ((- 1 ,) + (1 ,) * (normalized_relevance .dim () - 1 ))
287- for normalized_relevance in normalized_relevances
290+ for normalized_relevance in normalized_relevances_reduced
288291 )
289292 else :
290- return normalized_relevances * output .reshape (
291- (- 1 ,) + (1 ,) * (normalized_relevances .dim () - 1 )
293+ return normalized_relevances_reduced * output .reshape (
294+ (- 1 ,) + (1 ,) * (normalized_relevances_reduced .dim () - 1 )
292295 )
293296
294- # pyre-fixme[3]: Return type must be annotated.
295- # pyre-fixme[2]: Parameter must be annotated.
296- def _get_output_relevance ( self , output ) :
297+ def _get_output_relevance (
298+ self , output : Tensor
299+ ) -> Union [ Tensor , Tuple [ Tensor , ...], List [ Union [ Tensor , Tuple [ Tensor , ...]]]] :
297300 if isinstance (self .layer , list ):
298301 relevances = []
299302 for layer in self .layer :
@@ -303,11 +306,9 @@ def _get_output_relevance(self, output):
303306 return self ._get_single_output_relevance (self .layer , output )
304307
305308 @staticmethod
306- # pyre-fixme[3]: Return annotation cannot contain `Any`.
307309 def _convert_list_to_tuple (
308- # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
309- relevances : Union [List [Any ], Tuple [Any , ...]]
310- ) -> Tuple [Any , ...]:
310+ relevances : Union [List [T ], Tuple [T , ...]]
311+ ) -> Tuple [T , ...]:
311312 if isinstance (relevances , list ):
312313 return tuple (relevances )
313314 else :
0 commit comments