Skip to content

Commit 70aef5d

Browse files
authored
Add FP5 E2M2 support from upstream (#399)
* first update from upstream * add some primitives to support fp5 * binding for ExMy * add QuantLlmLinear * fix * update README * update README * remove fp6_linear from C++ * fix * fix * fix * update * add more experimental config * update * add from tc_fpx * remove redundant code * fix import * fix test * avoid division by 0 * add subclass. use uint8 * subclass API * update doc * remove unused op * update * rename. update * update docs * rename * fix for PyTorch 2.2 * _implements -> implements * set CUDA context * fix __repr__
1 parent 96d49cd commit 70aef5d

21 files changed

+958
-847
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
6868

6969
* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
7070
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
71-
* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)`
71+
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`
7272

7373
## Composability
7474

@@ -104,7 +104,7 @@ python setup.py install
104104
* [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
105105
* [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics
106106
* [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
107-
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
107+
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm)
108108
* [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
109109
* [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile
110110

benchmarks/benchmark_fp6_llm.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
import torch
2-
from torch import nn
3-
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
4-
from torch.utils.benchmark import Timer
52
import pandas as pd
3+
import torch.nn.functional as F
4+
from torchao.prototype.quant_llm import QuantLlmLinearWeight
5+
from torchao.utils import benchmark_torch_function_in_microseconds
66
from tqdm import tqdm
77

88

99
def benchmark(m: int, k: int, n: int):
10-
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
11-
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12-
fp6_linear = Fp6LlmLinear(fp6_weight, scales)
10+
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
11+
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12+
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)
1313

14-
fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
15-
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]
14+
fp16_weight = fp6_weight.dequantize(torch.half)
1615

1716
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
18-
fp6_output = fp6_linear(fp16_act)
19-
fp16_output = fp16_linear(fp16_act)
17+
fp6_output = F.linear(fp16_act, fp6_weight)
18+
fp16_output = F.linear(fp16_act, fp16_weight)
2019

21-
fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange()
22-
fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange()
20+
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
21+
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
2322

2423
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
2524
# doesn't seem to be the right way to check for correctness
@@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int):
2928
"m": m,
3029
"k": k,
3130
"n": n,
32-
"fp6_latency (ms)": fp6_measurement.median * 1000,
33-
"fp16_latency (ms)": fp16_measurement.median * 1000,
34-
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
31+
"fp6_latency (ms)": fp6_time,
32+
"fp16_latency (ms)": fp16_time,
33+
"speedup (d/s)": fp16_time / fp6_time,
3534
"correct": correct,
3635
}
3736

test/prototype/test_fp6_llm.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

test/prototype/test_quant_llm.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
from torch.testing._internal.common_utils import (
6+
TestCase,
7+
instantiate_parametrized_tests,
8+
parametrize,
9+
run_tests,
10+
)
11+
from torchao.prototype.quant_llm import (
12+
QuantLlmLinearWeight,
13+
quant_llm_fpx_weight_only,
14+
to_scaled_tc_fpx,
15+
from_scaled_tc_fpx,
16+
)
17+
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
18+
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
19+
from torchao.quantization.quant_api import quantize
20+
21+
22+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
23+
_FPx_DTYPES = [(3, 2), (2, 2)]
24+
25+
26+
class TestQuantLlmLinearWeight(TestCase):
27+
@parametrize("device", _DEVICES)
28+
def test_pack_tc_fp6_correctness(self, device):
29+
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
30+
31+
expected = _pack_tc_fpx(x, 6)
32+
actual = _pack_tc_fp6(x)
33+
torch.testing.assert_close(actual, expected)
34+
35+
@parametrize("ebits,mbits", _FPx_DTYPES)
36+
@parametrize("device", _DEVICES)
37+
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
38+
x = torch.randn(256, 64, device=device)
39+
40+
expected = to_scaled_tc_fpx(x, ebits, mbits)
41+
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
42+
torch.testing.assert_close(actual, expected)
43+
44+
@parametrize("ebits,mbits", _FPx_DTYPES)
45+
@parametrize("device", _DEVICES)
46+
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
47+
x = torch.randn(256, 64, device=device) * 100
48+
49+
# quantize and dequantize so that the values are exactly representable in FPx
50+
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)
51+
52+
tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
53+
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
54+
torch.testing.assert_close(actual, x)
55+
56+
@parametrize("ebits,mbits", _FPx_DTYPES)
57+
@parametrize("device", _DEVICES)
58+
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
59+
M, N = 256, 64
60+
nbits = 1 + ebits + mbits
61+
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
62+
scale = torch.randn(M, device=device)
63+
64+
expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
65+
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
66+
torch.testing.assert_close(actual, expected)
67+
68+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
69+
@parametrize("ebits,mbits", _FPx_DTYPES)
70+
@parametrize("leading_dims", [(4,), (2, 4)])
71+
@parametrize("bias", [False, True])
72+
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims):
73+
OC, IC = 256, 64
74+
device = "cuda"
75+
76+
fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half)
77+
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None
78+
79+
fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits)
80+
81+
x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
82+
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias)
83+
assert out.shape == leading_dims + (OC,)
84+
85+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
86+
@parametrize("ebits,mbits", _FPx_DTYPES)
87+
@parametrize("bias", [False, True])
88+
def test_quant_llm_quantize(self, ebits, mbits, bias):
89+
N, OC, IC = 4, 256, 64
90+
device = "cuda"
91+
92+
linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
93+
fpx_linear = copy.deepcopy(linear)
94+
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
95+
96+
x = torch.randn(N, IC, device=device, dtype=torch.half)
97+
expected = fpx_linear(x)
98+
actual = torch.compile(fpx_linear, fullgraph=True)(x)
99+
torch.testing.assert_close(actual, expected)
100+
101+
102+
instantiate_parametrized_tests(TestQuantLlmLinearWeight)
103+
104+
105+
if __name__ == "__main__":
106+
run_tests()

test/test_ops.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,69 @@
11
import torch
2-
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
2+
from torch.testing._internal.common_utils import (
3+
TestCase,
4+
instantiate_parametrized_tests,
5+
parametrize,
6+
run_tests,
7+
)
38
from torch.testing._internal.optests import opcheck
4-
import torchao
5-
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
6-
import unittest
7-
from parameterized import parameterized
9+
from torchao.utils import is_fbcode
10+
from torchao.prototype.quant_llm import from_scaled_tc_fpx
811
import pytest
912

13+
if is_fbcode():
14+
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels")
15+
1016
try:
1117
import torchao.ops
1218
except RuntimeError:
1319
pytest.skip("torchao.ops not available")
1420

1521

16-
# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
17-
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
18-
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
19-
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
2022
class TestOps(TestCase):
21-
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
22-
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
23-
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
24-
fp16_scale = torch.rand(OC).half() + 0.5
25-
fp16_activation = torch.rand(BS, IC).half() + 0.5
26-
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)
27-
28-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
29-
def test_fp6_llm_linear(self):
23+
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
24+
# Randomly initialize each byte
25+
nbits = 1 + ebits + mbits
26+
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
27+
scale = torch.rand(OC).half() + 0.5
28+
fp16_act = torch.rand(BS, IC).half() + 0.5
29+
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)
30+
31+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
32+
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
33+
def test_quant_llm_linear(self, ebits, mbits):
3034
BS = 2
3135
OC = 256
3236
IC = 256
3337
splitK = 1
34-
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
38+
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
3539

3640
# smoke test
37-
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
41+
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
3842

3943
# comprehensive testing
4044
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
41-
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
45+
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)
46+
47+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
48+
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
49+
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
50+
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
51+
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
52+
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
53+
54+
results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
4255

43-
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
44-
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
45-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
46-
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
47-
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
56+
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
57+
results_fp16 = fp16_act @ fp16_weight.T
4858

49-
results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
59+
error = (results_fpx - results_fp16).abs().mean()
60+
gt = results_fp16.abs().mean()
61+
relative_error = error / gt
62+
assert relative_error < 1e-3
5063

51-
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
52-
results_fp16 = fp16_activation @ fp16_weight.T
5364

54-
error = (results_fp6 - results_fp16).abs()
55-
relative_error = error / results_fp16.abs()
56-
assert relative_error.mean() < 1e-2
65+
instantiate_parametrized_tests(TestOps)
5766

5867

5968
if __name__ == "__main__":
60-
unittest.main()
69+
run_tests()

0 commit comments

Comments
 (0)