From 00f7535a8584d13219c6a616e82792031184ea83 Mon Sep 17 00:00:00 2001 From: frost-intel Date: Tue, 1 Jul 2025 19:07:45 +0000 Subject: [PATCH] Preserve RNG state for XPU in checkpointing --- torch/_functorch/partitioners.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 421b3570148e6..3edd4864ffc78 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1311,14 +1311,14 @@ def get_device(node) -> Optional[torch.device]: for candidate in candidates: if isinstance(candidate, torch.Tensor): - if candidate.device.type == "cuda": + if candidate.device.type in ["cuda", "xpu"]: return candidate.device return torch.device("cpu") def get_sample_rng_state(device: Optional[torch.device]): - if device is not None and device.type == "cuda": - return torch.cuda.get_rng_state() + if device is not None and device.type in ["cuda", "xpu"]: + return torch.get_device_module().get_rng_state() return torch.get_rng_state() # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.