Skip to content

Commit 123bf7e

Browse files
pcullitoncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 686665933
1 parent 96513a8 commit 123bf7e

File tree

7 files changed

+445
-14
lines changed

7 files changed

+445
-14
lines changed

compression/compress-inl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ struct CompressTraits<NuqStream> {
386386
size_t num, CompressPerThread& tls,
387387
const PackedSpan<Packed>& packed,
388388
const size_t packed_ofs) {
389-
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
389+
NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs);
390390

391391
if (COMPRESS_STATS) {
392392
for (size_t i = 0; i < num; ++i) {
@@ -396,8 +396,8 @@ struct CompressTraits<NuqStream> {
396396
const hn::Repartition<BF16, DF> dbf;
397397
const size_t N16 = hn::Lanes(dbf);
398398
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
399-
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
400-
distorted.get(), num);
399+
NuqCodec::DecompressAndZeroPadInterleaved(
400+
dbf, MakeConst(packed), packed_ofs, distorted.get(), num);
401401
DistortionStats stats;
402402
for (size_t i = 0; i < num; ++i) {
403403
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
@@ -410,7 +410,7 @@ struct CompressTraits<NuqStream> {
410410
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
411411
const size_t packed_ofs, hn::Vec<D>& raw0,
412412
hn::Vec<D>& raw1) {
413-
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
413+
NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1);
414414
}
415415

416416
// Store2 is not yet implemented.
@@ -419,7 +419,7 @@ struct CompressTraits<NuqStream> {
419419
static HWY_INLINE void DecompressAndZeroPad(
420420
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
421421
Raw* raw, const size_t num) {
422-
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
422+
NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num);
423423
}
424424
};
425425

compression/compress.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
1818
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
1919

20+
#include "hwy/base.h"
2021
#define COMPRESS_STATS 0
2122

2223
#include <stddef.h>
@@ -134,7 +135,7 @@ class MatPtr {
134135
size_t NumElements() const { return num_elements_; }
135136

136137
// Returns the number of bytes in the array.
137-
size_t SizeBytes() const { return num_elements_ * element_size_; }
138+
virtual size_t SizeBytes() const { return num_elements_ * element_size_; }
138139

139140
// Returns the number of rows in the 2-d array (outer dimension).
140141
size_t Rows() const { return rows_; }
@@ -240,10 +241,13 @@ class MatPtrT : public MatPtr {
240241
return name;
241242
}
242243

243-
// Sets the number of elements in the array. For use when the number of
244-
// elements is != rows * cols ONLY.
245-
void SetNumElements(size_t num_elements) {
246-
num_elements_ = CompressedArrayElements<MatT>(num_elements);
244+
// Returns the number of bytes in the array. Overrides MatPtr::SizeBytes()
245+
// to account for NUQ's differing packed size.
246+
size_t SizeBytes() const override {
247+
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
248+
return NuqStream::PackedEnd(num_elements_);
249+
}
250+
return num_elements_ * element_size_;
247251
}
248252

249253
// 2-d Accessor for a specific type but with a dynamic inner dimension.
@@ -333,6 +337,12 @@ class MatStorageT : public MatPtrT<MatT> {
333337
// from the current num_elements_ which was set by the constructor from the
334338
// rows and cols.
335339
void Allocate(size_t num_elements = 0) {
340+
// size_t num_elements = 0;
341+
// TODO: optimize this check or obviate it.
342+
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
343+
HWY_DASSERT(num_elements == 0);
344+
}
345+
336346
if (num_elements == 0) {
337347
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
338348
} else {

compression/nuq-inl.h

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
// After highway.h
4141
#include "compression/sfp-inl.h"
4242
#include "hwy/contrib/sort/vqsort-inl.h"
43+
#include "hwy/profiler.h" // uses SIMD
4344

4445
HWY_BEFORE_NAMESPACE();
4546
namespace gcpp {
@@ -529,12 +530,21 @@ class NuqCodec {
529530
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
530531
}
531532

533+
// Offset (in bytes) of a group's table for packed_ofs (in elements) within a
534+
// set of groups.
535+
static constexpr size_t TableByteOffset(size_t packed_ofs) {
536+
const size_t kBytesPerGroup =
537+
(kClusters * sizeof(SfpStream)) + kGroupSize / 2;
538+
return (packed_ofs / kGroupSize) * kBytesPerGroup;
539+
}
540+
532541
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
533542
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
534543
// not be available for bf16.
535544
template <class DU, HWY_IF_U16_D(DU)>
536545
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
537546
hn::Vec<DU>* HWY_RESTRICT tbl1) {
547+
PROFILER_FUNC;
538548
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
539549
// be faster than a large vector.
540550
const hn::CappedTag<BF16, kClusters> d_table;
@@ -606,6 +616,81 @@ class NuqCodec {
606616
}
607617

608618
public:
619+
// Encodes `num` floats from `raw` into `packed`. `packed` points to
620+
// compressed storage and `packed_ofs` indicates the destination offset within
621+
// it, in number of elements. Tables are interleaved with indices (clustered
622+
// elements) to allow for easier unpacking. Returns the total number of
623+
// unused clusters, which is typically zero.
624+
template <class DF, HWY_IF_F32_D(DF)>
625+
static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw,
626+
const size_t num,
627+
NuqStream::ClusterBuf& buf,
628+
const PackedSpan<NuqStream>& packed,
629+
size_t packed_ofs) {
630+
const hn::Repartition<uint16_t, DF> d16;
631+
const hn::Repartition<uint8_t, DF> d8;
632+
using V16 = hn::Vec<decltype(d16)>;
633+
using V8 = hn::Vec<decltype(d8)>;
634+
const size_t N16 = hn::Lanes(d16);
635+
636+
HWY_ASSERT(packed_ofs % kGroupSize == 0);
637+
638+
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
639+
// TODO: dynamic resize should be removed; it is no longer necessary as
640+
// interleaved encoding uses only a single buffer of the same size.
641+
buf.Resize(1);
642+
643+
size_t unused_clusters = 0;
644+
size_t current_offset = packed_ofs;
645+
for (size_t g = 0; g < num_groups; ++g) {
646+
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
647+
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
648+
649+
float* HWY_RESTRICT g_centers = buf.centers.get();
650+
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();
651+
652+
unused_clusters +=
653+
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);
654+
655+
uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset);
656+
SfpCodec::Enc(df, buf.centers.get(), kClusters,
657+
reinterpret_cast<SfpStream*>(centers));
658+
uint8_t* packed_start = centers + kClusters;
659+
660+
current_offset += g_num;
661+
662+
HWY_DASSERT(g_num % (4 * N16) == 0);
663+
664+
size_t i = 0;
665+
HWY_UNROLL(1)
666+
for (; i < g_num; i += 4 * N16) {
667+
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
668+
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
669+
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
670+
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
671+
const V8 nibbles =
672+
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
673+
hn::StoreU(nibbles, d8, packed_start + i / 2);
674+
}
675+
676+
const size_t remaining = g_num - i;
677+
678+
HWY_DASSERT(remaining < 4 * N16);
679+
if (HWY_UNLIKELY(remaining != 0)) {
680+
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
681+
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
682+
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
683+
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
684+
const V8 nibbles =
685+
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
686+
// i is even, but remaining might not be.
687+
hn::StoreN(nibbles, d8, packed_start + i / 2,
688+
hwy::DivCeil(remaining, 2));
689+
}
690+
}
691+
return unused_clusters;
692+
}
693+
609694
// Encodes `num` floats from `raw`. `packed` points to compressed storage and
610695
// `packed_ofs` indicates the destination offset within it, in units of float
611696
// values, for parallel encoding by multiple threads. Returns the total
@@ -733,6 +818,8 @@ class NuqCodec {
733818
raw1 = BitCast(dbf, c1);
734819
}
735820

821+
// TODO(philculliton): Remove non-interleaved function versions now that
822+
// interleaved is working / the default.
736823
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
737824
// vectors so that we only have to load one group's table.
738825
template <class DF, HWY_IF_F32_D(DF)>
@@ -765,6 +852,107 @@ class NuqCodec {
765852
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
766853
}
767854

855+
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
856+
// vectors so that we only have to load one group's table.
857+
template <class DBF, HWY_IF_BF16_D(DBF)>
858+
static HWY_INLINE void Dec2Interleaved(
859+
DBF dbf, const PackedSpan<const NuqStream>& packed,
860+
const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) {
861+
PROFILER_FUNC;
862+
const hn::RebindToUnsigned<decltype(dbf)> d16;
863+
const D8HFromD16<DBF> d8h;
864+
using V16 = hn::Vec<decltype(d16)>;
865+
using V8H = hn::Vec<decltype(d8h)>;
866+
867+
const size_t within_group = packed_ofs % kGroupSize;
868+
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
869+
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
870+
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
871+
872+
V16 tbl1 = Zero(d16);
873+
const V16 tbl0 = LoadTable(d16, table, &tbl1);
874+
875+
const V8H nibbles = hn::LoadU(d8h, indices);
876+
877+
V16 c0, c1;
878+
TableLookups(d16, tbl0, tbl1, nibbles, c0, c1);
879+
raw0 = BitCast(dbf, c0);
880+
raw1 = BitCast(dbf, c1);
881+
}
882+
883+
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
884+
// vectors so that we only have to load one group's table.
885+
template <class DF, HWY_IF_F32_D(DF)>
886+
static HWY_INLINE void Dec2Interleaved(
887+
DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs,
888+
hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) {
889+
const hn::Repartition<BF16, decltype(df)> dbf;
890+
const hn::RebindToUnsigned<decltype(dbf)> d16;
891+
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
892+
using V8Q = hn::Vec<decltype(d8q)>;
893+
using V16 = hn::Vec<decltype(d16)>;
894+
895+
const size_t within_group = packed_ofs % kGroupSize;
896+
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
897+
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
898+
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
899+
900+
V16 tbl1 = Zero(d16);
901+
const V16 tbl0 = LoadTable(d16, table, &tbl1);
902+
903+
// The single-vector TableLookups overload only calls OrderedUnpackU16<0>,
904+
// which expects a quarter vector of bytes.
905+
const V8Q nibbles = hn::LoadU(d8q, indices);
906+
907+
// TODO(janwas): From janwas: on AVX-512 I imagine we can get a
908+
// bit more speed for this function by changing LoadTable to return floats,
909+
// then we could have a single lookup here instead of PromoteUpperTo which
910+
// is not cheap.
911+
const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
912+
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
913+
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
914+
}
915+
916+
template <class D, typename Raw = hn::TFromD<D>>
917+
static HWY_INLINE void DecompressAndZeroPadInterleaved(
918+
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
919+
Raw* HWY_RESTRICT raw, size_t num) {
920+
// If unaligned, load elements from the first group and update the args,
921+
// from which we compute new tables/indices below.
922+
size_t current_offset = packed_ofs;
923+
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
924+
const uint8_t* tables =
925+
&packed.ptr->byte + TableByteOffset(current_offset);
926+
const uint8_t* indices = tables + kClusters;
927+
const size_t remaining = HWY_MIN(num, kGroupSize - within_group);
928+
929+
DecPartialGroup(d, tables, indices, raw, remaining);
930+
packed_ofs += remaining;
931+
current_offset += remaining;
932+
raw += remaining;
933+
num -= remaining;
934+
if (num == 0) return;
935+
}
936+
937+
HWY_DASSERT(packed_ofs % kGroupSize == 0);
938+
939+
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
940+
HWY_UNROLL(1)
941+
for (size_t g = 0; g < num_groups - 1; ++g) {
942+
const uint8_t* tables =
943+
&packed.ptr->byte + TableByteOffset(current_offset);
944+
const uint8_t* indices = tables + kClusters;
945+
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
946+
current_offset += kGroupSize;
947+
}
948+
949+
const size_t g = num_groups - 1;
950+
const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset);
951+
const uint8_t* indices = tables + kClusters;
952+
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
953+
num - g * kGroupSize);
954+
}
955+
768956
// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
769957
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
770958
// round `num` up to one vector, if it is not already.
@@ -955,6 +1143,7 @@ class NuqCodec {
9551143
}
9561144

9571145
const size_t remaining = num - i;
1146+
9581147
HWY_DASSERT(remaining < 4 * NF);
9591148
if (HWY_UNLIKELY(remaining != 0)) {
9601149
// i is even, but remaining might not be.

0 commit comments

Comments
 (0)