Skip to content

Commit 2fa1018

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers (#1505)
Summary: Pull Request resolved: #1505 We are adding test for unsupported gradient layers. Open to ideas if there is a better way to structure the test. A bit uncomfortable with removing pyre type validations as we allow anything to be passed into the GradientUnsupportedLayerOutput class. Differential Revision: D69792994
1 parent a799dfd commit 2fa1018

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

captum/_utils/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TupleOrTensorOrBoolGeneric = TypeVar(
2525
"TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool
2626
)
27+
PassThroughOutputType = TypeVar("PassThroughOutputType")
2728
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
2829
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
2930
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]

captum/testing/helpers/basic_models.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
10+
from captum._utils.typing import PassThroughOutputType
1011
from torch import Tensor
1112
from torch.futures import Future
1213

@@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
417418
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
418419

419420

421+
class GradientUnsupportedLayerOutput(nn.Module):
422+
"""
423+
This layer is used to test the case where the model returns a layer that
424+
is not supported by the gradient computation.
425+
"""
426+
427+
def __init__(self) -> None:
428+
super().__init__()
429+
430+
@no_type_check
431+
def forward(
432+
self, unsupported_layer_output: PassThroughOutputType
433+
) -> PassThroughOutputType:
434+
return unsupported_layer_output
435+
436+
437+
class BasicModel_GradientLayerAttribution(nn.Module):
438+
def __init__(
439+
self,
440+
inplace: bool = False,
441+
unsupported_layer_output: PassThroughOutputType = None,
442+
) -> None:
443+
super().__init__()
444+
# Linear 0 is simply identity transform
445+
self.unsupported_layer_output = unsupported_layer_output
446+
self.linear0 = nn.Linear(3, 3)
447+
self.linear0.weight = nn.Parameter(torch.eye(3))
448+
self.linear0.bias = nn.Parameter(torch.zeros(3))
449+
self.linear1 = nn.Linear(3, 4)
450+
self.linear1.weight = nn.Parameter(torch.ones(4, 3))
451+
self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
452+
453+
self.linear1_alt = nn.Linear(3, 4)
454+
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
455+
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
456+
457+
self.relu = nn.ReLU(inplace=inplace)
458+
self.relu_alt = nn.ReLU(inplace=False)
459+
self.unsupportedLayer = GradientUnsupportedLayerOutput()
460+
461+
self.linear2 = nn.Linear(4, 2)
462+
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
463+
self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
464+
465+
self.linear3 = nn.Linear(4, 2)
466+
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
467+
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468+
469+
@no_type_check
470+
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
471+
input = x if add_input is None else x + add_input
472+
lin0_out = self.linear0(input)
473+
lin1_out = self.linear1(lin0_out)
474+
lin1_out_alt = self.linear1_alt(lin0_out)
475+
476+
if self.unsupported_layer_output is not None:
477+
self.unsupportedLayer(self.unsupported_layer_output)
478+
# unsupportedLayer is unused in the forward func.
479+
self.relu_alt(
480+
lin1_out_alt
481+
) # relu_alt's output is supported but it's unused in the forward func.
482+
483+
relu_out = self.relu(lin1_out)
484+
lin2_out = self.linear2(relu_out)
485+
486+
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487+
488+
return torch.cat((lin2_out, lin3_out), dim=1)
489+
490+
420491
class MultiRelu(nn.Module):
421492
def __init__(self, inplace: bool = False) -> None:
422493
super().__init__()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429500

430501

431502
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
503+
def __init__(
504+
self,
505+
inplace: bool = False,
506+
multi_input_module: bool = False,
507+
) -> None:
433508
super().__init__()
434509
# Linear 0 is simply identity transform
435510
self.multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461536
input = x if add_input is None else x + add_input
462537
lin0_out = self.linear0(input)
463538
lin1_out = self.linear1(lin0_out)
539+
464540
if self.multi_input_module:
465541
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
466542
relu_out = relu_out1 + relu_out2

0 commit comments

Comments
 (0)