Skip to content

Commit

Permalink
Merge pull request #28 from invenia/rf/statsbase
Browse files Browse the repository at this point in the history
Add support for weighted StatsBase methods
  • Loading branch information
mcabbott authored Nov 17, 2020
2 parents 82d7037 + 0ebe5db commit d8e8580
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 6 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@ version = "0.1.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
AbstractFFTs = "0.5"
BenchmarkTools = "0.5"
CovarianceEstimation = "0.2"
IntervalSets = "0.5.1"
InvertedIndices = "1.0"
LazyStack = "0.0.7, 0.0.8"
NamedDims = "0.2.27"
OffsetArrays = "0.10, 0.11, 1.0"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1"
julia = "1"

Expand Down
2 changes: 2 additions & 0 deletions src/AxisKeys.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ include("stack.jl") # LazyStack.jl

include("fft.jl") # AbstractFFTs.jl

include("statsbase.jl") # StatsBase.jl

end
25 changes: 19 additions & 6 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,36 @@ end

using Statistics
for fun in [:mean, :std, :var] # These don't use mapreduce, but could perhaps be handled better?
@eval function Statistics.$fun(A::KeyedArray; dims=:)
@eval function Statistics.$fun(A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return $fun(parent(A))
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A); dims=numerical_dims)
data = $fun(parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
VERSION >= v"1.3" &&
@eval function Statistics.$fun(f, A::KeyedArray; dims=:)
dims === Colon() && return $fun(f, parent(A))
end

# Handle function interface for `mean` only
if VERSION >= v"1.3"
@eval function Statistics.mean(f, A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return mean(f, parent(A))
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(f, parent(A); dims=numerical_dims)
data = mean(f, parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

for fun in [:cov, :cor] # Returned the axes work are different for cov and cor
@eval function Statistics.$fun(A::KeyedMatrix; dims=1, kwargs...)
numerical_dim = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A); dims=numerical_dim, kwargs...)
# Use same remaining axis for both dimensions of data
rem_key = axiskeys(A, 3-numerical_dim)
KeyedArray(data, (copy(rem_key), copy(rem_key)))
end
end

function Base.dropdims(A::KeyedArray; dims)
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = dropdims(parent(A); dims=dims)
Expand Down
3 changes: 3 additions & 0 deletions src/names.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ NdaKaVoM{L,T} = Union{NamedDimsArray{L,T,1,<:KeyedArray}, NamedDimsArray{L,T,2,<
NamedDims.dimnames(A::KaNda{L}) where {L} = L
NamedDims.dimnames(A::KaNda{L,T,N}, d::Int) where {L,T,N} = d <= N ? L[d] : :_

# Special case `dim` for KeyedArrays around NamedDimsArrays.
NamedDims.dim(A::KaNda{L}, name) where {L} = NamedDims.dim(L, name)

Base.axes(A::KaNda{L}, s::Symbol) where {L} = axes(A, NamedDims.dim(L,s))
Base.size(A::KaNda{L,T,N}, s::Symbol) where {T,N,L} = size(A, NamedDims.dim(L,s))

Expand Down
83 changes: 83 additions & 0 deletions src/statsbase.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using StatsBase


# Support some of the weighted statistics function in StatsBase
# NOTES:
# - Ambiguity errors are still possible for weights with overly specific methods (e.g., UnitWeights)
# - Ideally, when the weighted statistics is moved to Statistics.jl we can remove this entire file.
# https://github.com/JuliaLang/Statistics.jl/pull/2
function Statistics.mean(A::KeyedArray, wv::AbstractWeights; dims=:, kwargs...)
dims === Colon() && return mean(parent(A), wv; kwargs...)
numerical_dims = AxisKeys.hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = mean(parent(A), wv; dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end

# var and std are separate cause they don't use the dims keyword and we need to set corrected=true
for fun in [:var, :std]
@eval function Statistics.$fun(A::KeyedArray, wv::AbstractWeights; dims=:, corrected=true, kwargs...)
dims === Colon() && return $fun(parent(A), wv; kwargs...)
numerical_dims = AxisKeys.hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A), wv, numerical_dims; corrected=corrected, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

for fun in [:cov, :cor]
@eval function Statistics.$fun(A::KeyedMatrix, wv::AbstractWeights; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = $fun(unname(keyless(A)), wv, d; kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end
end

# scattermat is a StatsBase function and takes dims as a kwarg
function StatsBase.scattermat(A::KeyedMatrix, wv::AbstractWeights; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = scattermat(unname(keyless(A)), wv; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end

for fun in (:std, :var, :cov)
full_name = Symbol("mean_and_$fun")

@eval function StatsBase.$full_name(A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=:, corrected::Bool=true, kwargs...)
return (
mean(A, wv...; dims=dims, kwargs...),
$fun(A, wv...; dims=dims, corrected=corrected, kwargs...)
)
end
end

# Since we get ambiguity errors with specific implementations we need to wrap each supported method
# A better approach might be to add `NamedDims` support to CovarianceEstimators.jl in the future.
using CovarianceEstimation
estimators = [
:SimpleCovariance,
:LinearShrinkage,
:DiagonalUnitVariance,
:DiagonalCommonVariance,
:DiagonalUnequalVariance,
:CommonCovariance,
:PerfectPositiveCorrelation,
:ConstantCorrelation,
:AnalyticalNonlinearShrinkage,
]
for estimator in estimators
@eval function Statistics.cov(ce::$estimator, A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = cov(ce, unname(keyless(A)), wv...; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end
end
94 changes: 94 additions & 0 deletions test/_packages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,97 @@ end
@test_broken sortkeys(fft(A)) β‰ˆ fftshift(fft(A)) # isapprox should be used for keys

end

@testset "statsbase" begin
using CovarianceEstimation, StatsBase

A = rand(4, 3)
A_ka = KeyedArray(A, (0:3, [:a, :b, :c]))
A_kanda = KeyedArray(A; time = 0:3, id = [:a, :b, :c])
wv = aweights(rand(4))

@testset "$f" for f in (mean, std, var)
R = f == mean ? f(A, wv; dims=1) : f(A, wv, 1; corrected=true)

R_ka = f(A_ka, wv; dims=1)
R_kanda_int = f(A_kanda, wv; dims=1)
R_kanda_sym = f(A_kanda, wv; dims=:time)
expected_keys = (Base.OneTo(1), [:a, :b, :c])
expected_names = (:time, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) β‰ˆ parent(parent(R_kanda_int)) β‰ˆ parent(parent(R_kanda_sym)) β‰ˆ R
end

@testset "$f" for f in (cov, cor, scattermat)
# Inconsistent statsbase behaviour
kwargs = f === cov ? [:corrected => true] : []
R = f === scattermat ? f(A, wv; dims=1, kwargs...) : f(A, wv, 1; kwargs...)

R_ka = f(A_ka, wv; dims=1, kwargs...)
R_kanda_int = f(A_kanda, wv; dims=1, kwargs...)
R_kanda_sym = f(A_kanda, wv; dims=:time, kwargs...)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) β‰ˆ parent(parent(R_kanda_int)) β‰ˆ parent(parent(R_kanda_sym)) β‰ˆ R
end

@testset "$f" for f in (mean_and_var, mean_and_std, mean_and_cov)
R1, R2 = f(A, wv, 1; corrected=true)
R1_ka, R2_ka = f(A_ka, wv; dims=1, corrected=true)
R1_kanda_int, R2_kanda_int = f(A_kanda, wv; dims=1, corrected=true)
R1_kanda_sym, R2_kanda_sym = f(A_kanda, wv; dims=:time, corrected=true)

@test parent(R1_ka) β‰ˆ parent(parent(R1_kanda_int)) β‰ˆ parent(parent(R1_kanda_sym)) β‰ˆ R1
@test parent(R2_ka) β‰ˆ parent(parent(R2_kanda_int)) β‰ˆ parent(parent(R2_kanda_sym)) β‰ˆ R2
end

@testset "conversions" begin
@testset "cov2cor" begin
@test cov2cor(cov(A_ka; dims=1), std(A_ka; dims=1)) β‰ˆ cor(A_ka; dims=1)
@test cov2cor(cov(A_ka; dims=2), std(A_ka; dims=2)) β‰ˆ cor(A_ka; dims=2)
@test cov2cor(cov(A_kanda; dims=:time), std(A_kanda; dims=:time)) β‰ˆ cor(A_kanda; dims=:time)
@test cov2cor(cov(A_kanda; dims=:id), std(A_kanda; dims=:id)) β‰ˆ cor(A_kanda; dims=:id)
end
@testset "cor2cov" begin
@test cor2cov(cor(A_ka; dims=1), std(A_ka; dims=1)) β‰ˆ cov(A_ka; dims=1)
@test cor2cov(cor(A_ka; dims=2), std(A_ka; dims=2)) β‰ˆ cov(A_ka; dims=2)
@test cor2cov(cor(A_kanda; dims=:time), std(A_kanda; dims=:time)) β‰ˆ cov(A_kanda; dims=:time)
@test cor2cov(cor(A_kanda; dims=:id), std(A_kanda, dims=:id)) β‰ˆ cov(A_kanda; dims=:id)
end
end

@testset "covariance estimation" begin
ce = SimpleCovariance()

@testset "unweighted" begin
R = cov(ce, A; dims=1)
R_ka = cov(ce, A_ka; dims=1)
R_kanda_int = cov(ce, A_kanda; dims=1)
R_kanda_sym = cov(ce, A_kanda; dims=:time)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) β‰ˆ parent(parent(R_kanda_int)) β‰ˆ parent(parent(R_kanda_sym)) β‰ˆ R
end

@testset "weighted" begin
R = cov(ce, A, wv; dims=1)
R_ka = cov(ce, A_ka, wv; dims=1)
R_kanda_int = cov(ce, A_kanda, wv; dims=1)
R_kanda_sym = cov(ce, A_kanda, wv; dims=:time)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) β‰ˆ parent(parent(R_kanda_int)) β‰ˆ parent(parent(R_kanda_sym)) β‰ˆ R
end
end
end

0 comments on commit d8e8580

Please sign in to comment.