diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 421b3570148e6..3edd4864ffc78 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1311,14 +1311,14 @@ def get_device(node) -> Optional[torch.device]: for candidate in candidates: if isinstance(candidate, torch.Tensor): - if candidate.device.type == "cuda": + if candidate.device.type in ["cuda", "xpu"]: return candidate.device return torch.device("cpu") def get_sample_rng_state(device: Optional[torch.device]): - if device is not None and device.type == "cuda": - return torch.cuda.get_rng_state() + if device is not None and device.type in ["cuda", "xpu"]: + return torch.get_device_module().get_rng_state() return torch.get_rng_state() # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.