Skip to content

Commit

Permalink
Merge pull request #76 from alan-turing-institute/mlj2
Browse files Browse the repository at this point in the history
Add facility for adding a data front-end to a model implementation
  • Loading branch information
ablaom authored Jan 4, 2021
2 parents 2c358e9 + 9c937d8 commit 1bd56dc
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "0.3.6"
version = "0.3.7"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 1 addition & 1 deletion src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export @mlj_model, metadata_pkg, metadata_model
# model api
export fit, update, update_data, transform, inverse_transform,
fitted_params, predict, predict_mode, predict_mean, predict_median,
predict_joint, evaluate, clean!
predict_joint, evaluate, clean!, reformat

# model traits
export input_scitype, output_scitype, target_scitype,
Expand Down
85 changes: 73 additions & 12 deletions src/model_api.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,92 @@
"""
every model interface must implement a `fit` method of the form
`fit(model, verb::Integer, training_args...) -> fitresult, cache, report`
fit(model, verbosity, data...) -> fitresult, cache, report
All models must implement a `fit` method. Here `data` is the
output of `reformat` on user-provided data, or some some resampling
thereof. The fallback of `reformat` returns the user-provided data
(eg, a table).
"""
function fit end

# fallback for static transformations
fit(::Static, ::Integer, a...) = (nothing, nothing, nothing)
fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)

# fallbacks for supervised models that don't support sample weights:
fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y)
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)

# this operation can be optionally overloaded to provide access to
# fitted parameters (eg, coeficients of linear model):
fitted_params(::Model, fitres) = (fitresult=fitres,)
"""
update(model, verbosity, fitresult, cache, data...)
Models may optionally implement an `update` method. The fallback calls
`fit`.
"""
each model interface may overload the `update` refitting method
update(m::Model, verbosity, fitresult, cache, data...) =
fit(m, verbosity, data...)

# to support online learning in the future:
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
function update_data end

"""
update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...)
MLJModelInterface.reformat(model, args...) -> data
Models optionally overload `reformat` to define transformations of
user-supplied data into some model-specific representation (e.g., from
a table to a matrix). When implemented, the MLJ user can avoid
repeating such transformations unnecessarily, and can additionally
make use of more efficient row subsampling, which is then based on the
model-specific representation of data, rather than the
user-representation. When `reformat` is overloaded,
`selectrows(::Model, ...)` must be as well (see
[`selectrows`](@ref)). Furthermore, the model `fit` method(s), and
operations, such as `predict` and `transform`, must be refactored to
act on the model-specific representions of the data.
To implement the `reformat` data front-end for a model, refer to
"Implementing a data front-end" in the [MLJ
manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/).
"""
each model interface may overload the `update_data` refitting method for online learning
reformat(model::Model, args...) = args

"""
function update_data end
selectrows(::Model, I, data...) -> sampled_data
A model overloads `selectrows` whenever it buys into the optional
`reformat` front-end for data preprocessing. See [`reformat`](@ref)
for details. The fallback assumes `data` is a tuple and calls
`selectrows(X, I)` for each `X` in `data`, returning the results in a
new tuple of the same length. This call makes sense when `X` is a
table, abstract vector or abstract matrix. In the last two cases, a
new object and *not* a view is returned.
"""
selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)

# this operation can be optionally overloaded to provide access to
# fitted parameters (eg, coeficients of linear model):
"""
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
Models may overload `fitted_params`. The fallback returns
`(fitresult=fitresult,)`.
Other training-related outcomes should be returned in the `report`
part of the tuple returned by `fit`.
"""
fitted_params(::Model, fitresult) = (fitresult=fitresult,)

"""
supervised methods must implement the `predict` operation
predict(model, fitresult, new_data...)
`Supervised` models must implement the `predict` operation. Here
`new_data` is the output of `reformat` called on user-specified data.
"""
function predict end

Expand Down
6 changes: 6 additions & 0 deletions test/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ end

mutable struct APIx1 <: Static end

@testset "selectrows(model, data...)" begin
X = (x1 = [2, 4, 6],)
y = [10.0, 20.0, 30.0]
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
end

@testset "fit-x" begin
m0 = APIx0(f0=1)
m1 = APIx0b(f0=3)
Expand Down

0 comments on commit 1bd56dc

Please sign in to comment.