Skip to content

Commit 28fad41

Browse files
authored
Merge branch 'master' into better-err-message
2 parents 39da91a + 0e45220 commit 28fad41

File tree

12 files changed

+168
-169
lines changed

12 files changed

+168
-169
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
1213
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
1314

1415

16+
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
17+
18+
1519
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1620

1721

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
displayName: 'Testing: standard'
8383
8484
- bash: |
85-
sh tests/special_tests.sh
85+
bash tests/special_tests.sh
8686
displayName: 'Testing: special'
8787
8888
- bash: |

benchmarks/test_sharded_parity.py

Lines changed: 43 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import time
1716
from typing import Type
1817

@@ -21,113 +20,13 @@
2120

2221
from pytorch_lightning import seed_everything, Trainer
2322
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
24-
from tests.accelerators import DDPLauncher
2523
from tests.helpers.boring_model import BoringModel, RandomDataset
2624
from tests.helpers.runif import RunIf
2725

2826

29-
@RunIf(min_gpus=1, skip_windows=True, fairscale=True)
30-
def test_ddp_sharded_plugin_correctness_one_gpu():
31-
plugin_parity_test(
32-
gpus=1,
33-
model_cls=SeedTrainLoaderModel,
34-
)
35-
36-
37-
@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True)
38-
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
39-
plugin_parity_test(
40-
gpus=1,
41-
precision=16,
42-
model_cls=SeedTrainLoaderModel,
43-
)
44-
45-
46-
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
47-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
48-
def test_ddp_sharded_plugin_correctness_multi_gpu():
49-
plugin_parity_test(
50-
gpus=2,
51-
model_cls=SeedTrainLoaderModel,
52-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
53-
)
54-
55-
56-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
57-
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
58-
plugin_parity_test(
59-
gpus=2,
60-
precision=16,
61-
model_cls=SeedTrainLoaderModel,
62-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
63-
)
64-
65-
66-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
67-
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
68-
plugin_parity_test(
69-
gpus=2,
70-
precision=16,
71-
model_cls=SeedTrainLoaderModel,
72-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
73-
)
74-
75-
76-
@RunIf(min_gpus=2, fairscale=True)
77-
@pytest.mark.skipif(
78-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
79-
)
80-
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32")
81-
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
82-
plugin_parity_test(
83-
gpus=args.gpus,
84-
precision=args.precision,
85-
model_cls=SeedTrainLoaderModel,
86-
)
87-
88-
89-
@RunIf(min_gpus=2, fairscale=True)
90-
@pytest.mark.skipif(
91-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
92-
)
93-
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
94-
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
95-
plugin_parity_test(
96-
gpus=args.gpus,
97-
precision=args.precision,
98-
model_cls=SeedTrainLoaderModel,
99-
)
100-
101-
102-
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
103-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
104-
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
105-
"""
106-
Ensures same results using multiple optimizers across multiple GPUs
107-
"""
108-
plugin_parity_test(
109-
gpus=2,
110-
model_cls=SeedTrainLoaderMultipleOptimizersModel,
111-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
112-
)
113-
114-
115-
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
116-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
117-
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
118-
"""
119-
Ensures using multiple optimizers across multiple GPUs with manual optimization
120-
"""
121-
plugin_parity_test(
122-
gpus=2,
123-
model_cls=SeedTrainLoaderManualModel,
124-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
125-
)
126-
127-
12827
class SeedTrainLoaderModel(BoringModel):
12928
"""
130-
Overrides training loader to ensure we enforce the same seed for all DDP processes.
29+
Overrides training loader to ensure we enforce the same seed for all DDP processes.
13130
"""
13231

13332
def train_dataloader(self):
@@ -177,7 +76,7 @@ class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel):
17776
def training_step(self, batch, batch_idx, optimizer_idx):
17877
output = self.layer(batch)
17978
loss = self.loss(batch, output)
180-
return {"loss": loss}
79+
return {'loss': loss}
18180

18281
def training_epoch_end(self, outputs) -> None:
18382
# outputs should be an array with an entry per optimizer
@@ -279,11 +178,48 @@ def plugin_parity_test(
279178
# Assert speed parity by ensuring percentage difference between custom/ddp is below threshold
280179
percent_diff = (custom_model_time - ddp_time) / custom_model_time
281180

282-
assert percent_diff <= max_percent_speed_diff, \
283-
f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'
181+
assert (
182+
percent_diff <= max_percent_speed_diff
183+
), f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'
284184

285185
if use_cuda:
286186
# Assert CUDA memory parity
287-
assert max_memory_custom <= max_memory_ddp, \
288-
f'Custom plugin used too much memory compared to DDP,' \
187+
assert max_memory_custom <= max_memory_ddp, (
188+
'Custom plugin used too much memory compared to DDP, '
289189
f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}'
190+
)
191+
192+
193+
@RunIf(skip_windows=True, fairscale=True)
194+
@pytest.mark.parametrize(
195+
'kwargs',
196+
[
197+
pytest.param(dict(gpus=1, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1)),
198+
pytest.param(
199+
dict(gpus=1, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1, amp_native=True)
200+
),
201+
pytest.param(dict(gpus=2, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2)),
202+
pytest.param(
203+
dict(gpus=2, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2, amp_native=True)
204+
),
205+
pytest.param(
206+
dict(gpus=2, model_cls=SeedTrainLoaderMultipleOptimizersModel),
207+
marks=[
208+
RunIf(min_gpus=2),
209+
pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'),
210+
],
211+
),
212+
pytest.param(
213+
dict(gpus=2, model_cls=SeedTrainLoaderManualModel),
214+
marks=[
215+
RunIf(min_gpus=2),
216+
pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'),
217+
],
218+
),
219+
],
220+
)
221+
def test_ddp_spawn_sharded_plugin(kwargs):
222+
if kwargs['gpus'] > 1:
223+
# TODO: decrease speed diff since only 2 GPUs sharding 2 optimizers
224+
kwargs['max_percent_speed_diff'] = 0.25
225+
plugin_parity_test(**kwargs)

pytorch_lightning/core/step_result.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,12 @@ def rename_keys(self, map_dict: dict):
633633
meta[dest] = meta[source]
634634
del meta[source]
635635

636+
def get_non_metrics_keys(self):
637+
"""
638+
This function is used to filter metric keys for which the value isn't a Metric
639+
"""
640+
return [k for k, v in self.items() if not isinstance(v, Metric)]
641+
636642

637643
def choose_last(x):
638644
if isinstance(x, (torch.Tensor, list)):

pytorch_lightning/profiler/profilers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def describe(self) -> None:
148148
# so to avoid them, we open and close the files within this function
149149
# by calling `_prepare_streams` and `teardown`
150150
self._prepare_streams()
151-
self._write_stream(self.summary())
151+
summary = self.summary()
152+
if summary:
153+
self._write_stream(summary)
152154
if self._output_file is not None:
153155
self._output_file.flush()
154156
self.teardown(stage=self._stage)

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
from collections import defaultdict
15-
from typing import Any, Dict, List, Optional, Tuple
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1617
from weakref import proxy
1718

1819
import torch
@@ -21,6 +22,19 @@
2122
from pytorch_lightning.core.step_result import Result
2223
from pytorch_lightning.trainer.states import TrainerState
2324
from pytorch_lightning.utilities import DistributedType, LightningEnum
25+
from pytorch_lightning.utilities.warnings import WarningCache
26+
27+
log = logging.getLogger(__name__)
28+
29+
30+
class MetricWarningCache(WarningCache):
31+
32+
def __init__(self):
33+
super().__init__()
34+
self.warned_metrics = []
35+
36+
37+
warning_cache = MetricWarningCache()
2438

2539

2640
class ResultStoreType(LightningEnum):
@@ -52,8 +66,10 @@ class HookResultStore:
5266
Those data structures enables us to reduce properly Result object when batch loop is finished.
5367
"""
5468

55-
def __init__(self, fx_name: str) -> None:
69+
def __init__(self, fx_name: str, all_gather_fn: Callable, should_warn: bool) -> None:
5670
self._fx_name = fx_name
71+
self._all_gather_fn = all_gather_fn
72+
self._should_warn = should_warn
5773
self._internals = {}
5874
self._internals_reduced = {}
5975
self._internal_type = None
@@ -109,6 +125,20 @@ def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> Non
109125

110126
func = getattr(opt_metric, func_name)
111127
metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs)
128+
if self._should_warn:
129+
for non_metric_key in opt_metric.get_non_metrics_keys():
130+
if non_metric_key in metrics_to_log and non_metric_key not in warning_cache.warned_metrics:
131+
metric = self._all_gather_fn(metrics_to_log[non_metric_key])
132+
if any(metric[0] != m for m in metric[1:]):
133+
warning_cache.warn(
134+
f"The value associated to the key {non_metric_key}: {metric.cpu().tolist()} "
135+
"doesn't appear to be the same accross all processes. "
136+
"HINT: One could either do: `self.log(..., sync_dist=True, sync_fn=torch.mean)`"
137+
" to force mean reduction across processes which can be inaccurate or implement"
138+
" a `torchmetrics.Metric`"
139+
)
140+
warning_cache.warned_metrics.append(non_metric_key)
141+
112142
results.append(metrics_to_log)
113143

114144
def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
@@ -227,6 +257,12 @@ class EpochResultStore:
227257

228258
def __init__(self, trainer: 'pl.Trainer') -> None:
229259
self.trainer = proxy(trainer)
260+
261+
# Add warning only for distributed (expect rpc as main worker is running the code).
262+
_should_warn = trainer.accelerator_connector.is_distributed
263+
_should_warn &= not trainer.training_type_plugin.rpc_enabled
264+
self._should_warn = _should_warn
265+
230266
self.reset()
231267

232268
def __getitem__(self, key: str) -> Any:
@@ -278,7 +314,8 @@ def cache_result(self) -> None:
278314
info = self.info
279315
fx_name = info["fx_name"]
280316

281-
self._internals.setdefault(fx_name, HookResultStore(fx_name))
317+
all_gather_fn = self.trainer.lightning_module.all_gather
318+
self._internals.setdefault(fx_name, HookResultStore(fx_name, all_gather_fn, self._should_warn))
282319

283320
# attach capture batch_size
284321
Result.attach_batch_size(self._batch_size, hook_result)

tests/accelerators/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +0,0 @@
1-
try:
2-
from dtrun.launcher import DDPLauncher
3-
except ImportError:
4-
5-
class DDPLauncher:
6-
7-
def run(cmd_line, **kwargs):
8-
9-
def inner(func):
10-
pass
11-
12-
return inner

tests/accelerators/test_ddp.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from pytorch_lightning import Trainer
23-
from tests.accelerators import ddp_model, DDPLauncher
23+
from tests.accelerators import ddp_model
2424
from tests.helpers.boring_model import BoringModel
2525
from tests.helpers.runif import RunIf
2626
from tests.utilities.distributed import call_training_script
@@ -71,19 +71,6 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir):
7171
assert out['test_acc'] > 0.7
7272

7373

74-
@RunIf(min_gpus=2)
75-
@DDPLauncher.run(
76-
"--max_epochs [max_epochs] --gpus 2 --accelerator [accelerator]",
77-
max_epochs=["1"],
78-
accelerator=["ddp", "ddp_spawn"]
79-
)
80-
def test_cli_to_pass(tmpdir, args=None):
81-
"""
82-
This test verify we can call function using test_cli name
83-
"""
84-
return '1'
85-
86-
8774
@RunIf(skip_windows=True)
8875
@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine")
8976
def test_torch_distributed_backend_env_variables(tmpdir):

tests/accelerators/test_multi_nodes_gpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
from unittest import mock
1717

18+
import pytest
1819
import torch
1920

2021
from tests.helpers.runif import RunIf
@@ -28,6 +29,9 @@
2829
from tests.helpers.boring_model import BoringModel # noqa: E402
2930

3031

32+
# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml)
33+
# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)`
34+
@pytest.mark.skip("Multi-node testing is currently disabled")
3135
@RunIf(special=True)
3236
def test_logging_sync_dist_true_ddp(tmpdir):
3337
"""
@@ -65,6 +69,9 @@ def validation_step(self, batch, batch_idx):
6569
assert trainer.logged_metrics['bar'] == fake_result
6670

6771

72+
# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml)
73+
# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)`
74+
@pytest.mark.skip("Multi-node testing is currently disabled")
6875
@RunIf(special=True)
6976
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
7077
def test__validation_step__log(tmpdir):

0 commit comments

Comments
 (0)