3636
3737
3838@patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
39- def test_checkpoint_callbacks_are_last (tmp_path ):
40- """Test that checkpoint callbacks always come last."""
39+ def test_progressbar_and_checkpoint_callbacks_are_last (tmp_path ):
40+ """Test that progress bar and checkpoint callbacks always come last."""
4141 checkpoint1 = ModelCheckpoint (tmp_path / "path1" , filename = "ckpt1" , monitor = "val_loss_c1" )
4242 checkpoint2 = ModelCheckpoint (tmp_path / "path2" , filename = "ckpt2" , monitor = "val_loss_c2" )
4343 early_stopping = EarlyStopping (monitor = "foo" )
@@ -48,9 +48,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
4848 # no model reference
4949 trainer = Trainer (callbacks = [checkpoint1 , progress_bar , lr_monitor , model_summary , checkpoint2 ])
5050 assert trainer .callbacks == [
51- progress_bar ,
5251 lr_monitor ,
5352 model_summary ,
53+ progress_bar ,
5454 checkpoint1 ,
5555 checkpoint2 ,
5656 ]
@@ -62,9 +62,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
6262 cb_connector = _CallbackConnector (trainer )
6363 cb_connector ._attach_model_callbacks ()
6464 assert trainer .callbacks == [
65- progress_bar ,
6665 lr_monitor ,
6766 model_summary ,
67+ progress_bar ,
6868 checkpoint1 ,
6969 checkpoint2 ,
7070 ]
@@ -77,10 +77,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
7777 cb_connector = _CallbackConnector (trainer )
7878 cb_connector ._attach_model_callbacks ()
7979 assert trainer .callbacks == [
80- progress_bar ,
8180 lr_monitor ,
8281 early_stopping ,
8382 model_summary ,
83+ progress_bar ,
8484 checkpoint1 ,
8585 checkpoint2 ,
8686 ]
@@ -95,10 +95,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
9595 cb_connector ._attach_model_callbacks ()
9696 assert trainer .callbacks == [
9797 batch_size_finder ,
98- progress_bar ,
9998 lr_monitor ,
10099 early_stopping ,
101100 model_summary ,
101+ progress_bar ,
102102 checkpoint2 ,
103103 checkpoint1 ,
104104 ]
@@ -200,7 +200,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
200200 trainer_callbacks = [progress_bar , EarlyStopping (monitor = "red" )],
201201 model_callbacks = [early_stopping1 ],
202202 )
203- assert trainer .callbacks == [progress_bar , early_stopping1 ]
203+ assert trainer .callbacks == [early_stopping1 , progress_bar ] # progress_bar should be last
204204
205205 # multiple callbacks of the same type in trainer
206206 trainer = _attach_callbacks (
@@ -225,7 +225,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
225225 ],
226226 model_callbacks = [early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 ],
227227 )
228- assert trainer .callbacks == [progress_bar , early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 ]
228+ assert trainer .callbacks == [early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 , progress_bar ]
229229
230230 class CustomProgressBar (TQDMProgressBar ): ...
231231
0 commit comments