diff --git a/Project.toml b/Project.toml index 7668cda..e853eb3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJEnsembles" uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" authors = ["Anthony D. Blaom "] -version = "0.3.3" +version = "0.4.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -9,11 +9,11 @@ CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" +StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] @@ -21,9 +21,21 @@ CategoricalArrays = "0.8, 0.9, 0.10" CategoricalDistributions = "0.1.2" ComputationalResources = "0.3" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -MLJBase = "0.20, 0.21" MLJModelInterface = "0.4.1, 1.1" ProgressMeter = "1.1" ScientificTypesBase = "2,3" +StatisticalMeasuresBase = "0.1" StatsBase = "0.32, 0.33, 0.34" julia = "1.6" + +[extras] +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Distances", "MLJBase", "NearestNeighbors", "Serialization", "StableRNGs", "StatisticalMeasures", "Test"] \ No newline at end of file diff --git a/src/MLJEnsembles.jl b/src/MLJEnsembles.jl index 1283e45..5ab4c99 100644 --- a/src/MLJEnsembles.jl +++ b/src/MLJEnsembles.jl @@ -2,7 +2,6 @@ module MLJEnsembles using MLJModelInterface import MLJModelInterface: predict, fit, save, restore -import MLJBase # still needed for aggregating measures in oob-estimates of error using Random using CategoricalArrays using CategoricalDistributions @@ -11,6 +10,7 @@ using Distributed import Distributions using ProgressMeter import StatsBase +import StatisticalMeasuresBase export EnsembleModel diff --git a/src/ensembles.jl b/src/ensembles.jl index 1d26500..4334584 100644 --- a/src/ensembles.jl +++ b/src/ensembles.jl @@ -321,11 +321,10 @@ If a single measure or non-empty vector of measures is specified by written to the training report (call `report` on the trained machine wrapping the ensemble model). -*Important:* If sample weights `w` (not to be confused with atomic -weights) are specified when constructing a machine for the ensemble -model, as in `mach = machine(ensemble_model, X, y, w)`, then `w` is -used by any measures specified in `out_of_bag_measure` that support -sample weights. +*Important:* If per-observation or class weights `w` (not to be confused with atomic +weights) are specified when constructing a machine for the ensemble model, as in `mach = +machine(ensemble_model, X, y, w)`, then `w` is used by any measures specified in +`out_of_bag_measure` that support them. """ function EnsembleModel( @@ -395,34 +394,56 @@ function _fit(res::CPUProcesses, func, verbosity, stuff) if i != nworkers() func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...) else - func(atom, 0, chunk_size + left_over, n_patterns, n_train, rng, progress_meter, args...) + func( + atom, + 0, + chunk_size + left_over, + n_patterns, + n_train, + rng, + progress_meter, + args..., + ) end end end -@static if VERSION >= v"1.3.0-DEV.573" - function _fit(res::CPUThreads, func, verbosity, stuff) - atom, n, n_patterns, n_train, rng, progress_meter, args = stuff - if verbosity > 0 - println("Ensemble-building in parallel on $(Threads.nthreads()) threads.") - end - nthreads = Threads.nthreads() - chunk_size = div(n, nthreads) - left_over = mod(n, nthreads) - resvec = Vector(undef, nthreads) # FIXME: Make this type-stable? - - Threads.@threads for i = 1:nthreads - resvec[i] = if i != nworkers() - func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...) - else - func(atom, 0, chunk_size + left_over, n_patterns, n_train, rng, progress_meter, args...) - end - end +function _fit(res::CPUThreads, func, verbosity, stuff) + atom, n, n_patterns, n_train, rng, progress_meter, args = stuff + if verbosity > 0 + println("Ensemble-building in parallel on $(Threads.nthreads()) threads.") + end + nthreads = Threads.nthreads() + chunk_size = div(n, nthreads) + left_over = mod(n, nthreads) + resvec = Vector(undef, nthreads) # FIXME: Make this type-stable? - return reduce(_reducer, resvec) + Threads.@threads for i = 1:nthreads + resvec[i] = if i != nworkers() + func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...) + else + func( + atom, + 0, + chunk_size + left_over, + n_patterns, + n_train, + rng, + progress_meter, + args..., + ) + end end + + return reduce(_reducer, resvec) end +# for subsampling weights, which could be `nothing`, per-observation weights, or +# class_weights: +_view(class_weights::AbstractDict, rows) = class_weights +_view(::Nothing, rows) = nothing +_view(weights, rows) = view(weights, rows) + function MMI.fit( model::EitherEnsembleModel{Atom}, verbosity::Int, args... ) where Atom<:Supervised @@ -446,10 +467,14 @@ function MMI.fit( acceleration = CPU1() end + # we wrap the measures in `robust_measure` so they can be called with weights, even + # when they don't support them, and just ignore them silently. if model.out_of_bag_measure isa Vector - out_of_bag_measure = model.out_of_bag_measure + out_of_bag_measure = + StatisticalMeasuresBase.robust_measure.(model.out_of_bag_measure) else - out_of_bag_measure = [model.out_of_bag_measure,] + out_of_bag_measure = + [StatisticalMeasuresBase.robust_measure(model.out_of_bag_measure),] end if model.rng isa Integer @@ -484,7 +509,7 @@ function MMI.fit( if !isempty(out_of_bag_measure) - metrics=zeros(length(ensemble),length(out_of_bag_measure)) + measurements=zeros(length(ensemble),length(out_of_bag_measure)) for i= 1:length(ensemble) #oob indices ooB_indices= setdiff(1:n_patterns, ensemble_indices[i]) @@ -493,42 +518,44 @@ function MMI.fit( "Data size too small or "* "bagging_fraction too close to 1.0. ") end - yhat = predict(atom, ensemble[i], selectrows(atom, ooB_indices, atom_specific_X)...) + yhat = predict( + atom, + ensemble[i], + selectrows(atom, ooB_indices, atom_specific_X)..., + ) Xtest = selectrows(X, ooB_indices) ytest = selectrows(y, ooB_indices) - if w === nothing - wtest = nothing - else - wtest = selectrows(w, ooB_indices) - end + # this could be class weights OR per-observation weights, OR `nothing`: + wtest = _view(w, ooB_indices) for k in eachindex(out_of_bag_measure) m = out_of_bag_measure[k] - if MMI.reports_each_observation(m) - s = MLJBase.aggregate( - MLJBase.value(m, yhat, Xtest, ytest, wtest), - m - ) - else - s = MLJBase.value(m, yhat, Xtest, ytest, wtest) - end - metrics[i,k] = s + s = m(yhat, ytest, wtest) + measurements[i,k] = s end end - # aggregate metrics across the ensembles: - aggregated_metrics = map(eachindex(out_of_bag_measure)) do k - MLJBase.aggregate(metrics[:,k], out_of_bag_measure[k]) + # aggregate measurements across the ensembles: + aggregated_measurements = map(eachindex(out_of_bag_measure)) do k + StatisticalMeasuresBase.aggregate( + measurements[:,k], + mode=StatisticalMeasuresBase.external_aggregation_mode( + out_of_bag_measure[k], + ) + ) end names = Symbol.(string.(out_of_bag_measure)) else - aggregated_metrics = missing + aggregated_measurements = missing end - report=(measures=out_of_bag_measure, oob_measurements=aggregated_metrics,) + report=( + measures=out_of_bag_measure, + oob_measurements=aggregated_measurements, + ) cache = deepcopy(model) return fitresult, cache, report @@ -542,7 +569,7 @@ function MMI.update(model::EitherEnsembleModel, n = model.n - if MLJBase.is_same_except(model.model, old_model.model, + if MMI.is_same_except(model.model, old_model.model, :n, :atomic_weights, :acceleration) if n > old_model.n verbosity < 1 || diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 53b8d70..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,16 +0,0 @@ -[deps] -CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -Distances = "0.10" -NearestNeighbors = "0.4" -StableRNGs = "1" diff --git a/test/ensembles.jl b/test/ensembles.jl index 411f134..f78f9bf 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -8,7 +8,7 @@ using MLJBase using ..Models using CategoricalArrays import Distributions - +using StatisticalMeasures ## HELPER FUNCTIONS