diff --git a/torchao/prototype/parq/__init__.py b/torchao/prototype/parq/__init__.py index 07353d2461..2239139495 100644 --- a/torchao/prototype/parq/__init__.py +++ b/torchao/prototype/parq/__init__.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .optim import ( # noqa: F401 ProxBinaryRelax, ProxHardQuant, diff --git a/torchao/prototype/parq/optim/__init__.py b/torchao/prototype/parq/optim/__init__.py index 627bedb4dd..237a058a12 100644 --- a/torchao/prototype/parq/optim/__init__.py +++ b/torchao/prototype/parq/optim/__init__.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .binarelax import ProxBinaryRelax # noqa: F401 from .parq import ProxPARQ # noqa: F401 from .proxmap import ProxHardQuant, ProxMap # noqa: F401 diff --git a/torchao/prototype/parq/optim/binarelax.py b/torchao/prototype/parq/optim/binarelax.py index 0ce88d8ccb..2cc6611f6b 100644 --- a/torchao/prototype/parq/optim/binarelax.py +++ b/torchao/prototype/parq/optim/binarelax.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from typing import Optional import torch @@ -41,7 +42,6 @@ def apply_( if step_count >= self.anneal_end: p.copy_(q) - return else: # linear annealing of relaxation coefficient theta = (step_count - self.anneal_start) / ( diff --git a/torchao/prototype/parq/optim/parq.py b/torchao/prototype/parq/optim/parq.py index b756efd3ec..ade403a87d 100644 --- a/torchao/prototype/parq/optim/parq.py +++ b/torchao/prototype/parq/optim/parq.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import math from functools import partial from typing import Optional @@ -23,14 +24,16 @@ def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None): return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs) -def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float: +def normalized_mirror_sigmoid( + t: float, t1: float, t2: float, s: float, c: float +) -> float: """Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2). s is steepness of the sigmoid-like function, almost linear for s < 1. 'mirror' means decreasing instead of increasing as true sigmoid, 'normalized' means value 1 at starting point t1 and 0 at end point t2.""" assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2" ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2 - st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid + st = 1 / (1 + math.exp(s * (ft - c))) # scaled and shifted mirror sigmoid s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0 s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1 return (st - s2) / (s1 - s2) # shift and scale to range (0, 1] @@ -38,13 +41,18 @@ def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float class ProxPARQ(ProxMap): def __init__( - self, anneal_start: int, anneal_end: int, steepness: float = 10 + self, + anneal_start: int, + anneal_end: int, + steepness: float = 10, + anneal_center: float = 0.5, ) -> None: assert anneal_start < anneal_end, "PARQ annealing: start before end." assert steepness > 0, "PARQ annealing steepness should be positive." self.anneal_start = anneal_start self.anneal_end = anneal_end self.steepness = steepness + self.anneal_center = anneal_center @torch.no_grad() @amp_custom_fwd(cast_inputs=torch.float32) @@ -72,8 +80,13 @@ def apply_( p.copy_(q) else: inv_slope = normalized_mirror_sigmoid( - step_count, self.anneal_start, self.anneal_end, self.steepness + step_count, + self.anneal_start, + self.anneal_end, + self.steepness, + self.anneal_center, ) + inv_slope = max(torch.finfo(p.dtype).tiny, inv_slope) # it is important to clamp idx-1 and then clamping idx itself # idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min() if dim is None: diff --git a/torchao/prototype/parq/optim/proxmap.py b/torchao/prototype/parq/optim/proxmap.py index 0bfbf57498..da867cc5db 100644 --- a/torchao/prototype/parq/optim/proxmap.py +++ b/torchao/prototype/parq/optim/proxmap.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Optional diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 016aea28bc..7ebc1a80a0 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + +from collections import defaultdict from collections.abc import Callable from functools import partial from typing import Any, Optional @@ -11,15 +13,14 @@ from torch import Tensor from torch.optim import Optimizer -from ..quant import LSBQuantizer, Quantizer +from ..quant import Quantizer +from ..utils import HAS_DTENSOR, is_dtensor from .proxmap import ProxMap -try: - from torch.distributed.tensor import DTensor - - HAS_DTENSOR = True -except ImportError: - HAS_DTENSOR = False +if HAS_DTENSOR: + from torch.distributed.tensor import distribute_tensor + from torch.distributed.tensor.experimental import local_map + from torch.distributed.tensor.placement_types import Shard class QuantOptimizer(Optimizer): @@ -31,7 +32,7 @@ class QuantOptimizer(Optimizer): a proximal mapping (e.g, HardQuant/STE, PARQ, BinaryRelax) - update model parameters based on the above two updates Other parameters: - - warmup_steps: int > 0 + - warmup_steps: int >= 0 - quant_period: int > 0 - quant_per_channel: True or False - quant_shrink: True or False @@ -86,23 +87,23 @@ def __repr__(self) -> str: extra_repr = "\n ".join(("(", base_optimizer, f"{quantizer=}", f"{prox_map=}")) return f"{self.__class__.__name__} {extra_repr}\n)" + @property + def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3] + return self._state if hasattr(self, "_state") else self.base_optimizer.state + @staticmethod def quantize_( p: Tensor, quants: Tensor, quantizer: Quantizer, b: int, - quant_update: bool, dim: Optional[int] = None, ) -> Optional[Tensor]: """Optionally update the quantization targets `quants` in place. Return the quantized `p` as a by-product if `quant_update=True`. """ - if quant_update: # update Q for each channel - q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28] - quants.copy_(Q) - else: - q = None + q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28] + quants.copy_(Q) return q def regularized_param_groups(self): # pyre-ignore[3] @@ -122,12 +123,13 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict( self, state_dict: dict[str, Any], start_step: Optional[int] = None ) -> None: - qat_state = state_dict.pop("qat_state") + qat_state = state_dict.get("qat_state") # resume from check points usually not corresponds to saved num_steps # so allow explicit start_step computed from epochs * steps_per_epoc if start_step is not None: self.num_steps = start_step - else: # hope discrepancy in num_steps does not cause major problem! + elif qat_state is not None: + # hope discrepancy in num_steps does not cause major problem! self.num_steps = qat_state["num_steps"] self.base_optimizer.load_state_dict(state_dict) @@ -144,9 +146,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] self.num_steps += 1 return loss - # call base optimizer step() method to update latent parameters - loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] - if self.num_steps == self.warmup_steps: # first step of qat, save latent params, instead of restore self.save_latent_params() @@ -154,6 +153,16 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # qat: restore latent params for update by the base optimizer self.restore_latent_params() + # call base optimizer step() method to update latent parameters + loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] + + if hasattr(self, "_state"): + assert self.warmup_steps == 0 + # restore the temporary state to the base optimizer's state + for p in self._state.keys(): + self.base_optimizer.state[p]["latent"] = self._state[p]["latent"] + del self._state + # check if it is time to update set of quantization values Q if (self.num_steps - self.warmup_steps) % self.quant_period == 0: quant_update = True @@ -165,6 +174,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] group["cumu_lr"] += group["lr"] gamma = max(1.0, group["cumu_lr"]) b = group["quant_bits"] + block_size = group.get("quant_block_size") inv_slope = 0.0 for p in group["params"]: if not p.requires_grad: @@ -177,44 +187,66 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] if self.quant_shrink: p.div_(gamma) + # reshape p according to block size if specified + if block_size is not None: + assert ( + p.size(-1) % block_size == 0 + ), f"{p.size(-1)=} is not divisible by {block_size=}" + assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}" + if p.dim() == 1: + p = p.unsqueeze(0) + + # row-major ordering ensures this is correct + p = p.view(-1, block_size) + # quantization by channel or by layer # update quantization targets periodically per_channel = self.quant_per_channel and p.dim() > 1 if quant_update: - quants_size = 3 if b == 0 else 2**b - if per_channel: - quants_size = (p.size(0), quants_size) - state["quants"] = torch.empty( - quants_size, device=p.device - ) # pyre-ignore[6] + quant_size = self.quantizer.get_quant_size(b) - # avoid type mismatch between sharded and full tensors - if HAS_DTENSOR and isinstance(p, DTensor): - p = p.full_tensor() + if per_channel: + quant_size = (p.size(0), quant_size) + state["quants"] = torch.empty(quant_size, device=p.device) + if is_dtensor(p): + state["quants"] = distribute_tensor( + state["quants"], + device_mesh=p.device_mesh, + placements=p.placements, + ) dim = -1 if per_channel else None if per_channel and p.dim() > 2: p = p.flatten(start_dim=1) - # NOTE: for LSBQ and optimal=False, use faster per-channel - # implementation instead of vmap - if isinstance(self.quantizer, LSBQuantizer) and self.quantizer.optimal: + q = None + if quant_update: qfunc = partial( - self.quantize_, - quantizer=self.quantizer, - b=b, - quant_update=quant_update, - ) - q = torch.vmap(qfunc, in_dims=0, out_dims=0)(p, state["quants"]) - else: - q = self.quantize_( - p, state["quants"], self.quantizer, b, quant_update, dim=dim + self.quantize_, quantizer=self.quantizer, b=b, dim=dim ) + if is_dtensor(p): + qfunc = local_map( + qfunc, + out_placements=[*p.placements], + in_placements=([Shard(0)], [Shard(0)]), + ) + q = qfunc(p, state["quants"]) # apply (step-dependent) proximal mapping in place - inv_slope = self.prox_map.apply_( # pyre-ignore[28] - p, q, state["quants"], self.num_steps, dim=dim + pfunc = partial( + self.prox_map.apply_, step_count=self.num_steps, dim=dim ) + if is_dtensor(p): + pfunc = local_map( + pfunc, + out_placements=None, + in_placements=( + [Shard(0)], + None if q is None else [Shard(0)], + [Shard(0)], + ), + ) + inv_slope = pfunc(p, q, state["quants"]) # quantized parameters share the same PARQ inverse slope if inv_slope: @@ -239,6 +271,12 @@ def restore_latent_params(self) -> None: @torch._disable_dynamo def save_latent_params(self) -> None: """Save updated latent parameters before applying prox-map""" + if self.warmup_steps == 0: + assert len(self.state) == 0, "Expected empty state at first step()" + # Maintain the invariant that `len(self.state) == 0` before first + # self.base_optimizer.step() call by using a temporary state buffer + self._state = defaultdict(dict) + for group in self.regularized_param_groups(): for p in group["params"]: if p.requires_grad: diff --git a/torchao/prototype/parq/quant/__init__.py b/torchao/prototype/parq/quant/__init__.py index b7251f2df1..8835740975 100644 --- a/torchao/prototype/parq/quant/__init__.py +++ b/torchao/prototype/parq/quant/__init__.py @@ -1,3 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .lsbq import LSBQuantizer # noqa: F401 from .quantizer import Quantizer # noqa: F401 -from .uniform import UnifQuantizer # noqa: F401 +from .uniform import ( # noqa: F401 + MaxUnifQuantizer, + TernaryUnifQuantizer, + UnifQuantizer, +) diff --git a/torchao/prototype/parq/quant/lsbq.py b/torchao/prototype/parq/quant/lsbq.py index e821b8f460..2d9f4e4c1e 100644 --- a/torchao/prototype/parq/quant/lsbq.py +++ b/torchao/prototype/parq/quant/lsbq.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import itertools from collections.abc import Iterable from typing import Optional @@ -23,10 +24,58 @@ def binary_quant_residue(u: Tensor, vs: Iterable[float]) -> Tensor: """Return residue for foldable binary quantization""" r = u.detach().clone() for v in vs: - r -= v * binary_sign(r) + r.sub_(v * binary_sign(r)) return r +def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool = False): + """Vectorized computation of optimal `v` for ternary/2-bit algorithm.""" + v_cands = p.abs().sort(dim=dim).values + cumsum = v_cands.cumsum(dim=dim) + cumsum, total_sum = cumsum[:, 1:-1], cumsum[:, -1:] + + # compute cumulative mean from right to left + counts = torch.arange(1, p.size(dim=dim), device=p.device) + counts_r2l = counts[:-1].flip((-1,)) + cmean_r2l = (total_sum - cumsum).div_(counts_r2l.mul_(2)) + v_cands, v_cands2 = v_cands[:, 1:-1], v_cands[:, 2:] + + # mask to estimate conditional expectation + mask = (v_cands <= cmean_r2l).logical_and_(v_cands2 >= cmean_r2l) + if ternary: + # detect and fix any edge cases + optimal_v = p.mean(dim=dim, keepdim=True).div_(2) + row_invalid = optimal_v < p.min(dim=dim, keepdim=True).values + if row_invalid.any(): + extra_col = row_invalid.to(p.dtype).mul(optimal_v) + v_cands = torch.cat((v_cands, extra_col), -1) + mask = torch.cat((mask, row_invalid), -1) + else: + # compute cumulative mean from left to right + cmean_l2r = cumsum.div_(counts[1:].mul_(2)).add_(cmean_r2l) + mask.logical_or_((v_cands <= cmean_l2r).logical_and_(v_cands2 >= cmean_l2r)) + + # handle variable number of candidates per channel + split_sizes = mask.sum(dim=dim).tolist() + v_cands = v_cands[mask].split(split_sizes) + v_cands = torch.nested.nested_tensor(list(v_cands)) + v_cands = torch.nested.to_padded_tensor(v_cands, 0.0) + + # update residual for each candidate `v` + r = p.unsqueeze(dim - 1) + v = v_cands.unsqueeze(-1) + r = r.sub(v * binary_sign(r)) + if not ternary: + v = v.mean(dim=dim, keepdim=True) + r = r.sub(v * binary_sign(r)) + + # compute least squares error, then select the `v` minimizes it + costs = r.norm(dim=dim) + indices = costs.argmin(dim=dim, keepdim=True) + v_best = v_cands.gather(1, indices) + return v_best + + class LSBQuantizer(Quantizer): """Least-Square Binary Quantizer, using greedy algorithm by default. Optimal solution available for three cases: b=1, b=2 and ternary. @@ -44,25 +93,31 @@ def __init__( self.optimal = optimal self.ternary_multiplier = ternary_multiplier + def get_quant_size(self, b: int) -> int: + return 2**b if b > 0 else 3 + def quantize( self, p: Tensor, b: int, dim: Optional[int] = None ) -> tuple[Tensor, Tensor]: """Instantiation of Quantizer.quantize(), with b=0 for ternary""" - assert b >= 0 # b==0 means ternary + if b < 0: + raise ValueError(f"Invalid {b=}; must be nonnegative") + if self.optimal and b > 2: + raise NotImplementedError(f"Unsupported {self.optimal=} for {b=}") + if self.center: q, mean = super().remove_mean(p.detach(), dim=dim) else: q = p.detach().clone() mean = torch.zeros(1, dtype=p.dtype, device=p.device) - # b == 0 means ternary; b == 1 optimal same as greedy - if b == 0: - if self.optimal: - q, Q = self.quantize_optimal_ternary(q) - else: - q, Q = self.quantize_simple_ternary(q, self.ternary_multiplier, dim=dim) - elif b == 2 and self.optimal: - q, Q = self.quantize_optimal_2bits(q) + if self.optimal and b != 1: # b == 1 optimal is the same as greedy + if b == 0: + q, Q = self.quantize_optimal_ternary(q, dim=dim) + elif b == 2: + q, Q = self.quantize_optimal_2bits(q, dim=dim) + elif b == 0: + q, Q = self.quantize_simple_ternary(q, self.ternary_multiplier, dim=dim) else: q, Q = self.quantize_greedy(q, b, dim=dim) @@ -81,7 +136,7 @@ def quantize_greedy( keepdim = dim is not None for _ in range(b): v = r.abs().mean(dim=dim, keepdim=keepdim) - r -= v * binary_sign(r) + r.sub_(binary_sign(r).mul_(v)) vs.append(v) q = p - r @@ -90,16 +145,32 @@ def quantize_greedy( B = torch.tensor(basis, dtype=p.dtype, device=p.device) if dim is not None: V = torch.concat(vs, dim=1) # [dim0, b] - Q = torch.sort(V @ B.T, dim=dim)[0] # [dim0, 2^b] + Q = torch.sort(V @ B.T, dim=dim).values # [dim0, 2^b] else: V = torch.tensor(vs, dtype=p.dtype, device=p.device) Q = torch.msort(B.matmul(V)) # [2^b] return q, Q @staticmethod - def quantize_optimal_2bits(p: Tensor) -> tuple[Tensor, Tensor]: + def quantize_optimal_2bits( + p: Tensor, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + # generate 4 x 2 basis tensor B, sorted lexicographically along dim 0 + basis = list(itertools.product((-1, 1), repeat=2)) + B = torch.tensor(basis, dtype=p.dtype, device=p.device) + if dim is not None: + v1 = compute_v_per_channel(p, dim=dim, ternary=False) + s = binary_sign(p).mul_(v1) + r = p.sub(s) + v2 = r.abs().mean(dim=dim, keepdim=True) + q = s.add_(binary_sign(r).mul_(v2)) + + V = torch.cat((v1, v2), dim=-1) # [dim0, b] + Q = V @ B.T # [dim0, 2^b] + return q, Q + # first form the cumulative sum of sorted absolute values of p - p_abs_sorted = torch.msort(torch.flatten(p.abs())) + p_abs_sorted = p.abs().flatten().sort().values cumsum = torch.cumsum(p_abs_sorted, dim=0) n = cumsum.numel() # find all solutions v1 to an inclusion problem (after sorting |p|) @@ -133,18 +204,31 @@ def quantize_optimal_2bits(p: Tensor) -> tuple[Tensor, Tensor]: min_error = error q = p - r v1, v2 = v1v2 - # generate 4 x 2 basis tensor B, sorted lexicographically along dim 0 - basis = list(itertools.product((-1, 1), repeat=2)) - B = torch.tensor(basis, dtype=p.dtype, device=p.device) - # vmap workaround: calling torch.tensor on v1, v2 raises an error - Q = v1 * B[:, 0] + v2 * B[:, 1] + + V = torch.tensor((v1, v2), dtype=p.dtype, device=p.device) + Q = B @ V return q, Q @staticmethod - def quantize_optimal_ternary(p: Tensor) -> tuple[Tensor, Tensor]: + def quantize_optimal_ternary( + p: Tensor, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: """Formula look reasonable, but derivation in reference incorrect?""" + if dim is not None: + v = compute_v_per_channel(p, dim=dim, ternary=True) + p_sign = binary_sign(p) + r = p.sub(p_sign.mul(v)) + + # 0 if sign(p) != sign(r), else sign(p) * 2v + q = p_sign.add_(binary_sign(r)).mul_(v) + + # each channel can take values [-2v, 0, 2v] + v.mul_(2) + Q = torch.cat((-v, torch.zeros_like(v), v), dim=-1) # [dim0, 3] + return q, Q + # first form the cumulative sum of sorted absolute values of p - p_abs_sorted = torch.msort(torch.flatten(p.abs())) + p_abs_sorted = p.abs().flatten().sort().values cumsum = torch.cumsum(p_abs_sorted, dim=0) n = cumsum.numel() # find all solutions v1 to an inclusion problem (after sorting |p|) diff --git a/torchao/prototype/parq/quant/quantizer.py b/torchao/prototype/parq/quant/quantizer.py index 22dd2a1bb7..b44050e773 100644 --- a/torchao/prototype/parq/quant/quantizer.py +++ b/torchao/prototype/parq/quant/quantizer.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Optional @@ -15,6 +16,10 @@ class Quantizer(ABC): def __init__(self, center: bool = False) -> None: self.center = center + @abstractmethod + def get_quant_size(self, b: int) -> int: + """Given number of bits b, return total number of quantization values""" + @abstractmethod def quantize(self, p: Tensor, b: int) -> tuple[Tensor, Tensor]: """Provide interface for quantization: diff --git a/torchao/prototype/parq/quant/uniform.py b/torchao/prototype/parq/quant/uniform.py index de9e465bc0..f264894115 100644 --- a/torchao/prototype/parq/quant/uniform.py +++ b/torchao/prototype/parq/quant/uniform.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + +import math from typing import Optional import torch @@ -11,50 +13,202 @@ from .quantizer import Quantizer +def get_q_max( + q: Tensor, b: int, dim: Optional[int] = None, scale_method: str = "mean" +) -> Tensor: + if scale_method == "mean": + # set range of quantization: min(b * |q|.mean(), |q|.max()) + q_abs = q.abs() + if dim is not None: + q_max = torch.minimum( + b * q_abs.mean(dim=dim, keepdim=True), # pyre-ignore[6,9] + torch.max(q_abs, dim=dim, keepdim=True).values, # pyre-ignore[6] + ) + else: + q_max = torch.minimum(b * q_abs.mean(), torch.max(q_abs)) # pyre-ignore[6] + elif scale_method == "max": + q_max = ( + q.abs().max(dim=dim, keepdim=True).values + if dim is not None + else q.abs().max() + ) + else: + raise NotImplementedError(f"Invalid {scale_method=}, choices=('mean','max')") + return q_max + + class UnifQuantizer(Quantizer): - """Uniform quantizer, range determined by multiples of |p|.mean()""" + """Uniform and symmetric quantizer""" + + def __init__( + self, + center: bool = False, + scale_method: str = "mean", + int_shift: float = 0.5, + zero_point: float = 0.5, + ): + """Set quantization function parameters. + + Args: + center: whether to subtract p.mean() prior to quantization + scale_method: compute scale based 'mean', multiples of |p|.mean(), + or 'max', |p|.max() (default: 'mean') + int_shift: float value to shift the lower bound of integer range by: + -2^{b - 1} + int_shift (default: 0.5). Using 0.5 results in 2^b + values. E.g., [-1.5, -0.5, 0.5, 1.5] for b=2. + zero_point: float value to shift p by after scale and round. + """ + assert scale_method in ("max", "mean"), f"Invalid {scale_method=}" + super().__init__(center=center) + + self.scale_method = scale_method + self.int_shift = int_shift + self.zero_point = zero_point + + def get_quant_size(self, b: int) -> int: + """Levels in [-2^{b-1} + self.int_shift, 2^{b-1} - self.int_shift]. - def __init__(self, center: bool = False) -> None: - super().__init__(center) + Note that range_absmax = 2^{b-1} - self.int_shift on both ends of the + boundary and the interval is closed.""" + return math.floor(2**b - 2 * self.int_shift) + 1 def quantize( self, p: Tensor, b: int, dim: Optional[int] = None ) -> tuple[Tensor, Tensor]: """Instantiation of Quantizer.quantize() method""" - assert b >= 1 + assert b != 0, "Please use TernaryUnifQuantizer instead" + if self.center: q, mean = super().remove_mean(p.detach(), dim=dim) else: q = p.detach().clone() mean = torch.zeros(1, dtype=p.dtype, device=p.device) - - # set range of quantization: min( b * |q|.mean(), |q|.max()) - q_abs = q.abs() - if dim is not None: - q_max = torch.minimum( - b * q_abs.mean(dim=dim, keepdim=True), # pyre-ignore[6,9] - torch.max(q_abs, dim=dim, keepdim=True)[0], # pyre-ignore[6] - ) - else: - q_max = torch.minimum(b * q_abs.mean(), torch.max(q_abs)) # pyre-ignore[6] + q_max = get_q_max(q, b, dim=dim, scale_method=self.scale_method) + q_max.clamp_(min=torch.finfo(q.dtype).tiny) # clamp to quantization range q.copy_(torch.minimum(torch.maximum(q, -q_max), q_max)) - # compute scale from [-2^{b-1}+0.5, 2^{b-1}-0.5] to [-q_max, q_max] - s = q_max / (2 ** (b - 1) - 0.5) + # scale from [-2^{b-1}+int_shift, 2^{b-1}-int_shift] to [-q_max, q_max] + range_absmax = 2 ** (b - 1) - self.int_shift + s = q_max / range_absmax - # scale by 1/s -> shift -0.5 -> round -> shift +0.5 -> scale by s - # where shift ensures rounding to integers 2^{b-1}, ..., 2^{b-1}-1 - q.div_(s).sub_(0.5).round_().add_(0.5).mul_(s) + # scale by 1/s -> shift -zero_point -> round -> shift +zero_point -> + # scale by s, where shift ensures rounding to integers + q.div_(s).sub_(self.zero_point).round_().add_(self.zero_point).mul_(s) # set of all target quantization values - Q = s * ( - torch.arange(-(2 ** (b - 1)) + 0.5, 2 ** (b - 1), step=1, device=q.device) + Q = torch.arange( + -range_absmax, range_absmax + 1e-5, dtype=p.dtype, device=p.device ) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) # broadcasted multiply requires copy + else: + Q.mul_(s) + + # return quantized tensor and set of possible quantization values + if self.center: + q += mean + Q += mean + return q, Q + + +class MaxUnifQuantizer(UnifQuantizer): + def __init__( + self, + center: bool = False, + scale_method: str = "max", + int_shift: float = 1.0, + zero_point: float = 0.0, + ): + """Set quantization function with int_shift=1.0. + + The final quantization range includes 2^b - 1 quantized values. E.g., + [-1, 0, 1] for b=2. The quantization scale is determined by |p|.max() + by default and zero point is 0.0. + """ + super().__init__( + center=center, + scale_method=scale_method, + int_shift=int_shift, + zero_point=zero_point, + ) + + +class AsymUnifQuantizer(Quantizer): + def get_quant_size(self, b: int) -> int: + """Equivalent to int_max - int_min + 1, where int_min = -2^{b-1} and + int_max = 2^{b-1} - 1.""" + return 2**b + + def quantize( + self, p: Tensor, b: int, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + assert b != 0, "Please use TernaryUnifQuantizer instead" + + if self.center: + q, mean = super().remove_mean(p.detach(), dim=dim) + else: + q = p.detach().clone() + mean = torch.zeros(1, dtype=p.dtype, device=p.device) + + if dim is not None: + q_min = q.min(dim=dim, keepdim=True).values + q_max = q.max(dim=dim, keepdim=True).values + else: + q_min = q.min() + q_max = q.max() + + int_min = -(2 ** (b - 1)) + int_max = 2 ** (b - 1) - 1 + s = (q_max - q_min) / (int_max - int_min) + s.clamp_(min=torch.finfo(q.dtype).tiny) + + zero_point = q_min.div_(s).round_() + q.div_(s).round_().sub_(zero_point).add_(zero_point).mul_(s) + + Q = torch.arange(int_min, int_max + 1, dtype=p.dtype, device=p.device) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) # broadcasted multiply requires copy + else: + Q.mul_(s) # return quantized tensor and set of possible quantization values if self.center: q += mean Q += mean return q, Q + + +class TernaryUnifQuantizer(Quantizer): + """Uniform quantizer for ternary bit case. Quantization range is [-1, 1].""" + + def get_quant_size(self, b: int) -> int: + return 3 + + def quantize( + self, p: Tensor, b: int, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + assert b == 0, f"Unexpected {b=} for ternary case" + + if self.center: + q, mean = super().remove_mean(p.detach(), dim=dim) + else: + q = p.detach().clone() + mean = torch.zeros(1, dtype=p.dtype, device=p.device) + + q_max = get_q_max(q, b, dim=dim, scale_method="max") + q_max.clamp_(min=torch.finfo(q.dtype).tiny) + s = q_max / 1.5 + q.div_(s).round_().clamp_(min=-1, max=1).mul_(s) + + Q = torch.tensor([-1, 0, 1], dtype=p.dtype, device=p.device) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) + else: + Q.mul_(s) + + if self.center: + q += mean + Q += mean + return q, Q diff --git a/torchao/prototype/parq/utils.py b/torchao/prototype/parq/utils.py index d18257574f..ac5024fb5d 100644 --- a/torchao/prototype/parq/utils.py +++ b/torchao/prototype/parq/utils.py @@ -3,9 +3,21 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import torch from torch import Tensor +try: + from torch.distributed.tensor import DTensor + + HAS_DTENSOR = True +except ImportError: + HAS_DTENSOR = False + + +def is_dtensor(x): + return HAS_DTENSOR and isinstance(x, DTensor) + def channel_bucketize(input: Tensor, boundaries: Tensor, right: bool = False) -> Tensor: """Generalizes torch.bucketize to run on 2-D boundaries.""" @@ -18,4 +30,4 @@ def channel_bucketize(input: Tensor, boundaries: Tensor, right: bool = False) -> boundaries = boundaries.unsqueeze(1) input = input.unsqueeze(-1) mask = input.ge(boundaries) if right else input.le(boundaries) - return mask.int().argmax(dim=-1) + return mask.to(torch.uint8).argmax(dim=-1)