Skip to content

Commit

Permalink
Vectorize TensorPrimitives.Tanh/Cosh/Sinh
Browse files Browse the repository at this point in the history
Tanh and Cosh are based on AOCL-LibM.

AOCL-LibM doesn't appear to have a sinh implementation, so this Sinh is just based on the sinh formula based on exp(x).

I also augmented the tests further, including:
- Added more tests for sinh/cosh/tanh
- Add an equality routine that supports comparing larger values with a tolerance
- Tightened the tolerance for most functions
- Changed some tests to be theories to be consistent with style elsewhere in the tests
- Fixed some use of Math to be MathF
  • Loading branch information
stephentoub committed Oct 6, 2023
1 parent 7f32a81 commit 3784cb9
Show file tree
Hide file tree
Showing 4 changed files with 502 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,8 @@ public static void AddMultiply(ReadOnlySpan<float> x, float y, ReadOnlySpan<floa
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Cosh(x[i]);
}
}
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<CoshOperator>(x, destination);

/// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
Expand Down Expand Up @@ -1012,20 +1000,8 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Sinh(x[i]);
}
}
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<SinhOperator>(x, destination);

/// <summary>Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down Expand Up @@ -1177,20 +1153,8 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Tanh(x[i]);
}
}
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<TanhOperator>(x, destination);

/// <summary>Throws an exception if the <paramref name="input"/> and <paramref name="output"/> spans overlap and don't begin at the same memory location.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2992,6 +2992,183 @@ public static Vector512<float> Invoke(Vector512<float> x)
#endif
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public static float Invoke(float x) => MathF.Sinh(x);
public static Vector128<float> Invoke(Vector128<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector128.Create(2f);
public static Vector256<float> Invoke(Vector256<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector256.Create(2f);
#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector512.Create(2f);
#endif
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
// This code is based on `vrs4_coshf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// coshf(|x| > 89.415985107421875) = Infinity
// coshf(Infinity) = infinity
// coshf(-Infinity) = infinity
//
// cosh(x) = (exp(x) + exp(-x))/2
// cosh(-x) = +cosh(x)
//
// checks for special cases
// if ( asint(x) > infinity) return x with overflow exception and
// return x.
// if x is NaN then raise invalid FP operation exception and return x.
//
// coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1

private const uint SIGN_MASK = 0x7FFFFFFF;
private const uint ARG_MAX = 0x42B2D4FC;
private const uint LOGV = 0x3f317300;
private const uint HALFV = 0x3f800074;
private const uint INVV2 = 0x3e7ffe30;

public static float Invoke(float x) => MathF.Cosh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32() & Vector128.Create(SIGN_MASK);
if (Vector128.GreaterThanAny(ux, Vector128.Create(ARG_MAX)))
{
return Vector128.Create(
MathF.Cosh(x.GetElement(0)),
MathF.Cosh(x.GetElement(1)),
MathF.Cosh(x.GetElement(2)),
MathF.Cosh(x.GetElement(3)));
}

Vector128<float> y = ux.AsSingle();
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV).AsSingle());
return Vector128.Create(HALFV).AsSingle() * (z + Vector128.Create(INVV2).AsSingle() * 1f / z);
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32() & Vector256.Create(SIGN_MASK);
if (Vector256.GreaterThanAny(ux, Vector256.Create(ARG_MAX)))
{
return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector256<float> y = ux.AsSingle();
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV).AsSingle());
return Vector256.Create(HALFV).AsSingle() * (z + Vector256.Create(INVV2).AsSingle() * 1f / z);
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32() & Vector512.Create(SIGN_MASK);
if (Vector512.GreaterThanAny(ux, Vector512.Create(ARG_MAX)))
{
return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector512<float> y = ux.AsSingle();
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV).AsSingle());
return Vector512.Create(HALFV).AsSingle() * (z + Vector512.Create(INVV2).AsSingle() * 1f / z);
}
#endif
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
// This code is based on `vrs4_tanhf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// To compute vrs4_tanhf(v_f32x4_t x)
// Let y = |x|
// If 0 <= y < 0x1.154246p3
// Let z = e^(-2.0 * y) - 1 -(1)
//
// Using (1), tanhf(y) can be calculated as,
// tanhf(y) = -z / (z + 2.0)
//
// For other cases, call scalar tanhf()
//
// If x < 0, then we use the identity
// tanhf(-x) = -tanhf(x)

private const uint V4_TANHF_ARG_MAX = 0x410AA123;
private const uint V4_TANHF_SIGN_MASK = 0x7FFFFFFF;

public static float Invoke(float x) => MathF.Tanh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32();
Vector128<uint> sign = ux & Vector128.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector128.Create(V4_TANHF_SIGN_MASK);
if (Vector128.GreaterThanAny(ux, Vector128.Create(V4_TANHF_ARG_MAX)))
{
return Vector128.Create(
MathF.Tanh(x.GetElement(0)),
MathF.Tanh(x.GetElement(1)),
MathF.Tanh(x.GetElement(2)),
MathF.Tanh(x.GetElement(3)));
}

Vector128<float> y = ux.AsSingle();
Vector128<float> z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f);
Vector128<uint> result = sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32();

return result.AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32();
Vector256<uint> sign = ux & Vector256.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector256.Create(V4_TANHF_SIGN_MASK);
if (Vector256.GreaterThanAny(ux, Vector256.Create(V4_TANHF_ARG_MAX)))
{
return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector256<float> y = ux.AsSingle();
Vector256<float> z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f);
Vector256<uint> result = sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32();

return result.AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32();
Vector512<uint> sign = ux & Vector512.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector512.Create(V4_TANHF_SIGN_MASK);
if (Vector512.GreaterThanAny(ux, Vector512.Create(V4_TANHF_ARG_MAX)))
{
return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector512<float> y = ux.AsSingle();
Vector512<float> z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f);
Vector512<uint> result = sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32();

return result.AsSingle();
}
#endif
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

/// <summary>MathF.Exp(x)</summary>
private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand All @@ -1035,6 +1036,36 @@ public Vector<float> Invoke(Vector<float> x) =>
throw new NotImplementedException();
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Sinh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Cosh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Tanh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Loading

0 comments on commit 3784cb9

Please sign in to comment.