Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[R] Add early stop callback function (#3938)
Browse files Browse the repository at this point in the history
Adding early stop round function
  • Loading branch information
ameshkoff authored and Qiang Kou (KK) committed Nov 23, 2016
1 parent 37f61d6 commit 35b5346
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,84 @@ mx.callback.save.checkpoint <- function(prefix, period=1) {
}
}

#' Early stop with different conditions
#'
#' Early stopping applying different conditions: hard thresholds or epochs number from the best score. Tested with "epoch.end.callback" function.
#'
#' @param train.metric Numeric. Hard threshold for the metric of the training data set (optional)
#' @param eval.metric Numeric. Hard threshold for the metric of the evaluating data set (if set, optional)
#' @param bad.steps Integer. How much epochs should gone from the best score? Use this option with evaluation data set
#' @param maximize Logical. Do your model use maximizing or minimizing optimization?
#' @param verbose Logical
#'
#' @export
#'
mx.callback.early.stop <- function(train.metric = NULL, eval.metric = NULL, bad.steps = NULL, maximize = FALSE, verbose = FALSE) {

function(iteration, nbatch, env, verbose = verbose) {

# hard threshold for train metric
if (!is.null(env$metric)) {
if (!is.null(train.metric)) {
result <- env$metric$get(env$train.metric)
if (result$value < train.metric | (maximize == TRUE & result$value > train.metric)) {
return(FALSE)
}
}

# hard threshold for test metric
if (!is.null(eval.metric)) {
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (result$value < eval.metric | (maximize == TRUE & result$value > eval.metric)) {
return(FALSE)
}
}
}
}

# not worse than previous X steps
if (!is.null(bad.steps)) {

# set / reset iteration variables
# it may be not the best practice to use global variables,
# but let's not touch "model.r" file
if (iteration == 1){
# reset iterator
mx.best.iter <<- 1

# reset best score
if (maximize) {
mx.best.score <<- 0
}
else {
mx.best.score <<- Inf
}
}

# test early stop round
if (!is.null(env$eval.metric)) {

result <- env$metric$get(env$eval.metric)

if (result$value > mx.best.score | (maximize == TRUE & result$value < mx.best.score)) {

if (mx.best.iter == bad.steps) {
if (verbose) {
cat(paste0("Best score=", mx.best.score, ", iteration [", iteration - bad.steps, "] \n"))
}
return(FALSE)
} else {
mx.best.iter <<- mx.best.iter + 1
}

} else {
mx.best.score <<- result$value
mx.best.iter <<- 1
}
}
}

return(TRUE)
}
}

0 comments on commit 35b5346

Please sign in to comment.