Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add facility for adding a data front-end to a model implementation #76

Merged
merged 6 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "0.3.6"
version = "0.3.7"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 1 addition & 1 deletion src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export @mlj_model, metadata_pkg, metadata_model
# model api
export fit, update, update_data, transform, inverse_transform,
fitted_params, predict, predict_mode, predict_mean, predict_median,
predict_joint, evaluate, clean!
predict_joint, evaluate, clean!, reformat

# model traits
export input_scitype, output_scitype, target_scitype,
Expand Down
85 changes: 73 additions & 12 deletions src/model_api.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,92 @@
"""
every model interface must implement a `fit` method of the form
`fit(model, verb::Integer, training_args...) -> fitresult, cache, report`
fit(model, verbosity, data...) -> fitresult, cache, report

All models must implement a `fit` method. Here `data` is the
output of `reformat` on user-provided data, or some some resampling
thereof. The fallback of `reformat` returns the user-provided data
(eg, a table).

"""
function fit end

# fallback for static transformations
fit(::Static, ::Integer, a...) = (nothing, nothing, nothing)
fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)

# fallbacks for supervised models that don't support sample weights:
fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y)
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)

# this operation can be optionally overloaded to provide access to
# fitted parameters (eg, coeficients of linear model):
fitted_params(::Model, fitres) = (fitresult=fitres,)
"""
update(model, verbosity, fitresult, cache, data...)

Models may optionally implement an `update` method. The fallback calls
`fit`.

"""
each model interface may overload the `update` refitting method
update(m::Model, verbosity, fitresult, cache, data...) =
fit(m, verbosity, data...)

# to support online learning in the future:
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
function update_data end

"""
update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...)
MLJModelInterface.reformat(model, args...) -> data

Models optionally overload `reformat` to define transformations of
user-supplied data into some model-specific representation (e.g., from
a table to a matrix). When implemented, the MLJ user can avoid
repeating such transformations unnecessarily, and can additionally
make use of more efficient row subsampling, which is then based on the
model-specific representation of data, rather than the
user-representation. When `reformat` is overloaded,
`selectrows(::Model, ...)` must be as well (see
[`selectrows`](@ref)). Furthermore, the model `fit` method(s), and
operations, such as `predict` and `transform`, must be refactored to
act on the model-specific representions of the data.

To implement the `reformat` data front-end for a model, refer to
"Implementing a data front-end" in the [MLJ
manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/).


"""
each model interface may overload the `update_data` refitting method for online learning
reformat(model::Model, args...) = args

"""
function update_data end
selectrows(::Model, I, data...) -> sampled_data

A model overloads `selectrows` whenever it buys into the optional
`reformat` front-end for data preprocessing. See [`reformat`](@ref)
for details. The fallback assumes `data` is a tuple and calls
`selectrows(X, I)` for each `X` in `data`, returning the results in a
new tuple of the same length. This call makes sense when `X` is a
table, abstract vector or abstract matrix. In the last two cases, a
new object and *not* a view is returned.

"""
selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)

# this operation can be optionally overloaded to provide access to
# fitted parameters (eg, coeficients of linear model):
"""
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple

Models may overload `fitted_params`. The fallback returns
`(fitresult=fitresult,)`.

Other training-related outcomes should be returned in the `report`
part of the tuple returned by `fit`.

"""
fitted_params(::Model, fitresult) = (fitresult=fitresult,)

"""
supervised methods must implement the `predict` operation

predict(model, fitresult, new_data...)

`Supervised` models must implement the `predict` operation. Here
`new_data` is the output of `reformat` called on user-specified data.

"""
function predict end

Expand Down
6 changes: 6 additions & 0 deletions test/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ end

mutable struct APIx1 <: Static end

@testset "selectrows(model, data...)" begin
X = (x1 = [2, 4, 6],)
y = [10.0, 20.0, 30.0]
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
end

@testset "fit-x" begin
m0 = APIx0(f0=1)
m1 = APIx0b(f0=3)
Expand Down