|
5 | 5 | import os
|
6 | 6 | import time
|
7 | 7 | from collections import defaultdict, deque, OrderedDict
|
| 8 | +from typing import List, Optional, Tuple |
8 | 9 |
|
9 | 10 | import torch
|
10 | 11 | import torch.distributed as dist
|
@@ -400,3 +401,65 @@ def reduce_across_processes(val):
|
400 | 401 | dist.barrier()
|
401 | 402 | dist.all_reduce(t)
|
402 | 403 | return t
|
| 404 | + |
| 405 | + |
| 406 | +def set_weight_decay( |
| 407 | + model: torch.nn.Module, |
| 408 | + weight_decay: float, |
| 409 | + norm_weight_decay: Optional[float] = None, |
| 410 | + norm_classes: Optional[List[type]] = None, |
| 411 | + custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, |
| 412 | +): |
| 413 | + if not norm_classes: |
| 414 | + norm_classes = [ |
| 415 | + torch.nn.modules.batchnorm._BatchNorm, |
| 416 | + torch.nn.LayerNorm, |
| 417 | + torch.nn.GroupNorm, |
| 418 | + torch.nn.modules.instancenorm._InstanceNorm, |
| 419 | + torch.nn.LocalResponseNorm, |
| 420 | + ] |
| 421 | + norm_classes = tuple(norm_classes) |
| 422 | + |
| 423 | + params = { |
| 424 | + "other": [], |
| 425 | + "norm": [], |
| 426 | + } |
| 427 | + params_weight_decay = { |
| 428 | + "other": weight_decay, |
| 429 | + "norm": norm_weight_decay, |
| 430 | + } |
| 431 | + custom_keys = [] |
| 432 | + if custom_keys_weight_decay is not None: |
| 433 | + for key, weight_decay in custom_keys_weight_decay: |
| 434 | + params[key] = [] |
| 435 | + params_weight_decay[key] = weight_decay |
| 436 | + custom_keys.append(key) |
| 437 | + |
| 438 | + def _add_params(module, prefix=""): |
| 439 | + for name, p in module.named_parameters(recurse=False): |
| 440 | + if not p.requires_grad: |
| 441 | + continue |
| 442 | + is_custom_key = False |
| 443 | + for key in custom_keys: |
| 444 | + target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name |
| 445 | + if key == target_name: |
| 446 | + params[key].append(p) |
| 447 | + is_custom_key = True |
| 448 | + break |
| 449 | + if not is_custom_key: |
| 450 | + if norm_weight_decay is not None and isinstance(module, norm_classes): |
| 451 | + params["norm"].append(p) |
| 452 | + else: |
| 453 | + params["other"].append(p) |
| 454 | + |
| 455 | + for child_name, child_module in module.named_children(): |
| 456 | + child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name |
| 457 | + _add_params(child_module, prefix=child_prefix) |
| 458 | + |
| 459 | + _add_params(model) |
| 460 | + |
| 461 | + param_groups = [] |
| 462 | + for key in params: |
| 463 | + if len(params[key]) > 0: |
| 464 | + param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) |
| 465 | + return param_groups |
0 commit comments