Skip to content

Commit d7dce20

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in LRP (#1401)
Summary: Initial work on fixing Pyre errors in LRP Reviewed By: craymichael Differential Revision: D64677351
1 parent e7c4de0 commit d7dce20

File tree

1 file changed

+18
-54
lines changed

1 file changed

+18
-54
lines changed

captum/attr/_core/lrp.py

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import typing
66
from collections import defaultdict
7-
from typing import Any, Callable, cast, List, Tuple, Union
7+
from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union
88

99
import torch.nn as nn
1010
from captum._utils.common import (
@@ -18,7 +18,7 @@
1818
apply_gradient_requirements,
1919
undo_gradient_requirements,
2020
)
21-
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
21+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
2222
from captum.attr._utils.attribution import GradientAttribution
2323
from captum.attr._utils.common import _sum_rows
2424
from captum.attr._utils.custom_modules import Addition_Module
@@ -43,6 +43,12 @@ class LRP(GradientAttribution):
4343
Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
4444
"""
4545

46+
verbose: bool = False
47+
_original_state_dict: Dict[str, Any] = {}
48+
layers: List[Module] = []
49+
backward_handles: List[RemovableHandle] = []
50+
forward_handles: List[RemovableHandle] = []
51+
4652
def __init__(self, model: Module) -> None:
4753
r"""
4854
Args:
@@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool:
6268
return True
6369

6470
@typing.overload
65-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
66-
# arguments of overload defined on line `75`.
6771
def attribute(
6872
self,
6973
inputs: TensorOrTupleOfTensorsGeneric,
7074
target: TargetType = None,
71-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
72-
additional_forward_args: Any = None,
75+
additional_forward_args: object = None,
7376
*,
74-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
75-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7677
return_convergence_delta: Literal[True],
7778
verbose: bool = False,
7879
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
7980

8081
@typing.overload
81-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
82-
# arguments of overload defined on line `65`.
8382
def attribute(
8483
self,
8584
inputs: TensorOrTupleOfTensorsGeneric,
8685
target: TargetType = None,
87-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
88-
additional_forward_args: Any = None,
89-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
90-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
91-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
86+
additional_forward_args: object = None,
9287
return_convergence_delta: Literal[False] = False,
9388
verbose: bool = False,
9489
) -> TensorOrTupleOfTensorsGeneric: ...
@@ -100,7 +95,7 @@ def attribute(
10095
self,
10196
inputs: TensorOrTupleOfTensorsGeneric,
10297
target: TargetType = None,
103-
additional_forward_args: Any = None,
98+
additional_forward_args: object = None,
10499
return_convergence_delta: bool = False,
105100
verbose: bool = False,
106101
) -> Union[
@@ -199,43 +194,30 @@ def attribute(
199194
>>> attribution = lrp.attribute(input, target=5)
200195
201196
"""
202-
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
203197
self.verbose = verbose
204-
# pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`.
205198
self._original_state_dict = self.model.state_dict()
206-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
207-
self.layers: List[Module] = []
199+
self.layers = []
208200
self._get_layers(self.model)
209201
self._check_and_attach_rules()
210-
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
211202
self.backward_handles: List[RemovableHandle] = []
212-
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
213203
self.forward_handles: List[RemovableHandle] = []
214204

215-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
216-
# `TensorOrTupleOfTensorsGeneric`.
217205
is_inputs_tuple = _is_tuple(inputs)
218-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
219-
# `Tuple[Tensor, ...]`.
220-
inputs = _format_tensor_into_tuples(inputs)
221-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
222-
# `TensorOrTupleOfTensorsGeneric`.
223-
gradient_mask = apply_gradient_requirements(inputs)
206+
input_tuple = _format_tensor_into_tuples(inputs)
207+
gradient_mask = apply_gradient_requirements(input_tuple)
224208

225209
try:
226210
# 1. Forward pass: Change weights of layers according to selected rules.
227211
output = self._compute_output_and_change_weights(
228-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
229-
# got `TensorOrTupleOfTensorsGeneric`.
230-
inputs,
212+
input_tuple,
231213
target,
232214
additional_forward_args,
233215
)
234216
# 2. Forward pass + backward pass: Register hooks to configure relevance
235217
# propagation and execute back-propagation.
236218
self._register_forward_hooks()
237219
normalized_relevances = self.gradient_func(
238-
self._forward_fn_wrapper, inputs, target, additional_forward_args
220+
self._forward_fn_wrapper, input_tuple, target, additional_forward_args
239221
)
240222
relevances = tuple(
241223
normalized_relevance
@@ -245,9 +227,7 @@ def attribute(
245227
finally:
246228
self._restore_model()
247229

248-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
249-
# `TensorOrTupleOfTensorsGeneric`.
250-
undo_gradient_requirements(inputs, gradient_mask)
230+
undo_gradient_requirements(input_tuple, gradient_mask)
251231

252232
if return_convergence_delta:
253233
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
@@ -310,13 +290,11 @@ def compute_convergence_delta(
310290
def _get_layers(self, model: Module) -> None:
311291
for layer in model.children():
312292
if len(list(layer.children())) == 0:
313-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
314293
self.layers.append(layer)
315294
else:
316295
self._get_layers(layer)
317296

318297
def _check_and_attach_rules(self) -> None:
319-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
320298
for layer in self.layers:
321299
if hasattr(layer, "rule"):
322300
layer.activations = {} # type: ignore
@@ -355,50 +333,41 @@ def _check_rules(self) -> None:
355333
)
356334

357335
def _register_forward_hooks(self) -> None:
358-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
359336
for layer in self.layers:
360337
if type(layer) in SUPPORTED_NON_LINEAR_LAYERS:
361338
backward_handles = _register_backward_hook(
362339
layer, PropagationRule.backward_hook_activation, self
363340
)
364-
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
365341
self.backward_handles.extend(backward_handles)
366342
else:
367343
forward_handle = layer.register_forward_hook(
368344
layer.rule.forward_hook # type: ignore
369345
)
370-
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
371346
self.forward_handles.append(forward_handle)
372-
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
373347
if self.verbose:
374348
print(f"Applied {layer.rule} on layer {layer}")
375349

376350
def _register_weight_hooks(self) -> None:
377-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
378351
for layer in self.layers:
379352
if layer.rule is not None:
380353
forward_handle = layer.register_forward_hook(
381354
layer.rule.forward_hook_weights # type: ignore
382355
)
383-
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
384356
self.forward_handles.append(forward_handle)
385357

386358
def _register_pre_hooks(self) -> None:
387-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
388359
for layer in self.layers:
389360
if layer.rule is not None:
390361
forward_handle = layer.register_forward_pre_hook(
391362
layer.rule.forward_pre_hook_activations # type: ignore
392363
)
393-
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
394364
self.forward_handles.append(forward_handle)
395365

396366
def _compute_output_and_change_weights(
397367
self,
398368
inputs: Tuple[Tensor, ...],
399369
target: TargetType,
400-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
401-
additional_forward_args: Any,
370+
additional_forward_args: object,
402371
) -> Tensor:
403372
try:
404373
self._register_weight_hooks()
@@ -416,15 +385,12 @@ def _compute_output_and_change_weights(
416385
return cast(Tensor, output)
417386

418387
def _remove_forward_hooks(self) -> None:
419-
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
420388
for forward_handle in self.forward_handles:
421389
forward_handle.remove()
422390

423391
def _remove_backward_hooks(self) -> None:
424-
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
425392
for backward_handle in self.backward_handles:
426393
backward_handle.remove()
427-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
428394
for layer in self.layers:
429395
if hasattr(layer.rule, "_handle_input_hooks"):
430396
for handle in layer.rule._handle_input_hooks: # type: ignore
@@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None:
433399
layer.rule._handle_output_hook.remove() # type: ignore
434400

435401
def _remove_rules(self) -> None:
436-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
437402
for layer in self.layers:
438403
if hasattr(layer, "rule"):
439404
del layer.rule
440405

441406
def _clear_properties(self) -> None:
442-
# pyre-fixme[16]: `LRP` has no attribute `layers`.
443407
for layer in self.layers:
444408
if hasattr(layer, "activation"):
445409
del layer.activation

0 commit comments

Comments
 (0)