Skip to content

Commit cca1699

Browse files
xiaohu2015datumbox
andauthored
support amp training for segmention models (#4994)
* support amp training for segmention models * fix lint Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 58016b0 commit cca1699

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

references/segmentation/train.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes):
7272
return confmat
7373

7474

75-
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq):
75+
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
7676
model.train()
7777
metric_logger = utils.MetricLogger(delimiter=" ")
7878
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
7979
header = f"Epoch: [{epoch}]"
8080
for image, target in metric_logger.log_every(data_loader, print_freq, header):
8181
image, target = image.to(device), target.to(device)
82-
output = model(image)
83-
loss = criterion(output, target)
82+
with torch.cuda.amp.autocast(enabled=scaler is not None):
83+
output = model(image)
84+
loss = criterion(output, target)
8485

8586
optimizer.zero_grad()
86-
loss.backward()
87-
optimizer.step()
87+
if scaler is not None:
88+
scaler.scale(loss).backward()
89+
scaler.step(optimizer)
90+
scaler.update()
91+
else:
92+
loss.backward()
93+
optimizer.step()
8894

8995
lr_scheduler.step()
9096

@@ -153,6 +159,8 @@ def main(args):
153159
params_to_optimize.append({"params": params, "lr": args.lr * 10})
154160
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
155161

162+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
163+
156164
iters_per_epoch = len(data_loader)
157165
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
158166
optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
@@ -186,6 +194,8 @@ def main(args):
186194
optimizer.load_state_dict(checkpoint["optimizer"])
187195
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
188196
args.start_epoch = checkpoint["epoch"] + 1
197+
if args.amp:
198+
scaler.load_state_dict(checkpoint["scaler"])
189199

190200
if args.test_only:
191201
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
@@ -196,7 +206,7 @@ def main(args):
196206
for epoch in range(args.start_epoch, args.epochs):
197207
if args.distributed:
198208
train_sampler.set_epoch(epoch)
199-
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
209+
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
200210
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
201211
print(confmat)
202212
checkpoint = {
@@ -206,6 +216,8 @@ def main(args):
206216
"epoch": epoch,
207217
"args": args,
208218
}
219+
if args.amp:
220+
checkpoint["scaler"] = scaler.state_dict()
209221
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
210222
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
211223

@@ -269,6 +281,9 @@ def get_args_parser(add_help=True):
269281
# Prototype models only
270282
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
271283

284+
# Mixed precision training parameters
285+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
286+
272287
return parser
273288

274289

0 commit comments

Comments
 (0)