Skip to content
26 changes: 26 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,32 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode
trainer = Trainer(inference_mode=False)
trainer.validate(model)

enable_autolog_hparams
^^^^^^^^^^^^^^^^^^^^^^

Whether to log hyperparameters at the start of a run. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_autolog_hparams=True)

# disable logging hyperparams
trainer = Trainer(enable_autolog_hparams=False)

With the parameter set to false, you can add custom code to log hyperparameters.

.. code-block:: python

model = LitModel()
trainer = Trainer(enable_autolog_hparams=False)
for logger in trainer.loggers:
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
logger.log_hyperparams(hparams_dict_1)
else:
logger.log_hyperparams(hparams_dict_2)

You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.

-----

Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))


- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))


Expand Down
10 changes: 9 additions & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
enable_autolog_hparams: bool = True,
) -> None:
r"""Customize every aspect of training via flags.

Expand Down Expand Up @@ -290,6 +291,9 @@ def __init__(
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
Default: ``True``.

Raises:
TypeError:
If ``gradient_clip_val`` is not an int or float.
Expand Down Expand Up @@ -496,6 +500,8 @@ def __init__(
num_sanity_val_steps,
)

self.enable_autolog_hparams = enable_autolog_hparams

def fit(
self,
model: "pl.LightningModule",
Expand Down Expand Up @@ -962,7 +968,9 @@ def _run(
call._call_callback_hooks(self, "on_fit_start")
call._call_lightning_module_hook(self, "on_fit_start")

_log_hyperparams(self)
# only log hparams if enabled
if self.enable_autolog_hparams:
_log_hyperparams(self)

if self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_metrics_reset_after_save(tmp_path):


@mock.patch(
# Mock the existance check, so we can simulate appending to the metrics file
# Mock the existence check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
def test_append_metrics_file(_, tmp_path):
Expand Down
Loading