-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Feature/double precision #6595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kaushikb11
merged 35 commits into
Lightning-AI:master
from
ethanwharris:feature/double_precision
Mar 24, 2021
Merged
Feature/double precision #6595
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
b224197
Add support for double precision `precision=64`.
ethanwharris d7a2098
Update CHANGELOG.md
ethanwharris 27c3e72
Minor changes
ethanwharris 55ced9e
Fix typings
ethanwharris 07bdb23
Switch to static methods
ethanwharris 564ad70
Use functools.wraps
ethanwharris dd79106
Update test
ethanwharris 0522ad8
Add teardown and pickle test
ethanwharris 4103196
Minor doc fix
ethanwharris 3006fab
Add copyright notice to test file
ethanwharris 3b53c81
Update error message in accelerator_connector.py
ethanwharris b1b8858
Add testfor training_step etc.
ethanwharris cf12a59
Switch patch logic to seperate class, and patch additional methods
ethanwharris df6d847
Switch to `.double()`
ethanwharris 72c9be4
Add check for original float32 data
ethanwharris 423302f
Enhance tests for double precision
ethanwharris b654be2
Update tests/plugins/test_double_plugin.py
ethanwharris b9c662b
Update tests/plugins/test_double_plugin.py
ethanwharris dd608b3
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 982767a
Update pytorch_lightning/plugins/precision/double.py
ethanwharris f92dd2c
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 68dce05
Update pytorch_lightning/plugins/precision/double.py
ethanwharris e8af281
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 9a8c021
Move `RandomFloatIntDataset`
ethanwharris 6489776
Fix type hint
ethanwharris f527a41
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 3da2d05
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 2e74cff
Update pytorch_lightning/plugins/precision/double.py
ethanwharris a7507ad
Update pytorch_lightning/plugins/precision/double.py
ethanwharris fa323a2
Update pytorch_lightning/plugins/precision/double.py
ethanwharris 23b21c5
Add type hints to args and kwargs
ethanwharris 210fd87
Fix failing tests
ethanwharris e7b6c7f
Switch `predict` to `predict_step`
ethanwharris 925d109
Merge branch 'master' into feature/double_precision
ethanwharris 59c093b
Remove line from test no longer needed
ethanwharris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import wraps | ||
from typing import Any, Sequence, Tuple, TYPE_CHECKING, List | ||
|
||
import torch | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
|
||
if TYPE_CHECKING: | ||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
|
||
class _DoublePrecisionPatch: | ||
"""Class to handle patching of methods in the ``LightningModule`` and subsequent teardown.""" | ||
|
||
def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None: | ||
self.model = model | ||
self.method_name = method_name | ||
self.old_method = old_method | ||
|
||
def teardown(self) -> None: | ||
setattr(self.model, self.method_name, self.old_method) | ||
|
||
@staticmethod | ||
def _to_double_precision(data: torch.Tensor) -> torch.Tensor: | ||
if data.is_floating_point(): | ||
return data.double() | ||
return data | ||
|
||
@staticmethod | ||
def _move_float_tensors_to_double(collection: Any) -> Any: | ||
return apply_to_collection( | ||
collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision | ||
) | ||
|
||
@classmethod | ||
def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch': | ||
old_method = getattr(model, method_name) | ||
|
||
@wraps(old_method) | ||
def new_method(*args: Any, **kwargs: Any) -> Any: | ||
return old_method( | ||
*_DoublePrecisionPatch._move_float_tensors_to_double(args), | ||
**_DoublePrecisionPatch._move_float_tensors_to_double(kwargs) | ||
) | ||
|
||
setattr(model, method_name, new_method if callable(old_method) else old_method) | ||
return cls(model, method_name, old_method) | ||
|
||
|
||
class DoublePrecisionPlugin(PrecisionPlugin): | ||
"""Plugin for training with double (``torch.float64``) precision.""" | ||
|
||
precision: int = 64 | ||
|
||
def __init__(self) -> None: | ||
self.patches: List[_DoublePrecisionPatch] = [] | ||
|
||
def connect( | ||
self, | ||
model: 'Module', | ||
optimizers: Sequence['Optimizer'], | ||
lr_schedulers: Sequence[Any], | ||
) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: | ||
"""Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`, | ||
`predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter | ||
`optimizers` or `lr_schedulers`.""" | ||
model = model.to(dtype=torch.float64) | ||
if isinstance(model, LightningModule): | ||
self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step')) | ||
self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step')) | ||
self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step')) | ||
self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step')) | ||
self.patches.append(_DoublePrecisionPatch.patch(model, 'forward')) | ||
|
||
return super().connect(model, optimizers, lr_schedulers) | ||
|
||
def post_dispatch(self) -> None: | ||
while len(self.patches) > 0: | ||
self.patches.pop().teardown() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from pytorch_lightning import Trainer | ||
from tests.helpers.boring_model import BoringModel, RandomDataset | ||
|
||
|
||
class RandomFloatIntDataset(Dataset): | ||
|
||
def __init__(self, size, length): | ||
self.len = length | ||
self.float_data = torch.randn(length, size) | ||
self.int_data = torch.randint(10, (length, 1)) | ||
|
||
def __getitem__(self, index): | ||
return self.float_data[index], self.int_data[index] | ||
|
||
def __len__(self): | ||
return self.len | ||
|
||
|
||
class DoublePrecisionBoringModel(BoringModel): | ||
|
||
def training_step(self, batch, batch_idx): | ||
float_data, int_data = batch | ||
assert float_data.dtype == torch.float64 | ||
output = self(float_data) | ||
loss = self.loss(batch, output) | ||
return {"loss": loss} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
assert batch.dtype == torch.float64 | ||
output = self(batch) | ||
loss = self.loss(batch, output) | ||
return {"x": loss} | ||
|
||
def test_step(self, batch, batch_idx): | ||
assert batch.dtype == torch.float64 | ||
output = self(batch) | ||
loss = self.loss(batch, output) | ||
return {"y": loss} | ||
|
||
def predict_step(self, batch, batch_idx, dataloader_idx=None): | ||
assert batch.dtype == torch.float64 | ||
return self(batch) | ||
|
||
def on_fit_start(self): | ||
assert self.layer.weight.dtype == torch.float64 | ||
|
||
def on_after_backward(self): | ||
assert self.layer.weight.grad.dtype == torch.float64 | ||
|
||
def train_dataloader(self): | ||
dataset = RandomFloatIntDataset(32, 64) | ||
assert dataset.float_data.dtype == torch.float32 # Don't start with double data | ||
return DataLoader(dataset) | ||
|
||
def predict_dataloader(self): | ||
return DataLoader(RandomDataset(32, 64)) | ||
|
||
|
||
class DoublePrecisionBoringModelNoForward(BoringModel): | ||
|
||
def training_step(self, batch, batch_idx): | ||
assert batch.dtype == torch.float64 | ||
output = self.layer(batch) | ||
assert output.dtype == torch.float64 | ||
loss = self.loss(batch, output) | ||
return {"loss": loss} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
assert batch.dtype == torch.float64 | ||
output = self.layer(batch) | ||
assert output.dtype == torch.float64 | ||
loss = self.loss(batch, output) | ||
return {"x": loss} | ||
|
||
def test_step(self, batch, batch_idx): | ||
assert batch.dtype == torch.float64 | ||
output = self.layer(batch) | ||
assert output.dtype == torch.float64 | ||
loss = self.loss(batch, output) | ||
return {"y": loss} | ||
|
||
def predict_step(self, batch, batch_idx, dataloader_idx=None): | ||
assert batch.dtype == torch.float64 | ||
output = self.layer(batch) | ||
assert output.dtype == torch.float64 | ||
return output | ||
|
||
def predict_dataloader(self): | ||
return DataLoader(RandomDataset(32, 64)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'boring_model', | ||
(DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward) | ||
) | ||
def test_double_precision(tmpdir, boring_model): | ||
model = boring_model() | ||
original_training_step = model.training_step | ||
|
||
trainer = Trainer( | ||
max_epochs=2, | ||
default_root_dir=tmpdir, | ||
fast_dev_run=2, | ||
precision=64, | ||
log_every_n_steps=1, | ||
) | ||
trainer.fit(model) | ||
trainer.test(model) | ||
trainer.predict(model) | ||
|
||
assert model.training_step == original_training_step |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.