diff --git a/src/scalarstats.jl b/src/scalarstats.jl index 41486c41bf8269..c5e0bfee92def0 100644 --- a/src/scalarstats.jl +++ b/src/scalarstats.jl @@ -47,10 +47,11 @@ end # compute mode, given the range of integer values """ mode(a, [r]) + mode(a::AbstractArray, wv::AbstractWeights) Return the mode (most common number) of an array, optionally -over a specified range `r`. If several modes exist, the first -one (in order of appearance) is returned. +over a specified range `r` or weighted via a vector `wv`. +If several modes exist, the first one (in order of appearance) is returned. """ function mode(a::AbstractArray{T}, r::UnitRange{T}) where T<:Integer isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) @@ -75,9 +76,10 @@ end """ modes(a, [r])::Vector + mode(a::AbstractArray, wv::AbstractWeights)::Vector Return all modes (most common numbers) of an array, optionally over a -specified range `r`. +specified range `r` or weighted via vector `wv`. """ function modes(a::AbstractArray{T}, r::UnitRange{T}) where T<:Integer r0 = r[1] @@ -158,6 +160,47 @@ function modes(a) return [x for (x, c) in cnts if c == mc] end +# Weighted mode of arbitrary vectors of values +function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real + isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + length(a) == length(wv) || + throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) + + # Iterate through the data + mv = first(a) + mw = first(wv) + weights = Dict{eltype(a), T}() + for (x, w) in zip(a, wv) + _w = get!(weights, x, zero(T)) + w + if _w > mw + mv = x + mw = _w + end + weights[x] = _w + end + + return mv +end + +function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real + isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + length(a) == length(wv) || + throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) + + # Iterate through the data + mw = first(wv) + weights = Dict{eltype(a), T}() + for (x, w) in zip(a, wv) + _w = get!(weights, x, zero(T)) + w + if _w > mw + mw = _w + end + weights[x] = _w + end + + # find values corresponding to maximum counts + return [x for (x, w) in weights if w == mw] +end ############################# # diff --git a/test/scalarstats.jl b/test/scalarstats.jl index f163988f7d27ed..220246367a1782 100644 --- a/test/scalarstats.jl +++ b/test/scalarstats.jl @@ -44,10 +44,24 @@ using Statistics @test modes(skipmissing([1, missing, missing, 3, 2, 2, missing])) == [2] @test sort(modes(skipmissing([1, missing, 3, 3, 2, 2, missing]))) == [2, 3] +d1 = [1, 2, 3, 3, 4, 5, 5, 3] +d2 = ['a', 'b', 'c', 'c', 'd', 'e', 'e', 'c'] +wv = weights([0.1:0.1:0.7; 0.1]) +@test mode(d1) == 3 +@test mode(d2) == 'c' +@test mode(d1, wv) == 5 +@test mode(d2, wv) == 'e' +@test sort(modes(d1[1:end-1], weights(ones(7)))) == [3, 5] +@test sort(modes(d1, weights([.9, .1, .1, .1, .9, .1, .1, .1]))) == [1, 4] + @test_throws ArgumentError mode(Int[]) @test_throws ArgumentError modes(Int[]) @test_throws ArgumentError mode(Any[]) @test_throws ArgumentError modes(Any[]) +@test_throws ArgumentError mode([], weights(Float64[])) +@test_throws ArgumentError modes([], weights(Float64[])) +@test_throws ArgumentError mode([1, 2, 3], weights([0.1, 0.3])) +@test_throws ArgumentError modes([1, 2, 3], weights([0.1, 0.3])) ## zscores