Skip to content

Commit

Permalink
Optimize multiplication for Normed
Browse files Browse the repository at this point in the history
This adds `wrapping_mul`, `saturating_mul` and `checked_mul` binary operations.
However, this does not specialize them for `Fixed` and does not change `*` for `Fixed`.

This replaces most of Normed's implementation of multiplication with integer operations.
This improves the speed in many cases and the accuracy in some cases.
  • Loading branch information
kimikage committed Aug 14, 2020
1 parent 9f889f6 commit e66a277
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/FixedPointNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ float(x::FixedPoint) = convert(floattype(x), x)
wrapping_neg(x::X) where {X <: FixedPoint} = X(-x.i, 0)
wrapping_add(x::X, y::X) where {X <: FixedPoint} = X(x.i + y.i, 0)
wrapping_sub(x::X, y::X) where {X <: FixedPoint} = X(x.i - y.i, 0)
wrapping_mul(x::X, y::X) where {X <: FixedPoint} = (float(x) * float(y)) % X

# saturating arithmetic
saturating_neg(x::X) where {X <: FixedPoint} = X(~min(x.i - true, x.i), 0)
Expand All @@ -202,10 +203,13 @@ saturating_sub(x::X, y::X) where {X <: FixedPoint} =
X(x.i - ifelse(x.i < 0, min(y.i, x.i - typemin(x.i)), max(y.i, x.i - typemax(x.i))), 0)
saturating_sub(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i - min(x.i, y.i), 0)

saturating_mul(x::X, y::X) where {X <: FixedPoint} = clamp(float(x) * float(y), X)

# checked arithmetic
checked_neg(x::X) where {X <: FixedPoint} = X(checked_neg(x.i), 0)
checked_add(x::X, y::X) where {X <: FixedPoint} = X(checked_add(x.i, y.i), 0)
checked_sub(x::X, y::X) where {X <: FixedPoint} = X(checked_sub(x.i, y.i), 0)
checked_mul(x::X, y::X) where {X <: FixedPoint} = X(float(x) * float(y))

# default arithmetic
const DEFAULT_ARITHMETIC = :wrapping
Expand All @@ -216,7 +220,7 @@ for (op, name) in ((:-, :neg), )
$op(x::X) where {X <: FixedPoint} = $f(x)
end
end
for (op, name) in ((:+, :add), (:-, :sub))
for (op, name) in ((:+, :add), (:-, :sub), (:*, :mul))
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
@eval begin
$op(x::X, y::X) where {X <: FixedPoint} = $f(x, y)
Expand Down Expand Up @@ -427,6 +431,13 @@ scaledual(::Type{Tdual}, x::AbstractArray{T}) where {Tdual, T <: FixedPoint} =
throw(ArgumentError(String(take!(io))))
end

@noinline function throw_overflowerror(op::Symbol, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
print(io, x, ' ', op, ' ', y, " overflowed for type ")
showtype(io, typeof(x))
throw(OverflowError(String(take!(io))))
end

function Random.rand(r::AbstractRNG, ::SamplerType{X}) where X <: FixedPoint
X(rand(r, rawtype(X)), 0)
end
Expand Down
42 changes: 39 additions & 3 deletions src/normed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ function rem(x::Float64, ::Type{N}) where {f, N <: Normed{UInt64,f}}
reinterpret(N, r << UInt8(f - 53) - unsigned(signed(r) >> 0x35))
end


function (::Type{T})(x::Normed) where {T <: AbstractFloat}
# The following optimization for constant division may cause rounding errors.
# y = reinterpret(x)*(one(rawtype(x))/convert(T, rawone(x)))
Expand Down Expand Up @@ -248,8 +247,45 @@ Base.BigFloat(x::Normed) = reinterpret(x) / BigFloat(rawone(x))

Base.Rational(x::Normed) = reinterpret(x)//rawone(x)

# unchecked arithmetic
*(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)*convert(floattype(T), y))
# Division by `2^f-1` with RoundNearest. The result would be in the lower half bits.
div_2fm1(x::T, ::Val{f}) where {T, f} = (x + (T(1)<<(f - 1) - 0x1)) ÷ (T(1) << f - 0x1)
div_2fm1(x::T, ::Val{1}) where T = x
div_2fm1(x::UInt16, ::Val{8}) = (((x + 0x80) >> 0x8) + x + 0x80) >> 0x8
div_2fm1(x::UInt32, ::Val{16}) = (((x + 0x8000) >> 0x10) + x + 0x8000) >> 0x10
div_2fm1(x::UInt64, ::Val{32}) = (((x + 0x80000000) >> 0x20) + x + 0x80000000) >> 0x20
div_2fm1(x::UInt64, ::Val{64}) = (((x + 0x8000000000000000) >> 0x40) + x + 0x8000000000000000) >> 0x40

# wrapping arithmetic
function wrapping_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
z = widemul(x.i, y.i)
N(div_2fm1(z, Val(Int(f))) % T, 0)
end

# saturating arithmetic
function saturating_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
f == bitwidth(T) && return wrapping_mul(x, y)
z = min(widemul(x.i, y.i), widemul(typemax(N).i, rawone(N)))
N(div_2fm1(z, Val(Int(f))) % T, 0)
end

# checked arithmetic
function checked_mul(x::N, y::N) where {N <: Normed}
z = float(x) * float(y)
z < typemax(N) + eps(N)/2 || throw_overflowerror(:*, x, y)
z % N
end
function checked_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
f == bitwidth(T) && return wrapping_mul(x, y)
z = widemul(x.i, y.i)
m = widemul(typemax(N).i, rawone(N)) + (rawone(N) >> 0x1)
z < m || throw_overflowerror(:*, x, y)
N(div_2fm1(z, Val(Int(f))) % T, 0)
end

# TODO: decide the default arithmetic for `Normed` mul
# Override the default arithmetic with `checked` for backward compatibility
*(x::N, y::N) where {N <: Normed} = checked_mul(x, y)

/(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)/convert(floattype(T), y))

# Functions
Expand Down
7 changes: 4 additions & 3 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ exponent_bias(::Type{Float32}) = 127
exponent_bias(::Type{Float64}) = 1023

_unsafe_trunc(::Type{T}, x::Integer) where {T} = x % T
_unsafe_trunc(::Type{T}, x) where {T} = unsafe_trunc(T, x)
_unsafe_trunc(::Type{T}, x) where {T} = unsafe_trunc(T, x)
# issue #202, #211
_unsafe_trunc(::Type{T}, x::BigFloat) where {T <: Integer} = trunc(BigInt, x) % T

if !signbit(signed(unsafe_trunc(UInt, -12.345)))
# a workaround for ARM (issue #134)
function _unsafe_trunc(::Type{T}, x::AbstractFloat) where {T <: Integer}
unsafe_trunc(T, unsafe_trunc(signedtype(T), x))
end
# exclude BigFloat (issue #202)
_unsafe_trunc(::Type{T}, x::BigFloat) where {T <: Integer} = unsafe_trunc(T, x)
end

wrapper(@nospecialize(T)) = Base.typename(T).wrapper
25 changes: 25 additions & 0 deletions test/fixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,31 @@ end
end
end

@testset "mul" begin
for F in target(Fixed; ex = :thin)
@test wrapping_mul(typemax(F), zero(F)) === zero(F)
@test saturating_mul(typemax(F), zero(F)) === zero(F)
@test checked_mul(typemax(F), zero(F)) === zero(F)

@test wrapping_mul(F(-1), typemax(F)) === -typemax(F)
@test saturating_mul(F(-1), typemax(F)) === -typemax(F)
@test checked_mul(F(-1), typemax(F)) === -typemax(F)

@test wrapping_mul(typemin(F), typemax(F)) === big(typemin(F)) * big(typemax(F)) % F
@test saturating_mul(typemin(F), typemax(F)) === typemin(F)
@test_throws Exception checked_mul(typemin(F), typemax(F)) # TODO: Exception -> OverflowError
end
for F in target(Fixed, :i8; ex = :thin)
xs = typemin(F):eps(F):typemax(F)
xys = ((x, y) for x in xs, y in xs)
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
@test all(((x, y),) -> wrapping_mul(x, y) === fmul(x, y) % F, xys)
@test all(((x, y),) -> saturating_mul(x, y) === clamp(fmul(x, y), F), xys)
@test all(((x, y),) -> !(typemin(F) < fmul(x, y) < typemax(F)) ||
wrapping_mul(x, y) === checked_mul(x, y), xys)
end
end

@testset "rounding" begin
for sym in (:i8, :i16, :i32, :i64)
T = symbol_to_inttype(Fixed, sym)
Expand Down
29 changes: 29 additions & 0 deletions test/normed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,35 @@ end
end
end

@testset "mul" begin
for N in target(Normed; ex = :thin)
@test wrapping_mul(typemax(N), zero(N)) === zero(N)
@test saturating_mul(typemax(N), zero(N)) === zero(N)
@test checked_mul(typemax(N), zero(N)) === zero(N)

@test wrapping_mul(one(N), typemax(N)) === typemax(N)
@test saturating_mul(one(N), typemax(N)) === typemax(N)
@test checked_mul(one(N), typemax(N)) === typemax(N)

@test wrapping_mul(typemax(N), typemax(N)) === big(typemax(N))^2 % N
@test saturating_mul(typemax(N), typemax(N)) === typemax(N)
if typemax(N) == 1
@test checked_mul(typemax(N), typemax(N)) === typemax(N)
else
@test_throws OverflowError checked_mul(typemax(N), typemax(N))
end
end
for N in target(Normed, :i8; ex = :thin)
xs = typemin(N):eps(N):typemax(N)
xys = ((x, y) for x in xs, y in xs)
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
@test all(((x, y),) -> wrapping_mul(x, y) === fmul(x, y) % N, xys)
@test all(((x, y),) -> saturating_mul(x, y) === clamp(fmul(x, y), N), xys)
@test all(((x, y),) -> !(typemin(N) < fmul(x, y) < typemax(N)) ||
wrapping_mul(x, y) === checked_mul(x, y), xys)
end
end

@testset "div/fld1" begin
@test div(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == 8
@test div(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == 7
Expand Down

0 comments on commit e66a277

Please sign in to comment.