Skip to content

Commit 349d3f7

Browse files
ezyangfacebook-github-bot
authored andcommitted
Prepare for "Fix type-safety of torch.nn.Module instances": fbcode/p* (meta-pytorch#1448)
Summary: See D52890934 Differential Revision: D66235323
1 parent 72a32af commit 349d3f7

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,10 @@ def _construct_ablated_input(
808808
current_mask = current_mask.to(expanded_input.device)
809809
assert baseline is not None, "baseline must be provided"
810810
ablated_tensor = (
811-
expanded_input * (1 - current_mask).to(expanded_input.dtype)
811+
expanded_input
812+
* (1 - current_mask).to(expanded_input.dtype)
813+
# pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
814+
# Tensor]` and `Tensor`.
812815
) + (baseline * current_mask.to(expanded_input.dtype))
813816
return ablated_tensor, current_mask
814817

captum/attr/_core/occlusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ def _construct_ablated_input(
323323
torch.ones(1, dtype=torch.long, device=expanded_input.device)
324324
- input_mask
325325
).to(expanded_input.dtype)
326+
# pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
327+
# Tensor]` and `Tensor`.
326328
) + (baseline * input_mask.to(expanded_input.dtype))
327329
return ablated_tensor, input_mask
328330

0 commit comments

Comments
 (0)