Skip to content

Commit 41cfa33

Browse files
authored
Address feedback for Fabric.init_module() (2/4) (#17722)
1 parent 88cd100 commit 41cfa33

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

src/lightning/fabric/fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None
898898
@contextmanager
899899
def _old_sharded_model_context(strategy: Strategy) -> Generator:
900900
if isinstance(strategy, _Sharded):
901-
with strategy.init_sharded_context():
901+
with strategy.module_sharded_context():
902902
yield
903903
else:
904904
yield

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
340340
raise NotImplementedError(self._err_msg_joint_setup_required())
341341

342342
@contextmanager
343-
def init_sharded_context(self) -> Generator[None, None, None]:
343+
def module_sharded_context(self) -> Generator[None, None, None]:
344344
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
345345
# manager, which happens at the start of `Fabric.run()`. Later modifications through e.g. `Fabric.setup()`
346346
# won't have an effect here.

src/lightning/fabric/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def init_context(self) -> Generator[None, None, None]:
260260
yield
261261

262262
@contextmanager
263-
def init_sharded_context(self) -> Generator:
263+
def module_sharded_context(self) -> Generator:
264264
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
265265
from torch.distributed.fsdp.wrap import enable_wrap
266266

src/lightning/fabric/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ class _Sharded(ABC):
396396

397397
@abstractmethod
398398
@contextmanager
399-
def init_sharded_context(self) -> Generator:
399+
def module_sharded_context(self) -> Generator:
400400
"""A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding
401401
of parameters on creation.
402402

tests/tests_fabric/test_fabric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,15 +718,15 @@ def test_module_sharding_context():
718718
"""Test that the sharding context manager gets applied when the strategy supports it and is a no-op
719719
otherwise."""
720720
fabric = Fabric()
721-
fabric._strategy = MagicMock(spec=DDPStrategy, init_sharded_context=Mock())
721+
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
722722
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
723723
pass
724-
fabric._strategy.init_sharded_context.assert_not_called()
724+
fabric._strategy.module_sharded_context.assert_not_called()
725725

726726
fabric._strategy = MagicMock(spec=_Sharded)
727727
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
728728
pass
729-
fabric._strategy.init_sharded_context.assert_called_once()
729+
fabric._strategy.module_sharded_context.assert_called_once()
730730

731731

732732
def test_init_module_context(monkeypatch):

0 commit comments

Comments
 (0)