diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index e9afb8dc3f..0333637744 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -364,7 +364,7 @@ def _find_output_mode_and_verify( "returns a scalar." ) else: - agg_output_mode = False + agg_output_mode = perturbations_per_eval == 1 if not allow_multi_outputs: assert ( isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1 diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index e2dde0fddb..c987858bec 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -806,6 +806,30 @@ def func_future(*inp): lambda *inp: func_to_use(*inp), use_future=use_future ) + @parameterized.expand([True, False]) + def test_mutli_inp_shapley_batch_scalar_tensor_expanded(self, use_future) -> None: + def func(*inp): + sum_val = torch.sum(net(*inp)).item() + return torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0]) + + def func_future(*inp): + temp = net_fut(*inp) + temp.wait() + sum_val = torch.sum(temp.value()).item() + fut = Future() + fut.set_result(torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0])) + return fut + + if use_future: + net_fut = BasicModel_MultiLayer_MultiInput_with_Future() + func_to_use = func_future + else: + net = BasicModel_MultiLayer_MultiInput() + func_to_use = func + self._multi_input_batch_scalar_shapley_assert( + lambda *inp: func_to_use(*inp), use_future=use_future, expanded_output=True + ) + @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) def test_shapley_sampling_with_show_progress(self, mock_stderr) -> None: net = BasicModel_MultiLayer() @@ -947,7 +971,7 @@ def _single_int_input_multi_sample_batch_scalar_shapley_assert( ) def _multi_input_batch_scalar_shapley_assert( - self, func: Callable, use_future: bool = False + self, func: Callable, use_future: bool = False, expanded_output: bool = False ) -> None: inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) @@ -955,10 +979,11 @@ def _multi_input_batch_scalar_shapley_assert( mask1 = torch.tensor([[1, 1, 1]]) mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2]]) + out_mult = 3 if expanded_output else 1 expected = ( - [[3850.6666, 3850.6666, 3850.6666]], - [[306.6666, 3850.6666, 410.6666]], - [[306.6666, 3850.6666, 410.6666]], + [[3850.6666, 3850.6666, 3850.6666]] * out_mult, + [[306.6666, 3850.6666, 410.6666]] * out_mult, + [[306.6666, 3850.6666, 410.6666]] * out_mult, ) if use_future: self._shapley_test_assert_future(