Skip to content

Commit c36a3ec

Browse files
authored
Merge branch 'main' into add-swag-weight
2 parents 9230f40 + 3925946 commit c36a3ec

File tree

2 files changed

+87
-6
lines changed

2 files changed

+87
-6
lines changed

references/classification/train.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,18 @@ def main(args):
229229

230230
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
231231

232-
if args.norm_weight_decay is None:
233-
parameters = [p for p in model.parameters() if p.requires_grad]
234-
else:
235-
param_groups = torchvision.ops._utils.split_normalization_params(model)
236-
wd_groups = [args.norm_weight_decay, args.weight_decay]
237-
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
232+
custom_keys_weight_decay = []
233+
if args.bias_weight_decay is not None:
234+
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
235+
if args.transformer_embedding_decay is not None:
236+
for key in ["class_token", "position_embedding", "relative_position_bias"]:
237+
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
238+
parameters = utils.set_weight_decay(
239+
model,
240+
args.weight_decay,
241+
norm_weight_decay=args.norm_weight_decay,
242+
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
243+
)
238244

239245
opt_name = args.opt.lower()
240246
if opt_name.startswith("sgd"):
@@ -393,6 +399,18 @@ def get_args_parser(add_help=True):
393399
type=float,
394400
help="weight decay for Normalization layers (default: None, same value as --wd)",
395401
)
402+
parser.add_argument(
403+
"--bias-weight-decay",
404+
default=None,
405+
type=float,
406+
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
407+
)
408+
parser.add_argument(
409+
"--transformer-embedding-decay",
410+
default=None,
411+
type=float,
412+
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
413+
)
396414
parser.add_argument(
397415
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
398416
)

references/classification/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import time
77
from collections import defaultdict, deque, OrderedDict
8+
from typing import List, Optional, Tuple
89

910
import torch
1011
import torch.distributed as dist
@@ -400,3 +401,65 @@ def reduce_across_processes(val):
400401
dist.barrier()
401402
dist.all_reduce(t)
402403
return t
404+
405+
406+
def set_weight_decay(
407+
model: torch.nn.Module,
408+
weight_decay: float,
409+
norm_weight_decay: Optional[float] = None,
410+
norm_classes: Optional[List[type]] = None,
411+
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412+
):
413+
if not norm_classes:
414+
norm_classes = [
415+
torch.nn.modules.batchnorm._BatchNorm,
416+
torch.nn.LayerNorm,
417+
torch.nn.GroupNorm,
418+
torch.nn.modules.instancenorm._InstanceNorm,
419+
torch.nn.LocalResponseNorm,
420+
]
421+
norm_classes = tuple(norm_classes)
422+
423+
params = {
424+
"other": [],
425+
"norm": [],
426+
}
427+
params_weight_decay = {
428+
"other": weight_decay,
429+
"norm": norm_weight_decay,
430+
}
431+
custom_keys = []
432+
if custom_keys_weight_decay is not None:
433+
for key, weight_decay in custom_keys_weight_decay:
434+
params[key] = []
435+
params_weight_decay[key] = weight_decay
436+
custom_keys.append(key)
437+
438+
def _add_params(module, prefix=""):
439+
for name, p in module.named_parameters(recurse=False):
440+
if not p.requires_grad:
441+
continue
442+
is_custom_key = False
443+
for key in custom_keys:
444+
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445+
if key == target_name:
446+
params[key].append(p)
447+
is_custom_key = True
448+
break
449+
if not is_custom_key:
450+
if norm_weight_decay is not None and isinstance(module, norm_classes):
451+
params["norm"].append(p)
452+
else:
453+
params["other"].append(p)
454+
455+
for child_name, child_module in module.named_children():
456+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457+
_add_params(child_module, prefix=child_prefix)
458+
459+
_add_params(model)
460+
461+
param_groups = []
462+
for key in params:
463+
if len(params[key]) > 0:
464+
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465+
return param_groups

0 commit comments

Comments
 (0)