33#
44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
6- import os
7- from copy import deepcopy
6+ import copy
7+ import tempfile
8+ import unittest
89
9- import pytest
1010import torch
11+ from torch .testing ._internal .common_utils import (
12+ TestCase ,
13+ run_tests ,
14+ )
1115
12- from torchao .quantization import quantize_
13- from torchao .testing . utils import skip_if_rocm
16+ from torchao .prototype . awq import AWQConfig , AWQStep
17+ from torchao .quantization import FbgemmConfig , Int4WeightOnlyConfig , quantize_
1418from torchao .utils import (
15- TORCH_VERSION_AT_LEAST_2_3 ,
16- TORCH_VERSION_AT_LEAST_2_5 ,
19+ TORCH_VERSION_AT_LEAST_2_6 ,
20+ _is_fbgemm_genai_gpu_available ,
1721)
1822
19- if TORCH_VERSION_AT_LEAST_2_3 :
20- from torchao .prototype .awq import AWQObservedLinear , awq_uintx , insert_awq_observer_
21-
2223
2324class ToyLinearModel (torch .nn .Module ):
2425 def __init__ (self , m = 512 , n = 256 , k = 128 ):
2526 super ().__init__ ()
2627 self .linear1 = torch .nn .Linear (m , n , bias = False )
2728 self .linear2 = torch .nn .Linear (n , k , bias = False )
28- self .linear3 = torch .nn .Linear (k , 1 , bias = False )
29+ self .linear3 = torch .nn .Linear (k , 64 , bias = False )
2930
3031 def example_inputs (
3132 self , batch_size , sequence_length = 10 , dtype = torch .bfloat16 , device = "cuda"
@@ -44,137 +45,197 @@ def forward(self, x):
4445 return x
4546
4647
47- devices = ["cpu" , "cuda" ]
48- # torch.uintx dtypes are introduced in 2.3
49- if TORCH_VERSION_AT_LEAST_2_3 :
50- qdtypes = (torch .uint4 , torch .uint7 )
51- else :
52- qdtypes = ()
53-
54-
55- @pytest .fixture (autouse = True )
56- def run_before_and_after_tests ():
57- yield
58- torch ._dynamo .reset () # reset cache between tests
59-
60-
61- @pytest .mark .parametrize ("device" , devices )
62- @pytest .mark .parametrize ("qdtype" , qdtypes )
63- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
64- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
65- @pytest .mark .skip ("Temporarily skipping to unpin nightiles" )
66- def test_awq_loading (device , qdtype ):
67- if qdtype == torch .uint4 and device == "cpu" :
68- pytest .skip ("uint4 not supported on cpu" )
69-
70- dataset_size = 100
71- l1 , l2 , l3 = 512 , 256 , 128
72- original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
73- quant_dtype = qdtype
74- group_size = 128
75- n_calibration_examples = 10
76- n_validation_examples = 10
77- sequence_length = 5
78-
79- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
80- dataset = m .example_inputs (
81- dataset_size ,
82- sequence_length = sequence_length ,
83- dtype = original_dtype ,
84- device = device ,
85- )
86- calibration_data = dataset [:n_calibration_examples ]
87-
88- # calibrate
89- insert_awq_observer_ (
90- m ,
91- n_validation_examples ,
92- sequence_length ,
93- quant_dtype = quant_dtype ,
94- group_size = group_size ,
95- )
96-
97- for example in calibration_data :
98- m (example .to (device ))
99-
100- # quantize
101- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
102- quantize_ (
103- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
104- )
105-
106- model_save_path = "awq_model.pth"
107- torch .save (m , model_save_path )
108- loaded_model = torch .load (model_save_path )
109- os .remove (model_save_path )
110-
111- if torch .cuda .is_available ():
48+ @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
49+ @unittest .skipIf (
50+ not _is_fbgemm_genai_gpu_available (),
51+ reason = "need to install fbgemm_gpu_genai package" ,
52+ )
53+ @unittest .skipIf (
54+ not TORCH_VERSION_AT_LEAST_2_6 ,
55+ reason = "torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig" ,
56+ )
57+ class TestAWQ (TestCase ):
58+ def test_awq_config (self ):
59+ base_config = Int4WeightOnlyConfig ()
60+ AWQConfig (base_config , step = AWQStep .PREPARE )
61+ AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
62+ AWQConfig (base_config , step = AWQStep .CONVERT )
63+
64+ AWQConfig (base_config , step = "prepare" )
65+ AWQConfig (base_config , step = "prepare_for_loading" )
66+ AWQConfig (base_config , step = "convert" )
67+
68+ with self .assertRaisesRegex (ValueError , "is not one of" ):
69+ AWQConfig (base_config , step = "not_supported" )
70+
71+ def test_awq_functionality (self ):
72+ device = "cuda"
73+ dataset_size = 100
74+ l1 , l2 , l3 = 512 , 256 , 128
75+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
76+ group_size = 128
77+ n_calibration_examples = 10
78+ sequence_length = 5
79+
80+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
81+
82+ # baseline quantization
83+ base_config = FbgemmConfig (
84+ input_dtype = torch .bfloat16 ,
85+ weight_dtype = torch .int4 ,
86+ output_dtype = torch .bfloat16 ,
87+ block_size = [1 , group_size ],
88+ preshuffle = False ,
89+ )
90+ m_baseline = copy .deepcopy (m )
91+ quantize_ (m_baseline , base_config )
92+
93+ # awq quantization
94+ dataset = m .example_inputs (
95+ dataset_size ,
96+ sequence_length = sequence_length ,
97+ dtype = original_dtype ,
98+ device = device ,
99+ )
100+ ref_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
101+
102+ calibration_data = dataset [:n_calibration_examples ]
103+
104+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
105+ quantize_ (m , quant_config )
106+
107+ for example in calibration_data :
108+ m (example )
109+
110+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
111+ quantize_ (m , quant_config )
112+
113+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
114+ baseline_out = torch .cat ([m_baseline (d .squeeze (0 )) for d in dataset ])
115+
116+ loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
117+ loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
118+ assert loss_awq < loss_base
119+
120+ def test_awq_loading (self ):
121+ device = "cuda"
122+ dataset_size = 100
123+ l1 , l2 , l3 = 512 , 256 , 128
124+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
125+ group_size = 128
126+ n_calibration_examples = 10
127+ sequence_length = 5
128+
129+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
130+ dataset = m .example_inputs (
131+ dataset_size ,
132+ sequence_length = sequence_length ,
133+ dtype = original_dtype ,
134+ device = device ,
135+ )
136+ calibration_data = dataset [:n_calibration_examples ]
137+
138+ # calibrate
139+ base_config = FbgemmConfig (
140+ input_dtype = torch .bfloat16 ,
141+ weight_dtype = torch .int4 ,
142+ output_dtype = torch .bfloat16 ,
143+ block_size = [1 , group_size ],
144+ preshuffle = False ,
145+ )
146+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
147+ quantize_ (m , quant_config )
148+
149+ for example in calibration_data :
150+ m (example )
151+
152+ # quantize
153+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
154+ quantize_ (m , quant_config )
155+
156+ with tempfile .NamedTemporaryFile () as f :
157+ torch .save (m .state_dict (), f )
158+ f .seek (0 )
159+ state_dict = torch .load (f )
160+
161+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
162+ loaded_model .load_state_dict (state_dict , assign = True )
163+
164+ m = torch .compile (m , fullgraph = True )
165+ loaded_model = torch .compile (loaded_model , fullgraph = True )
166+
167+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
168+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
169+
170+ assert awq_out is not None
171+ assert awq_save_load_out is not None
172+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
173+
174+ def test_awq_loading_vllm (self ):
175+ """Simulate weight loading in vllm:
176+ * prepare model weight to the same format (awq weight)
177+ * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
178+
179+ There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
180+ """
181+ device = "cuda"
182+ dataset_size = 100
183+ l1 , l2 , l3 = 512 , 256 , 128
184+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
185+ group_size = 128
186+ n_calibration_examples = 10
187+ sequence_length = 5
188+
189+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
190+ dataset = m .example_inputs (
191+ dataset_size ,
192+ sequence_length = sequence_length ,
193+ dtype = original_dtype ,
194+ device = device ,
195+ )
196+ calibration_data = dataset [:n_calibration_examples ]
197+
198+ # calibrate
199+ base_config = FbgemmConfig (
200+ input_dtype = torch .bfloat16 ,
201+ weight_dtype = torch .int4 ,
202+ output_dtype = torch .bfloat16 ,
203+ block_size = [1 , group_size ],
204+ preshuffle = False ,
205+ )
206+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
207+ quantize_ (m , quant_config )
208+
209+ for example in calibration_data :
210+ m (example )
211+
212+ # quantize
213+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
214+ quantize_ (m , quant_config )
215+
216+ with tempfile .NamedTemporaryFile () as f :
217+ torch .save (m .state_dict (), f )
218+ f .seek (0 )
219+ state_dict = torch .load (f )
220+
221+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
222+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
223+ quantize_ (loaded_model , quant_config )
224+
225+ loaded_model .linear1 .weight .copy_ (state_dict ["linear1.weight" ])
226+ loaded_model .linear2 .weight .copy_ (state_dict ["linear2.weight" ])
227+ loaded_model .linear3 .weight .copy_ (state_dict ["linear3.weight" ])
228+
112229 m = torch .compile (m , fullgraph = True )
113230 loaded_model = torch .compile (loaded_model , fullgraph = True )
114231
115- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
116- awq_save_load_out = torch .cat ([loaded_model (i .squeeze (0 )) for i in dataset ])
117-
118- assert awq_out is not None
119- assert awq_save_load_out is not None
120- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
121-
122-
123- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
124- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
125- @skip_if_rocm ("ROCm enablement in progress" )
126- def test_save_weights_only ():
127- dataset_size = 100
128- l1 , l2 , l3 = 512 , 256 , 128
129- original_dtype = torch .bfloat16
130- quant_dtype = torch .uint4
131- device = "cuda"
132- group_size = 128
133- n_calibration_examples = 10
134- n_validation_examples = 10
135- sequence_length = 5
136-
137- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
138- m2 = deepcopy (m )
139- dataset = m .example_inputs (
140- dataset_size ,
141- sequence_length = sequence_length ,
142- dtype = original_dtype ,
143- device = device ,
144- )
145- calibration_data = dataset [:n_calibration_examples ]
146-
147- # calibrate
148- insert_awq_observer_ (
149- m ,
150- n_validation_examples ,
151- sequence_length ,
152- quant_dtype = quant_dtype ,
153- group_size = group_size ,
154- )
155-
156- for example in calibration_data :
157- m (example .to (device ))
158-
159- # quantize
160- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
161- quantize_ (
162- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
163- )
164-
165- model_save_path = "awq_model.pth"
166- torch .save (m .state_dict (), model_save_path )
167- m2 .load_state_dict (
168- torch .load (model_save_path ), assign = True
169- ) # load weights only.torch.load(model_save_path)
170- os .remove (model_save_path )
171-
172- m = torch .compile (m , fullgraph = True )
173- m2 = torch .compile (m2 , fullgraph = True )
174-
175- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
176- awq_save_load_out = torch .cat ([m2 (i .squeeze (0 )) for i in dataset ])
177-
178- assert awq_out is not None
179- assert awq_save_load_out is not None
180- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
232+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
233+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
234+
235+ assert awq_out is not None
236+ assert awq_save_load_out is not None
237+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
238+
239+
240+ if __name__ == "__main__" :
241+ run_tests ()
0 commit comments