Skip to content

Commit 1ebda73

Browse files
datumboxfmassa
andauthored
Load variables when --resume /path/to/checkpoint --test-only (#3285)
Co-authored-by: Francisco Massa <[email protected]>
1 parent 24f5fa5 commit 1ebda73

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

references/segmentation/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,6 @@ def main(args):
133133
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
134134
model_without_ddp = model.module
135135

136-
if args.test_only:
137-
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
138-
print(confmat)
139-
return
140-
141136
params_to_optimize = [
142137
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
143138
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
@@ -155,10 +150,16 @@ def main(args):
155150

156151
if args.resume:
157152
checkpoint = torch.load(args.resume, map_location='cpu')
158-
model_without_ddp.load_state_dict(checkpoint['model'])
159-
optimizer.load_state_dict(checkpoint['optimizer'])
160-
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
161-
args.start_epoch = checkpoint['epoch'] + 1
153+
model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only)
154+
if not args.test_only:
155+
optimizer.load_state_dict(checkpoint['optimizer'])
156+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
157+
args.start_epoch = checkpoint['epoch'] + 1
158+
159+
if args.test_only:
160+
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
161+
print(confmat)
162+
return
162163

163164
start_time = time.time()
164165
for epoch in range(args.start_epoch, args.epochs):

0 commit comments

Comments
 (0)