-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Enable ZeRO tests for CI, fix to/half function calls for LightningDistributedWrapper #6070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
655969f
Enable ZeRO optimization, and make sure that the lightning module hoo…
75a54e2
Added test, update to function
c5413ab
Use device type mixin
1c1c114
Add precision
1d1a1e1
Turn off zero for checking optimizers are correct
d3fcc09
Remove import
f5d25fd
Use FP16 Wrapper
09289ea
Merge branch 'master' into fix/enable_zero_fix_dtype
cf6bd94
Move mixin to the base class
8969d17
Better name for the test, test precision move
7795fcb
Added CHANGELOG.md
98c152d
Revert "Added CHANGELOG.md"
47e606d
Move precision check into a separate test that requires cuda
09d7667
Merge branch 'master' into fix/enable_zero_fix_dtype
23824aa
Provide ZeRO config
7a6cd1e
Revert "Revert "Added CHANGELOG.md""
49ec362
Support torch device as input to cuda, as is with upstream pytorch
71ea9bf
Modify test to include all possible cuda variations
0f1b325
Merge branch 'master' into fix/enable_zero_fix_dtype
17715c7
Trigger Build
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,52 @@ | |
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin | ||
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule | ||
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from tests.helpers.boring_model import BoringModel | ||
|
||
|
||
def test_deepspeed_lightning_module(tmpdir): | ||
""" | ||
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly. | ||
""" | ||
|
||
model = BoringModel() | ||
module = LightningDeepSpeedModule(model, precision=16) | ||
|
||
module.half() | ||
assert module.dtype == torch.half | ||
assert model.dtype == torch.half | ||
|
||
module.to(torch.double) | ||
assert module.dtype == torch.double | ||
assert model.dtype == torch.double | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") | ||
def test_deepspeed_lightning_module_precision(tmpdir): | ||
""" | ||
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16. | ||
""" | ||
|
||
model = BoringModel() | ||
module = LightningDeepSpeedModule(model, precision=16) | ||
|
||
module.cuda().half() | ||
assert module.dtype == torch.half | ||
assert model.dtype == torch.half | ||
|
||
x = torch.randn((1, 32), dtype=torch.float).cuda() | ||
out = module(x) | ||
|
||
assert out.dtype == torch.half | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am really suprised we didn't add this kind of test before. |
||
|
||
module.to(torch.double) | ||
assert module.dtype == torch.double | ||
assert model.dtype == torch.double | ||
|
||
|
||
@pytest.fixture | ||
def deepspeed_config(): | ||
return { | ||
|
@@ -34,6 +75,11 @@ def deepspeed_config(): | |
} | ||
|
||
|
||
@pytest.fixture | ||
def deepspeed_zero_config(deepspeed_config): | ||
return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}} | ||
|
||
|
||
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") | ||
def test_deepspeed_plugin_string(tmpdir): | ||
""" | ||
|
@@ -179,12 +225,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args | |
return loss.backward() | ||
|
||
model = TestModel() | ||
trainer = Trainer( | ||
fast_dev_run=True, | ||
default_root_dir=tmpdir, | ||
plugins=DeepSpeedPlugin(zero_optimization=False), | ||
gpus=1, | ||
) | ||
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, plugins=DeepSpeedPlugin(), gpus=1, precision=16) | ||
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): | ||
trainer.fit(model) | ||
|
||
|
@@ -203,17 +244,21 @@ def test_deepspeed_run_configure_optimizers(tmpdir): | |
class TestModel(BoringModel): | ||
|
||
def on_train_start(self) -> None: | ||
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) | ||
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer | ||
|
||
assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) | ||
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) | ||
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally | ||
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler | ||
assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) | ||
|
||
model = TestModel() | ||
trainer = Trainer( | ||
plugins=DeepSpeedPlugin(zero_optimization=False), | ||
plugins=DeepSpeedPlugin(), # disable ZeRO so our optimizers are not wrapped | ||
default_root_dir=tmpdir, | ||
gpus=1, | ||
fast_dev_run=True, | ||
precision=16 | ||
) | ||
|
||
trainer.fit(model) | ||
|
@@ -226,7 +271,7 @@ def on_train_start(self) -> None: | |
@pytest.mark.skipif( | ||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" | ||
) | ||
def test_deepspeed_config(tmpdir, deepspeed_config): | ||
def test_deepspeed_config(tmpdir, deepspeed_zero_config): | ||
""" | ||
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers | ||
and saves the model weights to load correctly. | ||
|
@@ -235,18 +280,22 @@ def test_deepspeed_config(tmpdir, deepspeed_config): | |
class TestModel(BoringModel): | ||
|
||
def on_train_start(self) -> None: | ||
import deepspeed | ||
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) | ||
from deepspeed.runtime.lr_schedules import WarmupLR | ||
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer | ||
|
||
assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) | ||
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) | ||
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally | ||
assert isinstance(self.trainer.model.optimizer, torch.optim.SGD) | ||
assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR) | ||
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler | ||
assert isinstance(self.trainer.model.lr_scheduler, WarmupLR) | ||
|
||
model = TestModel() | ||
trainer = Trainer( | ||
plugins=[DeepSpeedPlugin(config=deepspeed_config)], | ||
plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], | ||
default_root_dir=tmpdir, | ||
gpus=1, | ||
fast_dev_run=True, | ||
precision=16 | ||
) | ||
|
||
trainer.fit(model) | ||
|
@@ -267,7 +316,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): | |
""" | ||
model = BoringModel() | ||
trainer = Trainer( | ||
plugins=[DeepSpeedPlugin(zero_optimization=False)], | ||
plugins=[DeepSpeedPlugin()], | ||
default_root_dir=tmpdir, | ||
gpus=2, | ||
fast_dev_run=True, | ||
|
@@ -285,8 +334,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer): | |
# carry out the check only on rank 0 | ||
if trainer.global_rank == 0: | ||
saved_model = BoringModel.load_from_checkpoint(checkpoint_path) | ||
saved_model = saved_model.float() | ||
model = model.float().cpu() | ||
if model.dtype == torch.half: | ||
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16 | ||
model = model.cpu() | ||
# Assert model parameters are identical after loading | ||
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): | ||
assert torch.equal(orig_param, trained_model_param) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.