Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:

jobs:
python-tests:
runs-on: ubuntu-22.04
runs-on: gcp-k8s-vllm-l4-solo
env:
HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def compress(
desc = "Compressing with quantization"
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
value = model_state[name]

# compress weights
if name.endswith("weight"):
prefix = name.removesuffix("weight")
Expand Down Expand Up @@ -129,10 +128,18 @@ def compress(
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
continue

if name.endswith("weight_scale") and self._skip_scale():
continue

compressed_dict[name] = value.to(compression_device)

return compressed_dict

def _skip_scale(self):
from compressed_tensors.compressors import NVFP4PackedCompressor

return isinstance(self, NVFP4PackedCompressor)

def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch import Tensor


__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8", "NVFP4PackedCompressor"]

FLOAT_TO_E2M1 = [
0.0,
Expand Down Expand Up @@ -103,6 +103,7 @@ def compress_weight(
if device is not None:
weight_packed = weight_packed.to(device)
compressed_dict["weight_packed"] = weight_packed
compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
return compressed_dict

def decompress_weight(
Expand All @@ -111,8 +112,8 @@ def decompress_weight(
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:
weight = compressed_data["weight_packed"]
scale = compressed_data["weight_scale"]
global_scale = compressed_data["weight_global_scale"]
scale = compressed_data["weight_scale"]
m, n = weight.shape
# TODO: use a user provided dequant dtype
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
Expand Down
15 changes: 6 additions & 9 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DynamicType,
QuantizationArgs,
QuantizationStrategy,
round_to_quantized_type,
round_to_quantized_type_args,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
Expand Down Expand Up @@ -466,20 +466,17 @@ def _quantize(
# if a global scale is optionally provided, use it
# to further scale the local `scale` parameter
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale
scale = scale / global_scale

scaled = x / scale

if zero_point is not None:
scaled += zero_point.to(x.dtype)

# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
clamped_value = torch.clamp(
scaled,
q_min,
q_max,
# clamp and round
quantized_value = round_to_quantized_type_args(
tensor=scaled, args=args, min=q_min, max=q_max
)
quantized_value = round_to_quantized_type(clamped_value, args)

if dtype is not None:
quantized_value = quantized_value.to(dtype)
Expand All @@ -499,7 +496,7 @@ def _dequantize(
# if a global scale is optionally provided, use it
# to further scale the local `scale` parameter
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale
scale = scale / global_scale

dequant_value = x_q.to(scale.dtype)

Expand Down
18 changes: 7 additions & 11 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
QuantizedKVCache,
)
from compressed_tensors.quantization import (
FP8_E4M3_DATA,
ActivationOrdering,
DynamicType,
QuantizationArgs,
Expand All @@ -36,7 +35,7 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
from compressed_tensors.quantization.utils import strategy_cdiv
from compressed_tensors.utils import (
disable_hf_hook,
get_execution_device,
Expand Down Expand Up @@ -250,20 +249,15 @@ def initialize_qparams(

# 2. Identify quantization scale and zp dtype
scale_dtype = observed_dtype

if is_fp4(quantization_args=quantization_args):
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
else:
# TODO: consider erroring out in the future as if the dtype if not one of these,
# there is likely bug
if quantization_args.scale_dtype is None:
if scale_dtype not in [
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]:
scale_dtype = torch.bfloat16
zp_dtype = quantization_args.pytorch_dtype()
scale_dtype = torch.float16
quantization_args.scale_dtype = scale_dtype

# 3. Initializes scale/zp for the module
init_scale = Parameter(
Expand All @@ -274,7 +268,9 @@ def initialize_qparams(

if force_zero_point or not quantization_args.symmetric:
init_zero_point = Parameter(
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
torch.zeros(
expected_shape, device=device, dtype=quantization_args.zp_dtype
),
requires_grad=False,
)
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
Expand Down
53 changes: 47 additions & 6 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from compressed_tensors.utils import Aliasable
from compressed_tensors.utils.helpers import deprecated
from compressed_tensors.utils.type import TorchDtype
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


Expand All @@ -30,7 +31,8 @@
"QuantizationType",
"QuantizationStrategy",
"QuantizationArgs",
"round_to_quantized_type",
"round_to_quantized_type_args",
"round_to_quantized_type_dtype",
"ActivationOrdering",
"DynamicType",
]
Expand Down Expand Up @@ -174,6 +176,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
block_structure: Optional[List[int]] = None
dynamic: Union[DynamicType, bool] = False
actorder: Union[ActivationOrdering, bool, None] = None
scale_dtype: Optional[TorchDtype] = None
zp_dtype: Optional[TorchDtype] = None
observer: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -266,6 +270,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
dynamic = model.dynamic
observer = model.observer
dynamic = model.dynamic
zp_dtype = model.zp_dtype

# infer strategy
if strategy is None:
Expand Down Expand Up @@ -353,9 +358,16 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
# default to minmax for non-dynamic cases
observer = "minmax"

if zp_dtype is None:
if model.num_bits == 4 and model.type == QuantizationType.FLOAT:
zp_dtype = FP8_E4M3_DATA.dtype
else:
zp_dtype = model.pytorch_dtype()

# write back modified values
model.strategy = strategy
model.observer = observer
model.zp_dtype = zp_dtype
return model

def pytorch_dtype(self) -> torch.dtype:
Expand All @@ -381,18 +393,47 @@ def get_observer(self) -> str:
model_config = ConfigDict(extra="forbid")


def round_to_quantized_type(
tensor: torch.Tensor, args: QuantizationArgs
def round_to_quantized_type_dtype(
tensor: torch.Tensor, dtype: torch.dtype
) -> torch.Tensor:
"""
Rounds each element of the input tensor to the nearest quantized representation,
keeping to original dtype
Rounds an input tensor to the nearest quantized representation given a dtype.
The original dtype is kept post-rounding.

:param tensor: tensor to round
:param args: QuantizationArgs to pull appropriate dtype from
:param dtype: dtype to use for rounding
:return: rounded tensor
"""
original_dtype = tensor.dtype
if torch.is_floating_point(torch.tensor([], dtype=dtype)):
finfo = torch.finfo(dtype)
rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
else:
iinfo = torch.iinfo(dtype)
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))

return rounded.to(original_dtype)


def round_to_quantized_type_args(
tensor: torch.Tensor,
args: QuantizationArgs,
min: torch.Tensor,
max: torch.Tensor,
) -> torch.Tensor:
"""
Rounds an input tensor to the nearest quantized representation given
qunatization args. The original dtype is kept post-rounding.

:param tensor: tensor to round
:param args: quantization args to use for rounding
:param min: min value to use for clamping
:param max: max value to use for clamping
:return: rounded tensor
"""

original_dtype = tensor.dtype
tensor = torch.clamp(tensor, min, max)
if args.type == QuantizationType.FLOAT:
if args.num_bits == 8:
rounded = tensor.to(FP8_E4M3_DATA.dtype)
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Set, Union
Expand Down
7 changes: 7 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import (
FP8_E4M3_DATA,
DynamicType,
QuantizationArgs,
QuantizationStrategy,
Expand Down Expand Up @@ -160,6 +161,8 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=16,
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
)
)

Expand All @@ -173,6 +176,8 @@ def is_preset_scheme(name: str) -> bool:
dynamic=False,
group_size=16,
observer="static_minmax",
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
),
input_activations=QuantizationArgs(
num_bits=4,
Expand All @@ -182,6 +187,8 @@ def is_preset_scheme(name: str) -> bool:
dynamic=DynamicType.LOCAL,
group_size=16,
observer="static_minmax",
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
),
)

Expand Down
Loading