Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions ignite/handlers/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""WandB logger and its helper handlers."""

from typing import Any, Callable, List, Optional, Union
from warnings import warn

from torch.optim import Optimizer

Expand Down Expand Up @@ -172,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
Default is None, global_step based on attached engine. If provided,
uses function output as global_step. To setup global step from another engine, please use
:meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
the default value of wandb.log.
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -284,7 +284,8 @@ def __init__(
state_attributes: Optional[List[str]] = None,
):
super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
self.sync = sync
if sync is not None:
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")

def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, WandBLogger):
Expand All @@ -298,7 +299,7 @@ def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, E
)

metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
logger.log(metrics, step=global_step, sync=self.sync)
logger.log(metrics, step=global_step)


class OptimizerParamsHandler(BaseOptimizerParamsHandler):
Expand All @@ -309,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
as a sequence.
param_name: parameter name
tag: common title for all produced plots. For example, "generator"
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
the default value of wandb.log.
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -346,7 +346,8 @@ def __init__(
self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
):
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
self.sync = sync
if sync is not None:
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")

def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, WandBLogger):
Expand All @@ -358,4 +359,4 @@ def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, E
f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
for i, param_group in enumerate(self.optimizer.param_groups)
}
logger.log(params, step=global_step, sync=self.sync)
logger.log(params, step=global_step)
50 changes: 13 additions & 37 deletions tests/ignite/handlers/test_wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def test_optimizer_params():
mock_engine.state.iteration = 123

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123, sync=None)
mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123)

wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123, sync=None)
mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123)


def test_output_handler_with_wrong_logger_type():
Expand All @@ -62,36 +62,14 @@ def test_output_handler_output_transform():

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=None)
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123)

wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=None)


def test_output_handler_output_transform_sync():
wrapper = OutputHandler("tag", output_transform=lambda x: x, sync=False)
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()

mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=False)

wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}, sync=True)
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=True)
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123)


def test_output_handler_metric_names():
Expand All @@ -104,7 +82,7 @@ def test_output_handler_metric_names():
mock_engine.state.iteration = 5

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5, sync=None)
mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5)

wrapper = OutputHandler("tag", metric_names=["a", "c"])
mock_engine = MagicMock()
Expand All @@ -115,7 +93,7 @@ def test_output_handler_metric_names():
mock_logger.log = MagicMock()

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7, sync=None)
mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7)

# all metrics
wrapper = OutputHandler("tag", metric_names="all")
Expand All @@ -127,7 +105,7 @@ def test_output_handler_metric_names():
mock_engine.state.iteration = 5

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5, sync=None)
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5)

# log a torch vector
wrapper = OutputHandler("tag", metric_names="all")
Expand All @@ -139,7 +117,7 @@ def test_output_handler_metric_names():
mock_engine.state.iteration = 5

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({f"tag/a/{i}": vector[i].item() for i in range(5)}, step=5, sync=None)
mock_logger.log.assert_called_once_with({f"tag/a/{i}": vector[i].item() for i in range(5)}, step=5)

wrapper = OutputHandler("tag", metric_names=["a"])
mock_engine = MagicMock()
Expand All @@ -151,7 +129,7 @@ def test_output_handler_metric_names():
mock_logger.log = MagicMock()

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({f"tag/a/{i}": v for i, v in enumerate(data)}, step=7, sync=None)
mock_logger.log.assert_called_once_with({f"tag/a/{i}": v for i, v in enumerate(data)}, step=7)

wrapper = OutputHandler("tag", metric_names="all")
mock_engine = MagicMock()
Expand Down Expand Up @@ -179,7 +157,6 @@ def test_output_handler_metric_names():
"tag/c/2/e": 32.1,
},
step=7,
sync=None,
)


Expand All @@ -195,7 +172,7 @@ def test_output_handler_both():

wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None)
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5)


def test_output_handler_with_wrong_global_step_transform_output():
Expand Down Expand Up @@ -229,7 +206,7 @@ def global_step_transform(*args, **kwargs):
mock_engine.state.output = 12345

wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10, sync=None)
mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10)


def test_output_handler_with_global_step_from_engine():
Expand All @@ -254,7 +231,7 @@ def test_output_handler_with_global_step_from_engine():

wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
mock_logger.log.assert_called_once_with(
{"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None
{"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch
)

mock_another_engine.state.epoch = 11
Expand All @@ -263,7 +240,7 @@ def test_output_handler_with_global_step_from_engine():
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
assert mock_logger.log.call_count == 2
mock_logger.log.assert_has_calls(
[call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None)]
[call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
)


Expand Down Expand Up @@ -291,7 +268,6 @@ def test_output_handler_state_attrs():
"tag/delta": "Some Text",
},
step=5,
sync=None,
)


Expand Down
Loading