From 7b08d27eeca945658ad73a563f5a6526fc575051 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 20 Nov 2024 10:05:23 -0800 Subject: [PATCH] Prepare for "Fix type-safety of `torch.nn.Module` instances": fbcode/p* (#1448) Summary: See D52890934 Reviewed By: r-barnes Differential Revision: D66235323 --- captum/attr/_core/feature_ablation.py | 5 ++++- captum/attr/_core/occlusion.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index dad8f47568..e5e60bb465 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -808,7 +808,10 @@ def _construct_ablated_input( current_mask = current_mask.to(expanded_input.device) assert baseline is not None, "baseline must be provided" ablated_tensor = ( - expanded_input * (1 - current_mask).to(expanded_input.dtype) + expanded_input + * (1 - current_mask).to(expanded_input.dtype) + # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, + # Tensor]` and `Tensor`. ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index fe5105c96a..f6bfcbe8a8 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -323,6 +323,8 @@ def _construct_ablated_input( torch.ones(1, dtype=torch.long, device=expanded_input.device) - input_mask ).to(expanded_input.dtype) + # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, + # Tensor]` and `Tensor`. ) + (baseline * input_mask.to(expanded_input.dtype)) return ablated_tensor, input_mask