1818from torchao .quantization import (
1919 Float8DynamicActivationFloat8WeightConfig ,
2020 Float8WeightOnlyConfig ,
21+ PerBlock ,
2122 PerRow ,
2223 PerTensor ,
2324 quantize_ ,
3031 _is_fbgemm_gpu_genai_available ,
3132 is_sm_at_least_89 ,
3233 is_sm_at_least_90 ,
34+ is_sm_at_least_100 ,
3335 torch_version_at_least ,
3436)
3537
3840
3941
4042class ToyLinearModel (torch .nn .Module ):
41- def __init__ (self , in_features , out_features ):
43+ def __init__ (self , in_features , out_features , bias ):
4244 super ().__init__ ()
43- self .linear1 = torch .nn .Linear (in_features , out_features , bias = False )
44- self .linear2 = torch .nn .Linear (out_features , in_features , bias = False )
45+ self .linear1 = torch .nn .Linear (in_features , out_features , bias = bias )
46+ self .linear2 = torch .nn .Linear (out_features , in_features , bias = bias )
4547
4648 def forward (self , x ):
4749 x = self .linear1 (x )
4850 x = self .linear2 (x )
4951 return x
5052
5153
54+ class ToyConvModel (torch .nn .Module ):
55+ def __init__ (
56+ self , dim , in_channels , out_channels , kernel_size , bias , padding , dtype , device
57+ ):
58+ super ().__init__ ()
59+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
60+ self .conv = convs [dim ](
61+ in_channels ,
62+ out_channels ,
63+ kernel_size ,
64+ bias = bias ,
65+ padding = padding ,
66+ dtype = dtype ,
67+ device = device ,
68+ )
69+ if dim == 3 :
70+ self .conv = self .conv .to (memory_format = torch .channels_last_3d )
71+
72+ def forward (self , x ):
73+ return self .conv (x )
74+
75+
5276# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5377@unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
5478@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -64,7 +88,10 @@ def setUp(self):
6488 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
6589 @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
6690 @common_utils .parametrize ("compile" , [True , False ])
67- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
91+ @common_utils .parametrize (
92+ "granularity" ,
93+ [PerTensor (), PerRow (), (PerBlock ((1 , 128 )), PerBlock ((128 , 128 )))],
94+ )
6895 @common_utils .parametrize (
6996 "kernel_preference" ,
7097 [KernelPreference .AUTO , KernelPreference .TORCH , KernelPreference .FBGEMM ],
@@ -74,9 +101,11 @@ def setUp(self):
74101 "sizes" ,
75102 [
76103 ((128 ,), 256 , 128 ),
77- ((32 , 128 ), 64 , 256 ),
104+ ((32 , 128 ), 256 , 512 ),
78105 ],
79106 )
107+ @common_utils .parametrize ("bias" , [False , True ])
108+ @torch .no_grad ()
80109 def test_fp8_linear_variants (
81110 self ,
82111 dtype : torch .dtype ,
@@ -85,14 +114,36 @@ def test_fp8_linear_variants(
85114 granularity ,
86115 kernel_preference : KernelPreference ,
87116 sizes : Tuple ,
117+ bias : bool ,
88118 ):
89- if (
90- isinstance (granularity , PerTensor )
91- and kernel_preference == KernelPreference .FBGEMM
92- ):
93- return unittest .skip (
94- "per tensor with fbgemm kernel preferece does not work yet"
95- )
119+ if isinstance (granularity , PerTensor ):
120+ if kernel_preference is KernelPreference .FBGEMM :
121+ return unittest .skip (
122+ "per tensor with fbgemm kernel preference does not work yet"
123+ )
124+ elif mode == "weight-only" :
125+ return unittest .skip ("unimplemented" )
126+
127+ elif granularity == (PerBlock ((1 , 128 )), PerBlock ((128 , 128 ))):
128+ if dtype is not torch .bfloat16 :
129+ return unittest .skip ("unimplemented" )
130+ elif mode != "dynamic" :
131+ return unittest .skip ("unimplemented" )
132+ elif kernel_preference not in (
133+ KernelPreference .AUTO ,
134+ KernelPreference .TORCH ,
135+ ):
136+ return unittest .skip ("unimplemented" )
137+
138+ if bias is True :
139+ sizes_to_keep = ((128 ,), 256 , 128 )
140+ if (
141+ sizes != sizes_to_keep
142+ or kernel_preference is not KernelPreference .TORCH
143+ ):
144+ return unittest .skip (
145+ "cut down on number of options to save test time"
146+ )
96147
97148 error_message = None
98149 if isinstance (granularity , PerRow ):
@@ -122,7 +173,7 @@ def test_fp8_linear_variants(
122173 input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
123174
124175 # Create a linear layer with bfloat16 dtype
125- model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
176+ model = ToyLinearModel (K , N , bias ).eval ().to (dtype ).to ("cuda" )
126177
127178 quantized_model = copy .deepcopy (model )
128179
@@ -137,6 +188,20 @@ def test_fp8_linear_variants(
137188
138189 quantize_ (quantized_model , config )
139190
191+ # ensure weight scaling is what we expect
192+ qs1 = quantized_model .linear1 .weight .scale
193+ qs2 = quantized_model .linear2 .weight .scale
194+ if granularity == PerTensor ():
195+ assert qs1 .shape == (1 , 1 )
196+ assert qs2 .shape == (1 , 1 )
197+ elif granularity == PerRow ():
198+ assert qs1 .shape == (N , 1 )
199+ assert qs2 .shape == (K , 1 )
200+ else :
201+ assert granularity == (PerBlock ((1 , 128 )), PerBlock ((128 , 128 )))
202+ assert qs1 .shape == (N // 128 , K // 128 )
203+ assert qs2 .shape == (K // 128 , N // 128 )
204+
140205 if compile :
141206 quantized_model = torch .compile (quantized_model , fullgraph = True )
142207
@@ -148,6 +213,85 @@ def test_fp8_linear_variants(
148213 f"Quantization error is too high got a SQNR of { error } "
149214 )
150215
216+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
217+ @unittest .skipIf (
218+ not is_sm_at_least_100 (), "Requires GPU with compute capability >= 10.0"
219+ )
220+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
221+ @common_utils .parametrize ("compile" , [True , False ])
222+ @common_utils .parametrize ("granularity" , [PerTensor ()])
223+ @common_utils .parametrize ("inference_mode" , [True , False ])
224+ @common_utils .parametrize (
225+ "kernel_preference" ,
226+ [KernelPreference .AUTO ],
227+ )
228+ # only test for 3D conv for now
229+ # Inputs are (N, C_in, C_out, D, H, W)
230+ @common_utils .parametrize (
231+ "sizes" ,
232+ [
233+ (4 , 16 , 64 , 32 , 32 , 32 ),
234+ ],
235+ )
236+ def test_fp8_conv_variants (
237+ self ,
238+ dtype : torch .dtype ,
239+ compile : bool ,
240+ granularity ,
241+ inference_mode : bool ,
242+ kernel_preference : KernelPreference ,
243+ sizes : Tuple ,
244+ ):
245+ if (not _is_fbgemm_gpu_genai_available ()) or (not is_sm_at_least_100 ()):
246+ return unittest .skip (
247+ "Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
248+ "fbgemm kernel preference test"
249+ )
250+
251+ dim = 3
252+ N , C_in , C_out , D , H , W = sizes
253+ kernel_size = 3
254+
255+ # Note: this is channel last memory format
256+ input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
257+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
258+
259+ # Create a linear layer with bfloat16 dtype
260+ model = ToyConvModel (
261+ dim ,
262+ C_in ,
263+ C_out ,
264+ kernel_size ,
265+ bias = False ,
266+ padding = 0 ,
267+ dtype = dtype ,
268+ device = "cuda" ,
269+ ).eval ()
270+
271+ quantized_model = copy .deepcopy (model )
272+
273+ config = Float8DynamicActivationFloat8WeightConfig (
274+ granularity = granularity ,
275+ kernel_preference = kernel_preference ,
276+ )
277+
278+ _is_conv3d = lambda m , fqn : isinstance (m , torch .nn .Conv3d )
279+
280+ quantize_ (quantized_model , config , filter_fn = _is_conv3d )
281+
282+ if compile :
283+ quantized_model = torch .compile (quantized_model , fullgraph = True )
284+
285+ inference_mode_ctx = torch .inference_mode () if inference_mode else nullcontext ()
286+ with inference_mode_ctx :
287+ output_original = model (input_tensor )
288+ output_quantized = quantized_model (input_tensor )
289+
290+ error = compute_error (output_original , output_quantized )
291+ assert compute_error (output_original , output_quantized ) > 20 , (
292+ f"Quantization error is too high got a SQNR of { error } "
293+ )
294+
151295 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
152296 @unittest .skipIf (
153297 not is_sm_at_least_90 (),
@@ -231,7 +375,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
231375 dtype = torch .bfloat16
232376 input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
233377 # Create a linear layer with bfloat16 dtype
234- model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
378+ model = ToyLinearModel (K , N , bias = False ).eval ().to (dtype ).to ("cuda" )
235379
236380 # reference kernel preference and results
237381 # we are using KerenelPreference.TORCH as the reference
0 commit comments