Skip to content

Commit ee7d79c

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Add Decompress2AndCompressInplace helper
PiperOrigin-RevId: 825966142
1 parent 0069990 commit ee7d79c

File tree

2 files changed

+78
-5
lines changed

2 files changed

+78
-5
lines changed

compression/compress-inl.h

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
604604
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
605605
}
606606

607+
// NOTE: the following are the recommended way to iterate over arrays of
608+
// potentially compressed elements, including remainder handling. Prefer them
609+
// over calling `Decompress2` directly, which does not handle remainders.
610+
// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
611+
// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
612+
// user code expressed as lambdas.
613+
607614
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
608615
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
609616
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
@@ -733,8 +740,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
733740
comp3);
734741
}
735742

736-
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
737-
// `DF` is the decompressed type, typically `float`.
743+
// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`.
744+
// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
738745
template <class DF, typename T, class Func>
739746
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
740747
size_t num, Func&& func) {
@@ -773,6 +780,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
773780
}
774781

775782
// One extra argument. `DF` is the decompressed type, typically `float`.
783+
// Calls `func(df, v_inout, v1)`.
776784
template <class DF, typename T, typename T1, class Func>
777785
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
778786
size_t num,
@@ -821,8 +829,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
821829
}
822830
}
823831

832+
// Two extra arguments. `DF` is the decompressed type, typically `float`.
833+
// Calls `func(df, v_inout, v1, v2)`.
834+
template <class DF, typename T, typename T1, typename T2, class Func>
835+
HWY_INLINE void Decompress2AndCompressInplace(
836+
DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
837+
const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
838+
const auto packed_inout = MakeSpan(inout, num);
839+
const auto packed1 = MakeSpan(p1, num);
840+
const auto packed2 = MakeSpan(p2, p2_ofs + num);
841+
842+
using VF = hn::Vec<decltype(df)>;
843+
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
844+
size_t i = 0;
845+
if (num >= 2 * NF) {
846+
for (; i <= num - 2 * NF; i += 2 * NF) {
847+
VF v0, v1;
848+
Decompress2(df, packed_inout, i, v0, v1);
849+
VF v10, v11;
850+
Decompress2(df, packed1, i, v10, v11);
851+
VF v20, v21;
852+
Decompress2(df, packed2, p2_ofs + i, v20, v21);
853+
const VF out0 = func(df, v0, v10, v20);
854+
const VF out1 = func(df, v1, v11, v21);
855+
Compress2(df, out0, out1, packed_inout, i);
856+
}
857+
}
858+
859+
const size_t remaining = num - i;
860+
HWY_DASSERT(remaining < 2 * NF);
861+
if (HWY_UNLIKELY(remaining != 0)) {
862+
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
863+
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
864+
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
865+
// Ensure the second vector is zeroed even if remaining <= NF.
866+
hn::Store(hn::Zero(df), df, buf_inout + NF);
867+
hn::Store(hn::Zero(df), df, buf1 + NF);
868+
hn::Store(hn::Zero(df), df, buf2 + NF);
869+
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
870+
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
871+
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
872+
const VF v0 = hn::Load(df, buf_inout);
873+
const VF v1 = hn::Load(df, buf_inout + NF);
874+
const VF v10 = hn::Load(df, buf1);
875+
const VF v11 = hn::Load(df, buf1 + NF);
876+
const VF v20 = hn::Load(df, buf2);
877+
const VF v21 = hn::Load(df, buf2 + NF);
878+
const VF out0 = func(df, v0, v10, v20);
879+
const VF out1 = func(df, v1, v11, v21);
880+
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
881+
// Clang generates incorrect code for CopyBytes if num = 2.
882+
for (size_t j = 0; j < remaining; ++j) {
883+
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
884+
}
885+
}
886+
}
887+
824888
// Single input, separate output. `DF` is the decompressed type, typically
825-
// `float`.
889+
// `float`. Calls `func(df, v1)`.
826890
template <class DF, typename T, typename T1, class Func>
827891
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
828892
const T1* HWY_RESTRICT p1,
@@ -863,7 +927,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
863927
}
864928
}
865929

866-
// Two inputs. `DF` is the decompressed type, typically `float`.
930+
// Two inputs, separate output. `DF` is the decompressed type, typically
931+
// `float`. Calls `func(df, v1, v2)`.
867932
template <class DF, typename T, typename T1, typename T2, class Func>
868933
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
869934
const T1* HWY_RESTRICT p1,
@@ -912,7 +977,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
912977
}
913978
}
914979

915-
// Three inputs. `DF` is the decompressed type, typically `float`.
980+
// Three inputs, separate output. `DF` is the decompressed type, typically
981+
// `float`. Calls `func(df, v1, v2, v3)`.
916982
template <class DF, typename T, typename T1, typename T2, typename T3,
917983
class Func>
918984
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,

compression/compress_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ class TestDecompressAndCompress {
259259
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
260260
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
261261

262+
// `out` already contains v + v1.
263+
Decompress2AndCompressInplace(
264+
df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0,
265+
[](DF, VF v, VF /*v1*/, VF v2)
266+
HWY_ATTR -> VF { return hn::Add(v, v2); });
267+
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
268+
262269
Decompress1AndCompressTo(df, out.get(), num, p.get(),
263270
[](DF, VF v) HWY_ATTR -> VF { return v; });
264271
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);

0 commit comments

Comments
 (0)