Skip to content

Commit 30ac908

Browse files
ezyangfacebook-github-bot
authored andcommitted
Add pyre fixme for downstream type errors from D65753120 (#1439)
Summary: Pull Request resolved: #1439 X-link: facebook/Ax#3055 X-link: ctrl-labs/src2#38515 X-link: ctrl-labs/src2#38514 bypass-github-export-checks "The check is bugged, I exported all the required exports" Reviewed By: jermenkoo Differential Revision: D65826205 fbshipit-source-id: b08f00caaac8f3cc235bc280c4fe3f99089a0753
1 parent e3a3574 commit 30ac908

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,9 +806,11 @@ def _construct_ablated_input(
806806
dim=0,
807807
).long()
808808
current_mask = current_mask.to(expanded_input.device)
809-
assert baseline is not None, "baseline must be provided"
810809
ablated_tensor = (
811-
expanded_input * (1 - current_mask).to(expanded_input.dtype)
810+
expanded_input
811+
* (1 - current_mask).to(expanded_input.dtype)
812+
# pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
813+
# Tensor]` and `Tensor`.
812814
) + (baseline * current_mask.to(expanded_input.dtype))
813815
return ablated_tensor, current_mask
814816

captum/attr/_core/occlusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,14 @@ def _construct_ablated_input(
316316
],
317317
dim=0,
318318
).long()
319-
assert baseline is not None, "baseline should not be None"
320319
ablated_tensor = (
321320
expanded_input
322321
* (
323322
torch.ones(1, dtype=torch.long, device=expanded_input.device)
324323
- input_mask
325324
).to(expanded_input.dtype)
325+
# pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
326+
# Tensor]` and `Tensor`.
326327
) + (baseline * input_mask.to(expanded_input.dtype))
327328
return ablated_tensor, input_mask
328329

captum/module/gaussian_stochastic_gates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _get_gate_active_probs(self) -> Tensor:
133133
probs (Tensor): probabilities tensor of the gates are active
134134
in shape(n_gates)
135135
"""
136-
std = self.std
137-
assert std is not None, "std should not be None"
138-
x = self.mu / std
136+
# pyre-fixme[58]: `/` is not supported for operand types `Parameter` and
137+
# `Optional[float]`.
138+
x = self.mu / self.std
139139
return 0.5 * (1 + torch.erf(x / math.sqrt(2)))
140140

141141
@classmethod

tests/attr/test_input_x_gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def _input_x_gradient_classification_assert(self, nt_type: str = "vanilla") -> N
117117
attributions = input_x_grad.attribute(input, target)
118118
output = model(input)[:, target]
119119
output.backward()
120-
input_grad = input.grad
121-
assert input_grad is not None
122-
expected = input_grad * input
120+
# pyre-fixme[58]: `*` is not supported for operand types
121+
# `Optional[Tensor]` and `Tensor`.
122+
expected = input.grad * input
123123
assertTensorAlmostEqual(self, attributions, expected, 0.00001, "max")
124124
else:
125125
nt = NoiseTunnel(input_x_grad)

0 commit comments

Comments
 (0)