diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f67eec28deeeb..0b5a653b68835 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Fixed + +- Fix XLA strategy to add support for for global_ordinal, local_ordinal, world_size which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852)) --- diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index a227d2322b9a3..b8350872f22d9 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -66,6 +66,11 @@ def world_size(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.world_size() + import torch_xla.core.xla_model as xm return xm.xrt_world_size() @@ -82,6 +87,11 @@ def global_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.global_ordinal() + import torch_xla.core.xla_model as xm return xm.get_ordinal() @@ -98,6 +108,11 @@ def local_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.local_ordinal() + import torch_xla.core.xla_model as xm return xm.get_local_ordinal() diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7e33d5db87dd4..f6a24792a4316 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -97,3 +97,31 @@ def test_detect(monkeypatch): monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True) assert XLAEnvironment.detect() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch): + """Test XLA environment attributes when using XLA runtime >= 2.1.""" + + env = XLAEnvironment() + + with ( + mock.patch("torch_xla.runtime.world_size", return_value=4), + mock.patch("torch_xla.runtime.global_ordinal", return_value=2), + mock.patch("torch_xla.runtime.local_ordinal", return_value=1), + ): + env.world_size.cache_clear() + env.global_rank.cache_clear() + env.local_rank.cache_clear() + + assert env.world_size() == 4 + assert env.global_rank() == 2 + assert env.local_rank() == 1 + + env.set_world_size(100) + assert env.world_size() == 4 + + env.set_global_rank(100) + assert env.global_rank() == 2