|
1 | 1 | import torch
|
2 | 2 | import pandas as pd
|
3 |
| -import torch.nn.functional as F |
4 |
| -from torchao.dtypes import to_affine_quantized_floatx |
5 |
| -from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType |
| 3 | +import torchao |
| 4 | +from torchao.dtypes.floatx import from_scaled_tc_floatx |
6 | 5 | from torchao.utils import benchmark_torch_function_in_microseconds
|
7 | 6 | from tqdm import tqdm
|
8 | 7 |
|
9 | 8 |
|
10 | 9 | def benchmark(m: int, k: int, n: int):
|
11 |
| - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") |
12 |
| - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) |
13 |
| - fp16_weight = fp6_weight.dequantize(torch.half) |
| 10 | + ebits = 3 |
| 11 | + mbits = 2 |
| 12 | + nbits = 1 + ebits + mbits |
14 | 13 |
|
15 |
| - fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") |
16 |
| - fp6_output = F.linear(fp16_act, fp6_weight) |
17 |
| - fp16_output = F.linear(fp16_act, fp16_weight) |
| 14 | + fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda") |
| 15 | + scale = torch.rand(n, device="cuda").half() + 0.5 |
| 16 | + fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5 |
18 | 17 |
|
19 |
| - fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) |
20 |
| - fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) |
| 18 | + fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) |
| 19 | + |
| 20 | + fp16_weight = from_scaled_tc_floatx(fp6_weight, ebits, mbits, scale).half() |
| 21 | + fp16_output = torch.matmul(fp16_act, fp16_weight.T) |
| 22 | + |
| 23 | + fp6_time = benchmark_torch_function_in_microseconds(torchao.ops.quant_llm_linear, ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) |
| 24 | + fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, fp16_act, fp16_weight.T) |
21 | 25 |
|
22 | 26 | # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
|
23 | 27 | # doesn't seem to be the right way to check for correctness
|
|
0 commit comments