Skip to content

Commit 3379b27

Browse files
vaishnavi17jainapurva
authored andcommitted
Introducing 1-bit quantization for Llama in torchchat (#910)
Differential Revision: D63052325 Pull Request resolved: #911
1 parent a2d77ce commit 3379b27

File tree

7 files changed

+505
-22
lines changed

7 files changed

+505
-22
lines changed

torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp

Lines changed: 171 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <benchmark/benchmark.h>
99

1010
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
1112
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
1213
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
1314
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
@@ -16,6 +17,128 @@
1617

1718
namespace {
1819

20+
// Benchmark utility to compare variants of uint1 packing
21+
void pack_uint1_values(
22+
uint8_t* packed,
23+
uint8_t* unpacked,
24+
int packed_size,
25+
int unpacked_size,
26+
int variant) {
27+
constexpr int nbit = 1;
28+
constexpr int bitsPerByte = 8;
29+
assert(unpacked_size * nbit / bitsPerByte == packed_size);
30+
assert(packed_size % variant == 0);
31+
32+
uint8x16_t unpacked0;
33+
uint8x16_t unpacked1;
34+
uint8x16_t unpacked2;
35+
uint8x16_t unpacked3;
36+
uint8x16_t unpacked4;
37+
uint8x16_t unpacked5;
38+
uint8x16_t unpacked6;
39+
uint8x16_t unpacked7;
40+
41+
switch (variant) {
42+
case 8:
43+
for (int i = 0; i < unpacked_size; i += 8) {
44+
torchao::bitpacking::internal::pack_8_uint1_values(
45+
packed + ((i * nbit) / bitsPerByte), unpacked + i);
46+
}
47+
break;
48+
case 64:
49+
for (int i = 0; i < unpacked_size; i += 64) {
50+
torchao::bitpacking::internal::vec_load_64_uint8_values(
51+
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
52+
torchao::bitpacking::internal::vec_pack_64_uint1_values(
53+
packed + ((i * nbit) / bitsPerByte),
54+
unpacked0,
55+
unpacked1,
56+
unpacked2,
57+
unpacked3);
58+
}
59+
break;
60+
case 128:
61+
for (int i = 0; i < unpacked_size; i += 128) {
62+
torchao::bitpacking::internal::vec_load_64_uint8_values(
63+
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
64+
torchao::bitpacking::internal::vec_load_64_uint8_values(
65+
unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64);
66+
torchao::bitpacking::internal::vec_pack_128_uint1_values(
67+
packed + ((i * nbit) / bitsPerByte),
68+
unpacked0,
69+
unpacked1,
70+
unpacked2,
71+
unpacked3,
72+
unpacked4,
73+
unpacked5,
74+
unpacked6,
75+
unpacked7);
76+
}
77+
break;
78+
}
79+
}
80+
81+
// Benchmark utility to compare variants of uint1 packing
82+
void unpack_uint1_values(
83+
uint8_t* unpacked,
84+
uint8_t* packed,
85+
int unpacked_size,
86+
int packed_size,
87+
int variant) {
88+
constexpr int nbit = 1;
89+
constexpr int bitsPerByte = 8;
90+
assert(unpacked_size * nbit / bitsPerByte == packed_size);
91+
assert(packed_size % variant == 0);
92+
93+
uint8x16_t unpacked0;
94+
uint8x16_t unpacked1;
95+
uint8x16_t unpacked2;
96+
uint8x16_t unpacked3;
97+
uint8x16_t unpacked4;
98+
uint8x16_t unpacked5;
99+
uint8x16_t unpacked6;
100+
uint8x16_t unpacked7;
101+
102+
switch (variant) {
103+
case 8:
104+
for (int i = 0; i < unpacked_size; i += 8) {
105+
torchao::bitpacking::internal::unpack_8_uint1_values(
106+
unpacked + i, packed + ((i * nbit) / bitsPerByte));
107+
}
108+
break;
109+
case 64:
110+
for (int i = 0; i < unpacked_size; i += 64) {
111+
torchao::bitpacking::internal::vec_unpack_64_uint1_values(
112+
unpacked0,
113+
unpacked1,
114+
unpacked2,
115+
unpacked3,
116+
packed + ((i * nbit) / bitsPerByte));
117+
torchao::bitpacking::internal::vec_store_64_uint8_values(
118+
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
119+
}
120+
break;
121+
case 128:
122+
for (int i = 0; i < unpacked_size; i += 128) {
123+
torchao::bitpacking::internal::vec_unpack_128_uint1_values(
124+
unpacked0,
125+
unpacked1,
126+
unpacked2,
127+
unpacked3,
128+
unpacked4,
129+
unpacked5,
130+
unpacked6,
131+
unpacked7,
132+
packed + ((i * nbit) / bitsPerByte));
133+
torchao::bitpacking::internal::vec_store_64_uint8_values(
134+
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
135+
torchao::bitpacking::internal::vec_store_64_uint8_values(
136+
unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7);
137+
}
138+
break;
139+
}
140+
}
141+
19142
// Benchmark utility to compare variants of uint2 packing
20143
void pack_uint2_values(
21144
uint8_t* packed,
@@ -470,6 +593,44 @@ void unpack_uint5_values(
470593

471594
} // namespace
472595

596+
static void benchmark_pack_uint1_values(benchmark::State& state) {
597+
int unpacked_size = state.range(0);
598+
int variant = state.range(1);
599+
int nbit = 1;
600+
601+
assert(unpacked_size % 8 == 0);
602+
int packed_size = (unpacked_size / 8) * nbit;
603+
604+
auto packed = std::vector<uint8_t>(packed_size, 0);
605+
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);
606+
607+
for (auto _ : state) {
608+
pack_uint1_values(
609+
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
610+
}
611+
}
612+
613+
static void benchmark_unpack_uint1_values(benchmark::State& state) {
614+
int unpacked_size = state.range(0);
615+
int variant = state.range(1);
616+
int nbit = 1;
617+
618+
assert(unpacked_size % 8 == 0);
619+
int packed_size = (unpacked_size / 8) * nbit;
620+
621+
auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
622+
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);
623+
624+
for (auto _ : state) {
625+
unpack_uint1_values(
626+
unpacked.data(),
627+
packed.data(),
628+
unpacked.size(),
629+
packed.size(),
630+
variant);
631+
}
632+
}
633+
473634
static void benchmark_pack_uint2_values(benchmark::State& state) {
474635
int unpacked_size = state.range(0);
475636
int variant = state.range(1);
@@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
478639
assert(unpacked_size % 8 == 0);
479640
int packed_size = (unpacked_size / 8) * nbit;
480641

481-
auto packed = std::vector<uint8_t>(unpacked_size, 0);
482-
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
642+
auto packed = std::vector<uint8_t>(packed_size, 0);
643+
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);
483644

484645
for (auto _ : state) {
485646
pack_uint2_values(
@@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
516677
assert(unpacked_size % 8 == 0);
517678
int packed_size = (unpacked_size / 8) * nbit;
518679

519-
auto packed = std::vector<uint8_t>(unpacked_size, 0);
520-
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
680+
auto packed = std::vector<uint8_t>(packed_size, 0);
681+
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);
521682

522683
for (auto _ : state) {
523684
pack_uint3_values(
@@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
554715
assert(unpacked_size % 8 == 0);
555716
int packed_size = (unpacked_size / 8) * nbit;
556717

557-
auto packed = std::vector<uint8_t>(unpacked_size, 0);
558-
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
718+
auto packed = std::vector<uint8_t>(packed_size, 0);
719+
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);
559720

560721
for (auto _ : state) {
561722
pack_uint4_values(
@@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
592753
assert(unpacked_size % 8 == 0);
593754
int packed_size = (unpacked_size / 8) * nbit;
594755

595-
auto packed = std::vector<uint8_t>(unpacked_size, 0);
596-
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
756+
auto packed = std::vector<uint8_t>(packed_size, 0);
757+
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);
597758

598759
for (auto _ : state) {
599760
pack_uint5_values(
@@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
622783
}
623784
}
624785

786+
BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
787+
BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
625788
BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
626789
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
627790
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});

torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
228228
false>) \
229229
->ArgsProduct(BENCHMARK_PARAMS)
230230

231+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
232+
1);
231233
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
232234
2);
233235
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
@@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
236238
4);
237239
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
238240
5);
241+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
242+
1);
239243
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
240244
2);
241245
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
@@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
244248
4);
245249
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
246250
5);
251+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
252+
1);
247253
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
248254
2);
249255
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(

0 commit comments

Comments
 (0)