@@ -604,6 +604,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
604604 Traits::DecompressAndZeroPad (d, MakeConst (packed), packed_ofs, raw, num);
605605}
606606
607+ // NOTE: the following are the recommended way to iterate over arrays of
608+ // potentially compressed elements, including remainder handling. Prefer them
609+ // over calling `Decompress2` directly, which does not handle remainders.
610+ // `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
611+ // `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
612+ // user code expressed as lambdas.
613+
607614// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
608615// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
609616// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
@@ -733,8 +740,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
733740 comp3);
734741}
735742
736- // Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
737- // `DF` is the decompressed type, typically `float`.
743+ // Similar to `hn::Transform*`, but for compressed `T`. Used by ` ops-inl.h` .
744+ // `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
738745template <class DF , typename T, class Func >
739746HWY_INLINE void DecompressAndCompressInplace (DF df, T* HWY_RESTRICT inout,
740747 size_t num, Func&& func) {
@@ -773,6 +780,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
773780}
774781
775782// One extra argument. `DF` is the decompressed type, typically `float`.
783+ // Calls `func(df, v_inout, v1)`.
776784template <class DF , typename T, typename T1, class Func >
777785HWY_INLINE void Decompress1AndCompressInplace (DF df, T* HWY_RESTRICT inout,
778786 size_t num,
@@ -821,8 +829,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
821829 }
822830}
823831
832+ // Two extra arguments. `DF` is the decompressed type, typically `float`.
833+ // Calls `func(df, v_inout, v1, v2)`.
834+ template <class DF , typename T, typename T1, typename T2, class Func >
835+ HWY_INLINE void Decompress2AndCompressInplace (
836+ DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
837+ const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
838+ const auto packed_inout = MakeSpan (inout, num);
839+ const auto packed1 = MakeSpan (p1, num);
840+ const auto packed2 = MakeSpan (p2, p2_ofs + num);
841+
842+ using VF = hn::Vec<decltype (df)>;
843+ HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes (df);
844+ size_t i = 0 ;
845+ if (num >= 2 * NF) {
846+ for (; i <= num - 2 * NF; i += 2 * NF) {
847+ VF v0, v1;
848+ Decompress2 (df, packed_inout, i, v0, v1);
849+ VF v10, v11;
850+ Decompress2 (df, packed1, i, v10, v11);
851+ VF v20, v21;
852+ Decompress2 (df, packed2, p2_ofs + i, v20, v21);
853+ const VF out0 = func (df, v0, v10, v20);
854+ const VF out1 = func (df, v1, v11, v21);
855+ Compress2 (df, out0, out1, packed_inout, i);
856+ }
857+ }
858+
859+ const size_t remaining = num - i;
860+ HWY_DASSERT (remaining < 2 * NF);
861+ if (HWY_UNLIKELY (remaining != 0 )) {
862+ HWY_ALIGN float buf_inout[2 * hn::MaxLanes (df)];
863+ HWY_ALIGN float buf1[2 * hn::MaxLanes (df)];
864+ HWY_ALIGN float buf2[2 * hn::MaxLanes (df)];
865+ // Ensure the second vector is zeroed even if remaining <= NF.
866+ hn::Store (hn::Zero (df), df, buf_inout + NF);
867+ hn::Store (hn::Zero (df), df, buf1 + NF);
868+ hn::Store (hn::Zero (df), df, buf2 + NF);
869+ DecompressAndZeroPad (df, packed_inout, i, buf_inout, remaining);
870+ DecompressAndZeroPad (df, packed1, i, buf1, remaining);
871+ DecompressAndZeroPad (df, packed2, p2_ofs + i, buf2, remaining);
872+ const VF v0 = hn::Load (df, buf_inout);
873+ const VF v1 = hn::Load (df, buf_inout + NF);
874+ const VF v10 = hn::Load (df, buf1);
875+ const VF v11 = hn::Load (df, buf1 + NF);
876+ const VF v20 = hn::Load (df, buf2);
877+ const VF v21 = hn::Load (df, buf2 + NF);
878+ const VF out0 = func (df, v0, v10, v20);
879+ const VF out1 = func (df, v1, v11, v21);
880+ Compress2 (df, out0, out1, MakeSpan (buf_inout, 2 * NF), 0 );
881+ // Clang generates incorrect code for CopyBytes if num = 2.
882+ for (size_t j = 0 ; j < remaining; ++j) {
883+ inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
884+ }
885+ }
886+ }
887+
824888// Single input, separate output. `DF` is the decompressed type, typically
825- // `float`.
889+ // `float`. Calls `func(df, v1)`.
826890template <class DF , typename T, typename T1, class Func >
827891HWY_INLINE void Decompress1AndCompressTo (DF df, T* HWY_RESTRICT out, size_t num,
828892 const T1* HWY_RESTRICT p1,
@@ -863,7 +927,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
863927 }
864928}
865929
866- // Two inputs. `DF` is the decompressed type, typically `float`.
930+ // Two inputs, separate output. `DF` is the decompressed type, typically
931+ // `float`. Calls `func(df, v1, v2)`.
867932template <class DF , typename T, typename T1, typename T2, class Func >
868933HWY_INLINE void Decompress2AndCompressTo (DF df, T* HWY_RESTRICT out, size_t num,
869934 const T1* HWY_RESTRICT p1,
@@ -912,7 +977,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
912977 }
913978}
914979
915- // Three inputs. `DF` is the decompressed type, typically `float`.
980+ // Three inputs, separate output. `DF` is the decompressed type, typically
981+ // `float`. Calls `func(df, v1, v2, v3)`.
916982template <class DF , typename T, typename T1, typename T2, typename T3,
917983 class Func >
918984HWY_INLINE void Decompress3AndCompressTo (DF df, T* HWY_RESTRICT out, size_t num,
0 commit comments