diff --git a/test/kernel/test_blockwise_triton.py b/test/kernel/test_blockwise_triton.py index 5de88ab7d9..ba377c560f 100644 --- a/test/kernel/test_blockwise_triton.py +++ b/test/kernel/test_blockwise_triton.py @@ -66,6 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype): A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + assert C_q.dtype == torch.bfloat16, "unsupported" error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C) print(f"Relative Error: {error.item():.6f}") diff --git a/torchao/kernel/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py index 192f6d5887..1a43e71a97 100644 --- a/torchao/kernel/blockwise_quantization.py +++ b/torchao/kernel/blockwise_quantization.py @@ -92,7 +92,7 @@ def blockwise_fp8_gemm( M = a.numel() // K N = b.size(0) M_BUCKET = math.ceil(math.log2(M)) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -105,7 +105,7 @@ def blockwise_fp8_gemm( @blockwise_fp8_gemm.register_fake def _(a, a_s, b, b_s, block_size=128): N = b.size(0) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16) return c @triton.jit