Skip to content

Commit 7ee195d

Browse files
Jerry-Masterfacebook-github-bot
authored andcommitted
mps fix (#1227)
Summary: By adding that single conversion from float64 to float32 the integrated gradients is fully compatible with MPS backend, which it wasn't previously. The change can be accepted as valid for any backend since float64 is a precision almost nobody using nowadays in deep learning. It seems that the default for torch.tensor is not updated with current trends. Pull Request resolved: #1227 Reviewed By: cyrjano Differential Revision: D54047349 Pulled By: vivekmig fbshipit-source-id: 1ffda83f065a3f14fa3c5b0229fe0feb5035cc99
1 parent f31b0be commit 7ee195d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

captum/attr/_core/integrated_gradients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _attribute(
359359
# calling contiguous to avoid `memory whole` problems
360360
scaled_grads = [
361361
grad.contiguous().view(n_steps, -1)
362-
* torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
362+
* torch.tensor(step_sizes).float().view(n_steps, 1).to(grad.device)
363363
for grad in grads
364364
]
365365

0 commit comments

Comments
 (0)