Skip to content

Commit 3bbf42a

Browse files
authored
Update torchao.prototype.parq and add 4-bit Llama 3.2 1B benchmark (#2017)
Replace torchao.prototype.parq with facebookresearch/parq submodule
1 parent 5a78b70 commit 3bbf42a

File tree

11 files changed

+420
-91
lines changed

11 files changed

+420
-91
lines changed

torchao/prototype/parq/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from .optim import ( # noqa: F401
28
ProxBinaryRelax,
39
ProxHardQuant,

torchao/prototype/parq/optim/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from .binarelax import ProxBinaryRelax # noqa: F401
28
from .parq import ProxPARQ # noqa: F401
39
from .proxmap import ProxHardQuant, ProxMap # noqa: F401

torchao/prototype/parq/optim/binarelax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
from typing import Optional
78

89
import torch
@@ -41,7 +42,6 @@ def apply_(
4142

4243
if step_count >= self.anneal_end:
4344
p.copy_(q)
44-
return
4545
else:
4646
# linear annealing of relaxation coefficient
4747
theta = (step_count - self.anneal_start) / (

torchao/prototype/parq/optim/parq.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
import math
78
from functools import partial
89
from typing import Optional
@@ -23,28 +24,35 @@ def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None):
2324
return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs)
2425

2526

26-
def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float:
27+
def normalized_mirror_sigmoid(
28+
t: float, t1: float, t2: float, s: float, c: float
29+
) -> float:
2730
"""Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2).
2831
s is steepness of the sigmoid-like function, almost linear for s < 1.
2932
'mirror' means decreasing instead of increasing as true sigmoid,
3033
'normalized' means value 1 at starting point t1 and 0 at end point t2."""
3134
assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2"
3235
ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2
33-
st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid
36+
st = 1 / (1 + math.exp(s * (ft - c))) # scaled and shifted mirror sigmoid
3437
s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0
3538
s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1
3639
return (st - s2) / (s1 - s2) # shift and scale to range (0, 1]
3740

3841

3942
class ProxPARQ(ProxMap):
4043
def __init__(
41-
self, anneal_start: int, anneal_end: int, steepness: float = 10
44+
self,
45+
anneal_start: int,
46+
anneal_end: int,
47+
steepness: float = 10,
48+
anneal_center: float = 0.5,
4249
) -> None:
4350
assert anneal_start < anneal_end, "PARQ annealing: start before end."
4451
assert steepness > 0, "PARQ annealing steepness should be positive."
4552
self.anneal_start = anneal_start
4653
self.anneal_end = anneal_end
4754
self.steepness = steepness
55+
self.anneal_center = anneal_center
4856

4957
@torch.no_grad()
5058
@amp_custom_fwd(cast_inputs=torch.float32)
@@ -72,8 +80,13 @@ def apply_(
7280
p.copy_(q)
7381
else:
7482
inv_slope = normalized_mirror_sigmoid(
75-
step_count, self.anneal_start, self.anneal_end, self.steepness
83+
step_count,
84+
self.anneal_start,
85+
self.anneal_end,
86+
self.steepness,
87+
self.anneal_center,
7688
)
89+
inv_slope = max(torch.finfo(p.dtype).tiny, inv_slope)
7790
# it is important to clamp idx-1 and then clamping idx itself
7891
# idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min()
7992
if dim is None:

torchao/prototype/parq/optim/proxmap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
from abc import ABC, abstractmethod
78
from typing import Optional
89

torchao/prototype/parq/optim/quantopt.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
7+
from collections import defaultdict
68
from collections.abc import Callable
79
from functools import partial
810
from typing import Any, Optional
@@ -11,15 +13,14 @@
1113
from torch import Tensor
1214
from torch.optim import Optimizer
1315

14-
from ..quant import LSBQuantizer, Quantizer
16+
from ..quant import Quantizer
17+
from ..utils import HAS_DTENSOR, is_dtensor
1518
from .proxmap import ProxMap
1619

17-
try:
18-
from torch.distributed.tensor import DTensor
19-
20-
HAS_DTENSOR = True
21-
except ImportError:
22-
HAS_DTENSOR = False
20+
if HAS_DTENSOR:
21+
from torch.distributed.tensor import distribute_tensor
22+
from torch.distributed.tensor.experimental import local_map
23+
from torch.distributed.tensor.placement_types import Shard
2324

2425

2526
class QuantOptimizer(Optimizer):
@@ -31,7 +32,7 @@ class QuantOptimizer(Optimizer):
3132
a proximal mapping (e.g, HardQuant/STE, PARQ, BinaryRelax)
3233
- update model parameters based on the above two updates
3334
Other parameters:
34-
- warmup_steps: int > 0
35+
- warmup_steps: int >= 0
3536
- quant_period: int > 0
3637
- quant_per_channel: True or False
3738
- quant_shrink: True or False
@@ -86,23 +87,23 @@ def __repr__(self) -> str:
8687
extra_repr = "\n ".join(("(", base_optimizer, f"{quantizer=}", f"{prox_map=}"))
8788
return f"{self.__class__.__name__} {extra_repr}\n)"
8889

90+
@property
91+
def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3]
92+
return self._state if hasattr(self, "_state") else self.base_optimizer.state
93+
8994
@staticmethod
9095
def quantize_(
9196
p: Tensor,
9297
quants: Tensor,
9398
quantizer: Quantizer,
9499
b: int,
95-
quant_update: bool,
96100
dim: Optional[int] = None,
97101
) -> Optional[Tensor]:
98102
"""Optionally update the quantization targets `quants` in place.
99103
Return the quantized `p` as a by-product if `quant_update=True`.
100104
"""
101-
if quant_update: # update Q for each channel
102-
q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28]
103-
quants.copy_(Q)
104-
else:
105-
q = None
105+
q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28]
106+
quants.copy_(Q)
106107
return q
107108

108109
def regularized_param_groups(self): # pyre-ignore[3]
@@ -122,12 +123,13 @@ def state_dict(self) -> dict[str, Any]:
122123
def load_state_dict(
123124
self, state_dict: dict[str, Any], start_step: Optional[int] = None
124125
) -> None:
125-
qat_state = state_dict.pop("qat_state")
126+
qat_state = state_dict.get("qat_state")
126127
# resume from check points usually not corresponds to saved num_steps
127128
# so allow explicit start_step computed from epochs * steps_per_epoc
128129
if start_step is not None:
129130
self.num_steps = start_step
130-
else: # hope discrepancy in num_steps does not cause major problem!
131+
elif qat_state is not None:
132+
# hope discrepancy in num_steps does not cause major problem!
131133
self.num_steps = qat_state["num_steps"]
132134
self.base_optimizer.load_state_dict(state_dict)
133135

@@ -144,16 +146,23 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
144146
self.num_steps += 1
145147
return loss
146148

147-
# call base optimizer step() method to update latent parameters
148-
loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6]
149-
150149
if self.num_steps == self.warmup_steps:
151150
# first step of qat, save latent params, instead of restore
152151
self.save_latent_params()
153152
else:
154153
# qat: restore latent params for update by the base optimizer
155154
self.restore_latent_params()
156155

156+
# call base optimizer step() method to update latent parameters
157+
loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6]
158+
159+
if hasattr(self, "_state"):
160+
assert self.warmup_steps == 0
161+
# restore the temporary state to the base optimizer's state
162+
for p in self._state.keys():
163+
self.base_optimizer.state[p]["latent"] = self._state[p]["latent"]
164+
del self._state
165+
157166
# check if it is time to update set of quantization values Q
158167
if (self.num_steps - self.warmup_steps) % self.quant_period == 0:
159168
quant_update = True
@@ -165,6 +174,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
165174
group["cumu_lr"] += group["lr"]
166175
gamma = max(1.0, group["cumu_lr"])
167176
b = group["quant_bits"]
177+
block_size = group.get("quant_block_size")
168178
inv_slope = 0.0
169179
for p in group["params"]:
170180
if not p.requires_grad:
@@ -177,44 +187,66 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
177187
if self.quant_shrink:
178188
p.div_(gamma)
179189

190+
# reshape p according to block size if specified
191+
if block_size is not None:
192+
assert (
193+
p.size(-1) % block_size == 0
194+
), f"{p.size(-1)=} is not divisible by {block_size=}"
195+
assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}"
196+
if p.dim() == 1:
197+
p = p.unsqueeze(0)
198+
199+
# row-major ordering ensures this is correct
200+
p = p.view(-1, block_size)
201+
180202
# quantization by channel or by layer
181203
# update quantization targets periodically
182204
per_channel = self.quant_per_channel and p.dim() > 1
183205
if quant_update:
184-
quants_size = 3 if b == 0 else 2**b
185-
if per_channel:
186-
quants_size = (p.size(0), quants_size)
187-
state["quants"] = torch.empty(
188-
quants_size, device=p.device
189-
) # pyre-ignore[6]
206+
quant_size = self.quantizer.get_quant_size(b)
190207

191-
# avoid type mismatch between sharded and full tensors
192-
if HAS_DTENSOR and isinstance(p, DTensor):
193-
p = p.full_tensor()
208+
if per_channel:
209+
quant_size = (p.size(0), quant_size)
210+
state["quants"] = torch.empty(quant_size, device=p.device)
211+
if is_dtensor(p):
212+
state["quants"] = distribute_tensor(
213+
state["quants"],
214+
device_mesh=p.device_mesh,
215+
placements=p.placements,
216+
)
194217

195218
dim = -1 if per_channel else None
196219
if per_channel and p.dim() > 2:
197220
p = p.flatten(start_dim=1)
198221

199-
# NOTE: for LSBQ and optimal=False, use faster per-channel
200-
# implementation instead of vmap
201-
if isinstance(self.quantizer, LSBQuantizer) and self.quantizer.optimal:
222+
q = None
223+
if quant_update:
202224
qfunc = partial(
203-
self.quantize_,
204-
quantizer=self.quantizer,
205-
b=b,
206-
quant_update=quant_update,
207-
)
208-
q = torch.vmap(qfunc, in_dims=0, out_dims=0)(p, state["quants"])
209-
else:
210-
q = self.quantize_(
211-
p, state["quants"], self.quantizer, b, quant_update, dim=dim
225+
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
212226
)
227+
if is_dtensor(p):
228+
qfunc = local_map(
229+
qfunc,
230+
out_placements=[*p.placements],
231+
in_placements=([Shard(0)], [Shard(0)]),
232+
)
233+
q = qfunc(p, state["quants"])
213234

214235
# apply (step-dependent) proximal mapping in place
215-
inv_slope = self.prox_map.apply_( # pyre-ignore[28]
216-
p, q, state["quants"], self.num_steps, dim=dim
236+
pfunc = partial(
237+
self.prox_map.apply_, step_count=self.num_steps, dim=dim
217238
)
239+
if is_dtensor(p):
240+
pfunc = local_map(
241+
pfunc,
242+
out_placements=None,
243+
in_placements=(
244+
[Shard(0)],
245+
None if q is None else [Shard(0)],
246+
[Shard(0)],
247+
),
248+
)
249+
inv_slope = pfunc(p, q, state["quants"])
218250

219251
# quantized parameters share the same PARQ inverse slope
220252
if inv_slope:
@@ -239,6 +271,12 @@ def restore_latent_params(self) -> None:
239271
@torch._disable_dynamo
240272
def save_latent_params(self) -> None:
241273
"""Save updated latent parameters before applying prox-map"""
274+
if self.warmup_steps == 0:
275+
assert len(self.state) == 0, "Expected empty state at first step()"
276+
# Maintain the invariant that `len(self.state) == 0` before first
277+
# self.base_optimizer.step() call by using a temporary state buffer
278+
self._state = defaultdict(dict)
279+
242280
for group in self.regularized_param_groups():
243281
for p in group["params"]:
244282
if p.requires_grad:
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from .lsbq import LSBQuantizer # noqa: F401
28
from .quantizer import Quantizer # noqa: F401
3-
from .uniform import UnifQuantizer # noqa: F401
9+
from .uniform import ( # noqa: F401
10+
MaxUnifQuantizer,
11+
TernaryUnifQuantizer,
12+
UnifQuantizer,
13+
)

0 commit comments

Comments
 (0)