Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.3.1] - 2021-05-11

### Fixed

- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362))
- Fixed `Trainer.current_epoch` not getting restored after tuning ([#7434](https://github.com/PyTorchLightning/pytorch-lightning/pull/7434))
- Fixed local rank displayed in console log ([#7395](https://github.com/PyTorchLightning/pytorch-lightning/pull/7395))


## [1.3.0] - 2021-05-06

### Added
Expand Down
6 changes: 4 additions & 2 deletions dockers/nvidia/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel_21-03.html#rel_21-03
FROM nvcr.io/nvidia/pytorch:21.03-py3
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
FROM nvcr.io/nvidia/pytorch:21.04-py3

MAINTAINER PyTorchLightning <https://github.com/PyTorchLightning>

Expand Down Expand Up @@ -46,6 +46,8 @@ RUN \
rm -rf pytorch-lightning && \
pip list

RUN pip install lightning-grid -U

ENV PYTHONPATH="/workspace"

RUN \
Expand Down
2 changes: 0 additions & 2 deletions docs/source/governance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,3 @@ Alumni
- Jeff Ling (`jeffling <https://github.com/jeffling>`_)
- Teddy Koker (`teddykoker <https://github.com/teddykoker>`_)
- Nate Raw (`nateraw <https://github.com/nateraw>`_)


2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.3.0'
__version__ = '1.3.1'
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
self.set_nvidia_flags(trainer.local_rank)
torch.cuda.set_device(self.root_device)
return super().setup(trainer, model)

Expand All @@ -55,12 +55,12 @@ def teardown(self) -> None:
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags() -> None:
def set_nvidia_flags(local_rank: int) -> None:
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, batch: Any) -> Any:
# no need to transfer batch to device in DP mode
Expand Down
27 changes: 24 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
allgather_bucket_size: int = 2e8,
reduce_bucket_size: int = 2e8,
zero_allow_untested_optimizer: bool = True,
logging_batch_size_per_gpu: Union[str, int] = "auto",
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
num_nodes: int = 1,
Expand Down Expand Up @@ -148,6 +149,13 @@ def __init__(
zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
DeepSpeed supported optimizer when using ZeRO (default: True)

logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging
on a per sample per second basis (only displayed if logging=logging.INFO).
If set to "auto", the plugin tries to infer this from
the train DataLoader's BatchSampler, else defaults to 1.
To obtain accurate logs when using datasets that do not support batch samplers,
set this to the actual per gpu batch size (trainer.batch_size).

config: Pass in a deepspeed formatted config dict,
or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
All defaults will be ignored if a config is passed in. (Default: ``None``)
Expand Down Expand Up @@ -182,6 +190,7 @@ def __init__(
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
rather than individual sharded weight files.
Disable to save sharded states individually. (Default: True)

"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand All @@ -197,6 +206,7 @@ def __init__(
self.config = self._create_default_config(
zero_optimization,
zero_allow_untested_optimizer,
logging_batch_size_per_gpu,
partition_activations=partition_activations,
cpu_checkpointing=cpu_checkpointing,
contiguous_memory_optimization=contiguous_memory_optimization,
Expand Down Expand Up @@ -409,14 +419,22 @@ def _format_batch_size_and_grad_accum_config(self):
" as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
)
if "train_micro_batch_size_per_gpu" not in self.config:
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val

def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
batch_size = 1
if hasattr(self.lightning_module, 'train_dataloader'):
train_dataloader = self.lightning_module.train_dataloader()
if hasattr(train_dataloader, 'batch_sampler'):
batch_size = train_dataloader.batch_sampler.batch_size
return batch_size

def _format_precision_config(self):
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
Expand Down Expand Up @@ -446,6 +464,7 @@ def _create_default_config(
self,
zero_optimization: bool,
zero_allow_untested_optimizer: bool,
logging_batch_size_per_gpu: Union[str, int],
partition_activations: bool,
cpu_checkpointing: bool,
contiguous_memory_optimization: bool,
Expand All @@ -466,6 +485,8 @@ def _create_default_config(
"zero_optimization": zero_kwargs,
**cfg
}
if logging_batch_size_per_gpu != 'auto':
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

def _filepath_to_dir(self, filepath: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __lr_finder_dump_params(trainer, model):
'logger': trainer.logger,
'max_steps': trainer.max_steps,
'checkpoint_callback': trainer.checkpoint_callback,
'current_epoch': trainer.current_epoch,
'configure_optimizers': model.configure_optimizers,
}

Expand All @@ -297,6 +298,7 @@ def __lr_finder_restore_params(trainer, model):
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.current_epoch = trainer.__dumped_params['current_epoch']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
del trainer.__dumped_params

Expand Down
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sphinx>=3.0, !=3.5 # fails with sphinx.ext.viewcode
sphinx>=3.0, <3.5 # fails with sphinx.ext.viewcode # fails with sphinx_paramlinks with 4.0.0
recommonmark # fails with badges
m2r # fails with multi-line text
nbsphinx>=0.8
Expand Down
41 changes: 40 additions & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -234,6 +235,44 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
@pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10),
(RandomIterableDataset, "auto"), (RandomIterableDataset, 10)])
def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value):
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""

class TestModel(BoringModel):

def train_dataloader(self):
return DataLoader(dataset_cls(32, 64))

class AssertCallback(Callback):

def on_train_start(self, trainer, pl_module) -> None:
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
config = trainer.accelerator.training_type_plugin.config

# int value overrides auto mode
expected_value = value if isinstance(value, int) else 1
if dataset_cls == RandomDataset:
expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value

assert config['train_micro_batch_size_per_gpu'] == expected_value
raise SystemExit

ck = AssertCallback()
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=ck,
gpus=1,
plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=value, zero_optimization=False),
)
with pytest.raises(SystemExit):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_run_configure_optimizers(tmpdir):
"""
Expand Down
6 changes: 5 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,12 @@ def _user_worker_init_fn(_):
pass


@RunIf(max_torch="1.8.9")
def test_missing_worker_init_fn():
""" Test that naive worker seed initialization leads to undesired random state in subprocesses. """
"""
Test that naive worker seed initialization leads to undesired random state in subprocesses.
PyTorch 1.9+ does not have this issue.
"""
dataset = NumpyRandomDataset()

seed_everything(0)
Expand Down
8 changes: 7 additions & 1 deletion tests/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = [
'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback'
'accumulate_grad_batches',
'auto_lr_find',
'callbacks',
'checkpoint_callback',
'current_epoch',
'logger',
'max_steps',
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.lr_find(model, num_training=5)
Expand Down
8 changes: 4 additions & 4 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = [
'max_steps',
'weights_summary',
'logger',
'callbacks',
'checkpoint_callback',
'limit_train_batches',
'current_epoch',
'limit_train_batches',
'logger',
'max_steps',
'weights_summary',
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.scale_batch_size(model, max_trials=5)
Expand Down