Skip to content

Commit d934fe3

Browse files
pcullitoncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 686665933
1 parent ed40919 commit d934fe3

14 files changed

+590
-9
lines changed

compression/compress-inl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ struct CompressTraits<NuqStream> {
386386
size_t num, CompressPerThread& tls,
387387
const PackedSpan<Packed>& packed,
388388
const size_t packed_ofs) {
389-
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
389+
NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs);
390390

391391
if (COMPRESS_STATS) {
392392
for (size_t i = 0; i < num; ++i) {
@@ -396,8 +396,8 @@ struct CompressTraits<NuqStream> {
396396
const hn::Repartition<BF16, DF> dbf;
397397
const size_t N16 = hn::Lanes(dbf);
398398
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
399-
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
400-
distorted.get(), num);
399+
NuqCodec::DecompressAndZeroPadInterleaved(
400+
dbf, MakeConst(packed), packed_ofs, distorted.get(), num);
401401
DistortionStats stats;
402402
for (size_t i = 0; i < num; ++i) {
403403
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
@@ -410,7 +410,7 @@ struct CompressTraits<NuqStream> {
410410
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
411411
const size_t packed_ofs, hn::Vec<D>& raw0,
412412
hn::Vec<D>& raw1) {
413-
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
413+
NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1);
414414
}
415415

416416
// Store2 is not yet implemented.
@@ -419,7 +419,7 @@ struct CompressTraits<NuqStream> {
419419
static HWY_INLINE void DecompressAndZeroPad(
420420
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
421421
Raw* raw, const size_t num) {
422-
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
422+
NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num);
423423
}
424424
};
425425

compression/nuq-inl.h

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <stdint.h>
2222
#include <stdio.h>
2323

24+
#include <cstdio>
25+
2426
#include "compression/shared.h"
2527
#include "util/basics.h"
2628
#include "hwy/base.h"
@@ -529,6 +531,12 @@ class NuqCodec {
529531
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
530532
}
531533

534+
static constexpr size_t TableOffset(size_t packed_ofs) {
535+
const size_t group_size =
536+
(16) + kGroupSize / 2; // NuqStream::PackedEnd(kGroupSize);
537+
return (packed_ofs / kGroupSize) * group_size;
538+
}
539+
532540
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
533541
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
534542
// not be available for bf16.
@@ -606,6 +614,81 @@ class NuqCodec {
606614
}
607615

608616
public:
617+
// Encodes `num` floats from `raw` into `packed`. `packed` points to
618+
// compressed storage and `packed_ofs` indicates the destination offset within
619+
// it, in number of elements. Tables are interleaved with indices (clustered
620+
// elements) to allow for easier unpacking. Returns the total number of
621+
// unused clusters, which is typically zero.
622+
template <class DF, HWY_IF_F32_D(DF)>
623+
static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw,
624+
const size_t num,
625+
NuqStream::ClusterBuf& buf,
626+
const PackedSpan<NuqStream>& packed,
627+
size_t packed_ofs) {
628+
const hn::Repartition<uint16_t, DF> d16;
629+
const hn::Repartition<uint8_t, DF> d8;
630+
using V16 = hn::Vec<decltype(d16)>;
631+
using V8 = hn::Vec<decltype(d8)>;
632+
const size_t N16 = hn::Lanes(d16);
633+
634+
HWY_ASSERT(packed_ofs % kGroupSize == 0);
635+
636+
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
637+
// TODO: dynamic resize should be removed; it is no longer necessary as
638+
// interleaved encoding uses only a single buffer of the same size.
639+
buf.Resize(1);
640+
641+
size_t unused_clusters = 0;
642+
size_t current_offset = packed_ofs;
643+
for (size_t g = 0; g < num_groups; ++g) {
644+
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
645+
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
646+
647+
float* HWY_RESTRICT g_centers = buf.centers.get();
648+
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();
649+
650+
unused_clusters +=
651+
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);
652+
653+
uint8_t* centers = &packed.ptr->byte + TableOffset(current_offset);
654+
SfpCodec::Enc(df, buf.centers.get(), kClusters,
655+
reinterpret_cast<SfpStream*>(centers));
656+
uint8_t* packed_start = centers + kClusters;
657+
658+
current_offset += g_num;
659+
660+
HWY_DASSERT(g_num % (4 * N16) == 0);
661+
662+
size_t i = 0;
663+
HWY_UNROLL(1)
664+
for (; i < g_num; i += 4 * N16) {
665+
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
666+
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
667+
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
668+
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
669+
const V8 nibbles =
670+
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
671+
hn::StoreU(nibbles, d8, packed_start + i / 2);
672+
}
673+
674+
const size_t remaining = g_num - i;
675+
676+
HWY_DASSERT(remaining < 4 * N16);
677+
if (HWY_UNLIKELY(remaining != 0)) {
678+
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
679+
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
680+
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
681+
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
682+
const V8 nibbles =
683+
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
684+
// i is even, but remaining might not be.
685+
hn::StoreN(nibbles, d8, packed_start + i / 2,
686+
hwy::DivCeil(remaining, 2));
687+
}
688+
}
689+
return unused_clusters;
690+
}
691+
609692
// Encodes `num` floats from `raw`. `packed` points to compressed storage and
610693
// `packed_ofs` indicates the destination offset within it, in units of float
611694
// values, for parallel encoding by multiple threads. Returns the total
@@ -765,6 +848,103 @@ class NuqCodec {
765848
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
766849
}
767850

851+
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
852+
// vectors so that we only have to load one group's table.
853+
template <class DBF, HWY_IF_BF16_D(DBF)>
854+
static HWY_INLINE void Dec2Interleaved(
855+
DBF dbf, const PackedSpan<const NuqStream>& packed,
856+
const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) {
857+
const hn::RebindToUnsigned<decltype(dbf)> d16;
858+
const D8HFromD16<DBF> d8h;
859+
using V16 = hn::Vec<decltype(d16)>;
860+
using V8H = hn::Vec<decltype(d8h)>;
861+
862+
const size_t within_group = packed_ofs % kGroupSize;
863+
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
864+
// const size_t ofs_in_groups = packed_ofs / kGroupSize;
865+
const uint8_t* table =
866+
&packed.ptr->byte +
867+
TableOffset(packed_ofs); // ofs_in_groups * kClusters;
868+
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
869+
870+
V16 tbl1 = Zero(d16);
871+
const V16 tbl0 = LoadTable(d16, table, &tbl1);
872+
873+
const V8H nibbles = hn::LoadU(d8h, indices);
874+
875+
V16 c0, c1;
876+
TableLookups(d16, tbl0, tbl1, nibbles, c0, c1);
877+
raw0 = BitCast(dbf, c0);
878+
raw1 = BitCast(dbf, c1);
879+
}
880+
881+
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
882+
// vectors so that we only have to load one group's table.
883+
template <class DF, HWY_IF_F32_D(DF)>
884+
static HWY_INLINE void Dec2Interleaved(
885+
DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs,
886+
hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) {
887+
const hn::Repartition<BF16, decltype(df)> dbf;
888+
const hn::RebindToUnsigned<decltype(dbf)> d16;
889+
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
890+
using V8Q = hn::Vec<decltype(d8q)>;
891+
using V16 = hn::Vec<decltype(d16)>;
892+
893+
const size_t within_group = packed_ofs % kGroupSize;
894+
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
895+
const uint8_t* table = &packed.ptr->byte + TableOffset(packed_ofs);
896+
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
897+
898+
V16 tbl1 = Zero(d16);
899+
const V16 tbl0 = LoadTable(d16, table, &tbl1);
900+
901+
// The single-vector TableLookups overload only calls OrderedUnpackU16<0>,
902+
// which expects a quarter vector of bytes.
903+
const V8Q nibbles = hn::LoadU(d8q, indices);
904+
905+
const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
906+
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
907+
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
908+
}
909+
910+
template <class D, typename Raw = hn::TFromD<D>>
911+
static HWY_INLINE void DecompressAndZeroPadInterleaved(
912+
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
913+
Raw* HWY_RESTRICT raw, size_t num) {
914+
// If unaligned, load elements from the first group and update the args,
915+
// from which we compute new tables/indices below.
916+
size_t current_offset = packed_ofs;
917+
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
918+
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
919+
const uint8_t* indices = tables + kClusters;
920+
const size_t remaining = HWY_MIN(num, kGroupSize - within_group);
921+
922+
DecPartialGroup(d, tables, indices, raw, remaining);
923+
packed_ofs += remaining;
924+
current_offset += remaining;
925+
raw += remaining;
926+
num -= remaining;
927+
if (num == 0) return;
928+
}
929+
930+
HWY_DASSERT(packed_ofs % kGroupSize == 0);
931+
932+
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
933+
HWY_UNROLL(1)
934+
for (size_t g = 0; g < num_groups - 1; ++g) {
935+
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
936+
const uint8_t* indices = tables + kClusters;
937+
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
938+
current_offset += kGroupSize;
939+
}
940+
941+
const size_t g = num_groups - 1;
942+
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
943+
const uint8_t* indices = tables + kClusters;
944+
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
945+
num - g * kGroupSize);
946+
}
947+
768948
// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
769949
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
770950
// round `num` up to one vector, if it is not already.
@@ -955,6 +1135,7 @@ class NuqCodec {
9551135
}
9561136

9571137
const size_t remaining = num - i;
1138+
9581139
HWY_DASSERT(remaining < 4 * NF);
9591140
if (HWY_UNLIKELY(remaining != 0)) {
9601141
// i is even, but remaining might not be.

0 commit comments

Comments
 (0)