Skip to content

Commit e934e1d

Browse files
authored
Merge branch 'main' into add-aqt-docs
2 parents ec351f2 + 4ca3985 commit e934e1d

20 files changed

+948
-204
lines changed

dev-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ pandas
1212

1313
# Custom CUDA Extensions
1414
ninja
15+
16+
# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
17+
qtorch

docs/source/api_ref_dtypes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ torchao.dtypes
1212

1313
to_nf4
1414
UInt4Tensor
15+
to_float6_e3m2
16+
from_float6_e3m2
1517

1618
..
1719
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ def get_extensions():
4646
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
4747
extension = CUDAExtension if use_cuda else CppExtension
4848

49-
extra_link_args = []
49+
extra_link_args = ["-fopenmp"]
5050
extra_compile_args = {
5151
"cxx": [
5252
"-O3" if not debug_mode else "-O0",
5353
"-fdiagnostics-color=always",
54+
"-fopenmp",
5455
],
5556
"nvcc": [
5657
"-O3" if not debug_mode else "-O0",

test/dtypes/test_float6_e3m2.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import torch
2+
from torch.testing._internal.common_utils import (
3+
TestCase,
4+
instantiate_parametrized_tests,
5+
parametrize,
6+
run_tests,
7+
)
8+
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
9+
10+
11+
_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
12+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
13+
14+
15+
class TestFp6(TestCase):
16+
17+
@parametrize("device", _DEVICES)
18+
@parametrize("dtype", _DTYPES)
19+
@parametrize(
20+
"input_output",
21+
[
22+
(0.0, 0b000000), # exact values
23+
(1.0, 0b001100), # normal numbers
24+
(1.25, 0b001101),
25+
(28.0, 0b011111), # max
26+
(0.1875, 0b000011), # subnormal number
27+
(0.0625, 0b000001), # min
28+
(29.0, 0b011111), # normal round down
29+
(26.0, 0b011110), # normal round to nearest even
30+
(0.1251, 0b000010), # subnormal round down
31+
(0.0314, 0b000001), # subnormal round up
32+
(0.03, 0b000000), # underflow
33+
],
34+
)
35+
def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output):
36+
input, output = input_output
37+
input = torch.tensor(input, device=device, dtype=dtype)
38+
assert to_float6_e3m2(input, no_bit_packing=True).item() == output
39+
40+
@parametrize("device", _DEVICES)
41+
@parametrize("dtype", _DTYPES)
42+
def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype):
43+
x = torch.randn(128, 128, device=device, dtype=dtype)
44+
results_unpacked = to_float6_e3m2(x, no_bit_packing=True)
45+
results_packed = to_float6_e3m2(x)
46+
47+
val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1)
48+
bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011
49+
bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222
50+
bits2 = (val2 << 6) | (val3); # 2233 3333
51+
52+
expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2)
53+
assert (results_packed == expected_packed).all()
54+
55+
@parametrize("device", _DEVICES)
56+
@parametrize("shape", [(), (0,), (10,), (20, 20)])
57+
def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape):
58+
x = torch.randn(shape, device=device)
59+
result = to_float6_e3m2(x, no_bit_packing=True)
60+
assert result.shape == shape
61+
62+
@parametrize("device", _DEVICES)
63+
@parametrize("shape", [(4,), (20, 20)])
64+
def test_to_float6_e3m2_bit_packing_shape(self, device, shape):
65+
x = torch.randn(shape, device=device)
66+
result = to_float6_e3m2(x)
67+
assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,)
68+
69+
@parametrize("device", _DEVICES)
70+
@parametrize("dtype", _DTYPES)
71+
@parametrize("no_bit_packing", [False, True])
72+
def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing):
73+
x = torch.randn(20, 20, device=device, dtype=dtype)
74+
expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing)
75+
76+
to_float6_e3m2_compiled = torch.compile(to_float6_e3m2)
77+
actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
78+
torch.testing.assert_close(actual, expected)
79+
80+
@parametrize("device", _DEVICES)
81+
@parametrize(
82+
"input_output",
83+
[
84+
(0b000000, 0.0),
85+
(0b001100, 1.0),
86+
(0b011111, 28.0), # max
87+
(0b000001, 0.0625), # min
88+
(0b001110, 1.5),
89+
(0b000011, 0.1875), # subnormal
90+
],
91+
)
92+
def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output):
93+
input, output = input_output
94+
input = torch.tensor(input, device=device, dtype=torch.uint8)
95+
assert from_float6_e3m2(input, no_bit_packing=True).item() == output
96+
97+
@parametrize("device", _DEVICES)
98+
def test_from_float6_e3m2_bit_packing_correctness(self, device):
99+
x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8)
100+
actual = from_float6_e3m2(x)
101+
102+
bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1)
103+
x_unpacked0 = bits0 >> 2
104+
x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4)
105+
x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6)
106+
x_unpacked3 = bits2 & 0x3F
107+
108+
x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2)
109+
expected = from_float6_e3m2(x_unpacked, no_bit_packing=True)
110+
torch.testing.assert_close(actual, expected)
111+
112+
@parametrize("device", _DEVICES)
113+
@parametrize("no_bit_packing", [False, True])
114+
def test_from_float6_e3m2_compile(self, device, no_bit_packing):
115+
x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8)
116+
expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing)
117+
118+
from_float6_e3m2_compiled = torch.compile(from_float6_e3m2)
119+
actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
120+
torch.testing.assert_close(actual, expected)
121+
122+
123+
instantiate_parametrized_tests(TestFp6)
124+
125+
126+
if __name__ == "__main__":
127+
run_tests()

test/test_ops.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,21 @@ def test_prepack_fp6_weight(self):
5050
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)
5151

5252
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
53-
def test_fp16_to_fp6(self):
53+
def test_fp16_to_fp6_original(self):
5454
OC = 256
5555
IC = 256
56-
57-
# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
58-
# also, we don't have nan/inf
59-
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
60-
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
6156
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
62-
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
63-
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0
57+
58+
# the original FP16->FP6 kernel checks for overflow/underflow
59+
fp16_weight.clip_(-28.0, 28.0)
60+
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0
6461

6562
# smoke test
66-
torchao.ops.fp16_to_fp6(fp16_weight)
63+
torchao.ops.fp16_to_fp6_original(fp16_weight)
6764

6865
# comprehensive testing
6966
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
70-
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)
67+
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)
7168

7269
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
7370
def test_fp16act_fp6weight_linear(self):
@@ -89,19 +86,6 @@ def test_fp16act_fp6weight_linear(self):
8986
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
9087
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)
9188

92-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
93-
def test_fp6_weight_dequant(self):
94-
OC = 256
95-
IC = 256
96-
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)
97-
98-
# smoke test
99-
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)
100-
101-
# comprehensive testing
102-
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
103-
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)
104-
10589
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
10690
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
10791
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -115,8 +99,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
11599

116100
results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
117101

118-
fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
119-
results_fp16 = act_cuda @ fp16_weight.T
102+
fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
103+
results_fp16 = act_cuda @ fp16_weight.cuda().T
120104

121105
error = (results_fp6 - results_fp16).abs()
122106
relative_error = error / results_fp16.abs()

torchao/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
from torchao.quantization import (
2-
apply_weight_only_int8_quant,
3-
apply_dynamic_quant,
4-
autoquant,
5-
)
6-
from . import dtypes
71
import torch
82
_IS_FBCODE = (
93
hasattr(torch._utils_internal, "IS_FBSOURCE") and
@@ -14,6 +8,13 @@
148
from . import _C
159
from . import ops
1610

11+
from torchao.quantization import (
12+
apply_weight_only_int8_quant,
13+
apply_dynamic_quant,
14+
autoquant,
15+
)
16+
from . import dtypes
17+
1718
__all__ = [
1819
"dtypes",
1920
"apply_dynamic_quant",

torchao/csrc/cuda/fp6_llm/weight_quant.cu

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414
//
1515
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h
16-
// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h
1716

1817
#include <cuda_fp16.h>
1918
#include <iostream>
@@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,
120119
}
121120
}
122121

123-
void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {
124-
assert(M%64==0); // Currently, M must be a multiple of 64.
125-
assert(K%64==0); // Currently, K must be a multiple of 64.
126-
size_t TotalSizeInByte = M*K*6/8;
127-
//
128-
half* OutPTR = A_16bit_h;
129-
for(size_t i=0; i<TotalSizeInByte/3; i++) { // 4 FP6 = 3 Bytes for each Loop
130-
unsigned char B1 = A_6bit_h[i*3+0] & 0xfc;
131-
B1 = (B1&0x80) | ((B1>>2)&0x1f);
132-
unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc);
133-
B2 = (B2&0x80) | ((B2>>2)&0x1f);
134-
unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc);
135-
B3 = (B3&0x80) | ((B3>>2)&0x1f);
136-
unsigned char B4 = A_6bit_h[i*3+2]<<2;
137-
B4 = (B4&0x80) | ((B4>>2)&0x1f);
138-
half FP1, FP2, FP3, FP4;
139-
unsigned char *PTR1, *PTR2, *PTR3, *PTR4;
140-
PTR1 = reinterpret_cast<unsigned char*>(&FP1);
141-
PTR2 = reinterpret_cast<unsigned char*>(&FP2);
142-
PTR3 = reinterpret_cast<unsigned char*>(&FP3);
143-
PTR4 = reinterpret_cast<unsigned char*>(&FP4);
144-
PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU
145-
PTR2[0] = 0; PTR2[1] = B2;
146-
PTR3[0] = 0; PTR3[1] = B3;
147-
PTR4[0] = 0; PTR4[1] = B4;
148-
OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) );
149-
OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) );
150-
OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) );
151-
OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) );
152-
//
153-
OutPTR +=4;
154-
}
155-
}
156-
157-
158122
#include <torch/extension.h>
159123
#include <ATen/ATen.h>
160124
#include <torch/library.h>
161125

162126
namespace torchao {
163127

164128
// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194
165-
at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
129+
at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor)
166130
{
167131
TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional");
168132
TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16");
@@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
183147
return packed_fp6_tensor;
184148
}
185149

186-
/*
187-
* Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
188-
* A useful tool to construct input matrices for the FP16 GEMM baseline.
189-
* [Input]
190-
* fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights.
191-
* fp16_scale: half tensor of shape [OC]; // for row-wise quantization.
192-
* [Output]
193-
* fp16_tensor: half tensor of shape [OC, IC].
194-
*/
195-
at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale)
196-
{
197-
int OC = fp6_tensor.size(0);
198-
TORCH_CHECK(fp6_tensor.size(1) % 3 == 0);
199-
int IC = fp6_tensor.size(1) / 3 * 16;
200-
TORCH_CHECK(fp16_scale.size(0) == OC);
201-
//
202-
auto fp6_tensor_ptr = reinterpret_cast<unsigned char*>(fp6_tensor.data_ptr<int>());
203-
auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());
204-
//
205-
auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device());
206-
at::Tensor fp16_tensor = at::empty({OC, IC}, options);
207-
auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());
208-
//
209-
DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr);
210-
//
211-
return fp16_tensor;
212-
}
213-
214150
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
215-
m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu);
216-
m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu);
151+
m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu);
217152
}
218153

219154
}

0 commit comments

Comments
 (0)