Skip to content

Commit

Permalink
use has_fma intrinsic for math (#43259)
Browse files Browse the repository at this point in the history
This removes all bad `fma` checks in `Base` and replaces them with fma
multiversioned ones. It also adds `Base.Math.two_mul` which gives the
fastest compensated multiplication available regardless of whether the
user has `fma` hardware. This is what should be used when possible for
compensated multiplication in the future.
  • Loading branch information
oscardssmith authored Nov 30, 2021
1 parent 4486567 commit cbc2ce8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
4 changes: 2 additions & 2 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ end
return hi, x-hi
end

@inline function twomul(a::Float64, b::Float64)
function twomul(a::Float64, b::Float64)
ahi, alo = splitbits(a)
bhi, blo = splitbits(b)
abhi = a*b
Expand All @@ -368,7 +368,7 @@ end
end

function fma_emulated(a::Float64, b::Float64,c::Float64)
abhi, ablo = twomul(a,b)
abhi, ablo = @inline twomul(a,b)
if !isfinite(abhi+c) || isless(abs(abhi), nextfloat(0x1p-969)) || issubnormal(a) || issubnormal(b)
aandbfinite = isfinite(a) && isfinite(b)
if !(isfinite(c) && aandbfinite)
Expand Down
35 changes: 26 additions & 9 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ end

# non-type specific math functions

@inline function two_mul(x::Float64, y::Float64)
if Core.Intrinsics.have_fma(Float64)
xy = x*y
return xy, fma(x, y, -xy)
end
return Base.twomul(x,y)
end

@inline function two_mul(x::T, y::T) where T<: Union{Float16, Float32}
if Core.Intrinsics.have_fma(T)
xy = x*y
return xy, fma(x, y, -xy)
end
xy = widen(x)*y
Txy = T(xy)
return Txy, T(xy-Txy)
end

"""
clamp(x, lo, hi)
Expand Down Expand Up @@ -278,8 +296,7 @@ end
hi, lo = p[end], zero(x)
for i in length(p)-1:-1:1
pi = p[i]
prod = hi*x
err = fma(hi, x, -prod)
prod, err = two_mul(hi,x)
hi = pi+prod
lo = fma(lo, x, prod - (hi - pi) + err)
end
Expand Down Expand Up @@ -686,7 +703,7 @@ function _hypot(x, y)
end
h = sqrt(muladd(ax, ax, ay*ay))
# This branch is correctly rounded but requires a native hardware fma.
if Base.Math.FMA_NATIVE
if Core.Intrinsics.have_fma(typeof(h))
hsquared = h*h
axsquared = ax*ax
h -= (fma(-ay, ay, hsquared-axsquared) + fma(h, h,-hsquared) - fma(ax, ax, -axsquared))/(2*h)
Expand Down Expand Up @@ -990,13 +1007,13 @@ function ^(x::Float64, n::Integer)
n == 3 && return x*x*x # keep compatibility with literal_pow
while n > 1
if n&1 > 0
yn = x*y
ynlo = fma(x, y , -yn) + muladd(y, xnlo, x*ynlo)
y = yn
err = muladd(y, xnlo, x*ynlo)
y, ynlo = two_mul(x,y)
ynlo += err
end
xn = x * x
xnlo = muladd(x, 2*xnlo, fma(x, x, -xn))
x = xn
err = x*2*xnlo
x, xnlo = two_mul(x, x)
xnlo += err
n >>>= 1
end
!isfinite(x) && return x*y
Expand Down
3 changes: 1 addition & 2 deletions base/special/hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ SINH_SMALL_X(::Type{Float32}) = 3.0f0

# For Float64, use DoubleFloat scheme for extra accuracy
function sinh_kernel(x::Float64)
x2 = x*x
x2lo = fma(x,x,-x2)
x2, x2lo = two_mul(x,x)
hi_order = evalpoly(x2, (8.333333333336817e-3, 1.9841269840165435e-4,
2.7557319381151335e-6, 2.5052096530035283e-8,
1.6059550718903307e-10, 7.634842144412119e-13,
Expand Down
36 changes: 12 additions & 24 deletions base/special/log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ const t_log_Float32 = (0.0,0.007782140442054949,0.015504186535965254,0.023167059
0.6773988235918061,0.6813592248079031,0.6853040030989194,0.689233281238809,
0.6931471805599453)

# determine if hardware FMA is available
# should probably check with LLVM, see #9855.
const FMA_NATIVE = muladd(nextfloat(1.0),nextfloat(1.0),-nextfloat(1.0,2)) != 0

# truncate lower order bits (up to 26)
# ideally, this should be able to use ANDPD instructions, see #9868.
@inline function truncbits(x::Float64)
Expand Down Expand Up @@ -209,18 +205,10 @@ end
# 2(f-u1-u2) - f*(u1+u2) = 0
# 2(f-u1) - f*u1 = (2+f)u2
# u2 = (2(f-u1) - f*u1)/(2+f)
if FMA_NATIVE
return u + fma(fma(-u,f,2(f-u)), g, q)
else
u1 = truncbits(u) # round to 24 bits
f1 = truncbits(f)
f2 = f-f1
u2 = ((2.0*(f-u1)-u1*f1)-u1*f2)*g
## Step 4
m_hi = logbU(Float64, base)
m_lo = logbL(Float64, base)
return fma(m_hi, u1, fma(m_hi, (u2 + q), m_lo*u1))
end

m_hi = logbU(Float64, base)
m_lo = logbL(Float64, base)
return fma(m_hi, u, fma(m_lo, u, m_hi*fma(fma(-u,f,2(f-u)), g, q)))
end


Expand Down Expand Up @@ -417,8 +405,8 @@ end
0.153846227114512262845736, 0.13332981086846273921509,
0.117754809412463995466069, 0.103239680901072952701192,
0.116255524079935043668677))
res_hi = hi_order * x_hi
res_lo = fma(x_lo, hi_order, fma(hi_order, x_hi, -res_hi))
res_hi, res_lo = two_mul(hi_order, x_hi)
res_lo = fma(x_lo, hi_order, res_lo)
ans_hi = c1hi + res_hi
ans_lo = ((c1hi - ans_hi) + res_hi) + (res_lo + 3.80554962542412056336616e-17)
return ans_hi, ans_lo
Expand All @@ -440,19 +428,19 @@ function _log_ext(d::Float64)
invy = inv(mp1hi)
xhi = (m - 1.0) * invy
xlo = fma(-xhi, mp1lo, fma(-xhi, mp1hi, m - 1.0)) * invy
x2hi = xhi * xhi
x2lo = muladd(xhi, xlo * 2.0, fma(xhi, xhi, -x2hi))
x2hi, x2lo = two_mul(xhi, xhi)
x2lo = muladd(xhi, xlo * 2.0, x2lo)
thi, tlo = log_ext_kernel(x2hi, x2lo)

shi = 0.6931471805582987 * e
xhi2 = xhi * 2.0
shinew = muladd(xhi, 2.0, shi)
slo = muladd(1.6465949582897082e-12, e, muladd(xlo, 2.0, (((shi - shinew) + xhi2))))
shi = shinew
x3hi = x2hi * xhi
x3lo = muladd(x2hi, xlo, muladd(xhi, x2lo,fma(x2hi, xhi, -x3hi)))
x3thi = x3hi * thi
x3tlo = muladd(x3hi, tlo, muladd(x3lo, thi, fma(x3hi, thi, -x3thi)))
x3hi, x3lo = two_mul(x2hi, xhi)
x3lo = muladd(x2hi, xlo, muladd(xhi, x2lo,x3lo))
x3thi, x3tlo = two_mul(x3hi, thi)
x3tlo = muladd(x3hi, tlo, muladd(x3lo, thi, x3tlo))
anshi = x3thi + shi
anslo = slo + x3tlo - ((anshi - shi) - x3thi)
return anshi, anslo
Expand Down

0 comments on commit cbc2ce8

Please sign in to comment.