From 0637172b87a47f0c18c6656df588a8c187324b21 Mon Sep 17 00:00:00 2001 From: Samuel Yusuf Date: Mon, 17 Mar 2025 08:17:09 -0700 Subject: [PATCH] Control for when output from model is a scalar or a 1D tensor (#1521) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1521 This is to make sure that we control for when the output is not a 2D tensor We also include an output accessor that parses a dictionary model output to get final output. Differential Revision: D69876980 --- captum/testing/helpers/basic_models.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 77584594a9..72a26607ab 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import no_type_check, Optional, Tuple, Union +from typing import Dict, no_type_check, Optional, Tuple, Union import torch import torch.nn as nn @@ -467,7 +467,9 @@ def __init__( self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) @no_type_check - def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: Tensor, add_input: Optional[Tensor] = None + ) -> Dict[str, Tensor]: input = x if add_input is None else x + add_input lin0_out = self.linear0(input) lin1_out = self.linear1(lin0_out) @@ -485,7 +487,14 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: lin3_out = self.linear3(lin1_out_alt).to(torch.int64) - return torch.cat((lin2_out, lin3_out), dim=1) + output_tensors = torch.cat((lin2_out, lin3_out), dim=1) + + # we return a dictionary of tensors as an output to test the case + # where an output accessor is required + return { + "task {}".format(i + 1): output_tensors[:, i] + for i in range(output_tensors.shape[1]) + } class MultiRelu(nn.Module):