Skip to content

Commit aef7e09

Browse files
authored
Refactor layout implementation (#491)
Summary: TODO Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent 6e7cf71 commit aef7e09

File tree

4 files changed

+137
-61
lines changed

4 files changed

+137
-61
lines changed

torchao/dtypes/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4-
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
4+
from .affine_quantized_tensor import (
5+
AffineQuantizedTensor,
6+
to_affine_quantized,
7+
LayoutType,
8+
PlainLayoutType,
9+
TensorCoreTiledLayoutType,
10+
)
511

612
__all__ = [
713
"NF4Tensor",
814
"to_nf4",
915
"UInt4Tensor"
1016
"AffineQuantizedTensor",
1117
"to_affine_quantized",
18+
"LayoutType",
19+
"PlainLayoutType",
20+
"TensorCoreTiledLayoutType",
1221
]

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 88 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,35 @@
2020
_ATEN_OP_OR_TORCH_FN_TABLE,
2121
_register_layout_cls,
2222
_get_layout_tensor_constructor,
23+
LayoutType,
2324
)
25+
from typing import ClassVar
26+
from dataclasses import dataclass
2427

2528
aten = torch.ops.aten
2629

30+
@dataclass(frozen=True)
31+
class PlainLayoutType(LayoutType):
32+
pass
33+
34+
@dataclass(frozen=True)
35+
class TensorCoreTiledLayoutType(LayoutType):
36+
inner_k_tiles: int = 8
37+
38+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
39+
orig_out_features, orig_in_features = input.shape
40+
in_features = find_multiple(orig_in_features, 1024)
41+
out_features = find_multiple(orig_out_features, 8)
42+
input = torch.nn.functional.pad(
43+
input,
44+
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
45+
)
46+
return input
47+
48+
def extra_repr(self):
49+
return f"inner_k_tiles={self.inner_k_tiles}"
50+
51+
2752
def _aqt_is_int8(aqt):
2853
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
2954
return (
@@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor):
5277
"""
5378
Base class for the layout tensor for `AffineQuantizedTensor`
5479
"""
55-
# this should be set for each layout class during registration
56-
extended_layout: Optional[str] = None
80+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
81+
pass
5782

58-
def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83+
def get_layout_type(self) -> LayoutType:
5984
pass
6085

6186
@classmethod
@@ -64,9 +89,15 @@ def from_plain(
6489
int_data: torch.Tensor,
6590
scale: torch.Tensor,
6691
zero_point: torch.Tensor,
92+
layout_type: LayoutType,
6793
):
6894
pass
6995

96+
def __repr__(self):
97+
int_data, scale, zero_point = self.get_plain()
98+
layout_type = self.get_layout_type()
99+
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
100+
70101
def _get_to_kwargs(self, *args, **kwargs):
71102
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
72103
device = self.device if device is None else device
@@ -194,30 +225,17 @@ def from_float(
194225
zero_point_dtype: Optional[torch.dtype] = None,
195226
preserve_zero: bool = True,
196227
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
197-
extended_layout: str = "plain",
198-
# TODO: this is only for "tensor_core_tiled", need to figure out
199-
# the proper API for this arg
200-
inner_k_tiles: Optional[int] = None,
228+
layout_type: LayoutType = PlainLayoutType(),
201229
):
202230
original_shape = input_float.shape
203-
if extended_layout == "tensor_core_tiled":
204-
orig_out_features, orig_in_features = input_float.shape
205-
in_features = find_multiple(orig_in_features, 1024)
206-
out_features = find_multiple(orig_out_features, 8)
207-
input_float = torch.nn.functional.pad(
208-
input_float,
209-
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
210-
)
231+
input_float = layout_type.pre_process(input_float)
211232

212233
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
213234
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
235+
int_data = layout_type.post_process(int_data)
214236

215-
layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
216-
# TODO: this is temporary, need to come up with the proper UX
217-
if extended_layout == "tensor_core_tiled":
218-
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
219-
else:
220-
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
237+
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
238+
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
221239
return cls(
222240
layout_tensor,
223241
block_size,
@@ -229,8 +247,8 @@ def from_float(
229247
)
230248

231249
@property
232-
def extended_layout(self) -> str:
233-
return self.layout_tensor.extended_layout
250+
def layout_type(self) -> str:
251+
return self.layout_tensor.layout_type
234252

235253
@classmethod
236254
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308326
def implements(aten_ops_or_torch_fn):
309327
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
310328

311-
def register_layout_cls(extended_layout: str):
312-
return _register_layout_cls(AffineQuantizedTensor, extended_layout)
329+
def register_layout_cls(layout_type_class: type(LayoutType)):
330+
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
313331

314-
def get_layout_tensor_constructor(extended_layout: str):
315-
return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout)
332+
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
333+
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)
316334

317-
@register_layout_cls("plain")
335+
@register_layout_cls(PlainLayoutType)
318336
class PlainAQTLayout(AQTLayout):
319337
"""
320338
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
@@ -330,6 +348,7 @@ def __new__(
330348
int_data: torch.Tensor,
331349
scale: torch.Tensor,
332350
zero_point: torch.Tensor,
351+
layout_type: LayoutType,
333352
):
334353
kwargs = {}
335354
kwargs["device"] = int_data.device
@@ -346,34 +365,39 @@ def __init__(
346365
int_data: torch.Tensor,
347366
scale: torch.Tensor,
348367
zero_point: torch.Tensor,
368+
layout_type: LayoutType,
349369
):
350370
self.int_data = int_data
351371
self.scale = scale
352372
self.zero_point = zero_point
373+
self.layout_type = layout_type
353374

354375
def __tensor_flatten__(self):
355-
return ["int_data", "scale", "zero_point"], []
376+
return ["int_data", "scale", "zero_point"], [self.layout_type]
356377

357378
@classmethod
358379
def __tensor_unflatten__(
359380
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
360381
):
361382
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
362-
return cls(int_data, scale, zero_point)
383+
layout_type, = tensor_attributes
384+
return cls(int_data, scale, zero_point, layout_type)
363385

364386
def to(self, *args, **kwargs):
365387
kwargs = self._get_to_kwargs(*args, **kwargs)
366388
return self.__class__(
367389
self.int_data.to(kwargs["device"]),
368390
self.scale.to(kwargs["device"]),
369391
self.zero_point.to(kwargs["device"]),
392+
self.layout_type,
370393
)
371394

372395
def _apply_fn_to_data(self, fn):
373396
return self.__class__(
374397
fn(self.int_data),
375398
fn(self.scale),
376399
fn(self.zero_point),
400+
self.layout_type,
377401
)
378402

379403
@classmethod
@@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
398422

399423
__torch_function__ = torch._C._disabled_torch_function_impl
400424

401-
def get_plain(self):
425+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
402426
return self.int_data, self.scale, self.zero_point
403427

428+
def get_layout_type(self) -> LayoutType:
429+
return self.layout_type
430+
404431
@classmethod
405432
def from_plain(
406433
cls,
407434
int_data: torch.Tensor,
408435
scale: torch.Tensor,
409436
zero_point: torch.Tensor,
437+
layout_type: LayoutType,
410438
):
411-
return cls(int_data, scale, zero_point)
439+
assert isinstance(layout_type, PlainLayoutType)
440+
return cls(int_data, scale, zero_point, layout_type)
412441

413-
@register_layout_cls("tensor_core_tiled")
442+
@register_layout_cls(TensorCoreTiledLayoutType)
414443
class TensorCoreTiledAQTLayout(AQTLayout):
415444
"""
416445
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
@@ -427,6 +456,7 @@ def __new__(
427456
packed_weight: torch.Tensor,
428457
scale_and_zero: torch.Tensor,
429458
transposed: bool,
459+
layout_type: LayoutType,
430460
):
431461
kwargs = {}
432462
kwargs["device"] = packed_weight.device
@@ -443,31 +473,40 @@ def __init__(
443473
packed_weight: torch.Tensor,
444474
scale_and_zero: torch.Tensor,
445475
transposed: bool,
476+
layout_type: LayoutType,
446477
):
447478
self.packed_weight = packed_weight
448479
self.scale_and_zero = scale_and_zero
449480
self.transposed = False
481+
self.layout_type = layout_type
450482

451483
def __tensor_flatten__(self):
452-
return ["packed_weight", "scale_and_zero"], [self.transposed]
484+
return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type]
453485

454486
@classmethod
455487
def __tensor_unflatten__(
456488
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
457489
):
458490
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
459-
transposed, = tensor_attributes
460-
return cls(packed_weight, scale_and_zero, transposed)
491+
transposed, layout_type, = tensor_attributes
492+
return cls(packed_weight, scale_and_zero, transposed, layout_type)
461493

462494
@classmethod
463-
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
495+
def from_plain(
496+
cls,
497+
int_data: torch.Tensor,
498+
scale: torch.Tensor,
499+
zero_point: torch.Tensor,
500+
layout_type: LayoutType
501+
):
502+
assert isinstance(layout_type, TensorCoreTiledLayoutType)
464503
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
465504
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
466-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
505+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
467506
scale = scale.reshape(int_data.shape[0], -1)
468507
zero_point = zero_point.reshape(int_data.shape[0], -1)
469508
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
470-
return cls(packed_weight, scale_and_zero, False)
509+
return cls(packed_weight, scale_and_zero, False, layout_type)
471510

472511
def to(self, *args, **kwargs):
473512
kwargs = self._get_to_kwargs(*args, **kwargs)
@@ -477,18 +516,15 @@ def to(self, *args, **kwargs):
477516
return self.__class__(
478517
self.packed_weight.to(device),
479518
self.scale_and_zero.to(device),
480-
self.transposed
519+
self.transposed,
520+
self.layout_type,
481521
)
482522

483523
def _apply_fn_to_data(self, fn):
484524
self.packed_weight = fn(self.packed_weight)
485525
self.scale_and_zero = fn(self.scale_and_zero)
486526
return self
487527

488-
def __repr__(self):
489-
int_data, scale, zero_point = self.get_plain()
490-
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})"
491-
492528
@classmethod
493529
def __torch_dispatch__(cls, func, types, args, kwargs):
494530
kwargs = {} if kwargs is None else kwargs
@@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
511547

512548
__torch_function__ = torch._C._disabled_torch_function_impl
513549

514-
def get_plain(self):
550+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
515551
from torchao.quantization.quant_primitives import (
516552
ZeroPointDomain,
517553
quantize_affine,
@@ -542,6 +578,9 @@ def get_plain(self):
542578
int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain)
543579
return int_data, scale, zero
544580

581+
def get_layout_type(self) -> LayoutType:
582+
return self.layout_type
583+
545584
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
546585
"""
547586
Quantized version of F.linear operator
@@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
565604
is_cuda and
566605
input_is_int8 and
567606
input_tensor.dtype == weight_qtensor.dtype and
568-
input_tensor.extended_layout == "plain" and
569-
weight_qtensor.extended_layout == "plain"
607+
isinstance(input_tensor.layout_type, PlainLayoutType) and
608+
isinstance(weight_qtensor.layout_type, PlainLayoutType)
570609
):
571610
#
572611
# 1. do the matrix form of dot(X_i, W_j)
@@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
608647
weight_qtensor.dtype == torch.bfloat16 and
609648
len(weight_qtensor.shape) == 2 and
610649
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
611-
weight_qtensor.extended_layout == "tensor_core_tiled"
650+
isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType)
612651
):
613652
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
614653
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
@@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
651690
weight_qtensor.block_size[0] == 1 and
652691
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
653692
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
654-
weight_qtensor.extended_layout == "plain"
693+
isinstance(weight_qtensor.layout_type, PlainLayoutType)
655694
):
656695
# TODO: enable cpu and mps efficient path
657696
# per channel int8 weight only quantizated mm

0 commit comments

Comments
 (0)