Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add partype method to lognormal and semicircle #1773

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ stdlogx(d::LogNormal) = d.σ
mean(d::LogNormal) = ((μ, σ) = params(d); exp(μ + σ^2/2))
median(d::LogNormal) = exp(d.μ)
mode(d::LogNormal) = ((μ, σ) = params(d); exp(μ - σ^2))
partype(::LogNormal{T}) where {T<:Real} = T

function var(d::LogNormal)
(μ, σ) = params(d)
Expand Down
2 changes: 2 additions & 0 deletions src/univariate/continuous/semicircle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Semicircle(r::Integer; check_args::Bool=true) = Semicircle(float(r); check_args=
@distr_support Semicircle -d.r +d.r

params(d::Semicircle) = (d.r,)
partype(::Semicircle{T}) where {T<:Real} = T


mean(d::Semicircle) = zero(d.r)
var(d::Semicircle) = d.r^2 / 4
Expand Down
7 changes: 7 additions & 0 deletions src/univariate/discrete/discreteuniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ span(d::DiscreteUniform) = d.b - d.a + 1
probval(d::DiscreteUniform) = d.pv
params(d::DiscreteUniform) = (d.a, d.b)

partype(::DiscreteUniform) = Int


### Show

show(io::IO, d::DiscreteUniform) = show(io, d, (:a, :b))
Expand Down Expand Up @@ -114,3 +117,7 @@ function fit_mle(::Type{DiscreteUniform}, x::AbstractArray{<:Real})
end
return DiscreteUniform(extrema(x)...)
end




1 change: 1 addition & 0 deletions src/univariate/discrete/hypergeometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ end

@distr_support Hypergeometric max(d.n - d.nf, 0) min(d.ns, d.n)

partype(::Hypergeometric) = Int

### Parameters

Expand Down
13 changes: 9 additions & 4 deletions src/univariate/locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ struct AffineDistribution{T<:Real, S<:ValueSupport, D<:UnivariateDistribution{S}
end
end


function AffineDistribution(μ::T, σ::T, ρ::UnivariateDistribution; check_args::Bool=true) where {T<:Real}
@check_args AffineDistribution (σ, !iszero(σ))
_T = promote_type(eltype(ρ), T)
return AffineDistribution{_T}(_T(μ), _T(σ), ρ)
# μ and σ act on both random numbers and parameter-like quantities like mean
# hence do not promote: but take care in eltype and partype
#_T = promote_type(eltype(ρ), T)
# _T = typeof(one(eltype(D))*one(T) + one(T))
# return AffineDistribution{_T}(_T(μ), _T(σ), ρ)
return AffineDistribution{T}(μ, σ, ρ)
end

function AffineDistribution(μ::Real, σ::Real, ρ::UnivariateDistribution; check_args::Bool=true)
Expand All @@ -71,7 +76,7 @@ end
const ContinuousAffineDistribution{T<:Real,D<:ContinuousUnivariateDistribution} = AffineDistribution{T,Continuous,D}
const DiscreteAffineDistribution{T<:Real,D<:DiscreteUnivariateDistribution} = AffineDistribution{T,Discrete,D}

Base.eltype(::Type{<:AffineDistribution{T}}) where T = T
Base.eltype(::Type{<:AffineDistribution{T,S,D}}) where {T,S,D} = promote_type(eltype(D), T)

minimum(d::AffineDistribution) =
d.σ > 0 ? d.μ + d.σ * minimum(d.ρ) : d.μ + d.σ * maximum(d.ρ)
Expand Down Expand Up @@ -102,7 +107,7 @@ Base.convert(::Type{AffineDistribution{T}}, d::AffineDistribution{T}) where {T<:
location(d::AffineDistribution) = d.μ
scale(d::AffineDistribution) = d.σ
params(d::AffineDistribution) = (d.μ,d.σ,d.ρ)
partype(::AffineDistribution{T}) where {T} = T
partype(d::AffineDistribution{T}) where {T} = promote_type(partype(d.ρ), T)

#### Statistics

Expand Down
2 changes: 1 addition & 1 deletion test/univariate/continuous/semicircle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ d = Semicircle(2.0)

rng = MersenneTwister(0)
for r in rand(rng, Uniform(0,10), 5)
N = 10^4
local N = 10^4
semi = Semicircle(r)
sample = rand(rng, semi, N)
mi, ma = extrema(sample)
Expand Down
2 changes: 1 addition & 1 deletion test/univariate/locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function test_location_scale(
rand!(rng, dtest, r)
end
@test mean(r) ≈ mean(dref) atol=0.02
@test std(r) ≈ std(dref) atol=0.01
@test std(r) ≈ std(dref) atol=0.02
@test cf(dtest, -0.1) ≈ cf(dref,-0.1)

if dref isa ContinuousDistribution
Expand Down
6 changes: 6 additions & 0 deletions test/univariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ function verify_and_test(D::Union{Type,Function}, d::UnivariateDistribution, dct
# test various constructors for promotion, all-Integer args, etc.
pars = params(d)

# verify parameter type
# truncated parameters may be nothing: Union{Nothing, promote_type()}
# in the following exclude all non-Real from creating the promoted type
@test partype(d) === mapfoldl(
typeof, (S, T) -> T <: Real ? promote_type(S, T) : S, pars; init = Bool)

# promotion constructor:
float_pars = map(x -> isa(x, AbstractFloat), pars)
if length(pars) > 1 && sum(float_pars) > 1 && !isa(D, typeof(truncated))
Expand Down
10 changes: 8 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ r = RealInterval(1.5, 4.0)

# special cases
@test partype(Kolmogorov()) == Float64
@test partype(Hypergeometric(2, 2, 2)) == Float64
@test partype(DiscreteUniform(0, 4)) == Float64
# twutz: Hypergeometric needs integer parameters
#@test partype(Hypergeometric(2, 2, 2)) == Float64
@test partype(Hypergeometric(2, 2, 2)) == Int
@test partype(Hypergeometric(2.0, 2, 2)) == Int
# twutz: should DiscreteUniform needs a partype of Int
#@test partype(DiscreteUniform(0, 4)) == Float64
@test partype(DiscreteUniform(0, 4)) == Int
@test partype(DiscreteUniform(0.0, 4)) == Int

A = rand(1:10, 5, 5)
B = rand(Float32, 4)
Expand Down