4
4
5
5
from torch import nn
6
6
from torch .testing ._internal .common_utils import TestCase , run_tests
7
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , unwrap_tensor_subclass
7
8
from torchao .dtypes import MarlinSparseLayoutType
8
9
from torchao .sparsity .sparse_api import apply_fake_sparsity
9
10
from torchao .quantization .quant_api import int4_weight_only , quantize_
12
13
unpack_from_marlin_24 ,
13
14
inject_24
14
15
)
15
- from torchao .quantization .utils import (
16
- get_group_qparams_symmetric ,
17
- groupwise_affine_quantize_tensor_from_qparams ,
16
+ from torchao .quantization .quant_primitives import (
17
+ choose_qparams_affine ,
18
+ quantize_affine ,
19
+ ZeroPointDomain ,
20
+ MappingType ,
18
21
)
19
22
20
23
21
24
class SparseMarlin24 (TestCase ):
22
25
23
- @ pytest . mark . skipif ( not torch . cuda . is_available (), reason = "Need CUDA available" )
24
- def test_quant_sparse_marlin_layout_eager ( self ):
26
+ def setUp ( self ):
27
+ super (). setUp ()
25
28
torch .manual_seed (0 )
26
29
27
- input = torch .randn ((32 , 16 , 4096 ), dtype = torch .float16 , device = "cuda" )
28
- model = (
30
+ self . input = torch .randn ((32 , 16 , 4096 ), dtype = torch .float16 , device = "cuda" )
31
+ self . model = (
29
32
nn .Sequential (
30
33
nn .Linear (4096 , 21504 ),
31
34
nn .Linear (21504 , 4096 ),
@@ -37,48 +40,38 @@ def test_quant_sparse_marlin_layout_eager(self):
37
40
.cuda ()
38
41
)
39
42
40
- apply_fake_sparsity (model )
41
- model_copy = copy .deepcopy (model )
43
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
44
+ def test_quant_sparse_marlin_layout_eager (self ):
45
+ apply_fake_sparsity (self .model )
46
+ model_copy = copy .deepcopy (self .model )
42
47
43
48
# Quantized
44
49
quantize_ (model_copy .bfloat16 (), int4_weight_only ())
45
- dense_result = model_copy (input .bfloat16 ()).half ()
50
+ dense_result = model_copy (self . input .bfloat16 ()).half ()
46
51
47
52
# Sparse + quantized
48
- quantize_ (model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
49
- sparse_result = model (input )
53
+ quantize_ (self . model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
54
+ sparse_result = self . model (self . input )
50
55
51
56
assert torch .allclose (dense_result , sparse_result , atol = 3e-1 ), "Results are not close"
52
57
53
58
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
54
59
def test_quant_sparse_marlin_layout_compile (self ):
55
- torch .manual_seed (0 )
56
-
57
- input = torch .randn ((32 , 16 , 4096 ), dtype = torch .float16 , device = "cuda" )
58
- model = (
59
- nn .Sequential (
60
- nn .Linear (4096 , 21504 ),
61
- nn .Linear (21504 , 4096 ),
62
- nn .ReLU (),
63
- nn .Linear (4096 , 21504 ),
64
- nn .Linear (21504 , 4096 ),
65
- )
66
- .half ()
67
- .cuda ()
68
- )
69
-
70
- apply_fake_sparsity (model )
71
- model_copy = copy .deepcopy (model )
60
+ apply_fake_sparsity (self .model )
61
+ model_copy = copy .deepcopy (self .model )
72
62
73
63
# Quantized
74
64
quantize_ (model_copy .bfloat16 (), int4_weight_only ())
75
65
model_copy .foward = torch .compile (model_copy .forward , fullgraph = True )
76
- dense_result = model_copy (input .bfloat16 ()).half ()
66
+ dense_result = model_copy (self . input .bfloat16 ()).half ()
77
67
78
68
# Sparse + quantized
79
- quantize_ (model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
80
- model .forward = torch .compile (model .forward , fullgraph = True )
81
- sparse_result = model (input )
69
+ quantize_ (self .model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
70
+ if not TORCH_VERSION_AT_LEAST_2_5 :
71
+ unwrap_tensor_subclass (self .model )
72
+
73
+ self .model .forward = torch .compile (self .model .forward , fullgraph = True )
74
+ sparse_result = self .model (self .input )
82
75
83
76
assert torch .allclose (dense_result , sparse_result , atol = 3e-1 ), "Results are not close"
84
77
@@ -87,34 +80,26 @@ def test_pack_unpack_equivalence(self):
87
80
num_bits = 4
88
81
group_size = 128
89
82
shape = (11008 , 4096 )
90
- max_q_val = 2 ** num_bits - 1
91
- half_q_val = (max_q_val + 1 ) // 2
83
+ block_size = (1 , group_size )
84
+ target_dtype = torch .int32
85
+ quant_min = 0
86
+ quant_max = 15
87
+ eps = 1e-6
88
+ zero_point_dtype = torch .bfloat16
89
+ mapping_type = MappingType .SYMMETRIC
90
+ preserve_zero = True
91
+ zero_point_domain = ZeroPointDomain .INT
92
+ scale_dtype = None
92
93
93
94
w = torch .rand (shape , dtype = torch .float16 , device = "cuda" )
94
- size_k , size_n = w .shape
95
95
96
96
# Inject 2:4 sparsity mask
97
97
w_24 , _ = inject_24 (w , * w .shape )
98
98
99
99
# Quantize weights
100
- w_24 = w_24 .reshape ((- 1 , group_size , size_n ))
101
- w_24 = w_24 .permute (1 , 0 , 2 )
102
- w_24 = w_24 .reshape ((group_size , - 1 ))
103
-
104
- # Compute scale for each group
105
- scales = torch .max (torch .abs (w_24 ), 0 , keepdim = True )[0 ]
106
- scales *= 2 / max_q_val # 2 => symmetric
107
-
108
- # Quantize
109
- w_q_24 = torch .round (w_24 / scales ).int ()
110
- w_q_24 += half_q_val
111
- w_q_24 = torch .clamp (w_q_24 , 0 , max_q_val )
112
-
113
- # Shape back to original shape
114
- w_q_24 = w_q_24 .reshape ((group_size , - 1 , size_n ))
115
- w_q_24 = w_q_24 .permute (1 , 0 , 2 )
116
- w_q_24 = w_q_24 .reshape ((size_k , size_n )).contiguous ()
117
- scales = scales .reshape ((- 1 , size_n )).contiguous ()
100
+ scales , zeros = choose_qparams_affine (w_24 , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , scale_dtype , zero_point_dtype , preserve_zero , zero_point_domain )
101
+ w_q_24 = quantize_affine (w_24 , block_size , scales , zeros , target_dtype , quant_min , quant_max , zero_point_domain )
102
+ scales = scales .reshape (- 1 , w_q_24 .shape [1 ])
118
103
119
104
# Test pack/unpack equivalence
120
105
q_w_comp , packed_scales , meta = pack_to_marlin_24 (
0 commit comments