Skip to content

Commit 8dbf031

Browse files
authored
Fix FP6-LLM benchmark (#312)
1 parent 12f44ab commit 8dbf031

File tree

2 files changed

+52
-82
lines changed

2 files changed

+52
-82
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

benchmarks/benchmark_fp6_llm.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
from torch import nn
3+
from torchao.quantization.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
4+
from torch.utils.benchmark import Timer
5+
import pandas as pd
6+
from tqdm import tqdm
7+
8+
9+
def benchmark(m: int, k: int, n: int):
10+
fp6_weight = torch.randint(256, size=(n, k // 4 * 3), dtype=torch.uint8, device="cuda")
11+
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12+
fp6_linear = Fp6LlmLinear(fp6_weight.view(torch.int32), scales)
13+
14+
fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
15+
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight.view(-1), n, k, dtype=torch.half) * scales[:, None]
16+
17+
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
18+
fp6_output = fp6_linear(fp16_act)
19+
fp16_output = fp16_linear(fp16_act)
20+
21+
fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange()
22+
fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange()
23+
24+
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
25+
# doesn't seem to be the right way to check for correctness
26+
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
27+
28+
return {
29+
"m": m,
30+
"k": k,
31+
"n": n,
32+
"fp6_latency (ms)": fp6_measurement.median * 1000,
33+
"fp16_latency (ms)": fp16_measurement.median * 1000,
34+
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
35+
"correct": correct,
36+
}
37+
38+
39+
if __name__ == "__main__":
40+
# from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh
41+
k_vals = (8192, 8192, 8192, 28672)
42+
n_vals = (8192, 10240, 57344, 8192)
43+
44+
results = []
45+
46+
for m in tqdm([1 << i for i in range(10)]):
47+
for n, k in zip(n_vals, k_vals):
48+
results.append(benchmark(m, n, k))
49+
50+
df = pd.DataFrame(results)
51+
df.to_csv("fp6_llm_benchmark_results.csv", index=False)
52+
print(df.to_markdown(index=False))

0 commit comments

Comments
 (0)