Skip to content

Commit

Permalink
Overhaul of rounding for rational numbers (#34658)
Browse files Browse the repository at this point in the history
* Simplify rounding for Rationals (fix #34645)
* Fix rounding for infinite Rationals (fix #34657)
* Remove explicit DivideError
* Type-stable rounding with digits/sigdigits
  • Loading branch information
sostock authored Apr 30, 2020
1 parent a916783 commit b077c63
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 54 deletions.
3 changes: 2 additions & 1 deletion base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ end
# NOTE: this relies on the current keyword dispatch behaviour (#9498).
function round(x::Real, r::RoundingMode=RoundNearest;
digits::Union{Nothing,Integer}=nothing, sigdigits::Union{Nothing,Integer}=nothing, base::Union{Nothing,Integer}=nothing)
isfinite(x) || return x
if digits === nothing
if sigdigits === nothing
if base === nothing
Expand All @@ -139,10 +138,12 @@ function round(x::Real, r::RoundingMode=RoundNearest;
# or throw(ArgumentError("`round` cannot use `base` argument without `digits` or `sigdigits` arguments."))
end
else
isfinite(x) || return float(x)
_round_sigdigits(x, r, sigdigits, base === nothing ? 10 : base)
end
else
if sigdigits === nothing
isfinite(x) || return float(x)
_round_digits(x, r, digits, base === nothing ? 10 : base)
else
throw(ArgumentError("`round` cannot use both `digits` and `sigdigits` arguments."))
Expand Down
2 changes: 1 addition & 1 deletion base/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ round(::Type{T}, ::Missing, ::RoundingMode=RoundNearest) where {T} =
throw(MissingException("cannot convert a missing value to type $T: use Union{$T, Missing} instead"))
round(::Type{T}, x::Any, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)
# to fix ambiguities
round(::Type{T}, x::Rational, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)
round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T>:Missing,Tr} = round(nonmissingtype_checked(T), x, r)
round(::Type{T}, x::Rational{Bool}, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)

# Handle ceil, floor, and trunc separately as they have no RoundingMode argument
Expand Down
56 changes: 8 additions & 48 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,66 +395,26 @@ for (S, T) in ((Rational, Integer), (Integer, Rational), (Rational, Rational))
end
end

trunc(::Type{T}, x::Rational) where {T} = convert(T,div(x.num,x.den))
floor(::Type{T}, x::Rational) where {T} = convert(T,fld(x.num,x.den))
ceil(::Type{T}, x::Rational) where {T} = convert(T,cld(x.num,x.den))
round(::Type{T}, x::Rational, r::RoundingMode=RoundNearest) where {T} = _round_rational(T, x, r)
round(x::Rational, r::RoundingMode) = round(Rational, x, r)

function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:Nearest}) where {T,Tr}
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr)+iseven(q))>>1 + copysign(Tr(2), numerator(x)))
s += copysign(one(Tr),numerator(x))
end
convert(T, s)
end
trunc(::Type{T}, x::Rational) where {T} = round(T, x, RoundToZero)
floor(::Type{T}, x::Rational) where {T} = round(T, x, RoundDown)
ceil(::Type{T}, x::Rational) where {T} = round(T, x, RoundUp)

function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTiesAway}) where {T,Tr}
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr))>>1 + copysign(Tr(2), numerator(x)))
s += copysign(one(Tr),numerator(x))
end
convert(T, s)
end
round(x::Rational, r::RoundingMode=RoundNearest) = round(typeof(x), x, r)

function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTiesUp}) where {T,Tr}
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
function round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T,Tr}
if iszero(denominator(x)) && !(T <: Integer)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr)+(numerator(x)<0))>>1 + copysign(Tr(2), numerator(x)))
s += copysign(one(Tr),numerator(x))
end
convert(T, s)
convert(T, div(numerator(x), denominator(x), r))
end

function round(::Type{T}, x::Rational{Bool}, ::RoundingMode=RoundNearest) where T
if denominator(x) == false && (T <: Union{Integer, Bool})
if denominator(x) == false && (T <: Integer)
throw(DivideError())
end
convert(T, x)
end

trunc(x::Rational{T}) where {T} = Rational(trunc(T,x))
floor(x::Rational{T}) where {T} = Rational(floor(T,x))
ceil(x::Rational{T}) where {T} = Rational(ceil(T,x))
round(x::Rational{T}) where {T} = Rational(round(T,x))

function ^(x::Rational, n::Integer)
n >= 0 ? power_by_squaring(x,n) : power_by_squaring(inv(x),-n)
end
Expand Down
83 changes: 79 additions & 4 deletions test/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,97 @@ end
end

@testset "round" begin
@test round(11//2) == 6//1 # rounds to closest _even_ integer
@test round(-11//2) == -6//1 # rounds to closest _even_ integer
@test round(11//3) == 4//1 # rounds to closest _even_ integer
@test round(-11//3) == -4//1 # rounds to closest _even_ integer
@test round(11//2) == round(11//2, RoundNearest) == 6//1 # rounds to closest _even_ integer
@test round(-11//2) == round(-11//2, RoundNearest) == -6//1 # rounds to closest _even_ integer
@test round(13//2) == round(13//2, RoundNearest) == 6//1 # rounds to closest _even_ integer
@test round(-13//2) == round(-13//2, RoundNearest) == -6//1 # rounds to closest _even_ integer
@test round(11//3) == round(11//3, RoundNearest) == 4//1 # rounds to closest _even_ integer
@test round(-11//3) == round(-11//3, RoundNearest) == -4//1 # rounds to closest _even_ integer

@test round(11//2, RoundNearestTiesAway) == 6//1
@test round(-11//2, RoundNearestTiesAway) == -6//1
@test round(13//2, RoundNearestTiesAway) == 7//1
@test round(-13//2, RoundNearestTiesAway) == -7//1
@test round(11//3, RoundNearestTiesAway) == 4//1
@test round(-11//3, RoundNearestTiesAway) == -4//1

@test round(11//2, RoundNearestTiesUp) == 6//1
@test round(-11//2, RoundNearestTiesUp) == -5//1
@test round(13//2, RoundNearestTiesUp) == 7//1
@test round(-13//2, RoundNearestTiesUp) == -6//1
@test round(11//3, RoundNearestTiesUp) == 4//1
@test round(-11//3, RoundNearestTiesUp) == -4//1

@test trunc(11//2) == round(11//2, RoundToZero) == 5//1
@test trunc(-11//2) == round(-11//2, RoundToZero) == -5//1
@test trunc(13//2) == round(13//2, RoundToZero) == 6//1
@test trunc(-13//2) == round(-13//2, RoundToZero) == -6//1
@test trunc(11//3) == round(11//3, RoundToZero) == 3//1
@test trunc(-11//3) == round(-11//3, RoundToZero) == -3//1

@test ceil(11//2) == round(11//2, RoundUp) == 6//1
@test ceil(-11//2) == round(-11//2, RoundUp) == -5//1
@test ceil(13//2) == round(13//2, RoundUp) == 7//1
@test ceil(-13//2) == round(-13//2, RoundUp) == -6//1
@test ceil(11//3) == round(11//3, RoundUp) == 4//1
@test ceil(-11//3) == round(-11//3, RoundUp) == -3//1

@test floor(11//2) == round(11//2, RoundDown) == 5//1
@test floor(-11//2) == round(-11//2, RoundDown) == -6//1
@test floor(13//2) == round(13//2, RoundDown) == 6//1
@test floor(-13//2) == round(-13//2, RoundDown) == -7//1
@test floor(11//3) == round(11//3, RoundDown) == 3//1
@test floor(-11//3) == round(-11//3, RoundDown) == -4//1

for T in (Float16, Float32, Float64)
@test round(T, true//false) === convert(T, Inf)
@test round(T, true//true) === one(T)
@test round(T, false//true) === zero(T)
@test trunc(T, true//false) === convert(T, Inf)
@test trunc(T, true//true) === one(T)
@test trunc(T, false//true) === zero(T)
@test floor(T, true//false) === convert(T, Inf)
@test floor(T, true//true) === one(T)
@test floor(T, false//true) === zero(T)
@test ceil(T, true//false) === convert(T, Inf)
@test ceil(T, true//true) === one(T)
@test ceil(T, false//true) === zero(T)
end

for T in (Int8, Int16, Int32, Int64, Bool)
@test_throws DivideError round(T, true//false)
@test round(T, true//true) === one(T)
@test round(T, false//true) === zero(T)
@test_throws DivideError trunc(T, true//false)
@test trunc(T, true//true) === one(T)
@test trunc(T, false//true) === zero(T)
@test_throws DivideError floor(T, true//false)
@test floor(T, true//true) === one(T)
@test floor(T, false//true) === zero(T)
@test_throws DivideError ceil(T, true//false)
@test ceil(T, true//true) === one(T)
@test ceil(T, false//true) === zero(T)
end

# issue 34657
@test round(1//0) === round(Rational, 1//0) === 1//0
@test trunc(1//0) === trunc(Rational, 1//0) === 1//0
@test floor(1//0) === floor(Rational, 1//0) === 1//0
@test ceil(1//0) === ceil(Rational, 1//0) === 1//0
@test round(-1//0) === round(Rational, -1//0) === -1//0
@test trunc(-1//0) === trunc(Rational, -1//0) === -1//0
@test floor(-1//0) === floor(Rational, -1//0) === -1//0
@test ceil(-1//0) === ceil(Rational, -1//0) === -1//0
for r = [RoundNearest, RoundNearestTiesAway, RoundNearestTiesUp,
RoundToZero, RoundUp, RoundDown]
@test round(1//0, r) === 1//0
@test round(-1//0, r) === -1//0
end

@test @inferred(round(1//0, digits=1)) === Inf
@test @inferred(trunc(1//0, digits=2)) === Inf
@test @inferred(floor(-1//0, sigdigits=1)) === -Inf
@test @inferred(ceil(-1//0, sigdigits=2)) === -Inf
end

@testset "issue 1552" begin
Expand Down

0 comments on commit b077c63

Please sign in to comment.