Skip to content

Commit 38558de

Browse files
authored
Merge pull request numpy#20367 from HowJMay/simd-trunc
ENH, SIMD: add new universal intrinsics for trunc
2 parents f146ec1 + 9b1bd0d commit 38558de

File tree

7 files changed

+85
-8
lines changed

7 files changed

+85
-8
lines changed

numpy/core/src/_simd/_simd.dispatch.c.src

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ SIMD_IMPL_INTRIN_1(sumup_@sfx@, @esfx@, v@sfx@)
381381
***************************/
382382
#if @fp_only@
383383
/**begin repeat1
384-
* #intrin = sqrt, recip, abs, square, ceil#
384+
* #intrin = sqrt, recip, abs, square, ceil, trunc#
385385
*/
386386
SIMD_IMPL_INTRIN_1(@intrin@_@sfx@, v@sfx@, v@sfx@)
387387
/**end repeat1**/
@@ -615,7 +615,7 @@ SIMD_INTRIN_DEF(sumup_@sfx@)
615615
***************************/
616616
#if @fp_only@
617617
/**begin repeat1
618-
* #intrin = sqrt, recip, abs, square, ceil#
618+
* #intrin = sqrt, recip, abs, square, ceil, trunc#
619619
*/
620620
SIMD_INTRIN_DEF(@intrin@_@sfx@)
621621
/**end repeat1**/

numpy/core/src/common/simd/avx2/math.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,8 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b)
109109
#define npyv_ceil_f32 _mm256_ceil_ps
110110
#define npyv_ceil_f64 _mm256_ceil_pd
111111

112+
// trunc
113+
#define npyv_trunc_f32(A) _mm256_round_ps(A, _MM_FROUND_TO_ZERO)
114+
#define npyv_trunc_f64(A) _mm256_round_pd(A, _MM_FROUND_TO_ZERO)
115+
112116
#endif // _NPY_SIMD_AVX2_MATH_H

numpy/core/src/common/simd/avx512/math.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,8 @@ NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b)
116116
#define npyv_ceil_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_POS_INF)
117117
#define npyv_ceil_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_POS_INF)
118118

119+
// trunc
120+
#define npyv_trunc_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_ZERO)
121+
#define npyv_trunc_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_ZERO)
122+
119123
#endif // _NPY_SIMD_AVX512_MATH_H

numpy/core/src/common/simd/neon/math.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,37 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b)
190190
#define npyv_ceil_f64 vrndpq_f64
191191
#endif // NPY_SIMD_F64
192192

193+
// trunc
194+
#ifdef NPY_HAVE_ASIMD
195+
#define npyv_trunc_f32 vrndq_f32
196+
#else
197+
NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a)
198+
{
199+
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
200+
const npyv_s32 max_int = vdupq_n_s32(0x7fffffff);
201+
/**
202+
* On armv7, vcvtq.f32 handles special cases as follows:
203+
* NaN return 0
204+
* +inf or +outrange return 0x80000000(-0.0f)
205+
* -inf or -outrange return 0x7fffffff(nan)
206+
*/
207+
npyv_s32 roundi = vcvtq_s32_f32(a);
208+
npyv_f32 round = vcvtq_f32_s32(roundi);
209+
// respect signed zero, e.g. -0.5 -> -0.0
210+
npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32(
211+
vreinterpretq_s32_f32(round),
212+
vandq_s32(vreinterpretq_s32_f32(a), szero)
213+
));
214+
// if nan or overflow return a
215+
npyv_u32 nnan = npyv_notnan_f32(a);
216+
npyv_u32 overflow = vorrq_u32(
217+
vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int)
218+
);
219+
return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a);
220+
}
221+
#endif
222+
#if NPY_SIMD_F64
223+
#define npyv_trunc_f64 vrndq_f64
224+
#endif // NPY_SIMD_F64
225+
193226
#endif // _NPY_SIMD_NEON_MATH_H

numpy/core/src/common/simd/sse/math.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,32 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b)
174174
}
175175
#endif
176176

177+
// trunc
178+
#ifdef NPY_HAVE_SSE41
179+
#define npyv_trunc_f32(A) _mm_round_ps(A, _MM_FROUND_TO_ZERO)
180+
#define npyv_trunc_f64(A) _mm_round_pd(A, _MM_FROUND_TO_ZERO)
181+
#else
182+
NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a)
183+
{
184+
const npyv_f32 szero = _mm_set1_ps(-0.0f);
185+
npyv_s32 roundi = _mm_cvttps_epi32(a);
186+
npyv_f32 trunc = _mm_cvtepi32_ps(roundi);
187+
// respect signed zero, e.g. -0.5 -> -0.0
188+
npyv_f32 rzero = _mm_or_ps(trunc, _mm_and_ps(a, szero));
189+
// if overflow return a
190+
return npyv_select_f32(_mm_cmpeq_epi32(roundi, _mm_castps_si128(szero)), a, rzero);
191+
}
192+
NPY_FINLINE npyv_f64 npyv_trunc_f64(npyv_f64 a)
193+
{
194+
const npyv_f64 szero = _mm_set1_pd(-0.0);
195+
const npyv_f64 one = _mm_set1_pd(1.0);
196+
const npyv_f64 two_power_52 = _mm_set1_pd(0x10000000000000);
197+
npyv_f64 abs_a = npyv_abs_f64(a);
198+
// round by add magic number 2^52
199+
npyv_f64 abs_round = _mm_sub_pd(_mm_add_pd(abs_a, two_power_52), two_power_52);
200+
npyv_f64 subtrahend = _mm_and_pd(_mm_cmpgt_pd(abs_round, abs_a), one);
201+
return _mm_or_pd(_mm_sub_pd(abs_round, subtrahend), _mm_and_pd(a, szero));
202+
}
203+
#endif
204+
177205
#endif // _NPY_SIMD_SSE_MATH_H

numpy/core/src/common/simd/vsx/math.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,8 @@ NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
7373
#define npyv_ceil_f32 vec_ceil
7474
#define npyv_ceil_f64 vec_ceil
7575

76+
// trunc
77+
#define npyv_trunc_f32 vec_trunc
78+
#define npyv_trunc_f64 vec_trunc
79+
7680
#endif // _NPY_SIMD_VSX_MATH_H

numpy/core/tests/test_simd.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,15 @@ def test_square(self):
330330
square = self.square(vdata)
331331
assert square == data_square
332332

333-
@pytest.mark.parametrize("intrin, func", [("self.ceil", math.ceil)])
333+
@pytest.mark.parametrize("intrin, func", [("self.ceil", math.ceil),
334+
("self.trunc", math.trunc)])
334335
def test_rounding(self, intrin, func):
335336
"""
336337
Test intrinsics:
337338
npyv_ceil_##SFX
339+
npyv_trunc_##SFX
338340
"""
341+
intrin_name = intrin
339342
intrin = eval(intrin)
340343
pinf, ninf, nan = self._pinfinity(), self._ninfinity(), self._nan()
341344
# special cases
@@ -352,11 +355,12 @@ def test_rounding(self, intrin, func):
352355
_round = intrin(vdata)
353356
assert _round == data_round
354357
# signed zero
355-
for w in (-0.25, -0.30, -0.45):
356-
_round = self._to_unsigned(intrin(self.setall(w)))
357-
data_round = self._to_unsigned(self.setall(-0.0))
358-
assert _round == data_round
359-
358+
if "ceil" in intrin_name or "trunc" in intrin_name:
359+
for w in (-0.25, -0.30, -0.45):
360+
_round = self._to_unsigned(intrin(self.setall(w)))
361+
data_round = self._to_unsigned(self.setall(-0.0))
362+
assert _round == data_round
363+
360364
def test_max(self):
361365
"""
362366
Test intrinsics:

0 commit comments

Comments
 (0)