44
55import typing
66from collections import defaultdict
7- from typing import Any , Callable , cast , List , Tuple , Union
7+ from typing import Any , Callable , cast , Dict , List , Literal , Tuple , Union
88
99import torch .nn as nn
1010from captum ._utils .common import (
1818 apply_gradient_requirements ,
1919 undo_gradient_requirements ,
2020)
21- from captum ._utils .typing import Literal , TargetType , TensorOrTupleOfTensorsGeneric
21+ from captum ._utils .typing import TargetType , TensorOrTupleOfTensorsGeneric
2222from captum .attr ._utils .attribution import GradientAttribution
2323from captum .attr ._utils .common import _sum_rows
2424from captum .attr ._utils .custom_modules import Addition_Module
@@ -43,6 +43,12 @@ class LRP(GradientAttribution):
4343 Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
4444 """
4545
46+ verbose : bool = False
47+ _original_state_dict : Dict [str , Any ] = {}
48+ layers : List [Module ] = []
49+ backward_handles : List [RemovableHandle ] = []
50+ forward_handles : List [RemovableHandle ] = []
51+
4652 def __init__ (self , model : Module ) -> None :
4753 r"""
4854 Args:
@@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool:
6268 return True
6369
6470 @typing .overload
65- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
66- # arguments of overload defined on line `75`.
6771 def attribute (
6872 self ,
6973 inputs : TensorOrTupleOfTensorsGeneric ,
7074 target : TargetType = None ,
71- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
72- additional_forward_args : Any = None ,
75+ additional_forward_args : object = None ,
7376 * ,
74- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
75- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7677 return_convergence_delta : Literal [True ],
7778 verbose : bool = False ,
7879 ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]: ...
7980
8081 @typing .overload
81- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
82- # arguments of overload defined on line `65`.
8382 def attribute (
8483 self ,
8584 inputs : TensorOrTupleOfTensorsGeneric ,
8685 target : TargetType = None ,
87- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
88- additional_forward_args : Any = None ,
89- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
90- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
91- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
86+ additional_forward_args : object = None ,
9287 return_convergence_delta : Literal [False ] = False ,
9388 verbose : bool = False ,
9489 ) -> TensorOrTupleOfTensorsGeneric : ...
@@ -100,7 +95,7 @@ def attribute(
10095 self ,
10196 inputs : TensorOrTupleOfTensorsGeneric ,
10297 target : TargetType = None ,
103- additional_forward_args : Any = None ,
98+ additional_forward_args : object = None ,
10499 return_convergence_delta : bool = False ,
105100 verbose : bool = False ,
106101 ) -> Union [
@@ -199,43 +194,30 @@ def attribute(
199194 >>> attribution = lrp.attribute(input, target=5)
200195
201196 """
202- # pyre-fixme[16]: `LRP` has no attribute `verbose`.
203197 self .verbose = verbose
204- # pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`.
205198 self ._original_state_dict = self .model .state_dict ()
206- # pyre-fixme[16]: `LRP` has no attribute `layers`.
207- self .layers : List [Module ] = []
199+ self .layers = []
208200 self ._get_layers (self .model )
209201 self ._check_and_attach_rules ()
210- # pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
211202 self .backward_handles : List [RemovableHandle ] = []
212- # pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
213203 self .forward_handles : List [RemovableHandle ] = []
214204
215- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
216- # `TensorOrTupleOfTensorsGeneric`.
217205 is_inputs_tuple = _is_tuple (inputs )
218- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
219- # `Tuple[Tensor, ...]`.
220- inputs = _format_tensor_into_tuples (inputs )
221- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
222- # `TensorOrTupleOfTensorsGeneric`.
223- gradient_mask = apply_gradient_requirements (inputs )
206+ input_tuple = _format_tensor_into_tuples (inputs )
207+ gradient_mask = apply_gradient_requirements (input_tuple )
224208
225209 try :
226210 # 1. Forward pass: Change weights of layers according to selected rules.
227211 output = self ._compute_output_and_change_weights (
228- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
229- # got `TensorOrTupleOfTensorsGeneric`.
230- inputs ,
212+ input_tuple ,
231213 target ,
232214 additional_forward_args ,
233215 )
234216 # 2. Forward pass + backward pass: Register hooks to configure relevance
235217 # propagation and execute back-propagation.
236218 self ._register_forward_hooks ()
237219 normalized_relevances = self .gradient_func (
238- self ._forward_fn_wrapper , inputs , target , additional_forward_args
220+ self ._forward_fn_wrapper , input_tuple , target , additional_forward_args
239221 )
240222 relevances = tuple (
241223 normalized_relevance
@@ -245,9 +227,7 @@ def attribute(
245227 finally :
246228 self ._restore_model ()
247229
248- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
249- # `TensorOrTupleOfTensorsGeneric`.
250- undo_gradient_requirements (inputs , gradient_mask )
230+ undo_gradient_requirements (input_tuple , gradient_mask )
251231
252232 if return_convergence_delta :
253233 # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
@@ -310,13 +290,11 @@ def compute_convergence_delta(
310290 def _get_layers (self , model : Module ) -> None :
311291 for layer in model .children ():
312292 if len (list (layer .children ())) == 0 :
313- # pyre-fixme[16]: `LRP` has no attribute `layers`.
314293 self .layers .append (layer )
315294 else :
316295 self ._get_layers (layer )
317296
318297 def _check_and_attach_rules (self ) -> None :
319- # pyre-fixme[16]: `LRP` has no attribute `layers`.
320298 for layer in self .layers :
321299 if hasattr (layer , "rule" ):
322300 layer .activations = {} # type: ignore
@@ -355,50 +333,41 @@ def _check_rules(self) -> None:
355333 )
356334
357335 def _register_forward_hooks (self ) -> None :
358- # pyre-fixme[16]: `LRP` has no attribute `layers`.
359336 for layer in self .layers :
360337 if type (layer ) in SUPPORTED_NON_LINEAR_LAYERS :
361338 backward_handles = _register_backward_hook (
362339 layer , PropagationRule .backward_hook_activation , self
363340 )
364- # pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
365341 self .backward_handles .extend (backward_handles )
366342 else :
367343 forward_handle = layer .register_forward_hook (
368344 layer .rule .forward_hook # type: ignore
369345 )
370- # pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
371346 self .forward_handles .append (forward_handle )
372- # pyre-fixme[16]: `LRP` has no attribute `verbose`.
373347 if self .verbose :
374348 print (f"Applied { layer .rule } on layer { layer } " )
375349
376350 def _register_weight_hooks (self ) -> None :
377- # pyre-fixme[16]: `LRP` has no attribute `layers`.
378351 for layer in self .layers :
379352 if layer .rule is not None :
380353 forward_handle = layer .register_forward_hook (
381354 layer .rule .forward_hook_weights # type: ignore
382355 )
383- # pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
384356 self .forward_handles .append (forward_handle )
385357
386358 def _register_pre_hooks (self ) -> None :
387- # pyre-fixme[16]: `LRP` has no attribute `layers`.
388359 for layer in self .layers :
389360 if layer .rule is not None :
390361 forward_handle = layer .register_forward_pre_hook (
391362 layer .rule .forward_pre_hook_activations # type: ignore
392363 )
393- # pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
394364 self .forward_handles .append (forward_handle )
395365
396366 def _compute_output_and_change_weights (
397367 self ,
398368 inputs : Tuple [Tensor , ...],
399369 target : TargetType ,
400- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
401- additional_forward_args : Any ,
370+ additional_forward_args : object ,
402371 ) -> Tensor :
403372 try :
404373 self ._register_weight_hooks ()
@@ -416,15 +385,12 @@ def _compute_output_and_change_weights(
416385 return cast (Tensor , output )
417386
418387 def _remove_forward_hooks (self ) -> None :
419- # pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
420388 for forward_handle in self .forward_handles :
421389 forward_handle .remove ()
422390
423391 def _remove_backward_hooks (self ) -> None :
424- # pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
425392 for backward_handle in self .backward_handles :
426393 backward_handle .remove ()
427- # pyre-fixme[16]: `LRP` has no attribute `layers`.
428394 for layer in self .layers :
429395 if hasattr (layer .rule , "_handle_input_hooks" ):
430396 for handle in layer .rule ._handle_input_hooks : # type: ignore
@@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None:
433399 layer .rule ._handle_output_hook .remove () # type: ignore
434400
435401 def _remove_rules (self ) -> None :
436- # pyre-fixme[16]: `LRP` has no attribute `layers`.
437402 for layer in self .layers :
438403 if hasattr (layer , "rule" ):
439404 del layer .rule
440405
441406 def _clear_properties (self ) -> None :
442- # pyre-fixme[16]: `LRP` has no attribute `layers`.
443407 for layer in self .layers :
444408 if hasattr (layer , "activation" ):
445409 del layer .activation
0 commit comments