Skip to content

Commit

Permalink
For a 0.5.1 release (#156)
Browse files Browse the repository at this point in the history
* add `feature_importances` stub  (#148)

* add intrinsic_importances stub and set fallback to
othing.

* fix error in intrinsic_importances docstring.

* rename intrinsic_importances method to eature_importances.

* remove fallback for eature_importances.

* Update src/model_api.jl

Co-authored-by: Anthony Blaom, PhD <[email protected]>

* bump 1.5

* Update `metadata_model` to include traits for feature importances and training losses (#155)

* + supports_training_losses,reports_feature_importances to model_metadata

* bump 1.5.1

* bump StatisticalTraits = "3.1"

Co-authored-by: Okon Samuel <[email protected]>
  • Loading branch information
ablaom and OkonSamuel authored Jul 8, 2022
1 parent e8da6ba commit 5886633
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.5"
version = "1.5.1"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -10,7 +10,7 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"

[compat]
ScientificTypesBase = "3.0"
StatisticalTraits = "3.0"
StatisticalTraits = "3.1"
julia = "1"

[extras]
Expand Down
9 changes: 8 additions & 1 deletion src/metadata_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ Helper function to write the metadata for a model `T`.
* `supports_class_weights=false`: whether the model supports class weights
* `load_path="unknown"`: where the model is (usually `PackageName.ModelName`)
* `human_name=nothing`: human name of the model
* `supports_training_losses=nothing`: whether the (necessarily iterative) model can report
training losses
* `reports_feature_importances=nothing`: whether the model reports feature importances
## Example
Expand Down Expand Up @@ -115,7 +118,9 @@ function metadata_model(
supports_class_weights::Union{Nothing,Bool}=class_weights,
docstring::Union{Nothing,String}=descr,
load_path::Union{Nothing,String}=path,
human_name::Union{Nothing,String}=nothing
human_name::Union{Nothing,String}=nothing,
supports_training_losses::Union{Nothing,Bool}=nothing,
reports_feature_importances::Union{Nothing,Bool}=nothing,
)
docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model)

Expand All @@ -132,6 +137,8 @@ function metadata_model(
_extend!(program, :docstring, docstring, T)
_extend!(program, :load_path, load_path, T)
_extend!(program, :human_name, human_name, T)
_extend!(program, :supports_training_losses, supports_training_losses, T)
_extend!(program, :reports_feature_importances, reports_feature_importances, T)

parentmodule(T).eval(program)
end
Expand Down
17 changes: 9 additions & 8 deletions src/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ in historical order. If the model calculates scores instead, then the
sign of the scores should be reversed.
The following trait overload is also required:
`supports_training_losses(::Type{<:M}) = true`
`MLJModelInterface.supports_training_losses(::Type{<:M}) = true`.
"""
training_losses(model, report) = nothing
Expand Down Expand Up @@ -168,16 +168,17 @@ function evaluate end
"""
feature_importances(model::M, fitresult, report)
For a given `model` of model type `M` supporting intrinsic feature importances, calculate
the feature importances from the model's `fitresult` and `report` as an
abstract vector of `feature::Symbol => importance::Real` pairs
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
For a given `model` of model type `M` supporting intrinsic feature importances, calculate
the feature importances from the model's `fitresult` and `report` as an
abstract vector of `feature::Symbol => importance::Real` pairs
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
The following trait overload is also required:
`reports_feature_importances(::Type{<:M}) = true`
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`
If for some reason a model is sometimes unable to report feature importances then
`feature_importances` should return all importances as 0.0, as in
`[:gender =>0.0, :height =>0.0, :weight => 0.0]`.
`feature_importances` should return all importances as 0.0, as in
`[:gender =>0.0, :height =>0.0, :weight => 0.0]`.
"""
function feature_importances end
2 changes: 2 additions & 0 deletions test/metadata_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
@test infos[:hyperparameters] == (:a, :b)
@test infos[:hyperparameter_types] == ("Int64", "Any")
@test infos[:hyperparameter_ranges] == (nothing, nothing)
@test !infos[:supports_training_losses]
@test !infos[:reports_feature_importances]
end

@testset "doc_header(ModelType)" begin
Expand Down
30 changes: 21 additions & 9 deletions test/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@ mutable struct APIx1 <: Static end
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
end

M.metadata_model(
APIx0,
supports_training_losses = true,
reports_feature_importances = true,
)

dummy_losses = [1.0, 2.0, 3.0]
M.training_losses(::APIx0, report) = report
M.feature_importances(::APIx0, fitresult, report) = [:a=>0, :b=>0]

@testset "fit-x" begin
m0 = APIx0(f0=1)
m1 = APIx0b(f0=3)
# no weight support: fallback
M.fit(m::APIx0, v::Int, X, y) = (5, nothing, nothing)
@test fit(m0, 1, randn(2), randn(2), 5) == (5, nothing, nothing)
M.fit(m::APIx0, v::Int, X, y) = (5, nothing, dummy_losses)
@test fit(m0, 1, randn(2), randn(2), 5) == (5, nothing, dummy_losses)
# with weight support: use
M.fit(m::APIx0b, v::Int, X, y, w) = (7, nothing, nothing)
@test fit(m1, 1, randn(2), randn(2), 5) == (7, nothing, nothing)
Expand All @@ -32,16 +42,18 @@ end
@test fit(s1, 1, 0) == (nothing, nothing, nothing)

# update fallback = fit
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, nothing)
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, dummy_losses)

# training losses:
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
@test M.training_losses(m0, r) === nothing

# intrinsic_importances
@test M.training_losses(m0, r) == dummy_losses

# training losses:
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
@test M.training_losses(m0, r) == dummy_losses

# feature_importances
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
MLJModelInterface.reports_feature_importances(::Type{APIx0}) = true
MLJModelInterface.feature_importances(::APIx0, fitresult, report) = [:a=>0, :b=>0]
@test MLJModelInterface.feature_importances(m0, f, r) == [:a=>0, :b=>0]
end

Expand All @@ -67,7 +79,7 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
end

MMI.input_scitype(::Type{<:UnivariateFiniteFitter}) = Nothing

MMI.target_scitype(::Type{<:UnivariateFiniteFitter}) = AbstractVector{<:Finite}

y = categorical(collect("aabbccaa"))
Expand Down

0 comments on commit 5886633

Please sign in to comment.