Skip to content

Update torchao.prototype.parq and add 4-bit Llama 3.2 1B benchmark #2017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchao/prototype/parq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .optim import ( # noqa: F401
ProxBinaryRelax,
ProxHardQuant,
Expand Down
6 changes: 6 additions & 0 deletions torchao/prototype/parq/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .binarelax import ProxBinaryRelax # noqa: F401
from .parq import ProxPARQ # noqa: F401
from .proxmap import ProxHardQuant, ProxMap # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/parq/optim/binarelax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
Expand Down Expand Up @@ -41,7 +42,6 @@ def apply_(

if step_count >= self.anneal_end:
p.copy_(q)
return
else:
# linear annealing of relaxation coefficient
theta = (step_count - self.anneal_start) / (
Expand Down
21 changes: 17 additions & 4 deletions torchao/prototype/parq/optim/parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import math
from functools import partial
from typing import Optional
Expand All @@ -23,28 +24,35 @@ def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None):
return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs)


def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float:
def normalized_mirror_sigmoid(
t: float, t1: float, t2: float, s: float, c: float
) -> float:
"""Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2).
s is steepness of the sigmoid-like function, almost linear for s < 1.
'mirror' means decreasing instead of increasing as true sigmoid,
'normalized' means value 1 at starting point t1 and 0 at end point t2."""
assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2"
ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2
st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid
st = 1 / (1 + math.exp(s * (ft - c))) # scaled and shifted mirror sigmoid
s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0
s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1
return (st - s2) / (s1 - s2) # shift and scale to range (0, 1]


class ProxPARQ(ProxMap):
def __init__(
self, anneal_start: int, anneal_end: int, steepness: float = 10
self,
anneal_start: int,
anneal_end: int,
steepness: float = 10,
anneal_center: float = 0.5,
) -> None:
assert anneal_start < anneal_end, "PARQ annealing: start before end."
assert steepness > 0, "PARQ annealing steepness should be positive."
self.anneal_start = anneal_start
self.anneal_end = anneal_end
self.steepness = steepness
self.anneal_center = anneal_center

@torch.no_grad()
@amp_custom_fwd(cast_inputs=torch.float32)
Expand Down Expand Up @@ -72,8 +80,13 @@ def apply_(
p.copy_(q)
else:
inv_slope = normalized_mirror_sigmoid(
step_count, self.anneal_start, self.anneal_end, self.steepness
step_count,
self.anneal_start,
self.anneal_end,
self.steepness,
self.anneal_center,
)
inv_slope = max(torch.finfo(p.dtype).tiny, inv_slope)
# it is important to clamp idx-1 and then clamping idx itself
# idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min()
if dim is None:
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/parq/optim/proxmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Optional

Expand Down
122 changes: 80 additions & 42 deletions torchao/prototype/parq/optim/quantopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from collections.abc import Callable
from functools import partial
from typing import Any, Optional
Expand All @@ -11,15 +13,14 @@
from torch import Tensor
from torch.optim import Optimizer

from ..quant import LSBQuantizer, Quantizer
from ..quant import Quantizer
from ..utils import HAS_DTENSOR, is_dtensor
from .proxmap import ProxMap

try:
from torch.distributed.tensor import DTensor

HAS_DTENSOR = True
except ImportError:
HAS_DTENSOR = False
if HAS_DTENSOR:
from torch.distributed.tensor import distribute_tensor
from torch.distributed.tensor.experimental import local_map
from torch.distributed.tensor.placement_types import Shard


class QuantOptimizer(Optimizer):
Expand All @@ -31,7 +32,7 @@ class QuantOptimizer(Optimizer):
a proximal mapping (e.g, HardQuant/STE, PARQ, BinaryRelax)
- update model parameters based on the above two updates
Other parameters:
- warmup_steps: int > 0
- warmup_steps: int >= 0
- quant_period: int > 0
- quant_per_channel: True or False
- quant_shrink: True or False
Expand Down Expand Up @@ -86,23 +87,23 @@ def __repr__(self) -> str:
extra_repr = "\n ".join(("(", base_optimizer, f"{quantizer=}", f"{prox_map=}"))
return f"{self.__class__.__name__} {extra_repr}\n)"

@property
def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3]
return self._state if hasattr(self, "_state") else self.base_optimizer.state

@staticmethod
def quantize_(
p: Tensor,
quants: Tensor,
quantizer: Quantizer,
b: int,
quant_update: bool,
dim: Optional[int] = None,
) -> Optional[Tensor]:
"""Optionally update the quantization targets `quants` in place.
Return the quantized `p` as a by-product if `quant_update=True`.
"""
if quant_update: # update Q for each channel
q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28]
quants.copy_(Q)
else:
q = None
q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28]
quants.copy_(Q)
return q

def regularized_param_groups(self): # pyre-ignore[3]
Expand All @@ -122,12 +123,13 @@ def state_dict(self) -> dict[str, Any]:
def load_state_dict(
self, state_dict: dict[str, Any], start_step: Optional[int] = None
) -> None:
qat_state = state_dict.pop("qat_state")
qat_state = state_dict.get("qat_state")
# resume from check points usually not corresponds to saved num_steps
# so allow explicit start_step computed from epochs * steps_per_epoc
if start_step is not None:
self.num_steps = start_step
else: # hope discrepancy in num_steps does not cause major problem!
elif qat_state is not None:
# hope discrepancy in num_steps does not cause major problem!
self.num_steps = qat_state["num_steps"]
self.base_optimizer.load_state_dict(state_dict)

Expand All @@ -144,16 +146,23 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
self.num_steps += 1
return loss

# call base optimizer step() method to update latent parameters
loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6]

if self.num_steps == self.warmup_steps:
# first step of qat, save latent params, instead of restore
self.save_latent_params()
else:
# qat: restore latent params for update by the base optimizer
self.restore_latent_params()

# call base optimizer step() method to update latent parameters
loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6]

if hasattr(self, "_state"):
assert self.warmup_steps == 0
# restore the temporary state to the base optimizer's state
for p in self._state.keys():
self.base_optimizer.state[p]["latent"] = self._state[p]["latent"]
del self._state

# check if it is time to update set of quantization values Q
if (self.num_steps - self.warmup_steps) % self.quant_period == 0:
quant_update = True
Expand All @@ -165,6 +174,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
group["cumu_lr"] += group["lr"]
gamma = max(1.0, group["cumu_lr"])
b = group["quant_bits"]
block_size = group.get("quant_block_size")
inv_slope = 0.0
for p in group["params"]:
if not p.requires_grad:
Expand All @@ -177,44 +187,66 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
if self.quant_shrink:
p.div_(gamma)

# reshape p according to block size if specified
if block_size is not None:
assert (
p.size(-1) % block_size == 0
), f"{p.size(-1)=} is not divisible by {block_size=}"
assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}"
if p.dim() == 1:
p = p.unsqueeze(0)

# row-major ordering ensures this is correct
p = p.view(-1, block_size)

# quantization by channel or by layer
# update quantization targets periodically
per_channel = self.quant_per_channel and p.dim() > 1
if quant_update:
quants_size = 3 if b == 0 else 2**b
if per_channel:
quants_size = (p.size(0), quants_size)
state["quants"] = torch.empty(
quants_size, device=p.device
) # pyre-ignore[6]
quant_size = self.quantizer.get_quant_size(b)

# avoid type mismatch between sharded and full tensors
if HAS_DTENSOR and isinstance(p, DTensor):
p = p.full_tensor()
if per_channel:
quant_size = (p.size(0), quant_size)
state["quants"] = torch.empty(quant_size, device=p.device)
if is_dtensor(p):
state["quants"] = distribute_tensor(
state["quants"],
device_mesh=p.device_mesh,
placements=p.placements,
)

dim = -1 if per_channel else None
if per_channel and p.dim() > 2:
p = p.flatten(start_dim=1)

# NOTE: for LSBQ and optimal=False, use faster per-channel
# implementation instead of vmap
if isinstance(self.quantizer, LSBQuantizer) and self.quantizer.optimal:
q = None
if quant_update:
qfunc = partial(
self.quantize_,
quantizer=self.quantizer,
b=b,
quant_update=quant_update,
)
q = torch.vmap(qfunc, in_dims=0, out_dims=0)(p, state["quants"])
else:
q = self.quantize_(
p, state["quants"], self.quantizer, b, quant_update, dim=dim
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
)
if is_dtensor(p):
qfunc = local_map(
qfunc,
out_placements=[*p.placements],
in_placements=([Shard(0)], [Shard(0)]),
)
q = qfunc(p, state["quants"])

# apply (step-dependent) proximal mapping in place
inv_slope = self.prox_map.apply_( # pyre-ignore[28]
p, q, state["quants"], self.num_steps, dim=dim
pfunc = partial(
self.prox_map.apply_, step_count=self.num_steps, dim=dim
)
if is_dtensor(p):
pfunc = local_map(
pfunc,
out_placements=None,
in_placements=(
[Shard(0)],
None if q is None else [Shard(0)],
[Shard(0)],
),
)
inv_slope = pfunc(p, q, state["quants"])

# quantized parameters share the same PARQ inverse slope
if inv_slope:
Expand All @@ -239,6 +271,12 @@ def restore_latent_params(self) -> None:
@torch._disable_dynamo
def save_latent_params(self) -> None:
"""Save updated latent parameters before applying prox-map"""
if self.warmup_steps == 0:
assert len(self.state) == 0, "Expected empty state at first step()"
# Maintain the invariant that `len(self.state) == 0` before first
# self.base_optimizer.step() call by using a temporary state buffer
self._state = defaultdict(dict)

for group in self.regularized_param_groups():
for p in group["params"]:
if p.requires_grad:
Expand Down
12 changes: 11 additions & 1 deletion torchao/prototype/parq/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .lsbq import LSBQuantizer # noqa: F401
from .quantizer import Quantizer # noqa: F401
from .uniform import UnifQuantizer # noqa: F401
from .uniform import ( # noqa: F401
MaxUnifQuantizer,
TernaryUnifQuantizer,
UnifQuantizer,
)
Loading
Loading