@@ -133,11 +133,6 @@ def main(args):
133
133
model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
134
134
model_without_ddp = model .module
135
135
136
- if args .test_only :
137
- confmat = evaluate (model , data_loader_test , device = device , num_classes = num_classes )
138
- print (confmat )
139
- return
140
-
141
136
params_to_optimize = [
142
137
{"params" : [p for p in model_without_ddp .backbone .parameters () if p .requires_grad ]},
143
138
{"params" : [p for p in model_without_ddp .classifier .parameters () if p .requires_grad ]},
@@ -155,10 +150,16 @@ def main(args):
155
150
156
151
if args .resume :
157
152
checkpoint = torch .load (args .resume , map_location = 'cpu' )
158
- model_without_ddp .load_state_dict (checkpoint ['model' ])
159
- optimizer .load_state_dict (checkpoint ['optimizer' ])
160
- lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
161
- args .start_epoch = checkpoint ['epoch' ] + 1
153
+ model_without_ddp .load_state_dict (checkpoint ['model' ], strict = not args .test_only )
154
+ if not args .test_only :
155
+ optimizer .load_state_dict (checkpoint ['optimizer' ])
156
+ lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
157
+ args .start_epoch = checkpoint ['epoch' ] + 1
158
+
159
+ if args .test_only :
160
+ confmat = evaluate (model , data_loader_test , device = device , num_classes = num_classes )
161
+ print (confmat )
162
+ return
162
163
163
164
start_time = time .time ()
164
165
for epoch in range (args .start_epoch , args .epochs ):
0 commit comments