Skip to content

Commit 4211b85

Browse files
Fix failing FP6 benchmark
1 parent 653efe9 commit 4211b85

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
import torch
22
import pandas as pd
3-
import torch.nn.functional as F
4-
from torchao.dtypes import to_affine_quantized_floatx
5-
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
3+
import torchao
4+
from torchao.dtypes.floatx import from_scaled_tc_floatx
65
from torchao.utils import benchmark_torch_function_in_microseconds
76
from tqdm import tqdm
87

98

109
def benchmark(m: int, k: int, n: int):
11-
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12-
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
13-
fp16_weight = fp6_weight.dequantize(torch.half)
10+
ebits = 3
11+
mbits = 2
12+
nbits = 1 + ebits + mbits
1413

15-
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
16-
fp6_output = F.linear(fp16_act, fp6_weight)
17-
fp16_output = F.linear(fp16_act, fp16_weight)
14+
fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda")
15+
scale = torch.rand(n, device="cuda").half() + 0.5
16+
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5
1817

19-
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
20-
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
18+
fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1)
19+
20+
fp16_weight = from_scaled_tc_floatx(fp6_weight, ebits, mbits, scale).half()
21+
fp16_output = torch.matmul(fp16_act, fp16_weight.T)
22+
23+
fp6_time = benchmark_torch_function_in_microseconds(torchao.ops.quant_llm_linear, ebits, mbits, fp16_act, fp6_weight, scale, splitK=1)
24+
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, fp16_act, fp16_weight.T)
2125

2226
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
2327
# doesn't seem to be the right way to check for correctness

0 commit comments

Comments
 (0)