Skip to content

Commit

Permalink
Cache missing counts (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg authored Apr 13, 2021
1 parent f547000 commit 82762ed
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 24 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Imports:
lgr (>= 0.3.4),
mlbench,
mlr3measures (>= 0.3.0),
mlr3misc (>= 0.7.0),
mlr3misc (>= 0.9.0),
parallelly,
palmerpenguins,
paradox (>= 0.6.0),
Expand Down
29 changes: 25 additions & 4 deletions R/DataBackendDataTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
initialize = function(data, primary_key) {
assert_data_table(data, col.names = "unique")
super$initialize(setkeyv(data, primary_key), primary_key, data_formats = "data.table")
assert_choice(primary_key, names(data))
ii = match(primary_key, names(data))
if (is.na(ii)) {
stopf("Primary key '%s' not in 'data'", primary_key)
}
private$.cache = set_names(replace(rep(NA, ncol(data)), ii, FALSE), names(data))
},

#' @description
Expand Down Expand Up @@ -105,8 +109,23 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
#'
#' @return Total of missing values per column (named `numeric()`).
missings = function(rows, cols) {
data = self$data(rows, cols)
map_int(data, function(x) sum(is.na(x)))
missind = private$.cache
missind = missind[reorder_vector(names(missind), cols)]

# update cache
ii = which(is.na(missind))
if (length(ii)) {
missind[ii] = map_lgl(private$.data[, names(missind[ii]), with = FALSE], anyMissing)
private$.cache = insert_named(private$.cache, missind[ii])
}

# query required columns
query_cols = which(missind)
insert_named(
named_vector(names(missind), 0L),
map_int(self$data(rows, names(query_cols)), count_missing)
)

}
),

Expand Down Expand Up @@ -139,6 +158,8 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
private = list(
.calculate_hash = function() {
hash(self$compact_seq, private$.data)
}
},

.cache = NULL
)
)
4 changes: 2 additions & 2 deletions R/DataBackendMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ DataBackendMatrix = R6Class("DataBackendMatrix", inherit = DataBackend, cloneabl
cols_dense = intersect(cols, colnames(private$.data$dense))

res = c(
apply(private$.data$sparse[rows, cols_sparse, drop = FALSE], 2L, function(x) sum(is.na(x))),
private$.data$dense[, map_int(.SD, function(x) sum(is.na(x))), .SDcols = cols_dense]
apply(private$.data$sparse[rows, cols_sparse, drop = FALSE], 2L, count_missing),
private$.data$dense[, map_int(.SD, count_missing), .SDcols = cols_dense]
)

res[reorder_vector(names(res), cols)]
Expand Down
14 changes: 9 additions & 5 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ assert_tasks = function(tasks, task_type = NULL, feature_types = NULL, task_prop
assert_learner = function(learner, task = NULL, properties = character(), .var.name = vname(learner)) {
assert_class(learner, "Learner", .var.name = .var.name)

if (!is.null(task)) {
assert_learnable(task, learner)
}

if (length(properties)) {
miss = setdiff(properties, learner$properties)
if (length(miss)) {
Expand All @@ -97,7 +93,7 @@ assert_learners = function(learners, task = NULL, properties = character(), .var
#' @rdname mlr_assertions
assert_learnable = function(task, learner) {
pars = learner$param_set$get_values(type = "only_token")
if(length(pars) > 0) {
if (length(pars) > 0) {
stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars)))
}

Expand All @@ -110,6 +106,14 @@ assert_learnable = function(task, learner) {
if (length(tmp)) {
stopf("%s has the following unsupported feature types: %s", task$format(), str_collapse(tmp))
}

if ("missings" %nin% learner$properties) {
miss = task$missings() > 0L
if (any(miss)) {
stopf("Task '%s' has missing values in column(s) %s, but learner '%s' does not support this",
task$id, str_collapse(names(miss)[miss], quote = "'"), learner$id)
}
}
}

#' @export
Expand Down
5 changes: 0 additions & 5 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,3 @@ get_progressor = function(n, label = NA_character_) {
allow_utf8_names = function() {
isTRUE(getOption("mlr3.allow_utf8_names"))
}


reorder_vector = function(x, y, na_last = NA) {
order(match(x, y), na.last = na_last)
}
3 changes: 0 additions & 3 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ learner_train = function(learner, task, row_ids = NULL) {
model
}


assert_task(task)
assert_learner(learner)
assert_learnable(task, learner)

# subset to train set w/o cloning
if (!is.null(row_ids)) {
Expand Down Expand Up @@ -104,7 +102,6 @@ learner_predict = function(learner, task, row_ids = NULL) {

assert_task(task)
assert_learner(learner)
assert_learnable(task, learner)

# subset to test set w/o cloning
if (!is.null(row_ids)) {
Expand Down
1 change: 0 additions & 1 deletion inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ run_experiment = function(task, learner, seed = NULL) {

task = mlr3::assert_task(mlr3::as_task(task))
learner = mlr3::assert_learner(mlr3::as_learner(learner, clone = TRUE))
mlr3::assert_learnable(task, learner)
prediction = NULL
score = NULL
learner$encapsulate = c(train = "evaluate", predict = "evaluate")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ test_that("task$droplevels works", {
test_that("task$missings() works", {
task = tsk("pima")
x = task$missings()
y = map_int(task$data(), function(x) sum(is.na(x)))
y = map_int(task$data(), count_missing)
expect_equal(x, y[match(names(x), names(y))])
})

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_mlr_learners_classif_debug.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ test_that("NA predictions", {
learner = lrn("classif.debug", predict_missing = 0.5, predict_type = "response")
learner$train(task)
p = learner$predict(task)
expect_equal(sum(is.na(p$response)), 75L)
expect_equal(count_missing(p$response), 75L)

learner = lrn("classif.debug", predict_missing = 0.5, predict_type = "prob")
learner$train(task)
p = learner$predict(task)
expect_equal(sum(is.na(p$response)), 75L)
expect_equal(count_missing(p$response), 75L)
expect_equal(is.na(p$response), apply(p$prob, 1, anyMissing))
})

0 comments on commit 82762ed

Please sign in to comment.