Skip to content

Adding gpu quantization workflows and apis #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions ao/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from smoothquant import * # noqa: F403
from quant_api import * # noqa: F403
from subclass import * # noqa: F403
from quant_primitives import * # noqa: F403
from utils import * # noqa: F403
from weight_only import * # noqa: F403

__all__ = [
"DynamicallyPerAxisQuantizedLinear",
"replace_with_custom_fn_if_matches_filter",
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_dqtensors",
"insert_subclass",
"safe_int_mm",
"dynamically_quantize_per_tensor",
"quantize_activation_per_token_absmax",
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
"quant_int8_per_token_matmul",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"DynamicallyQuantizedLinearWeight",
"log_with_rank",
"clear_logs",
"compute_error",
"forward_hook",
"apply_logging_hook",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
]
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added ao/quantization/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file not shown.
96 changes: 96 additions & 0 deletions ao/quantization/dynamic_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn
from quant_primitives import (
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
)

__all__ = ["DynamicallyPerAxisQuantizedLinear"]


class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
"""
This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on
the input across all axes except for the last axis.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
use_fused_int_mm=False,
) -> None:
super().__init__(in_features, out_features, bias)
self.use_fused_int_mm = use_fused_int_mm
# note: enabling use_fused_int_mm = True has best perf when additionally setting
# torch._inductor.config.force_fuse_int_mm_with_mul = True

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Performs the forward pass of the quantized linear layer.

This method applies dynamic quantization to the input tensor across all axes except
the last axis using the `quant_int8_dynamic_per_token_linear` function.

Args:
X (torch.Tensor): The input tensor to the quantized linear layer.

Returns:
torch.Tensor: The output tensor after the quantized matmul and rescale.

"""
# The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear
if not self.use_fused_int_mm:
X = X / self.fake_rescale
# somehow the inductor fusion that occurs for most transformer models
# when this module has an additional div op is faster than when it doesn't
# have it although the memory usage is slightly higher. fake_rescale is scalar 1
# so it doesn't affect accuracy
Y = quant_int8_dynamic_per_token_linear(
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
)
return Y

@classmethod
def from_float(
cls, mod: torch.nn.Linear, use_fused_int_mm=False
) -> "DynamicallyPerAxisQuantizedLinear":
"""
Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it.

Note: this class does not require calibration.

Args:
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.

Returns:
DynamicallyPerAxisQuantizedLinear: The converted quantized linear module.

"""

# create the new module with a toy size to ensure initialization is fast
fake_in_features, fake_out_features = 8, 8
new_mod = cls(
fake_in_features,
fake_out_features,
bias=mod.bias is not None,
use_fused_int_mm=use_fused_int_mm,
)
new_mod.in_features = mod.in_features
new_mod.out_features = mod.out_features
W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel(
mod.weight, -128, 127, torch.int8
)
new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t())
new_mod.W_scales = nn.Parameter(W_scales)
new_mod.bias = mod.bias
if not use_fused_int_mm:
new_mod.fake_rescale = torch.tensor(
[1.0], dtype=mod.weight.dtype, device=mod.weight.device
)
del new_mod.weight

device_to_use = next(mod.parameters()).device
new_mod.to(device_to_use)
return new_mod
74 changes: 74 additions & 0 deletions ao/quantization/quant_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Quantization API stuff which is not specific to SmoothQuant

Note: this is throwaway code for fast results on Blueberry, this is not
intended to be the actual long term quantization API for server GPUs.
"""

import torch
from dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from subclass import (
DynamicallyQuantizedLinearWeight,
)
from weight_only import (
WeightOnlyInt8QuantLinear,
)

__all__ = [
"replace_with_custom_fn_if_matches_filter",
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_dqtensors",
]


def replace_with_custom_fn_if_matches_filter(
model, replacement_fn, filter_fn, cur_fqn=""
) -> None:
"""
For each `child` in `model`, replaces it with `replacement_fn(child)`
if `filter_fn(child)` is `True`
"""
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
if cur_fqn == "":
new_fqn = name
else:
new_fqn = f"{cur_fqn}.{name}"
if filter_fn(child, new_fqn):
new_child = replacement_fn(child)
setattr(model, name, new_child)
else:
replace_with_custom_fn_if_matches_filter(
child, replacement_fn, filter_fn, new_fqn
)


def apply_weight_only_int8_quant(model):
replace_with_custom_fn_if_matches_filter(
model,
WeightOnlyInt8QuantLinear.from_float,
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)


def apply_dynamic_quant(model, use_fused_int_mm=0):
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)


def change_linear_weights_to_dqtensors(model):
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
)
return lin

replace_with_custom_fn_if_matches_filter(
model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear)
)
Loading