Skip to content

Commit 7fbbcca

Browse files
More elegant weight initialization for FP6 benchmark
1 parent 4211b85 commit 7fbbcca

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import torch
22
import pandas as pd
33
import torchao
4-
from torchao.dtypes.floatx import from_scaled_tc_floatx
4+
from torchao.dtypes.floatx import from_scaled_tc_floatx, to_scaled_tc_floatx
55
from torchao.utils import benchmark_torch_function_in_microseconds
66
from tqdm import tqdm
77

88

99
def benchmark(m: int, k: int, n: int):
1010
ebits = 3
1111
mbits = 2
12-
nbits = 1 + ebits + mbits
1312

14-
fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda")
15-
scale = torch.rand(n, device="cuda").half() + 0.5
13+
fp32_weight = torch.randn(n, k, device="cuda")
14+
fp6_weight, scale = to_scaled_tc_floatx(fp32_weight, ebits, mbits)
1615
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5
1716

1817
fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1)

0 commit comments

Comments
 (0)