|
21 | 21 | #include <stdint.h>
|
22 | 22 | #include <stdio.h>
|
23 | 23 |
|
| 24 | +#include <cstdio> |
| 25 | + |
24 | 26 | #include "compression/shared.h"
|
25 | 27 | #include "util/basics.h"
|
26 | 28 | #include "hwy/base.h"
|
@@ -529,6 +531,12 @@ class NuqCodec {
|
529 | 531 | return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
|
530 | 532 | }
|
531 | 533 |
|
| 534 | + static constexpr size_t TableOffset(size_t packed_ofs) { |
| 535 | + const size_t group_size = |
| 536 | + (16) + kGroupSize / 2; // NuqStream::PackedEnd(kGroupSize); |
| 537 | + return (packed_ofs / kGroupSize) * group_size; |
| 538 | + } |
| 539 | + |
532 | 540 | // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
|
533 | 541 | // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
|
534 | 542 | // not be available for bf16.
|
@@ -606,6 +614,81 @@ class NuqCodec {
|
606 | 614 | }
|
607 | 615 |
|
608 | 616 | public:
|
| 617 | + // Encodes `num` floats from `raw` into `packed`. `packed` points to |
| 618 | + // compressed storage and `packed_ofs` indicates the destination offset within |
| 619 | + // it, in number of elements. Tables are interleaved with indices (clustered |
| 620 | + // elements) to allow for easier unpacking. Returns the total number of |
| 621 | + // unused clusters, which is typically zero. |
| 622 | + template <class DF, HWY_IF_F32_D(DF)> |
| 623 | + static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw, |
| 624 | + const size_t num, |
| 625 | + NuqStream::ClusterBuf& buf, |
| 626 | + const PackedSpan<NuqStream>& packed, |
| 627 | + size_t packed_ofs) { |
| 628 | + const hn::Repartition<uint16_t, DF> d16; |
| 629 | + const hn::Repartition<uint8_t, DF> d8; |
| 630 | + using V16 = hn::Vec<decltype(d16)>; |
| 631 | + using V8 = hn::Vec<decltype(d8)>; |
| 632 | + const size_t N16 = hn::Lanes(d16); |
| 633 | + |
| 634 | + HWY_ASSERT(packed_ofs % kGroupSize == 0); |
| 635 | + |
| 636 | + const size_t num_groups = hwy::DivCeil(num, kGroupSize); |
| 637 | + // TODO: dynamic resize should be removed; it is no longer necessary as |
| 638 | + // interleaved encoding uses only a single buffer of the same size. |
| 639 | + buf.Resize(1); |
| 640 | + |
| 641 | + size_t unused_clusters = 0; |
| 642 | + size_t current_offset = packed_ofs; |
| 643 | + for (size_t g = 0; g < num_groups; ++g) { |
| 644 | + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); |
| 645 | + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; |
| 646 | + |
| 647 | + float* HWY_RESTRICT g_centers = buf.centers.get(); |
| 648 | + uint16_t* HWY_RESTRICT g_idx = buf.idx.get(); |
| 649 | + |
| 650 | + unused_clusters += |
| 651 | + NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); |
| 652 | + |
| 653 | + uint8_t* centers = &packed.ptr->byte + TableOffset(current_offset); |
| 654 | + SfpCodec::Enc(df, buf.centers.get(), kClusters, |
| 655 | + reinterpret_cast<SfpStream*>(centers)); |
| 656 | + uint8_t* packed_start = centers + kClusters; |
| 657 | + |
| 658 | + current_offset += g_num; |
| 659 | + |
| 660 | + HWY_DASSERT(g_num % (4 * N16) == 0); |
| 661 | + |
| 662 | + size_t i = 0; |
| 663 | + HWY_UNROLL(1) |
| 664 | + for (; i < g_num; i += 4 * N16) { |
| 665 | + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); |
| 666 | + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); |
| 667 | + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); |
| 668 | + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); |
| 669 | + const V8 nibbles = |
| 670 | + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); |
| 671 | + hn::StoreU(nibbles, d8, packed_start + i / 2); |
| 672 | + } |
| 673 | + |
| 674 | + const size_t remaining = g_num - i; |
| 675 | + |
| 676 | + HWY_DASSERT(remaining < 4 * N16); |
| 677 | + if (HWY_UNLIKELY(remaining != 0)) { |
| 678 | + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); |
| 679 | + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); |
| 680 | + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); |
| 681 | + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); |
| 682 | + const V8 nibbles = |
| 683 | + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); |
| 684 | + // i is even, but remaining might not be. |
| 685 | + hn::StoreN(nibbles, d8, packed_start + i / 2, |
| 686 | + hwy::DivCeil(remaining, 2)); |
| 687 | + } |
| 688 | + } |
| 689 | + return unused_clusters; |
| 690 | + } |
| 691 | + |
609 | 692 | // Encodes `num` floats from `raw`. `packed` points to compressed storage and
|
610 | 693 | // `packed_ofs` indicates the destination offset within it, in units of float
|
611 | 694 | // values, for parallel encoding by multiple threads. Returns the total
|
@@ -765,6 +848,103 @@ class NuqCodec {
|
765 | 848 | raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
766 | 849 | }
|
767 | 850 |
|
| 851 | + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two |
| 852 | + // vectors so that we only have to load one group's table. |
| 853 | + template <class DBF, HWY_IF_BF16_D(DBF)> |
| 854 | + static HWY_INLINE void Dec2Interleaved( |
| 855 | + DBF dbf, const PackedSpan<const NuqStream>& packed, |
| 856 | + const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) { |
| 857 | + const hn::RebindToUnsigned<decltype(dbf)> d16; |
| 858 | + const D8HFromD16<DBF> d8h; |
| 859 | + using V16 = hn::Vec<decltype(d16)>; |
| 860 | + using V8H = hn::Vec<decltype(d8h)>; |
| 861 | + |
| 862 | + const size_t within_group = packed_ofs % kGroupSize; |
| 863 | + HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); |
| 864 | + // const size_t ofs_in_groups = packed_ofs / kGroupSize; |
| 865 | + const uint8_t* table = |
| 866 | + &packed.ptr->byte + |
| 867 | + TableOffset(packed_ofs); // ofs_in_groups * kClusters; |
| 868 | + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); |
| 869 | + |
| 870 | + V16 tbl1 = Zero(d16); |
| 871 | + const V16 tbl0 = LoadTable(d16, table, &tbl1); |
| 872 | + |
| 873 | + const V8H nibbles = hn::LoadU(d8h, indices); |
| 874 | + |
| 875 | + V16 c0, c1; |
| 876 | + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); |
| 877 | + raw0 = BitCast(dbf, c0); |
| 878 | + raw1 = BitCast(dbf, c1); |
| 879 | + } |
| 880 | + |
| 881 | + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two |
| 882 | + // vectors so that we only have to load one group's table. |
| 883 | + template <class DF, HWY_IF_F32_D(DF)> |
| 884 | + static HWY_INLINE void Dec2Interleaved( |
| 885 | + DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs, |
| 886 | + hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) { |
| 887 | + const hn::Repartition<BF16, decltype(df)> dbf; |
| 888 | + const hn::RebindToUnsigned<decltype(dbf)> d16; |
| 889 | + const hn::Half<D8HFromD16<decltype(d16)>> d8q; |
| 890 | + using V8Q = hn::Vec<decltype(d8q)>; |
| 891 | + using V16 = hn::Vec<decltype(d16)>; |
| 892 | + |
| 893 | + const size_t within_group = packed_ofs % kGroupSize; |
| 894 | + HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); |
| 895 | + const uint8_t* table = &packed.ptr->byte + TableOffset(packed_ofs); |
| 896 | + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); |
| 897 | + |
| 898 | + V16 tbl1 = Zero(d16); |
| 899 | + const V16 tbl0 = LoadTable(d16, table, &tbl1); |
| 900 | + |
| 901 | + // The single-vector TableLookups overload only calls OrderedUnpackU16<0>, |
| 902 | + // which expects a quarter vector of bytes. |
| 903 | + const V8Q nibbles = hn::LoadU(d8q, indices); |
| 904 | + |
| 905 | + const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); |
| 906 | + raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); |
| 907 | + raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); |
| 908 | + } |
| 909 | + |
| 910 | + template <class D, typename Raw = hn::TFromD<D>> |
| 911 | + static HWY_INLINE void DecompressAndZeroPadInterleaved( |
| 912 | + D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs, |
| 913 | + Raw* HWY_RESTRICT raw, size_t num) { |
| 914 | + // If unaligned, load elements from the first group and update the args, |
| 915 | + // from which we compute new tables/indices below. |
| 916 | + size_t current_offset = packed_ofs; |
| 917 | + if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { |
| 918 | + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); |
| 919 | + const uint8_t* indices = tables + kClusters; |
| 920 | + const size_t remaining = HWY_MIN(num, kGroupSize - within_group); |
| 921 | + |
| 922 | + DecPartialGroup(d, tables, indices, raw, remaining); |
| 923 | + packed_ofs += remaining; |
| 924 | + current_offset += remaining; |
| 925 | + raw += remaining; |
| 926 | + num -= remaining; |
| 927 | + if (num == 0) return; |
| 928 | + } |
| 929 | + |
| 930 | + HWY_DASSERT(packed_ofs % kGroupSize == 0); |
| 931 | + |
| 932 | + const size_t num_groups = hwy::DivCeil(num, kGroupSize); |
| 933 | + HWY_UNROLL(1) |
| 934 | + for (size_t g = 0; g < num_groups - 1; ++g) { |
| 935 | + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); |
| 936 | + const uint8_t* indices = tables + kClusters; |
| 937 | + DecWholeGroup(d, tables, indices, raw + g * kGroupSize); |
| 938 | + current_offset += kGroupSize; |
| 939 | + } |
| 940 | + |
| 941 | + const size_t g = num_groups - 1; |
| 942 | + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); |
| 943 | + const uint8_t* indices = tables + kClusters; |
| 944 | + DecPartialGroup(d, tables, indices, raw + g * kGroupSize, |
| 945 | + num - g * kGroupSize); |
| 946 | + } |
| 947 | + |
768 | 948 | // Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
|
769 | 949 | // elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
|
770 | 950 | // round `num` up to one vector, if it is not already.
|
@@ -955,6 +1135,7 @@ class NuqCodec {
|
955 | 1135 | }
|
956 | 1136 |
|
957 | 1137 | const size_t remaining = num - i;
|
| 1138 | + |
958 | 1139 | HWY_DASSERT(remaining < 4 * NF);
|
959 | 1140 | if (HWY_UNLIKELY(remaining != 0)) {
|
960 | 1141 | // i is even, but remaining might not be.
|
|
0 commit comments