Skip to content

Commit fc6d402

Browse files
authored
fix logger creating directory structure too early in DDP (#6380)
* fix * add simple test * fix imports * add changelog * tighter test with on_fit_start hook closer to the dispatch call * move class inside test f unction * add a comment
1 parent 75c6486 commit fc6d402

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
108108

109109

110-
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
110+
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
111111

112112

113113
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
@@ -134,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
134134
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
135135

136136

137+
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
138+
139+
137140
## [1.2.2] - 2021-03-02
138141

139142
### Added

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,6 @@ def __init__(
381381
# Callback system
382382
self.on_init_end()
383383

384-
def setup_trainer(self, model: LightningModule):
385-
"""
386-
Sanity check a few things before starting actual training or testing.
387-
388-
Args:
389-
model: The model to run sanity test on.
390-
"""
391-
392-
# log hyper-parameters
393-
if self.logger is not None:
394-
# save exp to get started (this is where the first experiment logs are written)
395-
self.logger.log_hyperparams(model.hparams_initial)
396-
self.logger.log_graph(model)
397-
self.logger.save()
398-
399384
def fit(
400385
self,
401386
model: LightningModule,
@@ -444,7 +429,6 @@ def fit(
444429
self.call_setup_hook(model)
445430
self.call_hook("on_before_accelerator_backend_setup", model)
446431
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
447-
self.setup_trainer(model)
448432

449433
# ----------------------------
450434
# INSPECT THE CORE LOOPS
@@ -509,6 +493,13 @@ def fit(
509493
def pre_dispatch(self):
510494
self.accelerator.pre_dispatch()
511495

496+
# log hyper-parameters
497+
if self.logger is not None:
498+
# save exp to get started (this is where the first experiment logs are written)
499+
self.logger.log_hyperparams(self.lightning_module.hparams_initial)
500+
self.logger.log_graph(self.lightning_module)
501+
self.logger.save()
502+
512503
def post_dispatch(self):
513504
self.accelerator.post_dispatch()
514505
self.accelerator.teardown()

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
1515
from unittest import mock
16+
from unittest.mock import Mock
1617

17-
from pytorch_lightning import Trainer
18+
from pytorch_lightning import Callback, Trainer
1819
from tests.helpers import BoringModel
1920
from tests.helpers.runif import RunIf
2021

@@ -66,3 +67,39 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
6667
weights_summary=None,
6768
)
6869
trainer.fit(model)
70+
71+
72+
def test_first_logger_call_in_subprocess(tmpdir):
73+
"""
74+
Test that the Trainer does not call the logger too early. Only when the worker processes are initialized
75+
do we have access to the rank and know which one is the main process.
76+
"""
77+
78+
class LoggerCallsObserver(Callback):
79+
80+
def on_fit_start(self, trainer, pl_module):
81+
# this hook is executed directly before Trainer.pre_dispatch
82+
# logger should not write any logs until this point
83+
assert not trainer.logger.method_calls
84+
assert not os.listdir(trainer.logger.save_dir)
85+
86+
def on_train_start(self, trainer, pl_module):
87+
assert trainer.logger.method_call
88+
trainer.logger.log_hyperparams.assert_called_once()
89+
trainer.logger.log_graph.assert_called_once()
90+
91+
logger = Mock()
92+
logger.version = "0"
93+
logger.name = "name"
94+
logger.save_dir = tmpdir
95+
96+
model = BoringModel()
97+
trainer = Trainer(
98+
default_root_dir=tmpdir,
99+
limit_train_batches=1,
100+
limit_val_batches=1,
101+
max_epochs=1,
102+
logger=logger,
103+
callbacks=[LoggerCallsObserver()]
104+
)
105+
trainer.fit(model)

0 commit comments

Comments
 (0)