From 82762ed22e4792b00a3c1c93bc8d1d1117388e0f Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Tue, 13 Apr 2021 11:25:49 +0200 Subject: [PATCH] Cache missing counts (#625) --- DESCRIPTION | 2 +- R/DataBackendDataTable.R | 29 ++++++++++++++++--- R/DataBackendMatrix.R | 4 +-- R/assertions.R | 14 +++++---- R/helper.R | 5 ---- R/worker.R | 3 -- inst/testthat/helper_autotest.R | 1 - tests/testthat/test_Task.R | 2 +- .../test_mlr_learners_classif_debug.R | 4 +-- 9 files changed, 40 insertions(+), 24 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 03acaed34..f850c7b3b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,7 +68,7 @@ Imports: lgr (>= 0.3.4), mlbench, mlr3measures (>= 0.3.0), - mlr3misc (>= 0.7.0), + mlr3misc (>= 0.9.0), parallelly, palmerpenguins, paradox (>= 0.6.0), diff --git a/R/DataBackendDataTable.R b/R/DataBackendDataTable.R index 7f434992a..4436dd79a 100644 --- a/R/DataBackendDataTable.R +++ b/R/DataBackendDataTable.R @@ -46,7 +46,11 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend, initialize = function(data, primary_key) { assert_data_table(data, col.names = "unique") super$initialize(setkeyv(data, primary_key), primary_key, data_formats = "data.table") - assert_choice(primary_key, names(data)) + ii = match(primary_key, names(data)) + if (is.na(ii)) { + stopf("Primary key '%s' not in 'data'", primary_key) + } + private$.cache = set_names(replace(rep(NA, ncol(data)), ii, FALSE), names(data)) }, #' @description @@ -105,8 +109,23 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend, #' #' @return Total of missing values per column (named `numeric()`). missings = function(rows, cols) { - data = self$data(rows, cols) - map_int(data, function(x) sum(is.na(x))) + missind = private$.cache + missind = missind[reorder_vector(names(missind), cols)] + + # update cache + ii = which(is.na(missind)) + if (length(ii)) { + missind[ii] = map_lgl(private$.data[, names(missind[ii]), with = FALSE], anyMissing) + private$.cache = insert_named(private$.cache, missind[ii]) + } + + # query required columns + query_cols = which(missind) + insert_named( + named_vector(names(missind), 0L), + map_int(self$data(rows, names(query_cols)), count_missing) + ) + } ), @@ -139,6 +158,8 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend, private = list( .calculate_hash = function() { hash(self$compact_seq, private$.data) - } + }, + + .cache = NULL ) ) diff --git a/R/DataBackendMatrix.R b/R/DataBackendMatrix.R index b174905ad..3aacf78af 100644 --- a/R/DataBackendMatrix.R +++ b/R/DataBackendMatrix.R @@ -158,8 +158,8 @@ DataBackendMatrix = R6Class("DataBackendMatrix", inherit = DataBackend, cloneabl cols_dense = intersect(cols, colnames(private$.data$dense)) res = c( - apply(private$.data$sparse[rows, cols_sparse, drop = FALSE], 2L, function(x) sum(is.na(x))), - private$.data$dense[, map_int(.SD, function(x) sum(is.na(x))), .SDcols = cols_dense] + apply(private$.data$sparse[rows, cols_sparse, drop = FALSE], 2L, count_missing), + private$.data$dense[, map_int(.SD, count_missing), .SDcols = cols_dense] ) res[reorder_vector(names(res), cols)] diff --git a/R/assertions.R b/R/assertions.R index 66058ec85..2edb63eb4 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -71,10 +71,6 @@ assert_tasks = function(tasks, task_type = NULL, feature_types = NULL, task_prop assert_learner = function(learner, task = NULL, properties = character(), .var.name = vname(learner)) { assert_class(learner, "Learner", .var.name = .var.name) - if (!is.null(task)) { - assert_learnable(task, learner) - } - if (length(properties)) { miss = setdiff(properties, learner$properties) if (length(miss)) { @@ -97,7 +93,7 @@ assert_learners = function(learners, task = NULL, properties = character(), .var #' @rdname mlr_assertions assert_learnable = function(task, learner) { pars = learner$param_set$get_values(type = "only_token") - if(length(pars) > 0) { + if (length(pars) > 0) { stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars))) } @@ -110,6 +106,14 @@ assert_learnable = function(task, learner) { if (length(tmp)) { stopf("%s has the following unsupported feature types: %s", task$format(), str_collapse(tmp)) } + + if ("missings" %nin% learner$properties) { + miss = task$missings() > 0L + if (any(miss)) { + stopf("Task '%s' has missing values in column(s) %s, but learner '%s' does not support this", + task$id, str_collapse(names(miss)[miss], quote = "'"), learner$id) + } + } } #' @export diff --git a/R/helper.R b/R/helper.R index de4c47b0f..d7b7b5414 100644 --- a/R/helper.R +++ b/R/helper.R @@ -72,8 +72,3 @@ get_progressor = function(n, label = NA_character_) { allow_utf8_names = function() { isTRUE(getOption("mlr3.allow_utf8_names")) } - - -reorder_vector = function(x, y, na_last = NA) { - order(match(x, y), na.last = na_last) -} diff --git a/R/worker.R b/R/worker.R index 709801495..7cf95af9a 100644 --- a/R/worker.R +++ b/R/worker.R @@ -17,10 +17,8 @@ learner_train = function(learner, task, row_ids = NULL) { model } - assert_task(task) assert_learner(learner) - assert_learnable(task, learner) # subset to train set w/o cloning if (!is.null(row_ids)) { @@ -104,7 +102,6 @@ learner_predict = function(learner, task, row_ids = NULL) { assert_task(task) assert_learner(learner) - assert_learnable(task, learner) # subset to test set w/o cloning if (!is.null(row_ids)) { diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index a6655ae44..3087b37df 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -215,7 +215,6 @@ run_experiment = function(task, learner, seed = NULL) { task = mlr3::assert_task(mlr3::as_task(task)) learner = mlr3::assert_learner(mlr3::as_learner(learner, clone = TRUE)) - mlr3::assert_learnable(task, learner) prediction = NULL score = NULL learner$encapsulate = c(train = "evaluate", predict = "evaluate") diff --git a/tests/testthat/test_Task.R b/tests/testthat/test_Task.R index 09bf16b38..f9ffcdb23 100644 --- a/tests/testthat/test_Task.R +++ b/tests/testthat/test_Task.R @@ -305,7 +305,7 @@ test_that("task$droplevels works", { test_that("task$missings() works", { task = tsk("pima") x = task$missings() - y = map_int(task$data(), function(x) sum(is.na(x))) + y = map_int(task$data(), count_missing) expect_equal(x, y[match(names(x), names(y))]) }) diff --git a/tests/testthat/test_mlr_learners_classif_debug.R b/tests/testthat/test_mlr_learners_classif_debug.R index 81dbd8ede..9468129d6 100644 --- a/tests/testthat/test_mlr_learners_classif_debug.R +++ b/tests/testthat/test_mlr_learners_classif_debug.R @@ -35,11 +35,11 @@ test_that("NA predictions", { learner = lrn("classif.debug", predict_missing = 0.5, predict_type = "response") learner$train(task) p = learner$predict(task) - expect_equal(sum(is.na(p$response)), 75L) + expect_equal(count_missing(p$response), 75L) learner = lrn("classif.debug", predict_missing = 0.5, predict_type = "prob") learner$train(task) p = learner$predict(task) - expect_equal(sum(is.na(p$response)), 75L) + expect_equal(count_missing(p$response), 75L) expect_equal(is.na(p$response), apply(p$prob, 1, anyMissing)) })