Skip to content

Commit dc400ac

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers (#1505)
Summary: We are adding test for unsupported gradient layers. Open to ideas if there is a better way to structure the test. A bit uncomfortable with removing pyre type validations as we allow anything to be passed into the GradientUnsupportedLayerOutput class. Differential Revision: D69792994
1 parent a799dfd commit dc400ac

File tree

1 file changed

+78
-2
lines changed

1 file changed

+78
-2
lines changed

captum/testing/helpers/basic_models.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import no_type_check, Optional, Tuple, Union
5+
from typing import Iterable, no_type_check, Optional, Tuple, Union
66

77
import torch
88
import torch.nn as nn
@@ -417,6 +417,77 @@ def forward(self, input1, input2, input3=None):
417417
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
418418

419419

420+
class GradientUnsupportedLayerOutput(nn.Module):
421+
"""
422+
This layer is used to test the case where the model returns a layer that
423+
is not supported by the gradient computation.
424+
"""
425+
426+
def __init__(self) -> None:
427+
super().__init__()
428+
429+
@no_type_check
430+
def forward(
431+
self, unsupported_layer_output: Optional[Iterable[Tensor]]
432+
) -> Optional[Iterable[Tensor]]:
433+
return unsupported_layer_output
434+
435+
436+
class BasicModel_GradientLayerAttribution(nn.Module):
437+
def __init__(
438+
self,
439+
inplace: bool = False,
440+
unsupported_layer_output: Optional[Iterable[Tensor]] = None,
441+
) -> None:
442+
super().__init__()
443+
# Linear 0 is simply identity transform
444+
self.unsupported_layer_output = unsupported_layer_output
445+
self.linear0 = nn.Linear(3, 3)
446+
self.linear0.weight = nn.Parameter(torch.eye(3))
447+
self.linear0.bias = nn.Parameter(torch.zeros(3))
448+
self.linear1 = nn.Linear(3, 4)
449+
self.linear1.weight = nn.Parameter(torch.ones(4, 3))
450+
self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
451+
452+
self.linear1_alt = nn.Linear(3, 4)
453+
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
454+
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
455+
456+
self.relu = nn.ReLU(inplace=inplace)
457+
self.relu_alt = nn.ReLU(inplace=inplace)
458+
self.unsupportedLayer = GradientUnsupportedLayerOutput()
459+
460+
self.linear2 = nn.Linear(4, 2)
461+
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
462+
self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
463+
464+
self.linear3 = nn.Linear(4, 2)
465+
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
466+
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
467+
468+
@no_type_check
469+
# pyre-fixme[3]: Return type must be annotated.
470+
def forward(self, x: Tensor, add_input: Optional[Tensor] = None):
471+
input = x if add_input is None else x + add_input
472+
lin0_out = self.linear0(input)
473+
lin1_out = self.linear1(lin0_out)
474+
lin1_out_alt = self.linear1_alt(lin0_out)
475+
476+
if self.unsupported_layer_output is not None:
477+
self.unsupportedLayer(self.unsupported_layer_output)
478+
# unsupportedLayer is unused in the forward func. Used only to check whether layer output is supported.
479+
self.relu_alt(
480+
lin1_out_alt
481+
) # relu_alt's output is supported but it's unused in the forward func.
482+
483+
relu_out = self.relu(lin1_out)
484+
lin2_out = self.linear2(relu_out)
485+
486+
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487+
488+
return torch.cat((lin2_out, lin3_out), dim=1)
489+
490+
420491
class MultiRelu(nn.Module):
421492
def __init__(self, inplace: bool = False) -> None:
422493
super().__init__()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429500

430501

431502
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
503+
def __init__(
504+
self,
505+
inplace: bool = False,
506+
multi_input_module: bool = False,
507+
) -> None:
433508
super().__init__()
434509
# Linear 0 is simply identity transform
435510
self.multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461536
input = x if add_input is None else x + add_input
462537
lin0_out = self.linear0(input)
463538
lin1_out = self.linear1(lin0_out)
539+
464540
if self.multi_input_module:
465541
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
466542
relu_out = relu_out1 + relu_out2

0 commit comments

Comments
 (0)