From 35b53468463c9033b98d11d6325616aad82f4a88 Mon Sep 17 00:00:00 2001 From: Aleksandr Meshkov Date: Wed, 23 Nov 2016 17:20:43 +0300 Subject: [PATCH] [R] Add early stop callback function (#3938) Adding early stop round function --- R-package/R/callback.R | 81 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/R-package/R/callback.R b/R-package/R/callback.R index 58b43b468ee8..8154f29cb580 100644 --- a/R-package/R/callback.R +++ b/R-package/R/callback.R @@ -75,3 +75,84 @@ mx.callback.save.checkpoint <- function(prefix, period=1) { } } +#' Early stop with different conditions +#' +#' Early stopping applying different conditions: hard thresholds or epochs number from the best score. Tested with "epoch.end.callback" function. +#' +#' @param train.metric Numeric. Hard threshold for the metric of the training data set (optional) +#' @param eval.metric Numeric. Hard threshold for the metric of the evaluating data set (if set, optional) +#' @param bad.steps Integer. How much epochs should gone from the best score? Use this option with evaluation data set +#' @param maximize Logical. Do your model use maximizing or minimizing optimization? +#' @param verbose Logical +#' +#' @export +#' +mx.callback.early.stop <- function(train.metric = NULL, eval.metric = NULL, bad.steps = NULL, maximize = FALSE, verbose = FALSE) { + + function(iteration, nbatch, env, verbose = verbose) { + + # hard threshold for train metric + if (!is.null(env$metric)) { + if (!is.null(train.metric)) { + result <- env$metric$get(env$train.metric) + if (result$value < train.metric | (maximize == TRUE & result$value > train.metric)) { + return(FALSE) + } + } + + # hard threshold for test metric + if (!is.null(eval.metric)) { + if (!is.null(env$eval.metric)) { + result <- env$metric$get(env$eval.metric) + if (result$value < eval.metric | (maximize == TRUE & result$value > eval.metric)) { + return(FALSE) + } + } + } + } + + # not worse than previous X steps + if (!is.null(bad.steps)) { + + # set / reset iteration variables + # it may be not the best practice to use global variables, + # but let's not touch "model.r" file + if (iteration == 1){ + # reset iterator + mx.best.iter <<- 1 + + # reset best score + if (maximize) { + mx.best.score <<- 0 + } + else { + mx.best.score <<- Inf + } + } + + # test early stop round + if (!is.null(env$eval.metric)) { + + result <- env$metric$get(env$eval.metric) + + if (result$value > mx.best.score | (maximize == TRUE & result$value < mx.best.score)) { + + if (mx.best.iter == bad.steps) { + if (verbose) { + cat(paste0("Best score=", mx.best.score, ", iteration [", iteration - bad.steps, "] \n")) + } + return(FALSE) + } else { + mx.best.iter <<- mx.best.iter + 1 + } + + } else { + mx.best.score <<- result$value + mx.best.iter <<- 1 + } + } + } + + return(TRUE) + } +}