@@ -90,7 +90,7 @@ def setUp(self):
9090 @common_utils .parametrize ("compile" , [True , False ])
9191 @common_utils .parametrize (
9292 "granularity" ,
93- [PerTensor (), PerRow (), (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) ))],
93+ [PerTensor (), PerRow (), (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] ))],
9494 )
9595 @common_utils .parametrize (
9696 "kernel_preference" ,
@@ -124,7 +124,7 @@ def test_fp8_linear_variants(
124124 elif mode == "weight-only" :
125125 return unittest .skip ("unimplemented" )
126126
127- elif granularity == (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) )):
127+ elif granularity == (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] )):
128128 if dtype is not torch .bfloat16 :
129129 return unittest .skip ("unimplemented" )
130130 elif mode != "dynamic" :
@@ -198,7 +198,7 @@ def test_fp8_linear_variants(
198198 assert qs1 .shape == (N , 1 )
199199 assert qs2 .shape == (K , 1 )
200200 else :
201- assert granularity == (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) ))
201+ assert granularity == (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] ))
202202 assert qs1 .shape == (N // 128 , K // 128 )
203203 assert qs2 .shape == (K // 128 , N // 128 )
204204
0 commit comments