diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index efa437649c..512c910f08 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -24,6 +24,7 @@ TupleOrTensorOrBoolGeneric = TypeVar( "TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool ) +PassThroughOutputType = TypeVar("PassThroughOutputType") ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]] diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 6eaf58e5d7..77584594a9 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from captum._utils.typing import PassThroughOutputType from torch import Tensor from torch.futures import Future @@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None): return self.linear2(self.relu(self.linear1(embeddings))).sum(1) +class GradientUnsupportedLayerOutput(nn.Module): + """ + This layer is used to test the case where the model returns a layer that + is not supported by the gradient computation. + """ + + def __init__(self) -> None: + super().__init__() + + @no_type_check + def forward( + self, unsupported_layer_output: PassThroughOutputType + ) -> PassThroughOutputType: + return unsupported_layer_output + + +class BasicModel_GradientLayerAttribution(nn.Module): + def __init__( + self, + inplace: bool = False, + unsupported_layer_output: PassThroughOutputType = None, + ) -> None: + super().__init__() + # Linear 0 is simply identity transform + self.unsupported_layer_output = unsupported_layer_output + self.linear0 = nn.Linear(3, 3) + self.linear0.weight = nn.Parameter(torch.eye(3)) + self.linear0.bias = nn.Parameter(torch.zeros(3)) + self.linear1 = nn.Linear(3, 4) + self.linear1.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.linear1_alt = nn.Linear(3, 4) + self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.relu = nn.ReLU(inplace=inplace) + self.relu_alt = nn.ReLU(inplace=False) + self.unsupportedLayer = GradientUnsupportedLayerOutput() + + self.linear2 = nn.Linear(4, 2) + self.linear2.weight = nn.Parameter(torch.ones(2, 4)) + self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + self.linear3 = nn.Linear(4, 2) + self.linear3.weight = nn.Parameter(torch.ones(2, 4)) + self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + @no_type_check + def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: + input = x if add_input is None else x + add_input + lin0_out = self.linear0(input) + lin1_out = self.linear1(lin0_out) + lin1_out_alt = self.linear1_alt(lin0_out) + + if self.unsupported_layer_output is not None: + self.unsupportedLayer(self.unsupported_layer_output) + # unsupportedLayer is unused in the forward func. + self.relu_alt( + lin1_out_alt + ) # relu_alt's output is supported but it's unused in the forward func. + + relu_out = self.relu(lin1_out) + lin2_out = self.linear2(relu_out) + + lin3_out = self.linear3(lin1_out_alt).to(torch.int64) + + return torch.cat((lin2_out, lin3_out), dim=1) + + class MultiRelu(nn.Module): def __init__(self, inplace: bool = False) -> None: super().__init__() @@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]: class BasicModel_MultiLayer(nn.Module): - def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None: + def __init__( + self, + inplace: bool = False, + multi_input_module: bool = False, + ) -> None: super().__init__() # Linear 0 is simply identity transform self.multi_input_module = multi_input_module @@ -461,6 +536,7 @@ def forward( input = x if add_input is None else x + add_input lin0_out = self.linear0(input) lin1_out = self.linear1(lin0_out) + if self.multi_input_module: relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input)) relu_out = relu_out1 + relu_out2