Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for weighted StatsBase methods #28

Merged
merged 7 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@ version = "0.1.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
AbstractFFTs = "0.5"
BenchmarkTools = "0.5"
CovarianceEstimation = "0.2"
IntervalSets = "0.5.1"
InvertedIndices = "1.0"
LazyStack = "0.0.7, 0.0.8"
NamedDims = "0.2.27"
OffsetArrays = "0.10, 0.11, 1.0"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1"
julia = "1"

Expand Down
2 changes: 2 additions & 0 deletions src/AxisKeys.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ include("stack.jl") # LazyStack.jl

include("fft.jl") # AbstractFFTs.jl

include("statsbase.jl") # StatsBase.jl

end
25 changes: 19 additions & 6 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,36 @@ end

using Statistics
for fun in [:mean, :std, :var] # These don't use mapreduce, but could perhaps be handled better?
@eval function Statistics.$fun(A::KeyedArray; dims=:)
@eval function Statistics.$fun(A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return $fun(parent(A))
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A); dims=numerical_dims)
data = $fun(parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
VERSION >= v"1.3" &&
@eval function Statistics.$fun(f, A::KeyedArray; dims=:)
dims === Colon() && return $fun(f, parent(A))
end

# Handle function interface for `mean` only
if VERSION >= v"1.3"
@eval function Statistics.mean(f, A::KeyedArray; dims=:, kwargs...)
dims === Colon() && return mean(f, parent(A))
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(f, parent(A); dims=numerical_dims)
data = mean(f, parent(A); dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

for fun in [:cov, :cor] # Returned the axes work are different for cov and cor
@eval function Statistics.$fun(A::KeyedMatrix; dims=1, kwargs...)
numerical_dim = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A); dims=numerical_dim, kwargs...)
# Use same remaining axis for both dimensions of data
rem_key = axiskeys(A, 3-numerical_dim)
KeyedArray(data, (copy(rem_key), copy(rem_key)))
end
end

function Base.dropdims(A::KeyedArray; dims)
numerical_dims = hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = dropdims(parent(A); dims=dims)
Expand Down
3 changes: 3 additions & 0 deletions src/names.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ NdaKaVoM{L,T} = Union{NamedDimsArray{L,T,1,<:KeyedArray}, NamedDimsArray{L,T,2,<
NamedDims.dimnames(A::KaNda{L}) where {L} = L
NamedDims.dimnames(A::KaNda{L,T,N}, d::Int) where {L,T,N} = d <= N ? L[d] : :_

# Special case `dim` for KeyedArrays around NamedDimsArrays.
NamedDims.dim(A::KaNda{L}, name) where {L} = NamedDims.dim(L, name)

Base.axes(A::KaNda{L}, s::Symbol) where {L} = axes(A, NamedDims.dim(L,s))
Base.size(A::KaNda{L,T,N}, s::Symbol) where {T,N,L} = size(A, NamedDims.dim(L,s))

Expand Down
83 changes: 83 additions & 0 deletions src/statsbase.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using StatsBase


# Support some of the weighted statistics function in StatsBase
# NOTES:
# - Ambiguity errors are still possible for weights with overly specific methods (e.g., UnitWeights)
# - Ideally, when the weighted statistics is moved to Statistics.jl we can remove this entire file.
# https://github.com/JuliaLang/Statistics.jl/pull/2
function Statistics.mean(A::KeyedArray, wv::AbstractWeights; dims=:, kwargs...)
dims === Colon() && return mean(parent(A), wv; kwargs...)
numerical_dims = AxisKeys.hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = mean(parent(A), wv; dims=numerical_dims, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end

# var and std are separate cause they don't use the dims keyword and we need to set corrected=true
for fun in [:var, :std]
@eval function Statistics.$fun(A::KeyedArray, wv::AbstractWeights; dims=:, corrected=true, kwargs...)
dims === Colon() && return $fun(parent(A), wv; kwargs...)
numerical_dims = AxisKeys.hasnames(A) ? NamedDims.dim(dimnames(A), dims) : dims
data = $fun(parent(A), wv, numerical_dims; corrected=corrected, kwargs...)
new_keys = ntuple(d -> d in numerical_dims ? Base.OneTo(1) : axiskeys(A,d), ndims(A))
return KeyedArray(data, map(copy, new_keys))#, copy(A.meta))
end
end

for fun in [:cov, :cor]
@eval function Statistics.$fun(A::KeyedMatrix, wv::AbstractWeights; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = $fun(unname(keyless(A)), wv, d; kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end
end

# scattermat is a StatsBase function and takes dims as a kwarg
function StatsBase.scattermat(A::KeyedMatrix, wv::AbstractWeights; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = scattermat(unname(keyless(A)), wv; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end

for fun in (:std, :var, :cov)
full_name = Symbol("mean_and_$fun")

@eval function StatsBase.$full_name(A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=:, corrected::Bool=true, kwargs...)
return (
mean(A, wv...; dims=dims, kwargs...),
$fun(A, wv...; dims=dims, corrected=corrected, kwargs...)
)
end
end

# Since we get ambiguity errors with specific implementations we need to wrap each supported method
# A better approach might be to add `NamedDims` support to CovarianceEstimators.jl in the future.
using CovarianceEstimation
estimators = [
:SimpleCovariance,
:LinearShrinkage,
:DiagonalUnitVariance,
:DiagonalCommonVariance,
:DiagonalUnequalVariance,
:CommonCovariance,
:PerfectPositiveCorrelation,
:ConstantCorrelation,
:AnalyticalNonlinearShrinkage,
]
for estimator in estimators
@eval function Statistics.cov(ce::$estimator, A::KeyedMatrix, wv::Vararg{<:AbstractWeights}; dims=1, kwargs...)
d = NamedDims.dim(A, dims)
data = cov(ce, unname(keyless(A)), wv...; dims=d, kwargs...)
L1 = dimnames(A, 3 - d)
data2 = hasnames(A) ? NamedDimsArray(data, (L1, L1)) : data
K1 = axiskeys(A, 3 - d)
return KeyedArray(data2, (copy(K1), copy(K1)))
end
end
94 changes: 94 additions & 0 deletions test/_packages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,97 @@ end
@test_broken sortkeys(fft(A)) ≈ fftshift(fft(A)) # isapprox should be used for keys

end

@testset "statsbase" begin
using CovarianceEstimation, StatsBase

A = rand(4, 3)
A_ka = KeyedArray(A, (0:3, [:a, :b, :c]))
A_kanda = KeyedArray(A; time = 0:3, id = [:a, :b, :c])
wv = aweights(rand(4))

@testset "$f" for f in (mean, std, var)
R = f == mean ? f(A, wv; dims=1) : f(A, wv, 1; corrected=true)

R_ka = f(A_ka, wv; dims=1)
R_kanda_int = f(A_kanda, wv; dims=1)
R_kanda_sym = f(A_kanda, wv; dims=:time)
expected_keys = (Base.OneTo(1), [:a, :b, :c])
expected_names = (:time, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) ≈ parent(parent(R_kanda_int)) ≈ parent(parent(R_kanda_sym)) ≈ R
end

@testset "$f" for f in (cov, cor, scattermat)
# Inconsistent statsbase behaviour
kwargs = f === cov ? [:corrected => true] : []
R = f === scattermat ? f(A, wv; dims=1, kwargs...) : f(A, wv, 1; kwargs...)

R_ka = f(A_ka, wv; dims=1, kwargs...)
R_kanda_int = f(A_kanda, wv; dims=1, kwargs...)
R_kanda_sym = f(A_kanda, wv; dims=:time, kwargs...)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) ≈ parent(parent(R_kanda_int)) ≈ parent(parent(R_kanda_sym)) ≈ R
end

@testset "$f" for f in (mean_and_var, mean_and_std, mean_and_cov)
R1, R2 = f(A, wv, 1; corrected=true)
R1_ka, R2_ka = f(A_ka, wv; dims=1, corrected=true)
R1_kanda_int, R2_kanda_int = f(A_kanda, wv; dims=1, corrected=true)
R1_kanda_sym, R2_kanda_sym = f(A_kanda, wv; dims=:time, corrected=true)

@test parent(R1_ka) ≈ parent(parent(R1_kanda_int)) ≈ parent(parent(R1_kanda_sym)) ≈ R1
@test parent(R2_ka) ≈ parent(parent(R2_kanda_int)) ≈ parent(parent(R2_kanda_sym)) ≈ R2
end

@testset "conversions" begin
@testset "cov2cor" begin
@test cov2cor(cov(A_ka; dims=1), std(A_ka; dims=1)) ≈ cor(A_ka; dims=1)
@test cov2cor(cov(A_ka; dims=2), std(A_ka; dims=2)) ≈ cor(A_ka; dims=2)
@test cov2cor(cov(A_kanda; dims=:time), std(A_kanda; dims=:time)) ≈ cor(A_kanda; dims=:time)
@test cov2cor(cov(A_kanda; dims=:id), std(A_kanda; dims=:id)) ≈ cor(A_kanda; dims=:id)
end
@testset "cor2cov" begin
@test cor2cov(cor(A_ka; dims=1), std(A_ka; dims=1)) ≈ cov(A_ka; dims=1)
@test cor2cov(cor(A_ka; dims=2), std(A_ka; dims=2)) ≈ cov(A_ka; dims=2)
@test cor2cov(cor(A_kanda; dims=:time), std(A_kanda; dims=:time)) ≈ cov(A_kanda; dims=:time)
@test cor2cov(cor(A_kanda; dims=:id), std(A_kanda, dims=:id)) ≈ cov(A_kanda; dims=:id)
end
end

@testset "covariance estimation" begin
ce = SimpleCovariance()

@testset "unweighted" begin
R = cov(ce, A; dims=1)
R_ka = cov(ce, A_ka; dims=1)
R_kanda_int = cov(ce, A_kanda; dims=1)
R_kanda_sym = cov(ce, A_kanda; dims=:time)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) ≈ parent(parent(R_kanda_int)) ≈ parent(parent(R_kanda_sym)) ≈ R
end

@testset "weighted" begin
R = cov(ce, A, wv; dims=1)
R_ka = cov(ce, A_ka, wv; dims=1)
R_kanda_int = cov(ce, A_kanda, wv; dims=1)
R_kanda_sym = cov(ce, A_kanda, wv; dims=:time)
expected_keys = ([:a, :b, :c], [:a, :b, :c])
expected_names = (:id, :id)

@test dimnames(R_kanda_int) == dimnames(R_kanda_sym) == expected_names
@test axiskeys(R_ka) == axiskeys(R_kanda_int) == axiskeys(R_kanda_sym) == expected_keys
@test parent(R_ka) ≈ parent(parent(R_kanda_int)) ≈ parent(parent(R_kanda_sym)) ≈ R
end
end
end