Skip to content

Commit 8cb3817

Browse files
authored
Merge 34f5b04 into e1f5eac
2 parents e1f5eac + 34f5b04 commit 8cb3817

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
2525

2626

27+
- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))
28+
29+
2730
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
2831

2932

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@ class ModelCheckpoint(Callback):
131131
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
132132
... )
133133
134+
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
135+
# or Neptune, due to the presence of characters like '=' or '/')
136+
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
137+
>>> checkpoint_callback = ModelCheckpoint(
138+
... monitor='val/loss',
139+
... dirpath='my/path/',
140+
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
141+
... auto_insert_metric_name=False
142+
... )
143+
134144
# retrieve the best checkpoint after training
135145
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
136146
trainer = Trainer(callbacks=[checkpoint_callback])
@@ -156,6 +166,7 @@ def __init__(
156166
save_weights_only: bool = False,
157167
mode: str = "min",
158168
period: int = 1,
169+
auto_insert_metric_name: bool = True
159170
):
160171
super().__init__()
161172
self.monitor = monitor
@@ -164,6 +175,7 @@ def __init__(
164175
self.save_top_k = save_top_k
165176
self.save_weights_only = save_weights_only
166177
self.period = period
178+
self.auto_insert_metric_name = auto_insert_metric_name
167179
self._last_global_step_saved = -1
168180
self.current_score = None
169181
self.best_k_models = {}
@@ -356,6 +368,7 @@ def _format_checkpoint_name(
356368
step: int,
357369
metrics: Dict[str, Any],
358370
prefix: str = "",
371+
auto_insert_metric_name: bool = True
359372
) -> str:
360373
if not filename:
361374
# filename is not set, use default name
@@ -367,7 +380,10 @@ def _format_checkpoint_name(
367380
metrics.update({"epoch": epoch, 'step': step})
368381
for group in groups:
369382
name = group[1:]
370-
filename = filename.replace(group, name + "={" + name)
383+
384+
if auto_insert_metric_name:
385+
filename = filename.replace(group, name + "={" + name)
386+
371387
if name not in metrics:
372388
metrics[name] = 0
373389
filename = filename.format(**metrics)
@@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
392408
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
393409
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
394410
'epoch=2-val_loss=0.12.ckpt'
411+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
412+
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
413+
... auto_insert_metric_name=False)
414+
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
415+
'epoch=2-validation_loss=0.12.ckpt'
395416
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
396417
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
397418
'missing=0.ckpt'
@@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
400421
'step=0.ckpt'
401422
402423
"""
403-
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
424+
filename = self._format_checkpoint_name(
425+
self.filename,
426+
epoch,
427+
step,
428+
metrics,
429+
auto_insert_metric_name=self.auto_insert_metric_name)
430+
404431
if ver is not None:
405432
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
406433

tests/checkpointing/test_model_checkpoint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
425425
ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
426426
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'
427427

428+
# auto_insert_metric_name=False
429+
ckpt_name = ModelCheckpoint._format_checkpoint_name(
430+
'epoch={epoch:03d}-val_acc={val/acc}',
431+
3,
432+
2,
433+
{'val/acc': 0.03},
434+
auto_insert_metric_name=False)
435+
assert ckpt_name == 'epoch=003-val_acc=0.03'
436+
428437

429438
class ModelCheckpointExtensionTest(ModelCheckpoint):
430439
FILE_EXTENSION = '.tpkc'

0 commit comments

Comments
 (0)