@@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes):
72
72
return confmat
73
73
74
74
75
- def train_one_epoch (model , criterion , optimizer , data_loader , lr_scheduler , device , epoch , print_freq ):
75
+ def train_one_epoch (model , criterion , optimizer , data_loader , lr_scheduler , device , epoch , print_freq , scaler = None ):
76
76
model .train ()
77
77
metric_logger = utils .MetricLogger (delimiter = " " )
78
78
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
79
79
header = f"Epoch: [{ epoch } ]"
80
80
for image , target in metric_logger .log_every (data_loader , print_freq , header ):
81
81
image , target = image .to (device ), target .to (device )
82
- output = model (image )
83
- loss = criterion (output , target )
82
+ with torch .cuda .amp .autocast (enabled = scaler is not None ):
83
+ output = model (image )
84
+ loss = criterion (output , target )
84
85
85
86
optimizer .zero_grad ()
86
- loss .backward ()
87
- optimizer .step ()
87
+ if scaler is not None :
88
+ scaler .scale (loss ).backward ()
89
+ scaler .step (optimizer )
90
+ scaler .update ()
91
+ else :
92
+ loss .backward ()
93
+ optimizer .step ()
88
94
89
95
lr_scheduler .step ()
90
96
@@ -153,6 +159,8 @@ def main(args):
153
159
params_to_optimize .append ({"params" : params , "lr" : args .lr * 10 })
154
160
optimizer = torch .optim .SGD (params_to_optimize , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
155
161
162
+ scaler = torch .cuda .amp .GradScaler () if args .amp else None
163
+
156
164
iters_per_epoch = len (data_loader )
157
165
main_lr_scheduler = torch .optim .lr_scheduler .LambdaLR (
158
166
optimizer , lambda x : (1 - x / (iters_per_epoch * (args .epochs - args .lr_warmup_epochs ))) ** 0.9
@@ -186,6 +194,8 @@ def main(args):
186
194
optimizer .load_state_dict (checkpoint ["optimizer" ])
187
195
lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
188
196
args .start_epoch = checkpoint ["epoch" ] + 1
197
+ if args .amp :
198
+ scaler .load_state_dict (checkpoint ["scaler" ])
189
199
190
200
if args .test_only :
191
201
confmat = evaluate (model , data_loader_test , device = device , num_classes = num_classes )
@@ -196,7 +206,7 @@ def main(args):
196
206
for epoch in range (args .start_epoch , args .epochs ):
197
207
if args .distributed :
198
208
train_sampler .set_epoch (epoch )
199
- train_one_epoch (model , criterion , optimizer , data_loader , lr_scheduler , device , epoch , args .print_freq )
209
+ train_one_epoch (model , criterion , optimizer , data_loader , lr_scheduler , device , epoch , args .print_freq , scaler )
200
210
confmat = evaluate (model , data_loader_test , device = device , num_classes = num_classes )
201
211
print (confmat )
202
212
checkpoint = {
@@ -206,6 +216,8 @@ def main(args):
206
216
"epoch" : epoch ,
207
217
"args" : args ,
208
218
}
219
+ if args .amp :
220
+ checkpoint ["scaler" ] = scaler .state_dict ()
209
221
utils .save_on_master (checkpoint , os .path .join (args .output_dir , f"model_{ epoch } .pth" ))
210
222
utils .save_on_master (checkpoint , os .path .join (args .output_dir , "checkpoint.pth" ))
211
223
@@ -269,6 +281,9 @@ def get_args_parser(add_help=True):
269
281
# Prototype models only
270
282
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
271
283
284
+ # Mixed precision training parameters
285
+ parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
286
+
272
287
return parser
273
288
274
289
0 commit comments