Skip to content

Commit

Permalink
ensembles
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 24, 2023
1 parent 3a23b31 commit b03b9c3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
61 changes: 61 additions & 0 deletions src/conformal_models/training/inductive_classification.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using MLJEnsembles: EitherEnsembleModel
using MLJFlux: MLJFluxModel
using MLUtils

"""
score(conf_model::InductiveModel, model::MLJFluxModel, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for the `MLJFluxModel` type.
"""
function score(conf_model::SimpleInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
X = permutedims(matrix(X))
probas = permutedims(fitresult[1](X))
Expand All @@ -12,6 +19,32 @@ function score(conf_model::SimpleInductiveClassifier, ::Type{<:MLJFluxModel}, fi
end
end

"""
score(conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for ensembles of `MLJFluxModel` types.
"""
function score(conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
X = permutedims(matrix(X))
_chains = map(res -> res[1], fitresult.ensemble)
probas = MLUtils.stack(map(chain -> chain(X), _chains)) |>
p -> mean(p, dims=ndims(p)) |>
p -> MLUtils.unstack(p, dims=ndims(p))[1] |>
p -> permutedims(p)
scores = @.(conf_model.heuristic(probas))
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end

"""
score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for the `MLJFluxModel` type.
"""
function score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
L = levels(fitresult[2])
X = permutedims(matrix(X))
Expand All @@ -29,4 +62,32 @@ function score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel},
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end

"""
score(conf_model::AdaptiveInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for ensembles of `MLJFluxModel` types.
"""
function score(conf_model::AdaptiveInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
L = levels(fitresult.ensemble[1][2])
X = permutedims(matrix(X))
_chains = map(res -> res[1], fitresult.ensemble)
probas = MLUtils.stack(map(chain -> chain(X), _chains)) |>
p -> mean(p, dims=ndims(p)) |>
p -> MLUtils.unstack(p, dims=ndims(p))[1] |>
p -> permutedims(p)
scores = map(Base.Iterators.product(eachrow(probas), L)) do Z
probasᵢ, yₖ = Z
ranks = sortperm(.-probasᵢ) # rank in descending order
index_y = findall(L[ranks] .== yₖ)[1] # index of true y in sorted array
scoresᵢ = last(cumsum(probasᵢ[ranks][1:index_y])) # sum up until true y is reached
return scoresᵢ
end
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end

0 comments on commit b03b9c3

Please sign in to comment.