We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 871cc5e commit 624ad9bCopy full SHA for 624ad9b
torch/optim/optimizer.py
@@ -217,7 +217,7 @@ def _get_scalar_dtype(is_fused=None):
217
218
def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
219
r"""Return the device type list that supports capturable optimizer."""
220
- capturable_supported_devices = ["cuda"]
+ capturable_supported_devices = ["cuda", "xpu", "hpu"]
221
if not torch.jit.is_scripting():
222
capturable_supported_devices.append(torch._C._get_privateuse1_backend_name())
223
if supports_xla:
0 commit comments