Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions monai/handlers/lr_schedule_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
name: Optional[str] = None,
epoch_level: bool = True,
step_transform: Callable[[Engine], Any] = lambda engine: (),
logger_handler: Optional[logging.Handler] = None,
) -> None:
"""
Args:
Expand All @@ -47,6 +48,9 @@ def __init__(
`True` is epoch level, `False` is iteration level.
step_transform: a callable that is used to transform the information from `engine`
to expected input data of lr_scheduler.step() function if necessary.
logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc.
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.

Raises:
TypeError: When ``step_transform`` is not ``callable``.
Expand All @@ -59,6 +63,8 @@ def __init__(
if not callable(step_transform):
raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.")
self.step_transform = step_transform
if logger_handler is not None:
self.logger.addHandler(logger_handler)

self._name = name

Expand Down
3 changes: 2 additions & 1 deletion monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
key_var_format: a formatting string to control the output string format of key: value.
logger_handler: add additional handler to handle the stats data: save to file, etc.
Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
"""

self.epoch_print_logger = epoch_print_logger
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __init__(
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
additional_info: user can define callable function to extract additional info from input data.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.

Raises:
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def __init__(
additional info from input data. it also can be a sequence of string, each element
corresponds to a key in ``keys``.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
allow_missing_keys: don't raise exception if key is missing.

Expand Down
48 changes: 37 additions & 11 deletions tests/test_handler_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# limitations under the License.

import logging
import os
import re
import sys
import tempfile
import unittest

import numpy as np
Expand All @@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase):
def test_content(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
data = [0] * 8
test_lr = 0.1
gamma = 0.1

# set up engine
def _train_func(engine, batch):
Expand All @@ -41,24 +46,45 @@ def run_validation(engine):
net = torch.nn.PReLU()

def _reduce_lr_on_plateau():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)
handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"])
handler.attach(train_engine)
return lr_scheduler
return handler

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
handler = LrScheduleHandler(lr_scheduler)
handler.attach(train_engine)
return lr_scheduler
with tempfile.TemporaryDirectory() as tempdir:
key_to_handler = "test_log_lr"
key_to_print = "Current learning rate"
filename = os.path.join(tempdir, "test_lr.log")
# test with additional logging handler
file_saver = logging.FileHandler(filename, mode="w")
file_saver.setLevel(logging.INFO)

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma)
handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver)
handler.attach(train_engine)
handler.logger.setLevel(logging.INFO)
return handler

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()

train_engine.run(data, max_epochs=5)
file_saver.close()
schedulers[1].logger.removeHandler(file_saver)

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()
with open(filename) as f:
output_str = f.read()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

train_engine.run(data, max_epochs=5)
for scheduler in schedulers:
np.testing.assert_allclose(scheduler._last_lr[0], 0.001)
np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001)


if __name__ == "__main__":
Expand Down
56 changes: 31 additions & 25 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ def _update_metric(engine):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [5, 10]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_print(self):
log_stream = StringIO()
Expand All @@ -74,18 +75,19 @@ def _train_func(engine, batch):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_dict(self):
log_stream = StringIO()
Expand All @@ -102,21 +104,22 @@ def _train_func(engine, batch):

# set up testing handler
stats_handler = StatsHandler(
name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler
name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler
)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_file(self):
key_to_handler = "test_logging"
Expand All @@ -136,18 +139,19 @@ def _train_func(engine, batch):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)
handler.close()
stats_handler.logger.removeHandler(handler)
with open(filename) as f:
output_str = f.read()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_exception(self):
# set up engine
Expand Down Expand Up @@ -190,17 +194,19 @@ def _update_metric(engine):
name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler
)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(".*State values.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line) and idx in [5, 10]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)


if __name__ == "__main__":
Expand Down