|
40 | 40 | // After highway.h
|
41 | 41 | #include "compression/sfp-inl.h"
|
42 | 42 | #include "hwy/contrib/sort/vqsort-inl.h"
|
| 43 | +#include "hwy/profiler.h" // uses SIMD |
43 | 44 |
|
44 | 45 | HWY_BEFORE_NAMESPACE();
|
45 | 46 | namespace gcpp {
|
@@ -529,12 +530,21 @@ class NuqCodec {
|
529 | 530 | return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
|
530 | 531 | }
|
531 | 532 |
|
| 533 | + // Offset (in bytes) of a group's table for packed_ofs (in elements) within a |
| 534 | + // set of groups. |
| 535 | + static constexpr size_t TableByteOffset(size_t packed_ofs) { |
| 536 | + const size_t kBytesPerGroup = |
| 537 | + (kClusters * sizeof(SfpStream)) + kGroupSize / 2; |
| 538 | + return (packed_ofs / kGroupSize) * kBytesPerGroup; |
| 539 | + } |
| 540 | + |
532 | 541 | // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
|
533 | 542 | // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
|
534 | 543 | // not be available for bf16.
|
535 | 544 | template <class DU, HWY_IF_U16_D(DU)>
|
536 | 545 | static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
|
537 | 546 | hn::Vec<DU>* HWY_RESTRICT tbl1) {
|
| 547 | + PROFILER_FUNC; |
538 | 548 | // Cap to the table size (kClusters) for decoding SFP - sufficient, and may
|
539 | 549 | // be faster than a large vector.
|
540 | 550 | const hn::CappedTag<BF16, kClusters> d_table;
|
@@ -606,6 +616,81 @@ class NuqCodec {
|
606 | 616 | }
|
607 | 617 |
|
608 | 618 | public:
|
| 619 | + // Encodes `num` floats from `raw` into `packed`. `packed` points to |
| 620 | + // compressed storage and `packed_ofs` indicates the destination offset within |
| 621 | + // it, in number of elements. Tables are interleaved with indices (clustered |
| 622 | + // elements) to allow for easier unpacking. Returns the total number of |
| 623 | + // unused clusters, which is typically zero. |
| 624 | + template <class DF, HWY_IF_F32_D(DF)> |
| 625 | + static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw, |
| 626 | + const size_t num, |
| 627 | + NuqStream::ClusterBuf& buf, |
| 628 | + const PackedSpan<NuqStream>& packed, |
| 629 | + size_t packed_ofs) { |
| 630 | + const hn::Repartition<uint16_t, DF> d16; |
| 631 | + const hn::Repartition<uint8_t, DF> d8; |
| 632 | + using V16 = hn::Vec<decltype(d16)>; |
| 633 | + using V8 = hn::Vec<decltype(d8)>; |
| 634 | + const size_t N16 = hn::Lanes(d16); |
| 635 | + |
| 636 | + HWY_ASSERT(packed_ofs % kGroupSize == 0); |
| 637 | + |
| 638 | + const size_t num_groups = hwy::DivCeil(num, kGroupSize); |
| 639 | + // TODO: dynamic resize should be removed; it is no longer necessary as |
| 640 | + // interleaved encoding uses only a single buffer of the same size. |
| 641 | + buf.Resize(1); |
| 642 | + |
| 643 | + size_t unused_clusters = 0; |
| 644 | + size_t current_offset = packed_ofs; |
| 645 | + for (size_t g = 0; g < num_groups; ++g) { |
| 646 | + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); |
| 647 | + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; |
| 648 | + |
| 649 | + float* HWY_RESTRICT g_centers = buf.centers.get(); |
| 650 | + uint16_t* HWY_RESTRICT g_idx = buf.idx.get(); |
| 651 | + |
| 652 | + unused_clusters += |
| 653 | + NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); |
| 654 | + |
| 655 | + uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset); |
| 656 | + SfpCodec::Enc(df, buf.centers.get(), kClusters, |
| 657 | + reinterpret_cast<SfpStream*>(centers)); |
| 658 | + uint8_t* packed_start = centers + kClusters; |
| 659 | + |
| 660 | + current_offset += g_num; |
| 661 | + |
| 662 | + HWY_DASSERT(g_num % (4 * N16) == 0); |
| 663 | + |
| 664 | + size_t i = 0; |
| 665 | + HWY_UNROLL(1) |
| 666 | + for (; i < g_num; i += 4 * N16) { |
| 667 | + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); |
| 668 | + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); |
| 669 | + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); |
| 670 | + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); |
| 671 | + const V8 nibbles = |
| 672 | + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); |
| 673 | + hn::StoreU(nibbles, d8, packed_start + i / 2); |
| 674 | + } |
| 675 | + |
| 676 | + const size_t remaining = g_num - i; |
| 677 | + |
| 678 | + HWY_DASSERT(remaining < 4 * N16); |
| 679 | + if (HWY_UNLIKELY(remaining != 0)) { |
| 680 | + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); |
| 681 | + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); |
| 682 | + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); |
| 683 | + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); |
| 684 | + const V8 nibbles = |
| 685 | + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); |
| 686 | + // i is even, but remaining might not be. |
| 687 | + hn::StoreN(nibbles, d8, packed_start + i / 2, |
| 688 | + hwy::DivCeil(remaining, 2)); |
| 689 | + } |
| 690 | + } |
| 691 | + return unused_clusters; |
| 692 | + } |
| 693 | + |
609 | 694 | // Encodes `num` floats from `raw`. `packed` points to compressed storage and
|
610 | 695 | // `packed_ofs` indicates the destination offset within it, in units of float
|
611 | 696 | // values, for parallel encoding by multiple threads. Returns the total
|
@@ -733,6 +818,8 @@ class NuqCodec {
|
733 | 818 | raw1 = BitCast(dbf, c1);
|
734 | 819 | }
|
735 | 820 |
|
| 821 | + // TODO(philculliton): Remove non-interleaved function versions now that |
| 822 | + // interleaved is working / the default. |
736 | 823 | // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
|
737 | 824 | // vectors so that we only have to load one group's table.
|
738 | 825 | template <class DF, HWY_IF_F32_D(DF)>
|
@@ -765,6 +852,107 @@ class NuqCodec {
|
765 | 852 | raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
766 | 853 | }
|
767 | 854 |
|
| 855 | + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two |
| 856 | + // vectors so that we only have to load one group's table. |
| 857 | + template <class DBF, HWY_IF_BF16_D(DBF)> |
| 858 | + static HWY_INLINE void Dec2Interleaved( |
| 859 | + DBF dbf, const PackedSpan<const NuqStream>& packed, |
| 860 | + const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) { |
| 861 | + PROFILER_FUNC; |
| 862 | + const hn::RebindToUnsigned<decltype(dbf)> d16; |
| 863 | + const D8HFromD16<DBF> d8h; |
| 864 | + using V16 = hn::Vec<decltype(d16)>; |
| 865 | + using V8H = hn::Vec<decltype(d8h)>; |
| 866 | + |
| 867 | + const size_t within_group = packed_ofs % kGroupSize; |
| 868 | + HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); |
| 869 | + const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs); |
| 870 | + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); |
| 871 | + |
| 872 | + V16 tbl1 = Zero(d16); |
| 873 | + const V16 tbl0 = LoadTable(d16, table, &tbl1); |
| 874 | + |
| 875 | + const V8H nibbles = hn::LoadU(d8h, indices); |
| 876 | + |
| 877 | + V16 c0, c1; |
| 878 | + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); |
| 879 | + raw0 = BitCast(dbf, c0); |
| 880 | + raw1 = BitCast(dbf, c1); |
| 881 | + } |
| 882 | + |
| 883 | + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two |
| 884 | + // vectors so that we only have to load one group's table. |
| 885 | + template <class DF, HWY_IF_F32_D(DF)> |
| 886 | + static HWY_INLINE void Dec2Interleaved( |
| 887 | + DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs, |
| 888 | + hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) { |
| 889 | + const hn::Repartition<BF16, decltype(df)> dbf; |
| 890 | + const hn::RebindToUnsigned<decltype(dbf)> d16; |
| 891 | + const hn::Half<D8HFromD16<decltype(d16)>> d8q; |
| 892 | + using V8Q = hn::Vec<decltype(d8q)>; |
| 893 | + using V16 = hn::Vec<decltype(d16)>; |
| 894 | + |
| 895 | + const size_t within_group = packed_ofs % kGroupSize; |
| 896 | + HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); |
| 897 | + const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs); |
| 898 | + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); |
| 899 | + |
| 900 | + V16 tbl1 = Zero(d16); |
| 901 | + const V16 tbl0 = LoadTable(d16, table, &tbl1); |
| 902 | + |
| 903 | + // The single-vector TableLookups overload only calls OrderedUnpackU16<0>, |
| 904 | + // which expects a quarter vector of bytes. |
| 905 | + const V8Q nibbles = hn::LoadU(d8q, indices); |
| 906 | + |
| 907 | + // TODO(janwas): From janwas: on AVX-512 I imagine we can get a |
| 908 | + // bit more speed for this function by changing LoadTable to return floats, |
| 909 | + // then we could have a single lookup here instead of PromoteUpperTo which |
| 910 | + // is not cheap. |
| 911 | + const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); |
| 912 | + raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); |
| 913 | + raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); |
| 914 | + } |
| 915 | + |
| 916 | + template <class D, typename Raw = hn::TFromD<D>> |
| 917 | + static HWY_INLINE void DecompressAndZeroPadInterleaved( |
| 918 | + D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs, |
| 919 | + Raw* HWY_RESTRICT raw, size_t num) { |
| 920 | + // If unaligned, load elements from the first group and update the args, |
| 921 | + // from which we compute new tables/indices below. |
| 922 | + size_t current_offset = packed_ofs; |
| 923 | + if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { |
| 924 | + const uint8_t* tables = |
| 925 | + &packed.ptr->byte + TableByteOffset(current_offset); |
| 926 | + const uint8_t* indices = tables + kClusters; |
| 927 | + const size_t remaining = HWY_MIN(num, kGroupSize - within_group); |
| 928 | + |
| 929 | + DecPartialGroup(d, tables, indices, raw, remaining); |
| 930 | + packed_ofs += remaining; |
| 931 | + current_offset += remaining; |
| 932 | + raw += remaining; |
| 933 | + num -= remaining; |
| 934 | + if (num == 0) return; |
| 935 | + } |
| 936 | + |
| 937 | + HWY_DASSERT(packed_ofs % kGroupSize == 0); |
| 938 | + |
| 939 | + const size_t num_groups = hwy::DivCeil(num, kGroupSize); |
| 940 | + HWY_UNROLL(1) |
| 941 | + for (size_t g = 0; g < num_groups - 1; ++g) { |
| 942 | + const uint8_t* tables = |
| 943 | + &packed.ptr->byte + TableByteOffset(current_offset); |
| 944 | + const uint8_t* indices = tables + kClusters; |
| 945 | + DecWholeGroup(d, tables, indices, raw + g * kGroupSize); |
| 946 | + current_offset += kGroupSize; |
| 947 | + } |
| 948 | + |
| 949 | + const size_t g = num_groups - 1; |
| 950 | + const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset); |
| 951 | + const uint8_t* indices = tables + kClusters; |
| 952 | + DecPartialGroup(d, tables, indices, raw + g * kGroupSize, |
| 953 | + num - g * kGroupSize); |
| 954 | + } |
| 955 | + |
768 | 956 | // Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
|
769 | 957 | // elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
|
770 | 958 | // round `num` up to one vector, if it is not already.
|
@@ -955,6 +1143,7 @@ class NuqCodec {
|
955 | 1143 | }
|
956 | 1144 |
|
957 | 1145 | const size_t remaining = num - i;
|
| 1146 | + |
958 | 1147 | HWY_DASSERT(remaining < 4 * NF);
|
959 | 1148 | if (HWY_UNLIKELY(remaining != 0)) {
|
960 | 1149 | // i is even, but remaining might not be.
|
|
0 commit comments