Skip to content

Trainer.test() hangs when run from python interactive shell with multiple GPUs #13373

@jessecambon

Description

@jessecambon

🐛 Bug

Trainer.test() hangs when run from an interactive shell when the Trainer uses strategy="ddp" and gpus > 1.

To Reproduce

Run the python script below from the command line in a >1 GPU environment (python -i leaves the python interactive console open after the script completes):

python -i <scriptName.py>

After the script completes, rerun the Trainer.test() function call at the terminal:

trainer.test(model, datamodule=dm)

The trainer.test() call within the script runs successfully, but the terminal hangs when trainer.test() is called from the interactive shell. Here is the console log for reference:

(ptl) python -i script.py 
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
/anaconda/envs/ptl/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/anaconda/envs/ptl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
/anaconda/envs/ptl/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                                                                           | 0/2 [00:00<?, ?it/s][W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Epoch 0:  50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 1/2 [00:00<00:00,  3.16it/s][W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.90it/s, loss=0.46, v_num=28]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]                                                                                                                                                
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
/anaconda/envs/ptl/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.
  rank_zero_warn(
/anaconda/envs/ptl/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 162.63it/s]
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           -1.934678077697754
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
>>> trainer.test(model, datamodule=dm)

Expected behavior

Either the test loop should run successfully or an error should be thrown (if we shouldn't be running test() with GPUs > 1).

Python Script

import torch, os, logging, sys
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
#logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class DataModule(LightningDataModule):
    def setup(self, stage=None) -> None:
        self._dataloader = DataLoader(RandomDataset(32, 64), batch_size=2)

    def train_dataloader(self):
        return self._dataloader
    
    def test_dataloader(self):
        return self._dataloader

    def val_dataloader(self):
        return self._dataloader

if __name__ == "__main__":
    model = BoringModel()
    dm = DataModule()
    trainer = Trainer(
        gpus=2,
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        strategy="ddp"
    )
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)

Environment

  • CUDA:
    - GPU:
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - Tesla K80
    - available: True
    - version: 11.3
  • Packages:
    - numpy: 1.22.3
    - pyTorch_debug: False
    - pyTorch_version: 1.11.0
    - pytorch-lightning: 1.6.4
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.13
    - version: updated support for 1.2.0 #80~18.04.1-Ubuntu SMP Wed Apr 13 02:07:09 UTC 2022

Additional context

This was run in Azure Machine Learning Studio.

cc @tchaton @rohitgr7 @justusschock @kaushikb11 @awaelchli @akihironitta @ninginthecloud

Metadata

Metadata

Assignees

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions