-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathhybrid_adam.py
197 lines (175 loc) · 8.3 KB
/
hybrid_adam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from typing import Any, Optional
import torch
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import get_current_device, multi_tensor_applier
from .cpu_adam import CPUAdam
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depending on the device of parameters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
`HybridAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
* For parameters updating on CPU, it uses CPUAdam.
* For parameters updating on GPU, it uses FusedAdam.
* Hybrid precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(
self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None,
**defaults: Any,
):
super().__init__(
model_params,
lr,
bias_correction,
betas,
eps,
weight_decay,
adamw_mode,
nvme_offload_fraction,
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
self._pre_step("exp_avg", "exp_avg_sq")
for _, group in enumerate(self.param_groups):
g_l, p_l, m_l, v_l = [], [], [], []
group_step = 0
for _, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
target_device = p.device
if len(state) == 0:
state["step"] = 0
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
state["exp_avg_sq"] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state["step"] += 1
group_step = state["step"]
beta1, beta2 = group["betas"]
if target_device.type == "cpu" or target_device.type == "npu":
assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
self.torch_adam_update(
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
bias_correction1,
bias_correction2,
self.adamw_mode,
)
else:
self.cpu_adam_op.step(
state["step"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
group["bias_correction"],
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
div_scale,
)
self._post_update(p, "exp_avg", "exp_avg_sq")
elif target_device.type == "npu":
assert state["exp_avg"].device.type == "npu", "exp_avg should stay on npu"
assert state["exp_avg_sq"].device.type == "npu", "exp_avg should stay on npu"
# record the state by group and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state["exp_avg"])
v_l.append(state["exp_avg_sq"])
elif target_device.type == "cuda":
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"
# record the state by group and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state["exp_avg"])
v_l.append(state["exp_avg_sq"])
else:
raise RuntimeError
if len(g_l) > 0:
adamw_mode = 1 if self.adamw_mode else 0
bias_correction = 1 if group["bias_correction"] else 0
multi_tensor_applier(
self.gpu_adam_op,
self._dummy_overflow_buf,
[g_l, p_l, m_l, v_l],
group["lr"],
group["betas"][0],
group["betas"][1],
group["eps"],
group_step,
adamw_mode,
bias_correction,
group["weight_decay"],
div_scale,
)
self._post_step()
return loss