Skip to content

Commit

Permalink
Merge pull request #501 from alan-turing-institute/mlj2-again
Browse files Browse the repository at this point in the history
Realize performance improvements for models implementing new data front-end
  • Loading branch information
ablaom authored Jan 26, 2021
2 parents 1e1d75b + 7cacd24 commit a2eee18
Show file tree
Hide file tree
Showing 15 changed files with 787 additions and 305 deletions.
3 changes: 1 addition & 2 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,9 @@ following:
called `cache`, for passing onto the MLJ logic that handles smart
updating (namely, an `MLJBase.update` fallback for composite models).
- Calls `fit!(mach, verbosity=verbosity)`.
- Moves any data in sources nodes of the learning network into `cache`
- Moves any data in source nodes of the learning network into `cache`
(for data-anonymization purposes).
- Records a copy of `model` in `cache`.
Expand Down
2 changes: 2 additions & 0 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# *Note.* Be sure to read Note 4 in src/operations.jl to see see how
# fallbacks are provided for operations acting on Composite models.

caches_data_by_default(::Type{<:Composite}) = true

fitted_params(::Union{Composite,Surrogate},
fitresult::NamedTuple) =
fitted_params(glb(values(fitresult)...))
Expand Down
152 changes: 105 additions & 47 deletions src/machines.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
## MACHINE TYPE

mutable struct Machine{M<:Model} <: MLJType
caches_data_by_default(::Type{<:Model}) = true
caches_data_by_default(m::M) where M<:Model = caches_data_by_default(M)

mutable struct Machine{M<:Model,C} <: MLJType

model::M
old_model::M # for remembering the model used in last call to `fit!`
fitresult
cache

# training arguments (`Node`s or user-specified data wrapped in
# `Source`s):
args::Tuple{Vararg{AbstractNode}}

# cached model-specific reformatting of args (for C=true):
data

# cached subsample of data (for C=true):
resampled_data

report
frozen::Bool
old_rows
Expand All @@ -16,8 +29,9 @@ mutable struct Machine{M<:Model} <: MLJType
# cleared by fit!(::Node) calls; put! by `fit_only!(machine, true)` calls:
fit_okay::Channel{Bool}

function Machine(model::M, args::AbstractNode...) where M<:Model
mach = new{M}(model)
function Machine(model::M, args::AbstractNode...;
cache=caches_data_by_default(M)) where M<:Model
mach = new{M,cache}(model)
mach.frozen = false
mach.state = 0
mach.args = args
Expand Down Expand Up @@ -122,12 +136,14 @@ end


"""
machine(model, args...)
machine(model, args...; cache=true)
Construct a `Machine` object binding a `model`, storing
hyper-parameters of some machine learning algorithm, to some data,
`args`. When building a learning network, `Node` objects can be
substituted for concrete data.
substituted for concrete data. Specify `cache=false` to prioritize
memory managment over speed, and to guarantee data anonymity when
serializing composite models.
machine(Xs; oper1=node1, oper2=node2)
machine(Xs, ys; oper1=node1, oper2=node2)
Expand Down Expand Up @@ -200,7 +216,7 @@ predictions = yhat(Xnew)
"""
function machine end

machine(T::Type{<:Model}, args...) =
machine(T::Type{<:Model}, args...; kwargs...) =
throw(ArgumentError("Model *type* provided where "*
"model *instance* expected. "))

Expand All @@ -209,33 +225,35 @@ static_error() =
"has no training arguments. "*
"Use `machine(model)`. "))

function machine(model::Static, args...)
function machine(model::Static, args...; kwargs...)
isempty(args) || static_error()
return Machine(model)
return Machine(model; kwargs...)
end

function machine(model::Static, args::AbstractNode...)
function machine(model::Static, args::AbstractNode...; kwargs...)
isempty(args) || static_error()
return Machine(model)
return Machine(model; kwargs...)
end

machine(model::Model, raw_arg1, arg2::AbstractNode, args::AbstractNode...) =
machine(model::Model, raw_arg1, arg2::AbstractNode, args::AbstractNode...;
kwargs...) =
error("Mixing concrete data with `Node` training arguments "*
"is not allowed. ")

machine(model::Model, arg1::AbstractNode, arg2, args...) =
machine(model::Model, arg1::AbstractNode, arg2, args...; kwargs...) =
error("Mixing concrete data with `Node` training arguments "*
"is not allowed. ")

function machine(model::Model, raw_arg1, raw_args...)
function machine(model::Model, raw_arg1, raw_args...; kwargs...)
args = source.((raw_arg1, raw_args...))
check(model, args...; full=true)
return Machine(model, args...)
return Machine(model, args...; kwargs...)
end

function machine(model::Model, arg1::AbstractNode, args::AbstractNode...)
function machine(model::Model, arg1::AbstractNode, args::AbstractNode...;
kwargs...)
check(model, arg1, args...)
return Machine(model, arg1, args...)
return Machine(model, arg1, args...; kwargs...)
end


Expand Down Expand Up @@ -273,14 +291,18 @@ machines(::Source) = Machine[]

## DISPLAY

_cache_status(::Machine{<:Any,true}) = " caches data"
_cache_status(::Machine{<:Any,false}) = " does not cache data"

function Base.show(io::IO, ::MIME"text/plain", mach::Machine)
show(io, mach)
print(io, " trained $(mach.state) time")
if mach.state == 1
println(io, ".")
print(io, ";")
else
println(io, "s.")
print(io, "s;")
end
println(io, _cache_status(mach))
println(io, " args: ")
for i in eachindex(mach.args)
arg = mach.args[i]
Expand All @@ -307,7 +329,7 @@ end
# - `fit!`: trains a machine after first progressively training all
# machines on which the machine depends. Implicitly this involves
# making `fit_only!` calls on those machines, scheduled by the node
# `@tuple N1 N2 ... `.)
# `glb(N1, N2, ... )`, where `glb` means greatest lower bound.)


function fitlog(mach, action::Symbol, verbosity)
Expand All @@ -325,6 +347,16 @@ function fitlog(mach, action::Symbol, verbosity)
end
end

# for getting model specific representation of the row-restricted
# training data from a machine, according to the value of the machine
# type parameter `C` (`true` or `false`):
_resampled_data(mach::Machine{<:Model,true}, rows) = mach.resampled_data
function _resampled_data(mach::Machine{<:Model,false}, rows)
raw_args = map(N -> N(), mach.args)
data = MMI.reformat(mach.model, raw_args...)
return selectrows(mach.model, rows, data...)
end

"""
MLJBase.fit_only!(mach::Machine; rows=nothing, verbosity=1, force=false)
Expand All @@ -348,22 +380,23 @@ bound to it, and restricting the data to `rows` if specified:
### Training action logic
For the action to be a no-operation, either `mach.frozen == true` or
none of the following apply:
or none of the following apply:
- (i) `mach` has never been trained (`mach.state == 0`).
- (ii) `force == true`
- (ii) `force == true`.
- (iii) The `state` of some other machine on which `mach` depends has
changed since the last time `mach` was trained (ie, the last time
`mach.state` was last incremented)
`mach.state` was last incremented).
- (iv) The specified `rows` have changed since the last retraining.
- (iv) The specified `rows` have changed since the last retraining and
`mach.model` does not have `Static` type.
- (v) `mach.model` has changed since the last retraining.
In cases (i) - (iv), `mach` is trained ab initio. In case (v) a
training update is applied.
In any of the cases (i) - (iv), `mach` is trained ab initio. If only
(v) fails, then a training update is applied.
To freeze or unfreeze `mach`, use `freeze!(mach)` or `thaw!(mach)`.
Expand All @@ -381,7 +414,10 @@ and either `MLJBase.fit` (ab initio training) or `MLJBase.update`
more on these lower-level training methods.
"""
function fit_only!(mach::Machine; rows=nothing, verbosity=1, force=false)
function fit_only!(mach::Machine{<:Model,cache_data};
rows=nothing,
verbosity=1,
force=false) where cache_data

if mach.frozen
# no-op; do not increment `state`.
Expand All @@ -398,58 +434,80 @@ function fit_only!(mach::Machine; rows=nothing, verbosity=1, force=false)
warning = clean!(mach.model)
isempty(warning) || verbosity < 0 || @warn warning

upstream_state = upstream(mach)

rows === nothing && (rows = (:))
rows_is_new = !isdefined(mach, :old_rows) || rows != mach.old_rows

rows_have_changed = !isdefined(mach, :old_rows) ||
rows != mach.old_rows
condition_iv = rows_is_new && !(mach.model isa Static)

upstream_state = upstream(mach)
upstream_has_changed = mach.old_upstream_state != upstream_state

data_has_changed =
rows_have_changed ||upstream_state != mach.old_upstream_state
previously_fit = (mach.state > 0)
data_is_valid = isdefined(mach, :data) && !upstream_has_changed

raw_args = [N(rows=rows) for N in mach.args]
# build or update cached `data` if necessary:
if cache_data && !data_is_valid
raw_args = map(N -> N(), mach.args)
mach.data = MMI.reformat(mach.model, raw_args...)
end

# build or update cached `resampled_data` if necessary:
if cache_data && (!data_is_valid || condition_iv)
mach.resampled_data = selectrows(mach.model, rows, mach.data...)
end

# `fit`, `update`, or return untouched:
if mach.state == 0 || # condition (i)
force == true || # condition (ii)
upstream_has_changed || # condition (iii)
condition_iv

# we `fit`, `update`, or return untouched:
if !previously_fit || data_has_changed || force
# fit the model:
fitlog(mach, :train, verbosity)
mach.fitresult, mach.cache, mach.report =
try
fit(mach.model, verbosity, raw_args...)
fit(mach.model, verbosity, _resampled_data(mach, rows)...)
catch exception
@error "Problem fitting the machine $mach, "*
"possibly because an upstream node in a learning "*
"network is providing data of incompatible scitype. "
@error "Problem fitting the machine $mach. "
_sources = sources(glb(mach.args...))
length(_sources) > 2 ||
mach.model isa Composite ||
all((!isempty).(_sources)) ||
@warn "Some learning network source nodes are empty. "
@info "Running type checks... "
check(mach.model, source.(raw_args)... ; full=true) &&
raw_args = map(N -> N(), mach.args)
if check(mach.model, source.(raw_args)... ; full=true)
@info "Type checks okay. "
else
@info "It seems an upstream node in a learning "*
"network is providing data of incompatible scitype. See "*
"above. "
end
rethrow()
end
elseif mach.model != mach.old_model

elseif mach.model != mach.old_model # condition (v)

# update the model:
fitlog(mach, :update, verbosity)
mach.fitresult, mach.cache, mach.report =
update(mach.model,
verbosity,
mach.fitresult,
mach.cache,
raw_args...)
_resampled_data(mach, rows)...)

else

# don't fit the model and return without incrementing `state`:
fitlog(mach, :skip, verbosity)
return mach

end

# If we get to here it's because we have run `fit` or `update`!

if rows_have_changed
if rows_is_new
mach.old_rows = deepcopy(rows)
end

Expand Down Expand Up @@ -690,19 +748,19 @@ function MMI.save(file::Union{String,IO},
end

# deserializing:
function machine(file::Union{String,IO}, args...; kwargs...)
function machine(file::Union{String,IO}, args...; cache=true, kwargs...)
dict = JLSO.load(file)
model = dict[:model]
serializable_fitresult = dict[:fitresult]
report = dict[:report]
fitresult = restore(_filename(file), model, serializable_fitresult)
if isempty(args)
mach = Machine(model)
mach = Machine(model, cache=cache)
else
mach = machine(model, args...)
mach = machine(model, args..., cache=cache)
end
mach.fitresult = fitresult
mach.state = 1
mach.fitresult = fitresult
mach.report = report
return mach
end
Loading

0 comments on commit a2eee18

Please sign in to comment.