-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathconst_value_dist.jl
85 lines (61 loc) · 4.25 KB
/
const_value_dist.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# This file is a part of ValueShapes.jl, licensed under the MIT License (MIT).
"""
ConstValueDist <: Distributions.Distribution
Represents a delta distribution for a constant value of arbritrary type.
Calling `varshape` on a `ConstValueDist` will yield a
[`ConstValueShape`](@ref).
"""
struct ConstValueDist{VF<:VariateForm,T} <: Distribution{VF,Discrete}
value::T
end
export ConstValueDist
ConstValueDist(x::T) where {T<:Real} = ConstValueDist{Univariate,T}(x)
ConstValueDist(x::T) where {T<:AbstractVector{<:Real}} = ConstValueDist{Multivariate,T}(x)
ConstValueDist(x::T) where {T<:AbstractMatrix{<:Real}} = ConstValueDist{Matrixvariate,T}(x)
@static if isdefined(Distributions, :ArrayLikeVariate)
ConstValueDist(x::T) where {T<:AbstractArray{<:Real,N}} where N = ConstValueDist{ArrayLikeVariate{N},T}(x)
end
ConstValueDist(x::NamedTuple{names}) where names = ConstValueDist{NamedTupleVariate{names},typeof(x)}(x)
_pdf_impl(d::ConstValueDist, x) = d.value == x ? float(eltype(d))(1) : float(eltype(d))(0)
_logpdf_impl(d::ConstValueDist, x) = d.value == x ? float(eltype(d))(0) : float(eltype(d))(-Inf)
Distributions.pdf(d::ConstValueDist{Univariate}, x::Real) = _pdf_impl(d, x)
Distributions.logpdf(d::ConstValueDist{Univariate}, x::Real) = _logpdf_impl(d, x)
@static if isdefined(Distributions, :ArrayLikeVariate)
Distributions._pdf(d::ConstValueDist{<:ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where N = _pdf_impl(d, x)
Distributions._logpdf(d::ConstValueDist{<:ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where N = _logpdf_impl(d, x)
end
# Explicit defintions for Multivariate and Matrixvariate to avoid ambiguities with Distributions:
Distributions._pdf(d::ConstValueDist{Multivariate}, x::AbstractVector{<:Real}) = _pdf_impl(d, x)
Distributions._logpdf(d::ConstValueDist{Multivariate}, x::AbstractVector{<:Real}) = log(pdf(d, x))
Distributions._pdf(d::ConstValueDist{Matrixvariate}, x::AbstractMatrix{<:Real}) = _pdf_impl(d, x)
Distributions._logpdf(d::ConstValueDist{Matrixvariate}, x::AbstractMatrix{<:Real}) = log(pdf(d, x))
Distributions.pdf(d::ConstValueDist{<:NamedTupleVariate{names}}, x::NamedTuple{names}) where names = _pdf_impl(d, x)
Distributions.logpdf(d::ConstValueDist{<:NamedTupleVariate{names}}, x::NamedTuple{names}) where names = log(pdf(d, x))
Distributions.insupport(d::ConstValueDist{Univariate}, x::Real) = x == d.value
@static if isdefined(Distributions, :ArrayLikeVariate)
Distributions.insupport(d::ConstValueDist{<:ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where N = x == d.value
else
Distributions.insupport(d::ConstValueDist{Multivariate}, x::AbstractVector{<:Real}) = x == d.value
Distributions.insupport(d::ConstValueDist{Matrixvariate}, x::AbstractMatrix{<:Real}) = x == d.value
end
Distributions.insupport(d::ConstValueDist{<:NamedTupleVariate{names}}, x::NamedTuple{names}) where names = x == d.value
Distributions.cdf(d::ConstValueDist{Univariate}, x::Real) = d.value <= x ? Float32(1) : Float32(0)
Distributions.quantile(d::ConstValueDist{Univariate}, q::Real) = d.value # Sensible?
Distributions.minimum(d::ConstValueDist{Univariate}) = d.value
Distributions.maximum(d::ConstValueDist{Univariate}) = d.value
StatsBase.mean(d::ConstValueDist) = d.value
StatsBase.mode(d::ConstValueDist) = d.value
Base.size(d::ConstValueDist{<:PlainVariate}) = size(d.value)
Base.length(d::ConstValueDist{<:PlainVariate}) = prod(size(d))
Base.eltype(d::ConstValueDist{<:PlainVariate}) = eltype(d.value)
Random.rand(rng::AbstractRNG, d::ConstValueDist) = d.value
@static if isdefined(Distributions, :ArrayLikeVariate)
Distributions._rand!(rng::AbstractRNG, d::ConstValueDist{<:ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where N = copyto!(x, d.value)
else
Distributions._rand!(rng::AbstractRNG, d::ConstValueDist{<:Multivariate}, x::AbstractVector{<:Real}) = copyto!(x, d.value)
Distributions._rand!(rng::AbstractRNG, d::ConstValueDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) = copyto!(x, d.value)
end
Random.rand(rng::AbstractRNG, d::ConstValueDist{<:StructVariate}, dims::Dims) = Fill(d.value, dims)
Random.rand!(rng::AbstractRNG, d::ConstValueDist{<:StructVariate}, A::AbstractArray) = fill!(A, d.value)
ValueShapes.varshape(d::ConstValueDist) = ConstValueShape(d.value)
Statistics.var(d::ConstValueDist) = zero(d.value)