@@ -121,12 +121,6 @@ def on_train_end(self):
121
121
return
122
122
self ._teardown_already_run = True
123
123
124
- # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
125
- # when a checkpoint was saved at the last step
126
- self .trainer .global_step -= 1
127
- self .check_checkpoint_callback (should_update = True , is_last = True )
128
- self .trainer .global_step += 1
129
-
130
124
# hook
131
125
self .trainer .call_hook ("on_train_end" )
132
126
@@ -145,28 +139,6 @@ def on_train_end(self):
145
139
# reset bookkeeping
146
140
self .trainer ._running_stage = None
147
141
148
- def check_checkpoint_callback (self , should_update , is_last = False ):
149
- # TODO bake this logic into the ModelCheckpoint callback
150
- if should_update and self .trainer .checkpoint_connector .has_trained :
151
- callbacks = self .trainer .checkpoint_callbacks
152
-
153
- if is_last and any (cb .save_last and cb .verbose for cb in callbacks ):
154
- rank_zero_info ("Saving latest checkpoint..." )
155
-
156
- model = self .trainer .lightning_module
157
-
158
- for cb in callbacks :
159
- cb .on_validation_end (self .trainer , model )
160
-
161
- def check_early_stopping_callback (self , should_update ):
162
- # TODO bake this logic into the EarlyStopping callback
163
- if should_update and self .trainer .checkpoint_connector .has_trained :
164
- callbacks = [c for c in self .trainer .callbacks if isinstance (c , EarlyStopping )]
165
- model = self .trainer .lightning_module
166
-
167
- for cb in callbacks :
168
- cb .on_validation_end (self .trainer , model )
169
-
170
142
def on_train_epoch_start (self , epoch ):
171
143
172
144
# update training progress in trainer
@@ -562,15 +534,14 @@ def run_training_epoch(self):
562
534
if (val_loop_called and not should_check_val ) or should_train_only :
563
535
self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
564
536
565
- if should_train_only :
566
- self .check_checkpoint_callback (True )
567
- self .check_early_stopping_callback (True )
568
-
569
537
if should_check_val :
570
538
self .trainer .validating = True
571
539
self .trainer .run_evaluation (on_epoch = True )
572
540
self .trainer .training = True
573
541
542
+ if should_train_only :
543
+ self .trainer .call_hook ('on_train_epoch_final_end' )
544
+
574
545
# increment the global step once
575
546
# progress global step according to grads progress
576
547
self .increment_accumulated_grad_global_step ()
0 commit comments