Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aSHAP - aggregated SHAP values for a set of observations #520

Merged
merged 14 commits into from
Jan 15, 2023
Merged
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: DALEX
Title: moDel Agnostic Language for Exploration and eXplanation
Version: 2.4.2
Version: 2.4.3
Authors@R: c(person("Przemyslaw", "Biecek", email = "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-8423-1823")),
person("Szymon", "Maksymiuk", role = "aut",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
DALEX 2.4.3
---------------------------------------------------------------
* added implementation of aSHAP (aggregated SHAP) and waterfall plot ([#519](https://github.com/ModelOriented/DALEX/issues/519))

DALEX 2.4.2
---------------------------------------------------------------
* removed the `yardstick` dependency
Expand Down
127 changes: 127 additions & 0 deletions R/plot_shap_aggregated.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#' Plot Generic for Break Down Objects
#'
#' Displays a waterfall aggregated shap plot for objects of \code{shap_aggregated} class.
#'
#' @param x an explanation object created with function \code{\link[DALEX]{explain}}.
#' @param ... other parameters.
#' @param max_features maximal number of features to be included in the plot. default value is \code{10}.
#' @param min_max a range of OX axis. By default \code{NA}, therefore it will be extracted from the contributions of \code{x}. But it can be set to some constants, useful if these plots are to be used for comparisons.
#' @param add_contributions if \code{TRUE}, variable contributions will be added to the plot
#' @param add_boxplots if \code{TRUE}, boxplots of SHAP will be shown
#' @param shift_contributions number describing how much labels should be shifted to the right, as a fraction of range. By default equal to \code{0.05}.
#' @param vcolors If \code{NA} (default), DrWhy colors are used.
#' @param vnames a character vector, if specified then will be used as labels on OY axis. By default NULL
#' @param digits number of decimal places (\code{\link{round}}) or significant digits (\code{\link{signif}}) to be used.
#' See the \code{rounding_function} argument.
#' @param rounding_function a function to be used for rounding numbers.
#' This should be \code{\link{signif}} which keeps a specified number of significant digits or \code{\link{round}} (which is default) to have the same precision for all components.
#' @param baseline if numeric then veritical line starts in \code{baseline}.
#' @param title a character. Plot title. By default \code{"Break Down profile"}.
#' @param subtitle a character. Plot subtitle. By default \code{""}.
#' @param max_vars alias for the \code{max_features} parameter.
#'
#' @return a \code{ggplot2} object.
#'
#' @import ggplot2
#' @importFrom utils tail
#'
#' @examples
#' library("DALEX")
#' set.seed(1313)
#' model_titanic_glm <- glm(survived ~ gender + age + fare,
#' data = titanic_imputed, family = "binomial")
#' explain_titanic_glm <- explain(model_titanic_glm,
#' data = titanic_imputed,
#' y = titanic_imputed$survived,
#' label = "glm")
#'
#' bd_glm <- shap_aggregated(explain_titanic_glm, titanic_imputed[1:10, ])
#' bd_glm
#' plot(bd_glm)
#' plot(bd_glm, max_features = 3)
#' plot(bd_glm, max_features = 3,
#' vnames = c("average","+ male","+ young","+ cheap ticket", "+ other factors", "final"))
#'
#' @export
plot.shap_aggregated <- function(x, ..., shift_contributions = 0.05, add_contributions = TRUE, add_boxplots = TRUE, max_features = 10, title = "Aggregated SHAP") {
x <- select_only_k_features(x, k = max_features)
aggregate <- x[[1]]
raw <- x[[2]]
class(aggregate) <- c('break_down', class(aggregate))

# ret has at least 3 columns: first and last are intercept and prediction
aggregate$mean_boxplot <- c(0, aggregate$cumulative[1:(nrow(aggregate)-2)], 0)
raw <- merge(x = as.data.frame(aggregate[,c('variable', 'position', 'mean_boxplot')]), y = raw, by.x = "variable", by.y = "variable_name", all.y = TRUE)

# max_features = max_features + 1 because we have one more class already - "+ all other features"
p <- plot(aggregate, ..., add_contributions = FALSE, max_features = max_features + 1, title = title)

if(add_boxplots){
p <- p + geom_boxplot(data = raw,
aes(y = contribution + mean_boxplot,
x = position + 0.5,
group = position,
fill = "#371ea3",
xmin = min(contribution) - 0.85,
xmax = max(contribution) + 0.85),
color = "#371ea3",
fill = "#371ea3",
width = 0.15)
}

if (add_contributions) {
aggregate$right_side <- pmax(aggregate$cumulative, aggregate$cumulative - aggregate$contribution)
drange <- diff(range(aggregate$cumulative))

p <- p + geom_text(aes(y = right_side),
vjust = -1,
nudge_y = drange*shift_contributions,
hjust = -0.2,
color = "#371ea3")
}


p
}

select_only_k_features <- function(input, k = 10) {
x <- input[[1]]
y <- input[[2]]

# filter-out redundant rows
contribution_sum <- tapply(x$contribution, x$variable_name, function(contribution) sum(abs(contribution), na.rm = TRUE))
contribution_ordered_vars <- names(sort(contribution_sum[!(names(contribution_sum) %in% c("", "intercept"))]))
variables_keep <- tail(contribution_ordered_vars, k)
variables_remove <- setdiff(contribution_ordered_vars, variables_keep)

if (length(variables_remove) > 0) {
x_remove <- x[x$variable_name %in% variables_remove,]
x_keep <- x[!(x$variable_name %in% c(variables_remove, "")),]
x_prediction <- x[x$variable == "prediction",]
row.names(x_prediction) <- x_prediction$label
remainings <- tapply(x_remove$contribution, x_remove$label, sum, na.rm=TRUE)
# fix position and cumulative in x_keep
x_keep$position <- as.numeric(as.factor(x_keep$position)) + 2
for (i in 1:nrow(x_keep)) {
if (x_keep[i,"variable_name"] == "intercept") {
x_keep[i,"cumulative"] <- x_keep[i,"contribution"]
} else {
x_keep[i,"cumulative"] <- x_keep[i - 1,"cumulative"] + x_keep[i,"contribution"]
}
}
# for each model we shall calculate the others statistic
x_others <- data.frame(variable = "+ all other factors",
contribution = remainings,
variable_name = "+ all other factors",
variable_value = "",
cumulative = x_prediction[names(remainings),"cumulative"],
sign = sign(remainings),
position = 2,
label = names(remainings))
#
x <- rbind(x_keep, x_others, x_prediction)
y$variable_name <- factor(ifelse(y$variable_name %in% variables_remove, "+ all other factors", as.character(y$variable_name)), levels = levels(x$variable_name))
}

list(aggregated = x, raw = y)
}
22 changes: 19 additions & 3 deletions R/predict_parts.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#' Instance Level Parts of the Model Predictions
#'
#' Instance Level Variable Attributions as Break Down, SHAP or Oscillations explanations.
#' Instance Level Variable Attributions as Break Down, SHAP, aggregated SHAP or Oscillations explanations.
#' Model prediction is decomposed into parts that are attributed for particular variables.
#' From DALEX version 1.0 this function calls the \code{\link[iBreakDown]{break_down}} or
#' \code{\link[iBreakDown:break_down_uncertainty]{shap}} functions from the \code{iBreakDown} package or
#' \code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package.
#' Find information how to use the \code{break_down} method here: \url{https://ema.drwhy.ai/breakDown.html}.
#' Find information how to use the \code{shap} method here: \url{https://ema.drwhy.ai/shapley.html}.
#' Find information how to use the \code{oscillations} method here: \url{https://ema.drwhy.ai/ceterisParibusOscillations.html}.
#' aSHAP method provides explanations for a set of observations based on SHAP.
#'
#' @param explainer a model to be explained, preprocessed by the \code{explain} function
#' @param new_observation a new observation for which predictions need to be explained
Expand All @@ -16,7 +17,7 @@
#' @param variables names of variables for which splits shall be calculated. Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
#' @param N the maximum number of observations used for calculation of attributions. By default NULL (use all) or 500 (for oscillations).
#' @param variable_splits_type how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
#' @param type the type of variable attributions. Either \code{shap}, \code{oscillations}, \code{oscillations_uni},
#' @param type the type of variable attributions. Either \code{shap}, \code{aggregated_shap}, \code{oscillations}, \code{oscillations_uni},
#' \code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}.
#'
#' @return Depending on the \code{type} there are different classes of the resulting object.
Expand Down Expand Up @@ -82,7 +83,8 @@ predict_parts <- function(explainer, new_observation, ..., N = if(substr(type, 1
"oscillations" = predict_parts_oscillations(explainer, new_observation, ...),
"oscillations_uni" = predict_parts_oscillations_uni(explainer, new_observation, ...),
"oscillations_emp" = predict_parts_oscillations_emp(explainer, new_observation, ...),
stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp'")
"shap_aggregated" = predict_parts_shap_aggregated(explainer, new_observation, ...),
stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp' or 'shap_aggregated'")
)
}

Expand Down Expand Up @@ -184,6 +186,20 @@ predict_parts_shap <- function(explainer, new_observation, ...) {
res
}

#' @name predict_parts
#' @export
predict_parts_shap_aggregated <- function(explainer, new_observations, ...) {
test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_shap_aggregated")

res <- shap_aggregated(explainer,
new_observations = new_observations,
...)

class(res) <- c('predict_parts', class(res))

res
}

#' @name predict_parts
#' @export
variable_attribution <- predict_parts
197 changes: 197 additions & 0 deletions R/shap_aggregated.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#' SHAP aggregated values
#'
#' This function works in a similar way to shap function from \code{iBreakDown} but it calculates explanations for a set of observation and then aggregates them.
#'
#' @param x an explainer created with function \code{\link[DALEX]{explain}} or a model.
#' @param data validation dataset, will be extracted from \code{x} if it is an explainer.
#' @param predict_function predict function, will be extracted from \code{x} if it is an explainer.
#' @param new_observations a set of new observations with columns that correspond to variables used in the model.
#' @param order if not \code{NULL}, then it will be a fixed order of variables. It can be a numeric vector or vector with names of variables.
#' @param ... other parameters.
#' @param label name of the model. By default it's extracted from the 'class' attribute of the model.
#'
#' @return an object of the \code{shap_aggregated} class.
#'
#' @references Explanatory Model Analysis. Explore, Explain and Examine Predictive Models. \url{https://ema.drwhy.ai}
#'
#'
#' @examples
#' library("DALEX")
#' set.seed(1313)
#' model_titanic_glm <- glm(survived ~ gender + age + fare,
#' data = titanic_imputed, family = "binomial")
#' explain_titanic_glm <- explain(model_titanic_glm,
#' data = titanic_imputed,
#' y = titanic_imputed$survived,
#' label = "glm")
#'
#' bd_glm <- shap_aggregated(explain_titanic_glm, titanic_imputed[1:10, ])
#' bd_glm
#' plot(bd_glm, max_features = 3)
#' @export
shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, ...) {
ret_raw <- data.frame(contribution = c(), variable_name = c(), label = c())

for(i in 1:nrow(new_observations)){
new_obs <- new_observations[i,]
shap_vals <- iBreakDown::shap(explainer, new_observation = new_obs, B = B, ...)
shap_vals <- shap_vals[shap_vals$B != 0, c('contribution', 'variable_name', 'label')]
ret_raw <- rbind(ret_raw, shap_vals)
}

data_preds <- predict(explainer, explainer$data)
mean_prediction <- mean(data_preds)

subset_preds <- predict(explainer, new_observations)
mean_subset <- mean(subset_preds)

if(is.null(order)) {
order <- calculate_order(explainer, mean_prediction, new_observations, predict)
}

ret <- raw_to_aggregated(ret_raw, mean_prediction, mean_subset, order, explainer$label)

predictions_new <- data.frame(contribution = subset_preds, variable_name='prediction', label=ret$label[1])
predictions_old <- data.frame(contribution = data_preds, variable_name='intercept', label=ret$label[1])
ret_raw <- rbind(ret_raw, predictions_new, predictions_old)

out <- list(aggregated = ret, raw = ret_raw)
class(out) <- c('shap_aggregated', class(out))

out
}

raw_to_aggregated <- function(ret_raw, mean_prediction, mean_subset, order, label){
ret <- aggregate(ret_raw$contribution, list(ret_raw$variable_name, ret_raw$label), FUN=mean)
colnames(ret) <- c('variable', 'label', 'contribution')
ret$variable <- as.character(ret$variable)
rownames(ret) <- ret$variable

ret <- ret[order,]

ret$position <- (nrow(ret) + 1):2
ret$sign <- ifelse(ret$contribution >= 0, "1", "-1")

ret <- rbind(ret, data.frame(variable = "intercept",
label = label,
contribution = mean_prediction,
position = max(ret$position) + 1,
sign = "X"),
make.row.names=FALSE)

ret <- rbind(ret, data.frame(variable = "prediction",
label = label,
contribution = mean_subset,
position = 1,
sign = "X"),
make.row.names=FALSE)

ret <- ret[call_order_func(ret$position, decreasing = TRUE), ]

ret$cumulative <- cumsum(ret$contribution)
ret$cumulative[nrow(ret)] <- ret$contribution[nrow(ret)]

ret$variable_name <- ret$variable
ret$variable_name <- factor(ret$variable_name, levels=c(ret$variable_name, ''))
ret$variable_name[nrow(ret)] <- ''

ret$variable_value <- '' # column for consistency

ret
}

call_order_func <- function(...) {
order(...)
}

calculate_1d_changes <- function(model, new_observation, data, predict_function) {
average_yhats <- list()
j <- 1
for (i in colnames(new_observation)) {
current_data <- data
current_data[,i] <- new_observation[,i]
yhats <- predict_function(model, current_data)
average_yhats[[j]] <- colMeans(as.data.frame(yhats))
j <- j + 1
}
names(average_yhats) <- colnames(new_observation)
average_yhats
}

generate_average_observation <- function(subset) {
is_numeric_not_int <- function(...){(is.numeric(...) & !is.integer(...)) | is.complex(...)}

# takes average / median of columns

# (numeric not integer) or complex
numeric_cols <- unlist(lapply(subset, is_numeric_not_int))
numeric_cols <- names(numeric_cols[numeric_cols == TRUE])
if(length(numeric_cols) == 1){
df_numeric <- data.frame(tmp = mean(subset[,numeric_cols]))
colnames(df_numeric) <- numeric_cols[1]
} else {
df_numeric <- t(as.data.frame(colMeans(subset[,numeric_cols])))
}

# integer
int_cols <- unlist(lapply(subset, is.integer))
int_cols <- names(int_cols[int_cols == TRUE])
df_int <- as.data.frame(lapply(int_cols, function(col) {
tab <- table(subset[,col])
tab_val <- attr(tab, 'dimnames')[[1]]
tab_val <- tab_val[which.max(tab)]
as.integer(tab_val)
}), stringsAsFactors = FALSE)
colnames(df_int) <- int_cols

# logical
logical_cols <- unlist(lapply(subset, is.logical))
logical_cols <- names(logical_cols[logical_cols == TRUE])
df_logical <- as.data.frame(lapply(logical_cols, function(col) {
tab <- table(subset[,col])
tab_val <- attr(tab, 'dimnames')[[1]]
tab_val <- tab_val[which.max(tab)]
as.logical(tab_val)
}), stringsAsFactors = FALSE)
colnames(df_logical) <- logical_cols

# factors
factor_cols <- unlist(lapply(subset, is.factor))
factor_cols <- names(factor_cols[factor_cols == TRUE])
df_factory <- as.data.frame(lapply(factor_cols, function(col) {
factor(names(which.max(table(subset[,col]))), levels = levels(subset[,col]))
}))
colnames(df_factory) <- factor_cols

# character
other_cols <- unlist(lapply(subset, is.character))
other_cols <- names(other_cols[other_cols == TRUE])
df_others <- as.data.frame(lapply(other_cols, function(col) {
tab <- table(subset[,col])
tab_names <- attr(tab, 'dimnames')[[1]]
tab_names[which.max(tab)]
}), stringsAsFactors = FALSE)
colnames(df_others) <- other_cols

outs <- list()
if(!ncol(df_numeric) == 0){outs <- append(list(df_numeric), outs)}
if(!ncol(df_int) == 0){outs <- append(list(df_int), outs)}
if(!ncol(df_logical) == 0){outs <- append(list(df_logical), outs)}
if(!ncol(df_factory) == 0){outs <- append(list(df_factory), outs)}
if(!ncol(df_others) == 0){outs <- append(list(df_others), outs)}

do.call("cbind", outs)[,colnames(subset)]
}

calculate_order <- function(x, mean_prediction, new_data, predict_function) {
baseline_yhat <- mean_prediction

generated_obs <- generate_average_observation(new_data)

average_yhats <- calculate_1d_changes(x, generated_obs, x$data, predict_function)
diffs_1d <- sapply(seq_along(average_yhats), function(i) {
sqrt(mean((average_yhats[[i]] - baseline_yhat)^2))
})

order(diffs_1d, decreasing = TRUE)
}
Loading