Skip to content

Commit f1eb759

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer LRP pyre fixme issues (#1474)
Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67706680
1 parent ea3ce49 commit f1eb759

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

captum/attr/_core/layer/layer_lrp.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, cast, List, Literal, Optional, Tuple, Union
5+
from typing import cast, Dict, List, Literal, Optional, Tuple, TypeVar, Union
6+
7+
import torch
68

79
from captum._utils.common import (
810
_format_tensor_into_tuples,
@@ -21,8 +23,12 @@
2123
)
2224
from captum.attr._core.lrp import LRP
2325
from captum.attr._utils.attribution import LayerAttribution
26+
from captum.attr._utils.lrp_rules import PropagationRule
2427
from torch import Tensor
2528
from torch.nn import Module
29+
from torch.utils.hooks import RemovableHandle
30+
31+
T = TypeVar("T")
2632

2733

2834
class LayerLRP(LRP, LayerAttribution):
@@ -39,6 +45,13 @@ class LayerLRP(LRP, LayerAttribution):
3945
Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
4046
"""
4147

48+
device_ids: List[int]
49+
verbose: bool
50+
layers: List[Module]
51+
attribute_to_layer_input: bool = False
52+
backward_handles: List[RemovableHandle]
53+
forward_handles: List[RemovableHandle]
54+
4255
def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
4356
"""
4457
Args:
@@ -59,7 +72,6 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
5972
LayerAttribution.__init__(self, model, layer)
6073
LRP.__init__(self, model)
6174
if hasattr(self.model, "device_ids"):
62-
# pyre-fixme[4]: Attribute must be annotated.
6375
self.device_ids = cast(List[int], self.model.device_ids)
6476

6577
@typing.overload # type: ignore
@@ -208,56 +220,45 @@ def attribute(
208220
>>> attribution = layer_lrp.attribute(input, target=5)
209221
210222
"""
211-
# pyre-fixme[16]: `LayerLRP` has no attribute `verbose`.
212223
self.verbose = verbose
213-
# pyre-fixme[16]: `LayerLRP` has no attribute `_original_state_dict`.
214224
self._original_state_dict = self.model.state_dict()
215-
# pyre-fixme[16]: `LayerLRP` has no attribute `layers`.
216225
self.layers = []
217226
self._get_layers(self.model)
218227
self._check_and_attach_rules()
219-
# pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
220228
self.attribute_to_layer_input = attribute_to_layer_input
221-
# pyre-fixme[16]: `LayerLRP` has no attribute `backward_handles`.
222229
self.backward_handles = []
223-
# pyre-fixme[16]: `LayerLRP` has no attribute `forward_handles`.
224230
self.forward_handles = []
225231

226-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
227-
# `Tuple[Tensor, ...]`.
228-
inputs = _format_tensor_into_tuples(inputs)
229-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
230-
# `TensorOrTupleOfTensorsGeneric`.
231-
gradient_mask = apply_gradient_requirements(inputs)
232+
inputs_tuple = _format_tensor_into_tuples(inputs)
233+
gradient_mask = apply_gradient_requirements(inputs_tuple)
232234

233235
try:
234236
# 1. Forward pass
235237
output = self._compute_output_and_change_weights(
236-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
237-
# got `TensorOrTupleOfTensorsGeneric`.
238-
inputs,
238+
inputs_tuple,
239239
target,
240240
additional_forward_args,
241241
)
242242
self._register_forward_hooks()
243243
# 2. Forward pass + backward pass
244244
_ = compute_gradients(
245-
self._forward_fn_wrapper, inputs, target, additional_forward_args
245+
self._forward_fn_wrapper, inputs_tuple, target, additional_forward_args
246246
)
247247
relevances = self._get_output_relevance(output)
248248
finally:
249249
self._restore_model()
250-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
251-
# `TensorOrTupleOfTensorsGeneric`.
252-
undo_gradient_requirements(inputs, gradient_mask)
250+
undo_gradient_requirements(inputs_tuple, gradient_mask)
253251

254252
if return_convergence_delta:
255253
delta: Union[Tensor, List[Tensor]]
256254
if isinstance(self.layer, list):
257255
delta = []
258256
for relevance_layer in relevances:
259257
delta.append(
260-
self.compute_convergence_delta(relevance_layer, output)
258+
self.compute_convergence_delta(
259+
cast(Union[Tensor, Tuple[Tensor, ...]], relevance_layer),
260+
output,
261+
)
261262
)
262263
else:
263264
delta = self.compute_convergence_delta(
@@ -267,33 +268,35 @@ def attribute(
267268
else:
268269
return relevances # type: ignore
269270

270-
# pyre-fixme[3]: Return type must be annotated.
271-
# pyre-fixme[2]: Parameter must be annotated.
272-
def _get_single_output_relevance(self, layer, output):
273-
# pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
271+
def _get_single_output_relevance(
272+
self, layer: Module, output: Tensor
273+
) -> Union[Tensor, Tuple[Tensor, ...]]:
274274
if self.attribute_to_layer_input:
275-
normalized_relevances = layer.rule.relevance_input
275+
normalized_relevances = cast(
276+
Dict[torch.device, Tensor],
277+
cast(PropagationRule, layer.rule).relevance_input,
278+
)
276279
else:
277-
normalized_relevances = layer.rule.relevance_output
280+
normalized_relevances = cast(PropagationRule, layer.rule).relevance_output
278281
key_list = _sort_key_list(list(normalized_relevances.keys()), self.device_ids)
279-
normalized_relevances = _reduce_list(
282+
normalized_relevances_reduced = _reduce_list(
280283
[normalized_relevances[device_id] for device_id in key_list]
281284
)
282285

283-
if isinstance(normalized_relevances, tuple):
286+
if isinstance(normalized_relevances_reduced, tuple):
284287
return tuple(
285288
normalized_relevance
286289
* output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1))
287-
for normalized_relevance in normalized_relevances
290+
for normalized_relevance in normalized_relevances_reduced
288291
)
289292
else:
290-
return normalized_relevances * output.reshape(
291-
(-1,) + (1,) * (normalized_relevances.dim() - 1)
293+
return normalized_relevances_reduced * output.reshape(
294+
(-1,) + (1,) * (normalized_relevances_reduced.dim() - 1)
292295
)
293296

294-
# pyre-fixme[3]: Return type must be annotated.
295-
# pyre-fixme[2]: Parameter must be annotated.
296-
def _get_output_relevance(self, output):
297+
def _get_output_relevance(
298+
self, output: Tensor
299+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
297300
if isinstance(self.layer, list):
298301
relevances = []
299302
for layer in self.layer:
@@ -303,11 +306,9 @@ def _get_output_relevance(self, output):
303306
return self._get_single_output_relevance(self.layer, output)
304307

305308
@staticmethod
306-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
307309
def _convert_list_to_tuple(
308-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
309-
relevances: Union[List[Any], Tuple[Any, ...]]
310-
) -> Tuple[Any, ...]:
310+
relevances: Union[List[T], Tuple[T, ...]]
311+
) -> Tuple[T, ...]:
311312
if isinstance(relevances, list):
312313
return tuple(relevances)
313314
else:

captum/attr/_utils/lrp_rules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# pyre-strict
44

55
from abc import ABC, abstractmethod
6+
from typing import cast, Dict, List, Union
67

78
import torch
8-
99
from captum._utils.common import _format_tensor_into_tuples
10+
from torch import Tensor
1011

1112

1213
class PropagationRule(ABC):
@@ -15,6 +16,9 @@ class PropagationRule(ABC):
1516
STABILITY_FACTOR is used to assure that no zero divison occurs.
1617
"""
1718

19+
relevance_input: Dict[torch.device, Union[torch.Tensor, List[torch.Tensor]]] = {}
20+
relevance_output: Dict[torch.device, torch.Tensor] = {}
21+
1822
STABILITY_FACTOR = 1e-9
1923

2024
# pyre-fixme[3]: Return type must be annotated.
@@ -67,7 +71,7 @@ def _backward_hook_input(grad):
6771
# pyre-fixme[16]: `PropagationRule` has no attribute `relevance_input`.
6872
self.relevance_input[device] = relevance.data
6973
else:
70-
self.relevance_input[device].append(relevance.data)
74+
cast(List[Tensor], self.relevance_input[device]).append(relevance.data)
7175

7276
# replace_out is needed since two hooks are set on the same tensor
7377
# The output of this hook is needed in backward_hook_activation

0 commit comments

Comments
 (0)