Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add simplex tools #183

Merged
merged 3 commits into from
Oct 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/QuantEcon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export
smooth, periodogram, ar_periodogram,

# util
meshgrid, gridmake, gridmake!, ckron, is_stable,
meshgrid, gridmake, gridmake!, ckron, is_stable, num_compositions, simplex_grid, simplex_index,

# robustlq
RBLQ,
Expand Down
125 changes: 125 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,128 @@ function is_stable(A::AbstractMatrix)

end


"""
The total number of m-part compositions of n, which is equal to

(n + m - 1) choose (m - 1)

##### Arguments

- `m`::Int : Number of parts of composition
- `n`::Int : Integer to decompose

##### Returns

- ::Int - Total number of m-part compositions of n
"""
function num_compositions(m, n)
return binomial(n+m-1, m-1)
end


"""
Construct an array consisting of the integer points in the
(m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n
\}`, or equivalently, the m-part compositions of n, which are listed
in lexicographic order. The total number of the points (hence the
length of the output array) is L = (n+m-1)!/(n!*(m-1)!) (i.e.,
(n+m-1) choose (m-1)).

##### Arguments

- `m`::Int : Dimension of each point. Must be a positive integer.
- `n`::Int : Number which the coordinates of each point sum to. Must
be a nonnegative integer.

##### Returns
- `out`::Matrix{Int} : Array of shape (m, L) containing the integer
points in the simplex, aligned in lexicographic
order.

##### Notes

A grid of the (m-1)-dimensional *unit* simplex with n subdivisions
along each dimension can be obtained by `simplex_grid(m, n) / n`.

##### Examples

>>> simplex_grid(3, 4)

3×15 Array{Int64,2}:
0 0 0 0 0 1 1 1 1 2 2 2 3 3 4
0 1 2 3 4 0 1 2 3 0 1 2 0 1 0
4 3 2 1 0 3 2 1 0 2 1 0 1 0 0

##### References

A. Nijenhuis and H. S. Wilf, Combinatorial Algorithms, Chapter 5,
Academic Press, 1978.
"""
function simplex_grid(m, n)
# Get number of elements in array and allocate
L = num_compositions(m, n)
out = Array{Int}(m, L)

x = zeros(Int, m)
x[m] = n

# Fill in first column
copy!(out, 1, x, 1, m)

h = m
for i in 2:L
h -= 1

val = x[h+1]
x[h+1] = 0
x[m] = val - 1
x[h] += 1

copy!(out, 3*(i-1) + 1, x, 1, m)

if val != 1
h = m
end
end

return out
end


"""
Return the index of the point x in the lexicographic order of the
integer points of the (m-1)-dimensional simplex :math:`\{x \mid x_0
+ \cdots + x_{m-1} = n\}`.

##### Arguments

- `x`::Array{Int,1} : Integer point in the simplex, i.e., an array of
m nonnegative integers that sum to n.
- `m`::Int : Dimension of each point. Must be a positive integer.
- `n`::Int : Number which the coordinates of each point sum to. Must be a
nonnegative integer.

##### Returns
- `idx`::Int : Index of x.

"""
function simplex_index(x, m, n)
# If only one element then only one point in simplex
if m==1
return 1
end

decumsum = reverse(cumsum(reverse(x[2:end])))
idx = binomial(n+m-1, m-1)
for i in 1:m-1
if decumsum[i] == 0
break
end

idx -= num_compositions(m - (i-1), decumsum[i]-1)
end

return idx
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ tests = [
"optimization",
"interp",
"sampler",
"util",
]


Expand Down
29 changes: 24 additions & 5 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,35 @@
end

@testset "fix" begin
@test 1 == @inferred fix(1.2)
@test [0, 2] == @inferred fix([0.9, 2.1])
@test [0, 2] == @inferred fix([0, 2])
@test 1 == @inferred QuantEcon.fix(1.2)
@test [0, 2] == @inferred QuantEcon.fix([0.9, 2.1])
@test [0, 2] == @inferred QuantEcon.fix([0, 2])

out = [100, 100]
@inferred fix!([0, 2], out)
@inferred QuantEcon.fix!([0, 2], out)
@test [0, 2] == out

@test [0 2; 2 0] == @inferred fix([0.5 2.9999; 3-2eps() -0.9])
@test [0 2; 2 0] == @inferred QuantEcon.fix([0.5 2.9999; 3-2eps() -0.9])
end

@test ([1 2; 1 2], [3 3; 4 4]) == @inferred meshgrid([1, 2], [3, 4])

@testset "simplex_tools" begin
grid_3_4 = [0 0 0 0 0 1 1 1 1 2 2 2 3 3 4
0 1 2 3 4 0 1 2 3 0 1 2 0 1 0
4 3 2 1 0 3 2 1 0 2 1 0 1 0 0]
points = [0 1 4
0 1 0
4 2 0]
idx = Array{Int}(3)

for i in 1:3
idx[i] = simplex_index(points[:, i], 3, 4)
end

@test all(simplex_grid(3, 4) .== grid_3_4)
@test all(grid_3_4[:, idx] .== points)
@test size(grid_3_4, 2) == num_compositions(3, 4)

end
end