2020
2121
2222class ToyLinearModel (torch .nn .Module ):
23- def __init__ (self , m = 512 , n = 256 , k = 128 ):
24- super ().__init__ ()
25- self .linear1 = torch .nn .Linear (m , n , bias = False )
26- self .linear2 = torch .nn .Linear (n , k , bias = False )
27- self .linear3 = torch .nn .Linear (k , 64 , bias = False )
28-
29- def example_inputs (
30- self , batch_size , sequence_length = 10 , dtype = torch .bfloat16 , device = "cuda"
23+ def __init__ (
24+ self ,
25+ m = 512 ,
26+ n = 256 ,
27+ k = 128 ,
28+ dtype = None ,
29+ device = None ,
3130 ):
32- return [
33- torch .randn (
34- 1 , sequence_length , self .linear1 .in_features , dtype = dtype , device = device
35- )
36- for j in range (batch_size )
37- ]
31+ super ().__init__ ()
32+ self .dtype = dtype
33+ self .device = device
34+ self .linear1 = torch .nn .Linear (m , n , bias = False , device = device , dtype = dtype )
35+ self .linear2 = torch .nn .Linear (n , k , bias = False , device = device , dtype = dtype )
36+ self .linear3 = torch .nn .Linear (k , 64 , bias = False , device = device , dtype = dtype )
37+
38+ def example_inputs (self , batch_size , sequence_length = 10 ):
39+ # For AWQ tests, we intentionally insert some outliers to input features
40+ x = torch .randn (
41+ batch_size ,
42+ sequence_length ,
43+ self .linear1 .in_features ,
44+ dtype = self .dtype ,
45+ device = self .device ,
46+ )
47+ n_outliers = max (1 , int (x .size (- 1 ) * 0.1 ))
48+ # Randomly select outlier features
49+ outlier_indices = torch .randperm (x .size (- 1 ))[:n_outliers ]
50+ x [:, :, outlier_indices ] *= 10.0
51+ return (x ,)
3852
3953 def forward (self , x ):
4054 x = self .linear1 (x )
@@ -92,14 +106,12 @@ def test_awq_functionality(self, device):
92106 base_configs = device_to_base_configs [device ]
93107
94108 for base_config in base_configs :
95- m = ToyLinearModel (l1 , l2 , l3 ).eval (). to ( original_dtype ). to ( device )
109+ m = ToyLinearModel (l1 , l2 , l3 , device = device , dtype = original_dtype ).eval ()
96110 m_baseline = copy .deepcopy (m )
97111
98112 dataset = m .example_inputs (
99113 dataset_size ,
100114 sequence_length = sequence_length ,
101- dtype = original_dtype ,
102- device = device ,
103115 )
104116 # for test, we use calibration_data = dataset so that awq is
105117 # guranteed to be better than baseline
@@ -142,12 +154,10 @@ def test_awq_loading(self, device):
142154 base_configs = device_to_base_configs [device ]
143155
144156 for base_config in base_configs :
145- m = ToyLinearModel (l1 , l2 , l3 ).eval (). to ( original_dtype ). to ( device )
157+ m = ToyLinearModel (l1 , l2 , l3 , device = device , dtype = original_dtype ).eval ()
146158 dataset = m .example_inputs (
147159 dataset_size ,
148160 sequence_length = sequence_length ,
149- dtype = original_dtype ,
150- device = device ,
151161 )
152162 # for test purpose, we don't need to get a subset
153163 calibration_data = dataset
@@ -171,9 +181,9 @@ def test_awq_loading(self, device):
171181 f .seek (0 )
172182 state_dict = torch .load (f )
173183
174- loaded_model = (
175- ToyLinearModel ( l1 , l2 , l3 ). eval (). to ( original_dtype ). to ( device )
176- )
184+ loaded_model = ToyLinearModel (
185+ l1 , l2 , l3 , device = device , dtype = original_dtype
186+ ). eval ()
177187 loaded_model .load_state_dict (state_dict , assign = True )
178188
179189 m = torch .compile (m , fullgraph = True )
@@ -203,12 +213,10 @@ def test_awq_loading_vllm(self, device):
203213 base_configs = device_to_base_configs [device ]
204214
205215 for base_config in base_configs :
206- m = ToyLinearModel (l1 , l2 , l3 ).eval (). to ( original_dtype ). to ( device )
216+ m = ToyLinearModel (l1 , l2 , l3 , device = device , dtype = original_dtype ).eval ()
207217 dataset = m .example_inputs (
208218 dataset_size ,
209219 sequence_length = sequence_length ,
210- dtype = original_dtype ,
211- device = device ,
212220 )
213221 # for test purpose, we don't need to get a subset
214222 calibration_data = dataset
@@ -231,9 +239,9 @@ def test_awq_loading_vllm(self, device):
231239 f .seek (0 )
232240 state_dict = torch .load (f )
233241
234- loaded_model = (
235- ToyLinearModel ( l1 , l2 , l3 ). eval (). to ( original_dtype ). to ( device )
236- )
242+ loaded_model = ToyLinearModel (
243+ l1 , l2 , l3 , device = device , dtype = original_dtype
244+ ). eval ()
237245 quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
238246 quantize_ (loaded_model , quant_config )
239247
0 commit comments