@@ -67,10 +67,14 @@ def cuda_kernel_profiler(kernel_pattern):
6767@pytest .mark .skipif (
6868 not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
6969)
70- @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ])
71- @pytest .mark .parametrize ("bias" , [True , False ])
72- @pytest .mark .parametrize ("compile" , [True , False ])
73- @pytest .mark .parametrize ("emulate" , [True , False ])
70+ # @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
71+ @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn ])
72+ # @pytest.mark.parametrize("bias", [True, False])
73+ # @pytest.mark.parametrize("compile", [True, False])
74+ # @pytest.mark.parametrize("emulate", [True, False])
75+ @pytest .mark .parametrize ("bias" , [False ])
76+ @pytest .mark .parametrize ("compile" , [False ])
77+ @pytest .mark .parametrize ("emulate" , [False ])
7478@torch .no_grad ()
7579@skip_if_rocm (
7680 "ROCm float4 gemm require gfx950"
@@ -93,7 +97,11 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
9397 # TODO(future PR): investigate and fix this
9498 pytest .skip ("mxfp4 + compile currently does not work, low SQNR" )
9599
96- m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
100+ # M, N, K = 16, 3072, 4096
101+ # M, N, K = 1920, 3072, 256
102+ M , N , K = 1920 , 18432 , 3072
103+ # m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
104+ m = nn .Linear (K , N , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
97105 m_mx = copy .deepcopy (m )
98106
99107 if emulate :
@@ -108,18 +116,22 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
108116 gemm_kernel_choice = kernel_choice ,
109117 )
110118 quantize_ (m_mx , config = config )
119+ print ("m_mx:" , m_mx )
120+
111121 if compile :
112122 m_mx = torch .compile (m_mx , fullgraph = True )
113123
114- x = torch .randn (128 , 32 , device = "cuda" , dtype = torch .bfloat16 )
115- y_ref = m (x )
116- y_mx = m_mx (x )
124+ with torch .inference_mode ():
125+ x = torch .randn (1 , M , K , device = "cuda" , dtype = torch .bfloat16 )
126+ y_ref = m (x )
127+ y_mx = m_mx (x )
117128 sqnr = compute_error (y_ref , y_mx )
118129 SQNR_THRESHOLD = 25.0 if elem_dtype == torch .float8_e4m3fn else 15.0
119130 assert sqnr >= SQNR_THRESHOLD , (
120131 f"Got a sqnr of { sqnr } for { elem_dtype } and bias={ bias } "
121132 )
122133
134+ raise Exception ("stop" )
123135 # serialization
124136 with tempfile .NamedTemporaryFile () as f :
125137 torch .save (m_mx .state_dict (), f )
0 commit comments