Skip to content

Commit f9b15ab

Browse files
authored
Vectorize TensorPrimitives.Tanh/Cosh/Sinh (#93093)
* Vectorize TensorPrimitives.Tanh/Cosh/Sinh Tanh and Cosh are based on AOCL-LibM. AOCL-LibM doesn't appear to have a sinh implementation, so this Sinh is just based on the sinh formula based on exp(x). I also augmented the tests further, including: - Added more tests for sinh/cosh/tanh - Add an equality routine that supports comparing larger values with a tolerance - Tightened the tolerance for most functions - Changed some tests to be theories to be consistent with style elsewhere in the tests - Fixed some use of Math to be MathF * Remove unnecessary special-handling path from cosh * Remove unnecessary special-handling path from tanh * Redo sinh based on cosh * Address PR feedback
1 parent 7b08680 commit f9b15ab

File tree

4 files changed

+480
-203
lines changed

4 files changed

+480
-203
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,8 @@ public static void AddMultiply(ReadOnlySpan<float> x, float y, ReadOnlySpan<floa
147147
/// operating systems or architectures.
148148
/// </para>
149149
/// </remarks>
150-
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination)
151-
{
152-
if (x.Length > destination.Length)
153-
{
154-
ThrowHelper.ThrowArgument_DestinationTooShort();
155-
}
156-
157-
ValidateInputOutputSpanNonOverlapping(x, destination);
158-
159-
for (int i = 0; i < x.Length; i++)
160-
{
161-
destination[i] = MathF.Cosh(x[i]);
162-
}
163-
}
150+
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination) =>
151+
InvokeSpanIntoSpan<CoshOperator>(x, destination);
164152

165153
/// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
166154
/// <param name="x">The first tensor, represented as a span.</param>
@@ -1012,20 +1000,8 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
10121000
/// operating systems or architectures.
10131001
/// </para>
10141002
/// </remarks>
1015-
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination)
1016-
{
1017-
if (x.Length > destination.Length)
1018-
{
1019-
ThrowHelper.ThrowArgument_DestinationTooShort();
1020-
}
1021-
1022-
ValidateInputOutputSpanNonOverlapping(x, destination);
1023-
1024-
for (int i = 0; i < x.Length; i++)
1025-
{
1026-
destination[i] = MathF.Sinh(x[i]);
1027-
}
1028-
}
1003+
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination) =>
1004+
InvokeSpanIntoSpan<SinhOperator>(x, destination);
10291005

10301006
/// <summary>Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers.</summary>
10311007
/// <param name="x">The tensor, represented as a span.</param>
@@ -1177,20 +1153,8 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
11771153
/// operating systems or architectures.
11781154
/// </para>
11791155
/// </remarks>
1180-
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
1181-
{
1182-
if (x.Length > destination.Length)
1183-
{
1184-
ThrowHelper.ThrowArgument_DestinationTooShort();
1185-
}
1186-
1187-
ValidateInputOutputSpanNonOverlapping(x, destination);
1188-
1189-
for (int i = 0; i < x.Length; i++)
1190-
{
1191-
destination[i] = MathF.Tanh(x[i]);
1192-
}
1193-
}
1156+
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination) =>
1157+
InvokeSpanIntoSpan<TanhOperator>(x, destination);
11941158

11951159
/// <summary>Throws an exception if the <paramref name="input"/> and <paramref name="output"/> spans overlap and don't begin at the same memory location.</summary>
11961160
[MethodImpl(MethodImplOptions.AggressiveInlining)]

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Runtime.Intrinsics;
88
using System.Runtime.Intrinsics.Arm;
99
using System.Runtime.Intrinsics.X86;
10+
using System.Security.Cryptography;
1011

1112
namespace System.Numerics.Tensors
1213
{
@@ -147,15 +148,15 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
147148
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
148149
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.
149150

150-
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
151+
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
151152
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
152153
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
153154
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
154155
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
155156
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
156157
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
157158
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
158-
#pragma warning restore IDE0059
159+
#pragma warning restore IDE0059
159160

160161
static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
161162
{
@@ -462,13 +463,13 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
462463
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
463464
// We loop handling one input vector at a time, producing two output float vectors.
464465

465-
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
466+
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
466467
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
467468
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
468469
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
469470
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
470471
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
471-
#pragma warning restore IDE0059
472+
#pragma warning restore IDE0059
472473

473474
static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
474475
{
@@ -2992,6 +2993,156 @@ public static Vector512<float> Invoke(Vector512<float> x)
29922993
#endif
29932994
}
29942995

2996+
/// <summary>MathF.Cosh(x)</summary>
2997+
private readonly struct CoshOperator : IUnaryOperator
2998+
{
2999+
// This code is based on `vrs4_coshf` from amd/aocl-libm-ose
3000+
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
3001+
//
3002+
// Licensed under the BSD 3-Clause "New" or "Revised" License
3003+
// See THIRD-PARTY-NOTICES.TXT for the full license text
3004+
3005+
// Spec:
3006+
// coshf(|x| > 89.415985107421875) = Infinity
3007+
// coshf(Infinity) = infinity
3008+
// coshf(-Infinity) = infinity
3009+
//
3010+
// cosh(x) = (exp(x) + exp(-x))/2
3011+
// cosh(-x) = +cosh(x)
3012+
//
3013+
// checks for special cases
3014+
// if ( asint(x) > infinity) return x with overflow exception and
3015+
// return x.
3016+
// if x is NaN then raise invalid FP operation exception and return x.
3017+
//
3018+
// coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1
3019+
3020+
private const float LOGV = 0.693161f;
3021+
private const float HALFV = 1.0000138f;
3022+
private const float INVV2 = 0.24999309f;
3023+
3024+
public static float Invoke(float x) => MathF.Cosh(x);
3025+
3026+
public static Vector128<float> Invoke(Vector128<float> x)
3027+
{
3028+
Vector128<float> y = Vector128.Abs(x);
3029+
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV));
3030+
return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z));
3031+
}
3032+
3033+
public static Vector256<float> Invoke(Vector256<float> x)
3034+
{
3035+
Vector256<float> y = Vector256.Abs(x);
3036+
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV));
3037+
return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z));
3038+
}
3039+
3040+
#if NET8_0_OR_GREATER
3041+
public static Vector512<float> Invoke(Vector512<float> x)
3042+
{
3043+
Vector512<float> y = Vector512.Abs(x);
3044+
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV));
3045+
return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z));
3046+
}
3047+
#endif
3048+
}
3049+
3050+
/// <summary>MathF.Sinh(x)</summary>
3051+
private readonly struct SinhOperator : IUnaryOperator
3052+
{
3053+
// Same as cosh, but with `z -` rather than `z +`, and with the sign
3054+
// flipped on the result based on the sign of the input.
3055+
3056+
private const uint SIGN_MASK = 0x7FFFFFFF;
3057+
private const float LOGV = 0.693161f;
3058+
private const float HALFV = 1.0000138f;
3059+
private const float INVV2 = 0.24999309f;
3060+
3061+
public static float Invoke(float x) => MathF.Sinh(x);
3062+
3063+
public static Vector128<float> Invoke(Vector128<float> x)
3064+
{
3065+
Vector128<float> y = Vector128.Abs(x);
3066+
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV));
3067+
Vector128<float> result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z));
3068+
Vector128<uint> sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK);
3069+
return (sign ^ result.AsUInt32()).AsSingle();
3070+
}
3071+
3072+
public static Vector256<float> Invoke(Vector256<float> x)
3073+
{
3074+
Vector256<float> y = Vector256.Abs(x);
3075+
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV));
3076+
Vector256<float> result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z));
3077+
Vector256<uint> sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK);
3078+
return (sign ^ result.AsUInt32()).AsSingle();
3079+
}
3080+
3081+
#if NET8_0_OR_GREATER
3082+
public static Vector512<float> Invoke(Vector512<float> x)
3083+
{
3084+
Vector512<float> y = Vector512.Abs(x);
3085+
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV));
3086+
Vector512<float> result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z));
3087+
Vector512<uint> sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK);
3088+
return (sign ^ result.AsUInt32()).AsSingle();
3089+
}
3090+
#endif
3091+
}
3092+
3093+
/// <summary>MathF.Tanh(x)</summary>
3094+
private readonly struct TanhOperator : IUnaryOperator
3095+
{
3096+
// This code is based on `vrs4_tanhf` from amd/aocl-libm-ose
3097+
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
3098+
//
3099+
// Licensed under the BSD 3-Clause "New" or "Revised" License
3100+
// See THIRD-PARTY-NOTICES.TXT for the full license text
3101+
3102+
// To compute vrs4_tanhf(v_f32x4_t x)
3103+
// Let y = |x|
3104+
// If 0 <= y < 0x1.154246p3
3105+
// Let z = e^(-2.0 * y) - 1 -(1)
3106+
//
3107+
// Using (1), tanhf(y) can be calculated as,
3108+
// tanhf(y) = -z / (z + 2.0)
3109+
//
3110+
// For other cases, call scalar tanhf()
3111+
//
3112+
// If x < 0, then we use the identity
3113+
// tanhf(-x) = -tanhf(x)
3114+
3115+
private const uint SIGN_MASK = 0x7FFFFFFF;
3116+
3117+
public static float Invoke(float x) => MathF.Tanh(x);
3118+
3119+
public static Vector128<float> Invoke(Vector128<float> x)
3120+
{
3121+
Vector128<float> y = Vector128.Abs(x);
3122+
Vector128<float> z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f);
3123+
Vector128<uint> sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK);
3124+
return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle();
3125+
}
3126+
3127+
public static Vector256<float> Invoke(Vector256<float> x)
3128+
{
3129+
Vector256<float> y = Vector256.Abs(x);
3130+
Vector256<float> z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f);
3131+
Vector256<uint> sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK);
3132+
return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle();
3133+
}
3134+
3135+
#if NET8_0_OR_GREATER
3136+
public static Vector512<float> Invoke(Vector512<float> x)
3137+
{
3138+
Vector512<float> y = Vector512.Abs(x);
3139+
Vector512<float> z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f);
3140+
Vector512<uint> sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK);
3141+
return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle();
3142+
}
3143+
#endif
3144+
}
3145+
29953146
/// <summary>MathF.Log(x)</summary>
29963147
private readonly struct LogOperator : IUnaryOperator
29973148
{

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
10261026
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
10271027
}
10281028

1029+
/// <summary>MathF.Exp(x)</summary>
10291030
private readonly struct ExpOperator : IUnaryOperator
10301031
{
10311032
public bool CanVectorize => false;
@@ -1035,6 +1036,36 @@ public Vector<float> Invoke(Vector<float> x) =>
10351036
throw new NotImplementedException();
10361037
}
10371038

1039+
/// <summary>MathF.Sinh(x)</summary>
1040+
private readonly struct SinhOperator : IUnaryOperator
1041+
{
1042+
public bool CanVectorize => false;
1043+
public float Invoke(float x) => MathF.Sinh(x);
1044+
public Vector<float> Invoke(Vector<float> x) =>
1045+
// requires ShiftLeft (.NET 7+)
1046+
throw new NotImplementedException();
1047+
}
1048+
1049+
/// <summary>MathF.Cosh(x)</summary>
1050+
private readonly struct CoshOperator : IUnaryOperator
1051+
{
1052+
public bool CanVectorize => false;
1053+
public float Invoke(float x) => MathF.Cosh(x);
1054+
public Vector<float> Invoke(Vector<float> x) =>
1055+
// requires ShiftLeft (.NET 7+)
1056+
throw new NotImplementedException();
1057+
}
1058+
1059+
/// <summary>MathF.Tanh(x)</summary>
1060+
private readonly struct TanhOperator : IUnaryOperator
1061+
{
1062+
public bool CanVectorize => false;
1063+
public float Invoke(float x) => MathF.Tanh(x);
1064+
public Vector<float> Invoke(Vector<float> x) =>
1065+
// requires ShiftLeft (.NET 7+)
1066+
throw new NotImplementedException();
1067+
}
1068+
10381069
/// <summary>MathF.Log(x)</summary>
10391070
private readonly struct LogOperator : IUnaryOperator
10401071
{

0 commit comments

Comments
 (0)