Skip to content

Commit 503aadd

Browse files
pcullitoncopybara-github
authored andcommitted
Add 8-bit integer quantization (I8Stream) to Gemma.cpp.
PiperOrigin-RevId: 819787856
1 parent ee18916 commit 503aadd

25 files changed

+1428
-64
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ cc_library(
349349
"ops/matmul_static_f32.cc",
350350
"ops/matmul_static_nuq.cc",
351351
"ops/matmul_static_sfp.cc",
352+
"ops/matmul_static_i8.cc",
352353
],
353354
hdrs = [
354355
"ops/matmul_static.h",

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ set(SOURCES
6868
compression/compress.h
6969
compression/nuq-inl.h
7070
compression/sfp-inl.h
71+
compression/int-inl.h
7172
compression/types.h
7273
compression/test_util-inl.h
7374
evals/benchmark_helper.cc
@@ -109,6 +110,7 @@ set(SOURCES
109110
ops/matmul_static_f32.cc
110111
ops/matmul_static_nuq.cc
111112
ops/matmul_static_sfp.cc
113+
ops/matmul_static_i8.cc
112114
ops/matmul-inl.h
113115
ops/matmul.cc
114116
ops/matmul.h

compression/BUILD.bazel

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,37 @@ cc_library(
8080
],
8181
)
8282

83+
cc_library(
84+
name = "int",
85+
textual_hdrs = ["int-inl.h"],
86+
deps = [
87+
":types",
88+
"//:basics",
89+
"@highway//:hwy",
90+
],
91+
)
92+
93+
cc_test(
94+
name = "int_test",
95+
size = "small",
96+
timeout = "long",
97+
srcs = ["int_test.cc"],
98+
features = ["fully_static_link"],
99+
linkstatic = True,
100+
local_defines = ["HWY_IS_TEST"],
101+
# for test_suite.
102+
tags = ["hwy_ops_test"],
103+
deps = [
104+
":distortion",
105+
":int",
106+
"@googletest//:gtest_main", # buildcleaner: keep
107+
"//:test_util",
108+
"@highway//:hwy",
109+
"@highway//:hwy_test_util",
110+
"@highway//:nanobenchmark",
111+
],
112+
)
113+
83114
cc_library(
84115
name = "test_util",
85116
textual_hdrs = [
@@ -144,6 +175,7 @@ cc_library(
144175
textual_hdrs = ["compress-inl.h"],
145176
deps = [
146177
":distortion",
178+
":int",
147179
":nuq",
148180
":sfp",
149181
"//:basics",
@@ -182,6 +214,7 @@ cc_library(
182214
name = "analyze",
183215
textual_hdrs = ["analyze.h"],
184216
deps = [
217+
":int",
185218
":nuq",
186219
":sfp",
187220
":types",

compression/compress-inl.h

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
#include "hwy/highway.h"
4949
// After highway.h
50+
#include "compression/int-inl.h"
5051
#include "compression/nuq-inl.h"
5152
#include "compression/sfp-inl.h"
5253

@@ -416,6 +417,34 @@ struct CompressTraits<SfpStream> {
416417
}
417418
};
418419

420+
// Integer quantization.
421+
template <>
422+
struct CompressTraits<I8Stream> {
423+
using Packed = I8Stream;
424+
425+
template <class DF, HWY_IF_F32_D(DF)>
426+
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
427+
size_t num, CompressPerThread& tls,
428+
const PackedSpan<Packed>& packed,
429+
const size_t packed_ofs) {
430+
IntCodec::Enc(df, raw, num, packed, packed_ofs);
431+
}
432+
433+
template <class D> // Caller checks this is f32 or bf16
434+
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
435+
const size_t packed_ofs, hn::Vec<D>& raw0,
436+
hn::Vec<D>& raw1) {
437+
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
438+
}
439+
440+
template <class D, typename Raw>
441+
static HWY_INLINE void DecompressAndZeroPad(
442+
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
443+
Raw* raw, const size_t num) {
444+
IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
445+
}
446+
};
447+
419448
// Nonuniform quantization, 4.5 bits per element, two separate streams.
420449
template <>
421450
struct CompressTraits<NuqStream> {
@@ -737,9 +766,10 @@ template <class DF, typename T, typename T1, class Func>
737766
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
738767
size_t num,
739768
const T1* HWY_RESTRICT p1,
769+
const size_t p1_ofs,
740770
Func&& func) {
741771
const auto packed_inout = MakeSpan(inout, num);
742-
const auto packed1 = MakeSpan(p1, num);
772+
const auto packed1 = MakeSpan(p1, p1_ofs + num);
743773

744774
using VF = hn::Vec<decltype(df)>;
745775
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
@@ -749,7 +779,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
749779
VF v0, v1;
750780
Decompress2(df, packed_inout, i, v0, v1);
751781
VF v10, v11;
752-
Decompress2(df, packed1, i, v10, v11);
782+
Decompress2(df, packed1, p1_ofs + i, v10, v11);
753783
const VF out0 = func(df, v0, v10);
754784
const VF out1 = func(df, v1, v11);
755785
Compress2(df, out0, out1, packed_inout, i);
@@ -765,7 +795,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
765795
hn::Store(hn::Zero(df), df, buf_inout + NF);
766796
hn::Store(hn::Zero(df), df, buf1 + NF);
767797
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
768-
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
798+
DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining);
769799
const VF v0 = hn::Load(df, buf_inout);
770800
const VF v1 = hn::Load(df, buf_inout + NF);
771801
const VF v10 = hn::Load(df, buf1);
@@ -827,10 +857,10 @@ template <class DF, typename T, typename T1, typename T2, class Func>
827857
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
828858
const T1* HWY_RESTRICT p1,
829859
const T2* HWY_RESTRICT p2,
830-
Func&& func) {
860+
const size_t p2_ofs, Func&& func) {
831861
const auto packed_out = MakeSpan(out, num);
832862
const auto packed1 = MakeSpan(p1, num);
833-
const auto packed2 = MakeSpan(p2, num);
863+
const auto packed2 = MakeSpan(p2, p2_ofs + num);
834864

835865
using VF = hn::Vec<decltype(df)>;
836866
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
@@ -839,7 +869,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
839869
for (; i <= num - 2 * NF; i += 2 * NF) {
840870
VF v10, v11, v20, v21;
841871
Decompress2(df, packed1, i, v10, v11);
842-
Decompress2(df, packed2, i, v20, v21);
872+
Decompress2(df, packed2, p2_ofs + i, v20, v21);
843873
const VF out0 = func(df, v10, v20);
844874
const VF out1 = func(df, v11, v21);
845875
Compress2(df, out0, out1, packed_out, i);
@@ -856,7 +886,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
856886
hn::Store(hn::Zero(df), df, buf1 + NF);
857887
hn::Store(hn::Zero(df), df, buf2 + NF);
858888
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
859-
DecompressAndZeroPad(df, packed2, i, buf2, remaining);
889+
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
860890
const VF v10 = hn::Load(df, buf1);
861891
const VF v11 = hn::Load(df, buf1 + NF);
862892
const VF v20 = hn::Load(df, buf2);

compression/compress_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,17 +243,17 @@ class TestDecompressAndCompress {
243243

244244
// Uses `out` so as not to overwrite `p`.
245245
Decompress1AndCompressInplace(
246-
df, out.get(), num, p1.get(),
246+
df, out.get(), num, p1.get(), /*p1_ofs=*/0,
247247
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
248248
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
249249

250250
Decompress1AndCompressTo(df, out.get(), num, p.get(),
251251
[](DF, VF v) HWY_ATTR -> VF { return v; });
252252
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
253253

254-
Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(),
255-
[](DF, VF v, VF v1)
256-
HWY_ATTR -> VF { return hn::Add(v, v1); });
254+
Decompress2AndCompressTo(
255+
df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0,
256+
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
257257
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
258258

259259
Decompress3AndCompressTo(

0 commit comments

Comments
 (0)