Skip to content

Commit

Permalink
rename dataset generators and make public
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Sep 19, 2022
1 parent 36555f5 commit 7c27a12
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 23 deletions.
21 changes: 17 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ Query the document strings for details, or see

## Testing models in a new MLJ model interface implementation

The following tests the model interface implemented by some model type
`MyClassifier`, as might appear in tests for a package providing that
type:
The following tests the model interface implemented by some model type `MyClassifier` for
multiclass classification, as might appear in tests for a package providing that type:

```julia
import MLJTestIntegration
using Test
X, y = MLJTestIntegration.MLJ.make_blobs()
X, y = MLJTestIntegration.make_multiclass()
failures, summary = MLJTestIntegration.test([MyClassifier, ], X, y, verbosity=1, mod=@__MODULE__)
@test isempty(failures)
```
Expand Down Expand Up @@ -78,3 +77,17 @@ failures, summary =

summary |> DataFrame
```

# Datasets

The following commands generate datasets of the form `(X, y)` suitable for integration
tests:

- `MLJTestIntegration.make_binary`

- `MLJTestIntegration.make_multiclass`

- `MLJTestIntegration.make_regression`

- `MLJTestIntegration.make_count`

69 changes: 51 additions & 18 deletions src/special_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,60 @@ end
_test(data; ignore=true, kwargs...) = _test([], data; ignore, kwargs...)


# # SINGLE TARGET CLASSIFICATION
# # BABY DATA SETS

"""
make_binary()
function _make_binary()
Return data `(X, y)` for the crabs dataset, restricted to the two features `:FL`,
`:RW`. Target is `Multiclass{2}`.
"""
function make_binary()
data = MLJ.load_crabs()
y_, X = unpack(data, ==(:sp), col->col in [:FL, :RW])
y = coerce(y_, MLJ.OrderedFactor)
return X, y
end

"""
make_multiclass()
Return data `(X, y)` for the unshuffled iris dataset. Target is `Multiclass{3}`.
"""
make_multiclass() = MLJ.@load_iris

"""
make_regression()
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
`:Rm`. Target is `Continuous`.
"""
function make_regression()
data = MLJ.load_boston()
y, X = unpack(data, ==(:MedV), col->col in [:LStat, :Rm])
return X, y
end

"""
make_regression()
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
`:Rm`, with the `Continuous` target converted to `Count` (integer).
"""
function make_count()
X, y_ = make_regression()
y = map-> round(Int, η), y_)
return X, y
end


# # SINGLE TARGET CLASSIFICATION


"""
MLJTestIntegration.test_single_target_classifiers(; keyword_options...)
Expand All @@ -62,17 +107,11 @@ $DOC_AS_ABOVE
"""
test_single_target_classifiers(args...; kwargs...) =
_test(args..., _make_binary(); kwargs...)
_test(args..., make_binary(); kwargs...)


# # SINGLE TARGET REGRESSION

function _make_baby_boston()
data = MLJ.load_boston()
y, X = unpack(data, ==(:MedV), col->col in [:LStat, :Rm])
return X, y
end

"""
MLJTestIntegration.test_single_target_regressors(; keyword_options...)
Expand All @@ -87,17 +126,11 @@ $DOC_AS_ABOVE
"""
test_single_target_regressors(args...; kwargs...) =
_test(args..., _make_baby_boston(); kwargs...)
_test(args..., make_regression(); kwargs...)


# # SINGLE TARGET COUNT REGRESSORS

function _make_count()
X, y_ = _make_baby_boston()
y = map-> round(Int, η), y_)
return X, y
end

"""
MLJTestIntegration.test_single_count_regressors(; keyword_options...)
Expand All @@ -114,12 +147,12 @@ $DOC_AS_ABOVE
"""
test_single_target_count_regressors(args...; kwargs...) =
_test(args..., _make_count(); kwargs...)
_test(args..., make_count(); kwargs...)


# # CONTINUOUS TABLE TRANSFORMERS

_make_transformer() = (first(_make_baby_boston()),)
_make_transformer() = (first(make_regression()),)

"""
test_continuous_table_transformers(; keyword_options...)
Expand Down
2 changes: 1 addition & 1 deletion test/special_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ regressors = [
]

@testset "actual_proxies" begin
data = MTI._make_baby_boston()
data = MTI.make_regression()
proxies = @test_logs MTI.actual_proxies(regressors, data, false, 1)
@test proxies == regressors
proxies2 = @test_logs MTI.actual_proxies(regressors, data, true, 1)
Expand Down

0 comments on commit 7c27a12

Please sign in to comment.