Skip to content

Commit 27d9125

Browse files
authored
Cast input before moving to device for all strategies (#18264)
1 parent efa7b2f commit 27d9125

File tree

8 files changed

+26
-27
lines changed

8 files changed

+26
-27
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
138138
- Increased the minimum supported `wandb` version for `WandbLogger` from 0.12.0 to 0.12.10 ([#18171](https://github.com/Lightning-AI/lightning/pull/18171))
139139

140140

141+
- The input tensors now get cast to the right precision type before transfer to the device ([#18264](https://github.com/Lightning-AI/lightning/pull/18264))
142+
143+
141144
### Deprecated
142145

143146
- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> N
382382
"""
383383
trainer = self.trainer
384384

385+
batch = trainer.precision_plugin.convert_input(batch)
385386
batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx)
386387
batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx)
387388

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
211211
212212
"""
213213
trainer = self.trainer
214+
batch = trainer.precision_plugin.convert_input(batch)
214215
batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx)
215216
batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx)
216217

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
194194
self.batch_progress.is_last_batch = data_fetcher.done
195195

196196
trainer = self.trainer
197+
batch = trainer.precision_plugin.convert_input(batch)
197198
batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=0)
198199
batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=0)
199200

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -895,10 +895,3 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
895895
offload_params_device="nvme",
896896
offload_optimizer_device="nvme",
897897
)
898-
899-
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
900-
# The strategy casts the input before moving to the device
901-
# In all other strategies, the input gets converted in the `Strategy.*_step` methods
902-
# TODO: standardize this for all strategies
903-
batch = self.precision_plugin.convert_input(batch)
904-
return super().batch_to_device(batch, device, dataloader_idx)

src/lightning/pytorch/strategies/strategy.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
372372
See :meth:`~lightning.pytorch.core.module.LightningModule.training_step` for more details
373373
374374
"""
375-
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
376375
assert self.lightning_module is not None
377376
assert self.model is not None
378377
with self.precision_plugin.train_step_context():
@@ -394,7 +393,6 @@ def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
394393
See :meth:`~lightning.pytorch.core.module.LightningModule.validation_step` for more details
395394
396395
"""
397-
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
398396
assert self.lightning_module is not None
399397
assert self.model is not None
400398
with self.precision_plugin.val_step_context():
@@ -408,7 +406,6 @@ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
408406
See :meth:`~lightning.pytorch.core.module.LightningModule.test_step` for more details
409407
410408
"""
411-
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
412409
assert self.lightning_module is not None
413410
assert self.model is not None
414411
with self.precision_plugin.test_step_context():
@@ -422,7 +419,6 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any:
422419
See :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` for more details
423420
424421
"""
425-
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
426422
assert self.lightning_module is not None
427423
assert self.model is not None
428424
with self.precision_plugin.predict_step_context():

tests/tests_pytorch/models/test_gpu.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,23 @@ def to(self, *args, **kwargs):
210210
with patch.object(batch, "to", wraps=batch.to) as mocked:
211211
batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0"))
212212
mocked.assert_called_with(torch.device("cuda", 0))
213+
214+
215+
@RunIf(min_cuda_gpus=1)
216+
@pytest.mark.parametrize(
217+
("strategy", "precision", "expected_dtype"),
218+
[
219+
("auto", "16-mixed", torch.float32),
220+
("auto", "16-true", torch.float16),
221+
pytest.param("deepspeed", "bf16-true", torch.bfloat16, marks=RunIf(deepspeed=True, bf16_cuda=True)),
222+
],
223+
)
224+
def test_input_tensors_cast_before_transfer_to_device(strategy, precision, expected_dtype):
225+
class CustomBoringModel(BoringModel):
226+
def transfer_batch_to_device(self, batch, *args, **kwargs):
227+
assert batch.dtype == expected_dtype
228+
return super().transfer_batch_to_device(batch, *args, **kwargs)
229+
230+
model = CustomBoringModel()
231+
trainer = Trainer(strategy=strategy, devices=1, precision=precision, barebones=True, max_steps=2)
232+
trainer.fit(model)

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,22 +1256,6 @@ def configure_optimizers(self):
12561256
trainer.fit(model)
12571257

12581258

1259-
@RunIf(min_cuda_gpus=1, deepspeed=True)
1260-
def test_deepspeed_tensors_cast_to_fp16_before_hosted_on_device():
1261-
class CustomBoringModel(BoringModel):
1262-
def transfer_batch_to_device(self, batch, *args, **kwargs):
1263-
assert batch.dtype is torch.float16
1264-
return super().transfer_batch_to_device(batch, *args, **kwargs)
1265-
1266-
model = CustomBoringModel()
1267-
trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision="16-mixed")
1268-
trainer.strategy.connect(model)
1269-
batch = torch.zeros(1, dtype=torch.float32)
1270-
batch = trainer.strategy.batch_to_device(batch)
1271-
assert batch.is_cuda
1272-
assert batch.dtype is torch.float16
1273-
1274-
12751259
@RunIf(deepspeed=True)
12761260
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
12771261
def test_validate_parallel_devices_indices(device_indices):

0 commit comments

Comments
 (0)