From 66b4bd6d2b71637839a027c8bb5d32e4f4e7ef55 Mon Sep 17 00:00:00 2001 From: Samuel Yusuf Date: Wed, 5 Mar 2025 13:47:00 -0800 Subject: [PATCH] Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers (#1505) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1505 We are adding test for unsupported gradient layers. Open to ideas if there is a better way to structure the test. A bit uncomfortable with removing pyre type validations as we allow anything to be passed into the GradientUnsupportedLayerOutput class. Reviewed By: craymichael Differential Revision: D69792994 --- captum/_utils/typing.py | 1 + captum/testing/helpers/basic_models.py | 78 +++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index efa437649c..512c910f08 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -24,6 +24,7 @@ TupleOrTensorOrBoolGeneric = TypeVar( "TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool ) +PassThroughOutputType = TypeVar("PassThroughOutputType") ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]] diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 6eaf58e5d7..77584594a9 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from captum._utils.typing import PassThroughOutputType from torch import Tensor from torch.futures import Future @@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None): return self.linear2(self.relu(self.linear1(embeddings))).sum(1) +class GradientUnsupportedLayerOutput(nn.Module): + """ + This layer is used to test the case where the model returns a layer that + is not supported by the gradient computation. + """ + + def __init__(self) -> None: + super().__init__() + + @no_type_check + def forward( + self, unsupported_layer_output: PassThroughOutputType + ) -> PassThroughOutputType: + return unsupported_layer_output + + +class BasicModel_GradientLayerAttribution(nn.Module): + def __init__( + self, + inplace: bool = False, + unsupported_layer_output: PassThroughOutputType = None, + ) -> None: + super().__init__() + # Linear 0 is simply identity transform + self.unsupported_layer_output = unsupported_layer_output + self.linear0 = nn.Linear(3, 3) + self.linear0.weight = nn.Parameter(torch.eye(3)) + self.linear0.bias = nn.Parameter(torch.zeros(3)) + self.linear1 = nn.Linear(3, 4) + self.linear1.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.linear1_alt = nn.Linear(3, 4) + self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.relu = nn.ReLU(inplace=inplace) + self.relu_alt = nn.ReLU(inplace=False) + self.unsupportedLayer = GradientUnsupportedLayerOutput() + + self.linear2 = nn.Linear(4, 2) + self.linear2.weight = nn.Parameter(torch.ones(2, 4)) + self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + self.linear3 = nn.Linear(4, 2) + self.linear3.weight = nn.Parameter(torch.ones(2, 4)) + self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + @no_type_check + def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: + input = x if add_input is None else x + add_input + lin0_out = self.linear0(input) + lin1_out = self.linear1(lin0_out) + lin1_out_alt = self.linear1_alt(lin0_out) + + if self.unsupported_layer_output is not None: + self.unsupportedLayer(self.unsupported_layer_output) + # unsupportedLayer is unused in the forward func. + self.relu_alt( + lin1_out_alt + ) # relu_alt's output is supported but it's unused in the forward func. + + relu_out = self.relu(lin1_out) + lin2_out = self.linear2(relu_out) + + lin3_out = self.linear3(lin1_out_alt).to(torch.int64) + + return torch.cat((lin2_out, lin3_out), dim=1) + + class MultiRelu(nn.Module): def __init__(self, inplace: bool = False) -> None: super().__init__() @@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]: class BasicModel_MultiLayer(nn.Module): - def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None: + def __init__( + self, + inplace: bool = False, + multi_input_module: bool = False, + ) -> None: super().__init__() # Linear 0 is simply identity transform self.multi_input_module = multi_input_module @@ -461,6 +536,7 @@ def forward( input = x if add_input is None else x + add_input lin0_out = self.linear0(input) lin1_out = self.linear1(lin0_out) + if self.multi_input_module: relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input)) relu_out = relu_out1 + relu_out2