@@ -172,6 +172,8 @@ def test_inference_workflow_mx(
172172 ],
173173 ids = lambda s : f"{ s [0 ]} x{ s [1 ]} x{ s [2 ]} " ,
174174)
175+ @pytest .mark .parametrize ("use_inference_mode" , [False , True ])
176+ @pytest .mark .parametrize ("x_rank" , [2 , 3 ])
175177@torch .no_grad ()
176178@skip_if_rocm ("ROCm float4 gemm require gfx950" )
177179def test_inference_workflow_nvfp4 (
@@ -182,6 +184,8 @@ def test_inference_workflow_nvfp4(
182184 use_triton_kernel : bool ,
183185 use_dynamic_per_tensor_scale : bool ,
184186 shapes : tuple ,
187+ use_inference_mode : bool ,
188+ x_rank : int ,
185189):
186190 """
187191 Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
@@ -196,6 +200,16 @@ def test_inference_workflow_nvfp4(
196200
197201 if mm_config == NVFP4MMConfig .WEIGHT_ONLY and compile :
198202 pytest .skip ("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile" )
203+
204+ if use_inference_mode and (
205+ shapes != (128 , 64 , 256 ) or inpt_dtype != torch .bfloat16 or use_triton_kernel
206+ ):
207+ pytest .skip ("skipping unnecessary tests for inference mode" )
208+ if x_rank == 3 and (
209+ shapes != (128 , 64 , 256 ) or inpt_dtype != torch .bfloat16 or use_triton_kernel
210+ ):
211+ pytest .skip ("skipping unnecessary tests for x_rank 3" )
212+
199213 batch_size , in_features , out_features = shapes
200214
201215 m = nn .Linear (in_features , out_features , bias = bias , dtype = inpt_dtype , device = "cuda" )
@@ -212,14 +226,21 @@ def test_inference_workflow_nvfp4(
212226 m_mx = torch .compile (m_mx , fullgraph = True , backend = "aot_eager" )
213227
214228 x = torch .randn (batch_size , in_features , device = "cuda" , dtype = inpt_dtype )
229+ if x_rank == 3 :
230+ x = x .unsqueeze (0 )
231+
215232 y_ref = m (x )
216233
217234 if use_triton_kernel and mm_config != NVFP4MMConfig .WEIGHT_ONLY :
218235 with cuda_kernel_profiler ("quantize_nvfp4_triton_kernel" ) as result :
219236 y_mx = m_mx (x )
220237 assert result ["found" ], "Expected quantize_nvfp4 kernel to be found"
221238 else :
222- y_mx = m_mx (x )
239+ if use_inference_mode :
240+ with torch .inference_mode ():
241+ y_mx = m_mx (x )
242+ else :
243+ y_mx = m_mx (x )
223244
224245 sqnr = compute_error (y_ref , y_mx )
225246
0 commit comments