Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
72 changes: 18 additions & 54 deletions captum/attr/_core/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import typing
from collections import defaultdict
from typing import Any, Callable, cast, List, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union

import torch.nn as nn
from captum._utils.common import (
Expand All @@ -18,7 +18,7 @@
apply_gradient_requirements,
undo_gradient_requirements,
)
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.attr._utils.common import _sum_rows
from captum.attr._utils.custom_modules import Addition_Module
Expand All @@ -43,6 +43,12 @@ class LRP(GradientAttribution):
Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
"""

verbose: bool = False
_original_state_dict: Dict[str, Any] = {}
layers: List[Module] = []
backward_handles: List[RemovableHandle] = []
forward_handles: List[RemovableHandle] = []

def __init__(self, model: Module) -> None:
r"""
Args:
Expand All @@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool:
return True

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `75`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
verbose: bool = False,
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `65`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
additional_forward_args: object = None,
return_convergence_delta: Literal[False] = False,
verbose: bool = False,
) -> TensorOrTupleOfTensorsGeneric: ...
Expand All @@ -100,7 +95,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: object = None,
return_convergence_delta: bool = False,
verbose: bool = False,
) -> Union[
Expand Down Expand Up @@ -199,43 +194,30 @@ def attribute(
>>> attribution = lrp.attribute(input, target=5)

"""
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
self.verbose = verbose
# pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`.
self._original_state_dict = self.model.state_dict()
# pyre-fixme[16]: `LRP` has no attribute `layers`.
self.layers: List[Module] = []
self.layers = []
self._get_layers(self.model)
self._check_and_attach_rules()
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
self.backward_handles: List[RemovableHandle] = []
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles: List[RemovableHandle] = []

# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
input_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(input_tuple)

try:
# 1. Forward pass: Change weights of layers according to selected rules.
output = self._compute_output_and_change_weights(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs,
input_tuple,
target,
additional_forward_args,
)
# 2. Forward pass + backward pass: Register hooks to configure relevance
# propagation and execute back-propagation.
self._register_forward_hooks()
normalized_relevances = self.gradient_func(
self._forward_fn_wrapper, inputs, target, additional_forward_args
self._forward_fn_wrapper, input_tuple, target, additional_forward_args
)
relevances = tuple(
normalized_relevance
Expand All @@ -245,9 +227,7 @@ def attribute(
finally:
self._restore_model()

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(input_tuple, gradient_mask)

if return_convergence_delta:
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
Expand Down Expand Up @@ -310,13 +290,11 @@ def compute_convergence_delta(
def _get_layers(self, model: Module) -> None:
for layer in model.children():
if len(list(layer.children())) == 0:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
self.layers.append(layer)
else:
self._get_layers(layer)

def _check_and_attach_rules(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "rule"):
layer.activations = {} # type: ignore
Expand Down Expand Up @@ -355,50 +333,41 @@ def _check_rules(self) -> None:
)

def _register_forward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if type(layer) in SUPPORTED_NON_LINEAR_LAYERS:
backward_handles = _register_backward_hook(
layer, PropagationRule.backward_hook_activation, self
)
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
self.backward_handles.extend(backward_handles)
else:
forward_handle = layer.register_forward_hook(
layer.rule.forward_hook # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
if self.verbose:
print(f"Applied {layer.rule} on layer {layer}")

def _register_weight_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if layer.rule is not None:
forward_handle = layer.register_forward_hook(
layer.rule.forward_hook_weights # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)

def _register_pre_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if layer.rule is not None:
forward_handle = layer.register_forward_pre_hook(
layer.rule.forward_pre_hook_activations # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)

def _compute_output_and_change_weights(
self,
inputs: Tuple[Tensor, ...],
target: TargetType,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
additional_forward_args: object,
) -> Tensor:
try:
self._register_weight_hooks()
Expand All @@ -416,15 +385,12 @@ def _compute_output_and_change_weights(
return cast(Tensor, output)

def _remove_forward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
for forward_handle in self.forward_handles:
forward_handle.remove()

def _remove_backward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
for backward_handle in self.backward_handles:
backward_handle.remove()
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer.rule, "_handle_input_hooks"):
for handle in layer.rule._handle_input_hooks: # type: ignore
Expand All @@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None:
layer.rule._handle_output_hook.remove() # type: ignore

def _remove_rules(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "rule"):
del layer.rule

def _clear_properties(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "activation"):
del layer.activation
Expand Down
Loading
Loading