Skip to content

Commit 901b2ba

Browse files
authored
Unify current_fx_name and current_hook_fx_name [2/n] (#7594)
* Minor loggger connector cleanup [1/n] * Missing line * Address comments * Rely on validator * Unify `current_fx_name` and `current_hook_fx_name` * Fix test
1 parent dbea5bb commit 901b2ba

File tree

10 files changed

+33
-55
lines changed

10 files changed

+33
-55
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
104104
self._example_input_array = None
105105
self._datamodule = None
106106
self._results: Optional[Result] = None
107-
self._current_fx_name: str = ''
107+
self._current_fx_name: Optional[str] = None
108108
self._running_manual_backward: bool = False
109-
self._current_hook_fx_name: Optional[str] = None
110109
self._current_dataloader_idx: Optional[int] = None
111110
self._automatic_optimization: bool = True
112111
self._truncated_bptt_steps: int = 0
@@ -316,8 +315,9 @@ def log(
316315
on_step = self.__auto_choose_log_on_step(on_step)
317316
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
318317

318+
assert self._current_fx_name is not None
319319
self.trainer.logger_connector.check_logging_in_callbacks(
320-
self._current_hook_fx_name, on_step=on_step, on_epoch=on_epoch
320+
self._current_fx_name, on_step=on_step, on_epoch=on_epoch
321321
)
322322

323323
# make sure user doesn't introduce logic for multi-dataloaders

pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
class CallbackHookNameValidator:
1919

2020
@staticmethod
21-
def check_logging_in_callbacks(current_hook_fx_name: str, on_step: bool, on_epoch: bool) -> None:
22-
internal_func = getattr(CallbackHookNameValidator, f"_{current_hook_fx_name}_log", None)
21+
def check_logging_in_callbacks(fx_name: str, on_step: bool, on_epoch: bool) -> None:
22+
internal_func = getattr(CallbackHookNameValidator, f"_{fx_name}_log", None)
2323
if internal_func is None:
2424
return
2525

@@ -28,16 +28,14 @@ def check_logging_in_callbacks(current_hook_fx_name: str, on_step: bool, on_epoc
2828
if current_callback_hook_auth_args is not None:
2929
m = "{} function supports only {} in {}. Provided {}"
3030
if on_step not in current_callback_hook_auth_args["on_step"]:
31-
msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step)
31+
msg = m.format(fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step)
3232
raise MisconfigurationException(msg)
3333

3434
if on_epoch not in current_callback_hook_auth_args["on_epoch"]:
35-
msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch)
35+
msg = m.format(fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch)
3636
raise MisconfigurationException(msg)
3737
else:
38-
raise MisconfigurationException(
39-
f"{current_hook_fx_name} function doesn't support logging using self.log() yet."
40-
)
38+
raise MisconfigurationException(f"{fx_name} function doesn't support logging using self.log() yet.")
4139

4240
@staticmethod
4341
def _on_before_accelerator_backend_setup_log():

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ class EpochResultStore:
228228
229229
..example::
230230
231-
model._current_fx_name = 'something'
232231
model._results = Result()
232+
model._current_fx_name = 'something'
233233
model.log('a', ...)
234234
epoch_result_store.cache_result()
235235
"""
@@ -250,7 +250,7 @@ def info(self):
250250
model_ref = self.trainer.lightning_module
251251
return {
252252
"batch_idx": self.trainer.train_loop.batch_idx,
253-
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
253+
"fx_name": model_ref._current_fx_name,
254254
"dataloader_idx": model_ref._current_dataloader_idx or 0,
255255
"opt_idx": self._opt_idx or 0,
256256
"split_idx": self._split_idx or 0,
@@ -266,8 +266,7 @@ def reset_model(self):
266266
"""
267267
model_ref = self.trainer.lightning_module
268268
model_ref._results = Result()
269-
model_ref._current_hook_fx_name = None
270-
model_ref._current_fx_name = ''
269+
model_ref._current_fx_name = None
271270

272271
def cache_result(self) -> None:
273272
"""
@@ -280,8 +279,7 @@ def cache_result(self) -> None:
280279
hook_result = model_ref._results
281280

282281
if len(hook_result) == 1:
283-
model_ref._current_hook_fx_name = None
284-
model_ref._current_fx_name = ''
282+
model_ref._current_fx_name = None
285283
return
286284

287285
info = self.info

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,8 @@ def set_metrics(self, key: str, val: Dict) -> None:
9595
def reset(self) -> None:
9696
self.cached_results.reset()
9797

98-
def check_logging_in_callbacks(self, hook_fx_name: str, on_step: bool, on_epoch: bool) -> None:
99-
self._callback_hook_validator.check_logging_in_callbacks(
100-
current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch
101-
)
98+
def check_logging_in_callbacks(self, fx_name: str, on_step: bool, on_epoch: bool) -> None:
99+
self._callback_hook_validator.check_logging_in_callbacks(fx_name=fx_name, on_step=on_step, on_epoch=on_epoch)
102100

103101
def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders):
104102
model = self.trainer.lightning_module

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def on_evaluation_epoch_end(self) -> None:
253253
model_ref = self.trainer.lightning_module
254254
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
255255

256-
self.trainer._reset_result_and_set_hook_fx_name(hook_name)
256+
self.trainer._reset_result_and_set_fx_name(hook_name)
257257

258258
with self.trainer.profiler.profile(hook_name):
259259

pytorch_lightning/trainer/predict_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
113113

114114
self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
115115

116-
model_ref._current_fx_name = "predict"
116+
model_ref._current_fx_name = "predict_step"
117117
predictions = self.trainer.accelerator.predict_step(step_kwargs)
118118

119119
if predictions is None:

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,19 +1186,18 @@ def _call_teardown_hook(self, model: LightningModule) -> None:
11861186
self.teardown(stage=fn)
11871187
model.teardown(stage=fn)
11881188

1189-
model._current_fx_name = ""
1190-
model._current_hook_fx_name = None
1189+
model._current_fx_name = None
11911190
model._current_dataloader_idx = None
11921191

1193-
def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool:
1192+
def _reset_result_and_set_fx_name(self, hook_name: str) -> bool:
11941193
# on_before_zero_grad is called within training_step
11951194
if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"):
11961195
return True
11971196
model_ref = self.lightning_module
11981197
if model_ref is not None:
11991198
# used to track current hook name called
12001199
model_ref._results = Result()
1201-
model_ref._current_hook_fx_name = hook_name
1200+
model_ref._current_fx_name = hook_name
12021201
return False
12031202

12041203
def _cache_logged_metrics(self):
@@ -1214,7 +1213,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
12141213
# TrainLoop._on_train_epoch_end_hook
12151214

12161215
# set hook_name to model + reset Result obj
1217-
skip = self._reset_result_and_set_hook_fx_name(hook_name)
1216+
skip = self._reset_result_and_set_fx_name(hook_name)
12181217

12191218
# always profile hooks
12201219
with self.profiler.profile(hook_name):

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,6 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
595595
# run training_epoch_end
596596
# refresh the result for custom logging at the epoch level
597597
model._current_fx_name = 'training_epoch_end'
598-
599-
# lightningmodule hook
600598
training_epoch_end_output = model.training_epoch_end(processed_epoch_output)
601599

602600
if training_epoch_end_output is not None:
@@ -621,7 +619,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
621619
hook_name = "on_train_epoch_end"
622620

623621
# set hook_name to model + reset Result obj
624-
skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name)
622+
skip = self.trainer._reset_result_and_set_fx_name(hook_name)
625623

626624
# always profile hooks
627625
with self.trainer.profiler.profile(hook_name):

tests/trainer/logging_/test_logger_connector.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -372,19 +372,17 @@ def test_call_back_validator(tmpdir):
372372
and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
373373
)
374374
if allowed:
375-
validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
375+
validator.check_logging_in_callbacks(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
376376
if not is_start and is_stage:
377377
with pytest.raises(MisconfigurationException, match="function supports only"):
378-
validator.check_logging_in_callbacks(
379-
current_hook_fx_name=func_name, on_step=True, on_epoch=on_epoch
380-
)
378+
validator.check_logging_in_callbacks(fx_name=func_name, on_step=True, on_epoch=on_epoch)
381379
else:
382380
assert func_name in not_supported
383381
with pytest.raises(MisconfigurationException, match="function doesn't support"):
384-
validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
382+
validator.check_logging_in_callbacks(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
385383

386384
# should not fail
387-
validator.check_logging_in_callbacks(current_hook_fx_name=None, on_step=None, on_epoch=None)
385+
validator.check_logging_in_callbacks(fx_name=None, on_step=None, on_epoch=None)
388386

389387

390388
@RunIf(min_gpus=2)

tests/trainer/test_trainer.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,30 +1912,19 @@ def test_on_load_checkpoint_missing_callbacks(tmpdir):
19121912

19131913

19141914
def test_module_current_fx_attributes_reset(tmpdir):
1915-
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
1915+
""" Ensure that lightning module's attributes related to current fx are reset at the end of execution. """
19161916
model = BoringModel()
1917-
model.validation_step = None
1918-
model.training_epoch_end = None
19191917
trainer = Trainer(
19201918
default_root_dir=tmpdir,
1921-
max_epochs=1,
1919+
fast_dev_run=1,
19221920
checkpoint_callback=False,
19231921
logger=False,
1924-
limit_val_batches=0,
19251922
)
1923+
19261924
trainer.fit(model)
1927-
assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}"
1928-
assert (
1929-
model._current_hook_fx_name is None
1930-
), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}"
1931-
assert (
1932-
model._current_dataloader_idx is None
1933-
), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}"
1925+
assert model._current_fx_name is None
1926+
assert model._current_dataloader_idx is None
1927+
19341928
trainer.test(model)
1935-
assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}"
1936-
assert (
1937-
model._current_hook_fx_name is None
1938-
), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}"
1939-
assert (
1940-
model._current_dataloader_idx is None
1941-
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"
1929+
assert model._current_fx_name is None
1930+
assert model._current_dataloader_idx is None

0 commit comments

Comments
 (0)