Skip to content

Commit d0d2d63

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] [references/classification] Adding gradient clipping (#4824)
Summary: * [references] Adding gradient clipping * ufmt formatting * remove apex code * resolve naming issue Reviewed By: kazhang Differential Revision: D32216659 fbshipit-source-id: 9c5ffb102fa5fd9861ae5ba0c44052920c34ebaf
1 parent 5357fc9 commit d0d2d63

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

references/classification/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
4242
else:
4343
loss = criterion(output, target)
4444
loss.backward()
45+
46+
if args.clip_grad_norm is not None:
47+
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
48+
4549
optimizer.step()
4650

4751
if model_ema and i % args.model_ema_steps == 0:
@@ -472,6 +476,7 @@ def get_args_parser(add_help=True):
472476
parser.add_argument(
473477
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
474478
)
479+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
475480

476481
# Prototype models only
477482
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

references/classification/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,11 @@ def reduce_across_processes(val):
409409
dist.barrier()
410410
dist.all_reduce(t)
411411
return t
412+
413+
414+
def get_optimizer_params(optimizer):
415+
"""Generator to iterate over all parameters in the optimizer param_groups."""
416+
417+
for group in optimizer.param_groups:
418+
for p in group["params"]:
419+
yield p

0 commit comments

Comments
 (0)