Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand Down
31 changes: 29 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ class ModelCheckpoint(Callback):
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val/loss',
... dirpath='my/path/',
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
... auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
Expand All @@ -156,6 +166,7 @@ def __init__(
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
):
super().__init__()
self.monitor = monitor
Expand All @@ -164,6 +175,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -356,6 +368,7 @@ def _format_checkpoint_name(
step: int,
metrics: Dict[str, Any],
prefix: str = "",
auto_insert_metric_name: bool = True
) -> str:
if not filename:
# filename is not set, use default name
Expand All @@ -367,7 +380,10 @@ def _format_checkpoint_name(
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)

if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)
Copy link
Contributor

@talregev talregev Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename = filename.replace('/', '_')

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @talregev
Thanks for the comment.
Could you open an issue about it, describing the need for this?


if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
Expand All @@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
Expand All @@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
'step=0.ckpt'

"""
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

Expand Down
9 changes: 9 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}',
3,
2,
{'val/acc': 0.03},
auto_insert_metric_name=False)
assert ckpt_name == 'epoch=003-val_acc=0.03'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
Expand Down