Skip to content

Make WeightOnlyInt8QuantLinear more compatible with torch.compile #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions torchao/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ class WeightOnlyInt8QuantLinear(torch.nn.Linear):
This class is a replacement for `torch.nn.Linear`. It implements a
mixed dtype matmul using int8 symmetric per-channel weight quantization
"""
def __init__(self, *args, **kwargs):
w_int8 = kwargs.pop("w_int8")
scales = kwargs.pop("scales")
super().__init__(*args, **kwargs)
self.w_int8 = w_int8
self.scales = scales
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
) -> None:
super().__init__(in_features, out_features, bias)
# self.w_int8 = w_int8
# self.scales = scales

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs the forward pass of the quantized linear layer which consists
ofmixed dtype matmul using int8 symmetric per-channel weight quantization
Expand All @@ -39,15 +42,15 @@ def forward(self, x):
# if len(x.shape)<=2:
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales
# else: # turn x into 2d tensor, then undo it for y
x_view = x.view(-1, x.shape[-1])
x_view = x.reshape(-1, x.shape[-1])
y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales
y = y.reshape(*x.shape[:-1], -1)
if self.bias is not None:
y += self.bias
return y

@classmethod
def from_float(cls, mod):
def from_float(cls, mod: torch.nn.Linear) -> "WeightOnlyInt8QuantLinear":
"""
Converts a `mod` of class `torch.nn.Linear` to the
`WeightOnlyInt8QuantLinear` class
Expand All @@ -69,9 +72,9 @@ def from_float(cls, mod):
fake_in_features,
fake_out_features,
bias=mod.bias is not None,
w_int8=w_int8.t().contiguous(),
scales=scales,
)
new_mod.register_buffer("w_int8", w_int8.t())
new_mod.scales = torch.nn.Parameter(scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably want requires_grad = False

new_mod.in_features = mod.in_features
new_mod.out_features = mod.out_features
del new_mod.weight
Expand Down