8
8
#include < benchmark/benchmark.h>
9
9
10
10
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
11
+ #include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
11
12
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
12
13
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
13
14
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
16
17
17
18
namespace {
18
19
20
+ // Benchmark utility to compare variants of uint1 packing
21
+ void pack_uint1_values (
22
+ uint8_t * packed,
23
+ uint8_t * unpacked,
24
+ int packed_size,
25
+ int unpacked_size,
26
+ int variant) {
27
+ constexpr int nbit = 1 ;
28
+ constexpr int bitsPerByte = 8 ;
29
+ assert (unpacked_size * nbit / bitsPerByte == packed_size);
30
+ assert (packed_size % variant == 0 );
31
+
32
+ uint8x16_t unpacked0;
33
+ uint8x16_t unpacked1;
34
+ uint8x16_t unpacked2;
35
+ uint8x16_t unpacked3;
36
+ uint8x16_t unpacked4;
37
+ uint8x16_t unpacked5;
38
+ uint8x16_t unpacked6;
39
+ uint8x16_t unpacked7;
40
+
41
+ switch (variant) {
42
+ case 8 :
43
+ for (int i = 0 ; i < unpacked_size; i += 8 ) {
44
+ torchao::bitpacking::internal::pack_8_uint1_values (
45
+ packed + ((i * nbit) / bitsPerByte), unpacked + i);
46
+ }
47
+ break ;
48
+ case 64 :
49
+ for (int i = 0 ; i < unpacked_size; i += 64 ) {
50
+ torchao::bitpacking::internal::vec_load_64_uint8_values (
51
+ unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
52
+ torchao::bitpacking::internal::vec_pack_64_uint1_values (
53
+ packed + ((i * nbit) / bitsPerByte),
54
+ unpacked0,
55
+ unpacked1,
56
+ unpacked2,
57
+ unpacked3);
58
+ }
59
+ break ;
60
+ case 128 :
61
+ for (int i = 0 ; i < unpacked_size; i += 128 ) {
62
+ torchao::bitpacking::internal::vec_load_64_uint8_values (
63
+ unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
64
+ torchao::bitpacking::internal::vec_load_64_uint8_values (
65
+ unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64 );
66
+ torchao::bitpacking::internal::vec_pack_128_uint1_values (
67
+ packed + ((i * nbit) / bitsPerByte),
68
+ unpacked0,
69
+ unpacked1,
70
+ unpacked2,
71
+ unpacked3,
72
+ unpacked4,
73
+ unpacked5,
74
+ unpacked6,
75
+ unpacked7);
76
+ }
77
+ break ;
78
+ }
79
+ }
80
+
81
+ // Benchmark utility to compare variants of uint1 packing
82
+ void unpack_uint1_values (
83
+ uint8_t * unpacked,
84
+ uint8_t * packed,
85
+ int unpacked_size,
86
+ int packed_size,
87
+ int variant) {
88
+ constexpr int nbit = 1 ;
89
+ constexpr int bitsPerByte = 8 ;
90
+ assert (unpacked_size * nbit / bitsPerByte == packed_size);
91
+ assert (packed_size % variant == 0 );
92
+
93
+ uint8x16_t unpacked0;
94
+ uint8x16_t unpacked1;
95
+ uint8x16_t unpacked2;
96
+ uint8x16_t unpacked3;
97
+ uint8x16_t unpacked4;
98
+ uint8x16_t unpacked5;
99
+ uint8x16_t unpacked6;
100
+ uint8x16_t unpacked7;
101
+
102
+ switch (variant) {
103
+ case 8 :
104
+ for (int i = 0 ; i < unpacked_size; i += 8 ) {
105
+ torchao::bitpacking::internal::unpack_8_uint1_values (
106
+ unpacked + i, packed + ((i * nbit) / bitsPerByte));
107
+ }
108
+ break ;
109
+ case 64 :
110
+ for (int i = 0 ; i < unpacked_size; i += 64 ) {
111
+ torchao::bitpacking::internal::vec_unpack_64_uint1_values (
112
+ unpacked0,
113
+ unpacked1,
114
+ unpacked2,
115
+ unpacked3,
116
+ packed + ((i * nbit) / bitsPerByte));
117
+ torchao::bitpacking::internal::vec_store_64_uint8_values (
118
+ unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
119
+ }
120
+ break ;
121
+ case 128 :
122
+ for (int i = 0 ; i < unpacked_size; i += 128 ) {
123
+ torchao::bitpacking::internal::vec_unpack_128_uint1_values (
124
+ unpacked0,
125
+ unpacked1,
126
+ unpacked2,
127
+ unpacked3,
128
+ unpacked4,
129
+ unpacked5,
130
+ unpacked6,
131
+ unpacked7,
132
+ packed + ((i * nbit) / bitsPerByte));
133
+ torchao::bitpacking::internal::vec_store_64_uint8_values (
134
+ unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
135
+ torchao::bitpacking::internal::vec_store_64_uint8_values (
136
+ unpacked + i + 64 , unpacked4, unpacked5, unpacked6, unpacked7);
137
+ }
138
+ break ;
139
+ }
140
+ }
141
+
19
142
// Benchmark utility to compare variants of uint2 packing
20
143
void pack_uint2_values (
21
144
uint8_t * packed,
@@ -470,6 +593,44 @@ void unpack_uint5_values(
470
593
471
594
} // namespace
472
595
596
+ static void benchmark_pack_uint1_values (benchmark::State& state) {
597
+ int unpacked_size = state.range (0 );
598
+ int variant = state.range (1 );
599
+ int nbit = 1 ;
600
+
601
+ assert (unpacked_size % 8 == 0 );
602
+ int packed_size = (unpacked_size / 8 ) * nbit;
603
+
604
+ auto packed = std::vector<uint8_t >(packed_size, 0 );
605
+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit);
606
+
607
+ for (auto _ : state) {
608
+ pack_uint1_values (
609
+ packed.data (), unpacked.data (), packed_size, unpacked_size, variant);
610
+ }
611
+ }
612
+
613
+ static void benchmark_unpack_uint1_values (benchmark::State& state) {
614
+ int unpacked_size = state.range (0 );
615
+ int variant = state.range (1 );
616
+ int nbit = 1 ;
617
+
618
+ assert (unpacked_size % 8 == 0 );
619
+ int packed_size = (unpacked_size / 8 ) * nbit;
620
+
621
+ auto packed = torchao::get_random_lowbit_vector (packed_size, 8 );
622
+ auto unpacked = std::vector<uint8_t >(unpacked_size, 0 );
623
+
624
+ for (auto _ : state) {
625
+ unpack_uint1_values (
626
+ unpacked.data (),
627
+ packed.data (),
628
+ unpacked.size (),
629
+ packed.size (),
630
+ variant);
631
+ }
632
+ }
633
+
473
634
static void benchmark_pack_uint2_values (benchmark::State& state) {
474
635
int unpacked_size = state.range (0 );
475
636
int variant = state.range (1 );
@@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
478
639
assert (unpacked_size % 8 == 0 );
479
640
int packed_size = (unpacked_size / 8 ) * nbit;
480
641
481
- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
482
- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
642
+ auto packed = std::vector<uint8_t >(packed_size , 0 );
643
+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
483
644
484
645
for (auto _ : state) {
485
646
pack_uint2_values (
@@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
516
677
assert (unpacked_size % 8 == 0 );
517
678
int packed_size = (unpacked_size / 8 ) * nbit;
518
679
519
- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
520
- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
680
+ auto packed = std::vector<uint8_t >(packed_size , 0 );
681
+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
521
682
522
683
for (auto _ : state) {
523
684
pack_uint3_values (
@@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
554
715
assert (unpacked_size % 8 == 0 );
555
716
int packed_size = (unpacked_size / 8 ) * nbit;
556
717
557
- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
558
- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
718
+ auto packed = std::vector<uint8_t >(packed_size , 0 );
719
+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
559
720
560
721
for (auto _ : state) {
561
722
pack_uint4_values (
@@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
592
753
assert (unpacked_size % 8 == 0 );
593
754
int packed_size = (unpacked_size / 8 ) * nbit;
594
755
595
- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
596
- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
756
+ auto packed = std::vector<uint8_t >(packed_size , 0 );
757
+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
597
758
598
759
for (auto _ : state) {
599
760
pack_uint5_values (
@@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
622
783
}
623
784
}
624
785
786
+ BENCHMARK (benchmark_pack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
787
+ BENCHMARK (benchmark_unpack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
625
788
BENCHMARK (benchmark_pack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
626
789
BENCHMARK (benchmark_unpack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
627
790
BENCHMARK (benchmark_pack_uint3_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
0 commit comments