Skip to content

Commit d4024cb

Browse files
authored
Merge branch 'master' into models/ssdlite
2 parents 5fbc112 + a78d0d8 commit d4024cb

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

references/detection/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,19 @@ def main(args):
206206
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
207207
lr_scheduler.step()
208208
if args.output_dir:
209-
utils.save_on_master({
209+
checkpoint = {
210210
'model': model_without_ddp.state_dict(),
211211
'optimizer': optimizer.state_dict(),
212212
'lr_scheduler': lr_scheduler.state_dict(),
213213
'args': args,
214-
'epoch': epoch},
214+
'epoch': epoch
215+
}
216+
utils.save_on_master(
217+
checkpoint,
215218
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
219+
utils.save_on_master(
220+
checkpoint,
221+
os.path.join(args.output_dir, 'checkpoint.pth'))
216222

217223
# evaluate after every epoch
218224
evaluate(model, data_loader_test, device=device)

references/segmentation/train.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,19 @@ def main(args):
157157
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
158158
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
159159
print(confmat)
160+
checkpoint = {
161+
'model': model_without_ddp.state_dict(),
162+
'optimizer': optimizer.state_dict(),
163+
'lr_scheduler': lr_scheduler.state_dict(),
164+
'epoch': epoch,
165+
'args': args
166+
}
160167
utils.save_on_master(
161-
{
162-
'model': model_without_ddp.state_dict(),
163-
'optimizer': optimizer.state_dict(),
164-
'lr_scheduler': lr_scheduler.state_dict(),
165-
'epoch': epoch,
166-
'args': args
167-
},
168+
checkpoint,
168169
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
170+
utils.save_on_master(
171+
checkpoint,
172+
os.path.join(args.output_dir, 'checkpoint.pth'))
169173

170174
total_time = time.time() - start_time
171175
total_time_str = str(datetime.timedelta(seconds=int(total_time)))

0 commit comments

Comments
 (0)