Skip to content

Commit ef17ed6

Browse files
committed
chore: cleanup
1 parent cf5e286 commit ef17ed6

File tree

5 files changed

+30
-118
lines changed

5 files changed

+30
-118
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
logger = logging.getLogger(__name__)
4141

4242
from torchao.float8.inference import Float8MMConfig
43+
aten = torch.ops.aten
4344

4445

4546
###############################
@@ -682,11 +683,6 @@ class MarlinSparseAQTLayout(AQTLayout):
682683
group_size (int): the group size used to pack the tensor
683684
num_bits (int): the number of bits used to quantize the tensor
684685
"""
685-
686-
implements = classmethod(_implements)
687-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
688-
__torch_function__ = classmethod(_dispatch__torch_function__)
689-
690686
@staticmethod
691687
def __new__(
692688
cls,
@@ -729,6 +725,19 @@ def __init__(
729725
self.group_size = group_size
730726
self.num_bits = num_bits
731727

728+
@classmethod
729+
def __torch_dispatch__(cls, func, types, args, kwargs):
730+
kwargs = {} if kwargs is None else kwargs
731+
732+
if func is aten.detach.default:
733+
return return_and_correct_aliasing(
734+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
735+
)
736+
737+
raise NotImplementedError(
738+
f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported"
739+
)
740+
732741
def __tensor_flatten__(self):
733742
return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits]
734743

@@ -826,12 +835,6 @@ def _apply_fn_to_data(self, fn):
826835
return self
827836

828837

829-
# Marlin Sparse op dispatch registration
830-
@MarlinSparseAQTLayout.implements(aten.detach.default)
831-
def _(func, types, args, kwargs):
832-
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
833-
834-
835838
@register_layout_cls(Float8LayoutType)
836839
class Float8AQTLayout(AQTLayout):
837840
"""

torchao/quantization/quant_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,9 @@ def apply_int4_weight_only_quant(weight):
454454
zero_point_dtype = torch.bfloat16
455455
zero_point_domain = ZeroPointDomain.FLOAT
456456

457-
# Sparse Marlin only supports symmetric quantization
457+
# Sparse Marlin only supports symmetric quantization.
458+
# NOTE: If we start having lots of layouts that require different configurations,
459+
# we should consider moving this logic somewhere else.
458460
if isinstance(layout_type, MarlinSparseLayoutType):
459461
mapping_type = MappingType.SYMMETRIC
460462
preserve_zero = True

torchao/sparsity/marlin/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import numpy as np
32
from typing import Tuple, Dict, List
43

54
import torchao.sparsity.marlin.utils as utils

torchao/sparsity/marlin/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import numpy as np
32
from typing import List, Tuple
43
from dataclasses import dataclass, field
54

@@ -97,17 +96,13 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
9796
"""Precompute permutations for Marlin24 weight and scale shuffling
9897
9998
Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible
100-
with the tensor-core format that is described here:
101-
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
102-
103-
As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core
104-
(without the need to use ldmatrix instructions)
99+
with the tensor-core format.
105100
106101
Args:
107102
num_bits (int): Number of bits to pack.
108103
Returns:
109-
Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list and
110-
scale permutation list for single group.
104+
Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and
105+
scale permutation list for a single group.
111106
"""
112107
perm_list: List[int] = []
113108
for i in range(32):
@@ -125,23 +120,28 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
125120
4 * block)
126121
for j in range(4):
127122
perm_list.extend([p + 1 * j for p in perm1])
128-
perm = np.array(perm_list)
123+
124+
# Convert to torch tensor
125+
perm = torch.tensor(perm_list, dtype=torch.int32)
129126

130127
if num_bits == 4:
131-
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
128+
interleave = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], dtype=torch.int32)
132129
elif num_bits == 8:
133-
interleave = np.array([0, 2, 1, 3])
130+
interleave = torch.tensor([0, 2, 1, 3], dtype=torch.int32)
134131
else:
135132
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
136133

137-
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
138-
perm = torch.from_numpy(perm)
134+
# Reshape and apply interleave
135+
perm = perm.view(-1, len(interleave))[:, interleave].reshape(-1)
136+
139137
scale_perm: List[int] = []
140138
for i in range(8):
141139
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
140+
142141
scale_perm_single: List[int] = []
143142
for i in range(8):
144143
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
144+
145145
return perm, scale_perm, scale_perm_single
146146

147147

wip_test_llama2.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)