Skip to content

Commit 1129d4c

Browse files
ringohoffmanBordalantiga
authored
Add step parameter to TensorBoardLogger.log_hyperparams (#20176)
* Add step to TensorBoardLogger.log_hyperparams The metrics that get logged using this method are always logged to step 0, so no meaningful graph is made * Change expected step value to None --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Luca Antiga <[email protected]>
1 parent a9125c2 commit 1129d4c

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

src/lightning/fabric/loggers/tensorboard.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,19 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
220220
@override
221221
@rank_zero_only
222222
def log_hyperparams(
223-
self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None
223+
self,
224+
params: Union[dict[str, Any], Namespace],
225+
metrics: Optional[dict[str, Any]] = None,
226+
step: Optional[int] = None,
224227
) -> None:
225228
"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
226229
hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to
227230
display the new ones with hyperparameters.
228231
229232
Args:
230-
params: a dictionary-like container with the hyperparameters
233+
params: A dictionary-like container with the hyperparameters
231234
metrics: Dictionary with metric names as keys and measured quantities as values
235+
step: Optional global step number for the logged metrics
232236
233237
"""
234238
params = _convert_params(params)
@@ -244,7 +248,7 @@ def log_hyperparams(
244248
metrics = {"hp_metric": metrics}
245249

246250
if metrics:
247-
self.log_metrics(metrics, 0)
251+
self.log_metrics(metrics, step)
248252

249253
if _TENSORBOARD_AVAILABLE:
250254
from torch.utils.tensorboard.summary import hparams
@@ -253,9 +257,9 @@ def log_hyperparams(
253257

254258
exp, ssi, sei = hparams(params, metrics)
255259
writer = self.experiment._get_file_writer()
256-
writer.add_summary(exp)
257-
writer.add_summary(ssi)
258-
writer.add_summary(sei)
260+
writer.add_summary(exp, step)
261+
writer.add_summary(ssi, step)
262+
writer.add_summary(sei, step)
259263

260264
@override
261265
@rank_zero_only

src/lightning/pytorch/loggers/tensorboard.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,19 @@ def save_dir(self) -> str:
153153
@override
154154
@rank_zero_only
155155
def log_hyperparams(
156-
self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None
156+
self,
157+
params: Union[dict[str, Any], Namespace],
158+
metrics: Optional[dict[str, Any]] = None,
159+
step: Optional[int] = None,
157160
) -> None:
158161
"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
159162
hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to
160163
display the new ones with hyperparameters.
161164
162165
Args:
163-
params: a dictionary-like container with the hyperparameters
166+
params: A dictionary-like container with the hyperparameters
164167
metrics: Dictionary with metric names as keys and measured quantities as values
168+
step: Optional global step number for the logged metrics
165169
166170
"""
167171
if _OMEGACONF_AVAILABLE:
@@ -175,7 +179,7 @@ def log_hyperparams(
175179
else:
176180
self.hparams.update(params)
177181

178-
return super().log_hyperparams(params=params, metrics=metrics)
182+
return super().log_hyperparams(params=params, metrics=metrics, step=step)
179183

180184
@override
181185
@rank_zero_only

tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def test_step(self, batch, batch_idx):
544544
"valid_loss_1",
545545
}
546546
assert mock_log_metrics.mock_calls == [
547-
call({"hp_metric": -1}, 0),
547+
call({"hp_metric": -1}, None),
548548
call(metrics={"train_loss": ANY, "epoch": 0}, step=0),
549549
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0),
550550
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1),

0 commit comments

Comments
 (0)