diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml index 9606be02..f357fa0b 100644 --- a/.github/workflows/check.yaml +++ b/.github/workflows/check.yaml @@ -51,10 +51,11 @@ jobs: CMDSTAN=/root/.cmdstan CMDSTAN_PATH=/root/.cmdstan CMDSTANR_NO_VER_CHECK=true - linter: - if: github.event_name == 'pull_request' - name: SuperLinter 🦸‍♀️ - uses: insightsengineering/r.pkg.template/.github/workflows/linter.yaml@main + ### Disabled pending https://github.com/insightsengineering/idr-tasks/issues/667 + # linter: + # if: github.event_name == 'pull_request' + # name: SuperLinter 🦸‍♀️ + # uses: insightsengineering/r.pkg.template/.github/workflows/linter.yaml@main roxygen: name: Roxygen 🅾 uses: insightsengineering/r.pkg.template/.github/workflows/roxygen.yaml@main diff --git a/.lintr b/.lintr index 254b14a8..b9e28d00 100755 --- a/.lintr +++ b/.lintr @@ -2,5 +2,6 @@ linters: linters_with_defaults( line_length_linter(120), object_name_linter = NULL, object_usage_linter = NULL, - cyclocomp_linter = NULL + cyclocomp_linter = NULL, + indentation_linter(4) ) diff --git a/DESCRIPTION b/DESCRIPTION index b1d9ced3..12ca3ac1 100755 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -42,6 +42,7 @@ Imports: tibble, methods, digest, + posterior, stats Suggests: bayesplot, diff --git a/NAMESPACE b/NAMESPACE index a0d328f3..40ce1f9a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -20,9 +20,9 @@ export(Surv) export(SurvivalExponential) export(SurvivalLogLogistic) export(SurvivalModel) +export(SurvivalSamples) export(SurvivalWeibullPH) export(addLink) -export(aggregate) export(as.data.frame) export(as.list) export(autoplot) @@ -37,6 +37,7 @@ export(link_gsf_identity) export(link_gsf_ttg) export(longitudinal) export(merge) +export(predict) export(prior_beta) export(prior_cauchy) export(prior_gamma) @@ -46,7 +47,6 @@ export(prior_normal) export(prior_std_normal) export(read_stan) export(sampleStanModel) -export(samples_median_ci) export(show) export(sim_lm_gsf) export(sim_lm_random_slope) @@ -54,7 +54,7 @@ export(sim_os_exponential) export(sim_os_loglogistic) export(sim_os_weibull) export(simulate_joint_data) -export(survival) +export(subset) export(write_stan) exportClasses(DataJoint) exportClasses(DataLongitudinal) @@ -76,7 +76,6 @@ exportClasses(StanModule) exportClasses(SurvivalExponential) exportClasses(SurvivalLogLogistic) exportClasses(SurvivalModel) -exportClasses(SurvivalSamples) exportClasses(SurvivalWeibullPH) exportClasses(link_gsf_abstract) exportClasses(link_gsf_dsld) @@ -87,7 +86,6 @@ exportMethods(as.character) exportMethods(generateQuantities) exportMethods(longitudinal) exportMethods(names) -exportMethods(survival) import(assertthat) import(ggplot2) import(methods) diff --git a/R/DataJoint.R b/R/DataJoint.R index 4887932a..774d9007 100755 --- a/R/DataJoint.R +++ b/R/DataJoint.R @@ -12,6 +12,19 @@ NULL #' @slot survival (`DataSurvival`)\cr object created by [DataSurvival()]. #' @slot longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()]. #' +#' @name DataJoint +#' @aliases DataJoint-class +#' +#' @param survival (`DataSurvival`)\cr object created by [DataSurvival()]. +#' @param longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()]. +#' @details +#' +#' - `as.list(x)`, `as(x, "list")`: Coerces x into a list of data components required +#' for fitting a [JointModel()]. See the vignette (TODO) for more details. +#' +#' @family DataJoint +#' +#' @export DataJoint #' @exportClass DataJoint .DataJoint <- setClass( Class = "DataJoint", @@ -23,17 +36,7 @@ NULL # DataJoint-constructors ---- -#' @rdname DataJoint-class -#' -#' @param survival (`DataSurvival`)\cr object created by [DataSurvival()]. -#' @param longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()]. -#' -#' @details -#' -#' - `as.list(x)`, `as(x, "list")`: Coerces x into a list of data components required -#' for fitting a [JointModel()]. See the vignette (TODO) for more details. -#' -#' @export +#' @rdname DataJoint DataJoint <- function(survival, longitudinal) { .DataJoint( survival = survival, @@ -76,6 +79,7 @@ setMethod( # coerce-DataJoint,list ---- +#' @param x (`DataJoint`) \cr A [DataJoint][DataJoint-class] object created by [DataJoint()] #' @rdname as.list #' #' @name coerce-DataJoint-list-method @@ -85,3 +89,81 @@ setAs( to = "list", def = function(from) as.list(from) ) + + +#' Subsetting `DataJoint` as a `data.frame` +#' +#' @param x (`DataJoint`) \cr A [DataJoint][DataJoint-class] object created by [DataJoint()] +#' @param patients (`character` or `list`)\cr the patients that you wish to subset the `data.frame` +#' to contain. See details. +#' +#' @description +#' +#' Coerces the object into a `data.frame` containing just event times and status +#' filtering for specific patients. If `patients` is a list then an additional variable `group` will be added +#' onto the dataset specifying which group the row belongs to. +#' +#' @examples +#' \dontrun{ +#' pts <- c("PT1", "PT3", "PT4") +#' subset(x, pts) +#' +#' groups <- list( +#' "g1" = c("PT1", "PT3", "PT4"), +#' "g2" = c("PT2", "PT3") +#' ) +#' subset(x, groups) +#' } +#' @family DataJoint +#' @family subset +setMethod( + f = "subset", + signature = "DataJoint", + definition = function(x, patients) { + data <- as.list(x) + dat <- data.frame( + time = data[["Times"]], + event = as.numeric(seq_along(data[["Times"]]) %in% data[["dead_ind_index"]]), + patient = names(data[["pt_to_ind"]]) + ) + subset_and_add_grouping(dat, patients) + } +) + + +#' `subset_and_add_grouping` +#' +#' @param dat (`data.frame`) \cr Must have a column called `patient` which corresponds to the +#' values passed to `groupings` +#' @param groupings (`character` or `list`)\cr the patients that you wish to subset the dataset +#' to contain. If `groupings` is a list then an additional variable `group` will be added +#' onto the dataset specifying which group the row belongs to. +#' +#' @details +#' Example of usage +#' ``` +#' pts <- c("PT1", "PT3", "PT4") +#' subset_and_add_grouping(dat, pts) +#' +#' groups <- list( +#' "g1" = c("PT1", "PT3", "PT4"), +#' "g2" = c("PT2", "PT3") +#' ) +#' subset_and_add_grouping(dat, groups) +#' ``` +#' +#' @keywords internal +subset_and_add_grouping <- function(dat, groupings) { + groupings <- decompose_patients(groupings, dat$patient)$groups + dat_subset_list <- lapply( + seq_along(groupings), + \(i) { + dat_reduced <- dat[dat$patient %in% groupings[[i]], , drop = FALSE] + dat_reduced[["group"]] <- names(groupings)[[i]] + dat_reduced + } + ) + x <- Reduce(rbind, dat_subset_list) + row.names(x) <- NULL + x +} diff --git a/R/DataLongitudinal.R b/R/DataLongitudinal.R index 8feb515f..8ff89d69 100644 --- a/R/DataLongitudinal.R +++ b/R/DataLongitudinal.R @@ -131,7 +131,6 @@ setMethod( ) |> t() - # TODO - Maybe reimplement this using a more robust approach than magic number adj_threshold <- if (is.null(vars$threshold)) { -999999 } else { diff --git a/R/JointModel.R b/R/JointModel.R index 05a1a25c..ad6c1245 100755 --- a/R/JointModel.R +++ b/R/JointModel.R @@ -111,14 +111,7 @@ setMethod( definition = function(object, data, ...) { args <- list(...) - - if (is(data, "DataJoint")) { - args[["data"]] <- as.list(data) - } else if (is(data, "list")) { - args[["data"]] <- data - } else { - stop("`data` must either be a list or a DataJoint object") - } + args[["data"]] <- as.list(data) if (!"init" %in% names(args)) { values_initial <- initialValues(object) @@ -135,8 +128,7 @@ setMethod( .JointModelSamples( model = object, - data = args$data, - init = values_initial_expanded, + data = data, results = results ) } diff --git a/R/JointModelSamples.R b/R/JointModelSamples.R index 303ec75d..40aadfc9 100644 --- a/R/JointModelSamples.R +++ b/R/JointModelSamples.R @@ -1,6 +1,8 @@ #' @include JointModel.R NULL +setOldClass("CmdStanMCMC") + # JointModelSamples-class ---- #' `JointModelSamples` @@ -9,7 +11,6 @@ NULL #' #' @slot model (`JointModel`)\cr the original model. #' @slot data (`list`)\cr data input. -#' @slot init (`list`)\cr initial values. #' @slot results (`CmdStanMCMC`)\cr the results from [sampleStanModel()]. #' #' @aliases JointModelSamples @@ -18,9 +19,8 @@ NULL "JointModelSamples", slots = c( model = "JointModel", - data = "list", - init = "list", - results = "ANY" + data = "DataJoint", + results = "CmdStanMCMC" ) ) @@ -39,7 +39,7 @@ setMethod( f = "generateQuantities", signature = c(object = "JointModelSamples"), definition = function(object, patients, time_grid_lm, time_grid_sm, ...) { - data <- object@data + data <- as.list(object@data) data[["n_lm_time_grid"]] <- length(time_grid_lm) data[["lm_time_grid"]] <- time_grid_lm data[["n_sm_time_grid"]] <- length(time_grid_sm) @@ -83,9 +83,9 @@ setMethod( signature = c(object = "JointModelSamples"), definition = function(object, patients = NULL, time_grid = NULL, ...) { - data <- object@data + data <- as.list(object@data) time_grid <- expand_time_grid(time_grid, max(data[["Tobs"]])) - patients <- expand_patients(patients, names(object@data$pt_to_ind)) + patients <- expand_patients(patients, names(data$pt_to_ind)) gq <- generateQuantities( object, patients = patients, @@ -122,65 +122,3 @@ setMethod( .LongitudinalSamples(results) } ) - - -# survival-JointModelSamples ---- - -#' @rdname survival -#' -#' @param patients (`character` or `NULL`)\cr optional subset of patients for -#' which the survival function samples should be extracted, the default `NULL` -#' meaning all patients. -#' -#' @param time_grid (`numeric`)\cr grid of time points to use for providing samples -#' of the survival model fit functions. If `NULL`, will be taken as a sequence of -#' 201 values from 0 to the maximum observed event time. -#' -#' @export -setMethod( - f = "survival", - signature = c(object = "JointModelSamples"), - definition = function(object, patients = NULL, time_grid = NULL, ...) { - - data <- object@data - time_grid <- expand_time_grid(time_grid, max(data[["Times"]])) - patients <- expand_patients(patients, names(object@data$pt_to_ind)) - gq <- generateQuantities( - object, - patients = patients, - time_grid_lm = numeric(0), - time_grid_sm = time_grid - ) - - log_surv_at_grid_samples <- gq$draws(format = "draws_matrix") - log_surv_at_obs_samples <- object@results$draws( - "log_surv_fit_at_obs_times", - format = "draws_matrix" - ) - - results <- list() - for (this_pt_ind in seq_along(patients)) { - this_pt <- patients[this_pt_ind] - this_result <- list() - patient_ind <- object@data$pt_to_ind[this_pt] - this_surv_fit_names <- sprintf( - "log_surv_fit_at_time_grid[%i,%i]", - this_pt_ind, - seq_along(time_grid) - ) - this_result$samples <- exp(log_surv_at_grid_samples[, this_surv_fit_names, drop = FALSE]) - this_result$summary <- data.frame( - time = time_grid, - samples_median_ci(this_result$samples) - ) - this_result$observed <- data.frame( - t = data$Times[patient_ind], - death = (patient_ind %in% object@data$dead_ind_index), - samples_median_ci(exp(log_surv_at_obs_samples[, patient_ind, drop = FALSE])) - ) - rownames(this_result$observed) <- this_pt - results[[this_pt]] <- this_result - } - .SurvivalSamples(results) - } -) diff --git a/R/StanModule.R b/R/StanModule.R index d095c869..112489ac 100755 --- a/R/StanModule.R +++ b/R/StanModule.R @@ -25,7 +25,6 @@ STAN_BLOCKS <- list( #' @keywords internal add_missing_stan_blocks <- function(x) { # STAN_BLOCKS is defined as a global variable in StanModule.R - # TODO - Make it an argument to the function for (block in names(STAN_BLOCKS)) { if (is.null(x[[block]])) { x[[block]] <- "" diff --git a/R/SurvivalSamples.R b/R/SurvivalSamples.R index 4ed6f477..1b542d94 100644 --- a/R/SurvivalSamples.R +++ b/R/SurvivalSamples.R @@ -1,111 +1,352 @@ + + +#' NULL Documentation page to house re-usable elements across SurvivalSamples methods/objects +#' +#' @param object (`SurvivalSamples`) \cr A [SurvivalSamples][SurvivalSamples-class] +#' object created by [SurvivalSamples()] +#' +#' @param patients (`character` or `list` or `NULL`)\cr which patients to calculate the desired +#' quantities for. +#' See "Patient Specification" for more details. +#' +#' @param type (`character`)\cr The quantity to be generated. +#' Must be one of `surv`, `haz`, `loghaz`, `cumhaz`. +#' +#' @param time_grid (`numeric`)\cr a vector of time points to calculate the desired quantity at. +#' +#' @name SurvivalSamples-Joint +#' +#' @section Patient Specification: +#' If `patients` is a character vector then quantities / summary statistics +#' will only be calculated for those specific patients +#' +#' If `patients` is a list then any elements with more than 1 patient ID will be grouped together +#' and their quantities / summary statistics (as selected by `type`) +#' will be calculated by taking the point-wise average. For example: +#' `patients = list("g1" = c("pt1", "pt2"), "g2" = c("pt3", "pt4"))` would result +#' in 2 groups being created whose values are the pointwise average +#' of `c("pt1", "pt2")` and `c("pt3", "pt4")` respectively. +#' +#' If `patients=NULL` then all patients from original dataset will be selected +#' @keywords internal +NULL + + # SurvivalSamples-class ---- -#' `SurvivalSamples` +#' `SurvivalSamples` Object and Constructor Function #' -#' This class is an extension of `list` so that we -#' can define specific survival postprocessing methods for it. +#' `SurvivalSamples()` creates a `SurvivalSamples` object. The `SurvivalSamples` class +#' is an extension of the [JointModelSamples][JointModelSamples-class] class so that we +#' can define specific survival postprocessing methods for it (e.g. there are no new +#' additional slots defined). #' -#' @aliases SurvivalSamples -#' @exportClass SurvivalSamples +#' @param object (`JointModelSamples`) \cr A [`JointModelSamples`][JointModelSamples-class] object +#' @family SurvivalSamples +#' @seealso [JointModelSamples][JointModelSamples-class] +#' @name SurvivalSamples-class +#' @export SurvivalSamples .SurvivalSamples <- setClass( "SurvivalSamples", - contains = "list" + contains = "JointModelSamples" ) -# SurvivalSamples-[ ---- - #' @rdname SurvivalSamples-class +SurvivalSamples <- function(object) { + .SurvivalSamples(object) +} + + +#' Predict #' -#' @param x (`SurvivalSamples`)\cr the samples object to subset. -#' @param i (`vector`)\cr the index vector. -#' @param j not used. -#' @param drop not used. -#' @param ... not used. +#' @inheritParams SurvivalSamples-Joint +#' @inheritSection SurvivalSamples-Joint Patient Specification #' -#' @returns The subsetted `SurvivalSamples` object. -#' @export +#' @description +#' This method returns a `data.frame` of key quantities (survival / log-hazard / etc) +#' for selected patients at a given set of time points. +#' +#' @family SurvivalSamples +#' @family predict setMethod( - f = "[", + f = "predict", signature = "SurvivalSamples", - definition = function(x, i, ...) { - # Note that we cannot use `callNextMethod()` here because `list` is S3. - x@.Data <- x@.Data[i] - x - } -) + definition = function( + object, + patients = NULL, + time_grid = NULL, + type = c("surv", "haz", "loghaz", "cumhaz") + ) { + type <- match.arg(type) -# SurvivalSamples-aggregate ---- + data <- as.list(object@data) + patients <- decompose_patients(patients, names(data$pt_to_ind)) -#' @rdname aggregate -#' @param groups (`list`)\cr defining into which groups to aggregate -#' individual samples, where the names are the new group labels and -#' the character vectors are the old individual sample labels. -setMethod( - f = "aggregate", - signature = c(x = "SurvivalSamples"), - definition = function(x, groups, ...) { - assert_that( - is.list(groups), - !is.null(names(groups)), - length(x) > 0 + time_grid <- expand_time_grid(time_grid, max(data[["Times"]])) + + gq <- generateQuantities( + object, + patients = patients$unique_values, + time_grid_lm = numeric(0), + time_grid_sm = time_grid + ) + + quantities <- extract_survival_quantities(gq, type) + + quantities_summarised <- lapply( + patients$indexes, + summarise_by_group, + time_index = seq_along(time_grid), + quantities = quantities ) - x_names <- names(x) - x <- as(x, "list") - names(x) <- x_names - time_grid <- x[[1]]$summary$time - results <- list() - for (this_group in names(groups)) { - this_result <- list() - this_ids <- groups[[this_group]] - # Samples. - this_ids_samples <- Map("[[", x[this_ids], i = "samples") - this_ids_samples <- lapply(this_ids_samples, "/", length(this_ids)) - this_result$samples <- Reduce(f = "+", x = this_ids_samples) - # Summary. - surv_fit <- samples_median_ci(this_result$samples) - this_result$summary <- cbind(time = time_grid, surv_fit) - # Observations. - this_ids_obs <- Map("[[", x[this_ids], i = "observed") - this_result$observed <- do.call(rbind, this_ids_obs) - # Save all. - results[[this_group]] <- this_result + + for (i in seq_along(quantities_summarised)) { + assert_that(nrow(quantities_summarised[[i]]) == length(time_grid)) + quantities_summarised[[i]][["time"]] <- time_grid + quantities_summarised[[i]][["group"]] <- names(patients$groups)[[i]] + quantities_summarised[[i]][["type"]] <- type } - .SurvivalSamples(results) + Reduce(rbind, quantities_summarised) } ) + +#' Summarise Quantities By Group +#' +#' This function takes a [posterior::draws_matrix()] (matrix of cmdstanr sample draws) and calculates +#' summary statistics (median / lower ci / upper ci) for selected columns. +#' A key feature is that it allows for columns to be aggregated together (see details). +#' +#' @param subject_index (`numeric`)\cr Which subject indices to extract from `quantities`. +#' See details. +#' +#' @param time_index (`numeric`)\cr Which time point indices to extract from `quantities`. +#' See details. +#' +#' @param quantities ([`posterior::draws_matrix`])\cr A matrix of sample draws. +#' See details. +#' +#' @details +#' It is assumed that `quantities` consists of the cartesian product +#' of subject indices and time indices. That is, if the matrix contains 4 subjects and 3 time +#' points then it should have 12 columns. +#' It is also assumed that each column of `quantities` are named as: +#' ``` +#' "quantity[x,y]" +#' ``` +#' Where +#' - `x` is the subject index +#' - `y` is the time point index +#' +#' The resulting `data.frame` that is created will have 1 row per value of `time_index` where +#' each row represents the summary statistics for that time point. +#' +#' Note that if multiple values are provided for `subject_index` then the pointwise average +#' will be calculated for each time point by taking the mean across the specified subjects +#' at that time point. +#' +#' @return A data frame containing 1 row per `time_index` (in order) with the following columns: +#' - `median` - The median value of the samples in `quantities` +#' - `lower` - The lower `95%` CI value of the samples in `quantities` +#' - `upper` - The upper `95%` CI value of the samples in `quantities` +#' +#' @keywords internal +summarise_by_group <- function(subject_index, time_index, quantities) { + assert_that( + is.numeric(subject_index), + is.numeric(time_index), + length(time_index) == length(unique(time_index)), + inherits(quantities, "draws_matrix") + ) + stacked_quantities <- array(dim = c( + nrow(quantities), + length(time_index), + length(subject_index) + )) + for (ind in seq_along(subject_index)) { + quantity_index <- sprintf( + "quantity[%i,%i]", + subject_index[ind], + time_index + ) + stacked_quantities[, , ind] <- quantities[, quantity_index] + } + averaged_quantities <- apply( + stacked_quantities, + c(1, 2), + mean, + simplify = TRUE + ) + samples_median_ci(averaged_quantities) +} + + + +#' Extract Survival Quantities +#' +#' Utility function to extract generated quantities from a [cmdstanr::CmdStanGQ] object. +#' Multiple quantities are generated by default so this is a convenience function to extract +#' the desired ones and return them them as a user friendly [posterior::draws_matrix] object +#' +#' @param gq (`CmdStanGQ`) \cr A [cmdstanr::CmdStanGQ] object created by [generateQuantities] +#' @inheritParams SurvivalSamples-Joint +#' @keywords internal +extract_survival_quantities <- function(gq, type = c("surv", "haz", "loghaz", "cumhaz")) { + type <- match.arg(type) + assert_that( + inherits(gq, "CmdStanGQ") + ) + meta <- switch(type, + surv = list("log_surv_fit_at_time_grid", exp), + cumhaz = list("log_surv_fit_at_time_grid", \(x) -x), + haz = list("log_haz_fit_at_time_grid", exp), + loghaz = list("log_haz_fit_at_time_grid", identity) + ) + result <- gq$draws(meta[[1]], format = "draws_matrix") + result_transformed <- meta[[2]](result) + cnames <- colnames(result_transformed) + colnames(result_transformed) <- gsub(meta[[1]], "quantity", cnames) + result_transformed +} + + # SurvivalSamples-autoplot ---- -#' @rdname autoplot -#' @param add_km (`flag`)\cr whether to add the Kaplan-Meier plot of the -#' survival data to the plots. +#' Automatic Plotting for SurvivalSamples +#' +#' @inheritParams SurvivalSamples-Joint +#' @inheritSection SurvivalSamples-Joint Patient Specification +#' @param add_km (`logical`) \cr If `TRUE` Kaplan-Meier curves will be added to the plot for +#' each group/patient as defined by `patients` +#' @param add_ci (`logical`) \cr If `TRUE` 95% CI will be added to the plot for +#' each group/patient as defined by `patients` +#' @param add_wrap (`logical`) \cr If `TRUE` will apply a [ggplot2::facet_wrap()] to the plot +#' by each group/patient as defined by `patients` +#' @param ... other arguments passed to plotting methods. +#' +#' @family autoplot +#' @family SurvivalSamples +#' setMethod( f = "autoplot", signature = c(object = "SurvivalSamples"), - function(object, add_km = TRUE, ...) { + function(object, + patients, + time_grid = NULL, + type = c("surv", "haz", "loghaz", "cumhaz"), + add_km = FALSE, + add_ci = TRUE, + add_wrap = TRUE, + ...) { assert_that(is.flag(add_km)) - - all_fit_dfs <- lapply(object, "[[", i = "summary") - all_fit_dfs_with_id <- Map(cbind, all_fit_dfs, id = names(object)) - all_fit_df <- do.call(rbind, all_fit_dfs_with_id) - - obs_dfs <- lapply(object, "[[", i = "observed") - obs_dfs_with_id <- Map(cbind, obs_dfs, id = names(object)) - all_obs_df <- do.call(rbind, obs_dfs_with_id) - # To avoid issues with logical status in the Kaplan-Meier layer. - all_obs_df$death_num <- as.numeric(all_obs_df$death) - - p <- ggplot() + - geom_line(aes(x = .data$time, y = .data$median), data = all_fit_df) + - geom_ribbon(aes(x = .data$time, ymin = .data$lower, ymax = .data$upper), data = all_fit_df, alpha = 0.3) + - xlab(expression(t)) + - ylab(expression(S(t))) + - facet_wrap(~ id) - if (add_km) { - p <- p + - ggplot2.utils::geom_km(aes(time = .data$t, status = .data$death_num), data = all_obs_df) + - ggplot2.utils::geom_km_ticks(aes(time = .data$t, status = .data$death_num), data = all_obs_df) - } - p + kmdf <- if (add_km) subset(object@data, patients) else NULL + type <- match.arg(type) + all_fit_df <- predict(object, patients, time_grid, type) + label <- switch(type, + "surv" = expression(S(t)), + "cumhaz" = expression(H(t)), + "haz" = expression(h(t)), + "loghaz" = expression(log(h(t))) + ) + survival_plot( + data = all_fit_df, + add_ci = add_ci, + add_wrap = add_wrap, + kmdf = kmdf, + y_label = label + ) } ) + + + +#' Survival Plot +#' +#' Internal plotting function to create survival plots with KM curve overlays +#' This function predominately exists to extract core logic into its own function +#' to enable easier unit testing. +#' +#' @param data (`data.frame`)\cr A `data.frame` of summary statistics for a survival +#' curve to be plotted. See details. +#' @param add_ci (`logical`)\cr Should confidence intervals be added? Default = `TRUE`. +#' @param add_wrap (`logical`)\cr Should the plots be wrapped by `data$group`? Default = `TRUE`. +#' @param kmdf (`data.frame` or `NULL`)\cr A `data.frame` of event times and status used to plot +#' overlaying KM curves. If `NULL` no KM curve will be plotted. See details. +#' @param y_label (`character` or `expression`) \cr Label to display on the y-axis. +#' Default = `expression(S(t))` +#' @param x_label (`character` or `expression`) \cr Label to display on the x-axis. +#' +#' @details +#' +#' ## `data` +#' Should contain the following columns: +#' - `time` - Time point +#' - `group` - The group in which the observation belongs to +#' - `median` - The median value for the summary statistic +#' - `upper` - The upper 95% CI for the summary statistic +#' - `lower` - The lower 95% CI for the summary statistic +#' +#' ## `kmdf` +#' Should contain the following columns: +#' - `time` - The time at which an event occurred +#' - `event` - 1/0 status indicator for the event +#' - `group` - Which group the event belongs to, should correspond to values in `data$group` +#' @keywords internal +survival_plot <- function( + data, + add_ci = TRUE, + add_wrap = TRUE, + kmdf = NULL, + y_label = expression(S(t)), + x_label = expression(t) +) { + assert_that( + is.flag(add_ci), + is.flag(add_wrap), + is.expression(y_label) || is.character(y_label), + is.expression(x_label) || is.character(x_label), + is.null(kmdf) | is.data.frame(kmdf) + ) + + p <- ggplot() + + xlab(x_label) + + ylab(y_label) + + theme_bw() + + if (add_wrap) { + p <- p + facet_wrap(~group) + aes_ci <- aes(x = .data$time, ymin = .data$lower, ymax = .data$upper) + aes_line <- aes(x = .data$time, y = .data$median) + aes_km <- aes(time = .data$time, status = .data$event) + } else { + aes_ci <- aes( + x = .data$time, + ymin = .data$lower, + ymax = .data$upper, + fill = .data$group, + group = .data$group + ) + aes_line <- aes( + x = .data$time, + y = .data$median, + colour = .data$group, + group = .data$group + ) + aes_km <- aes( + time = .data$time, + status = .data$event, + group = .data$group, + colour = .data$group + ) + } + p <- p + geom_line(aes_line, data = data) + if (add_ci) { + p <- p + geom_ribbon(aes_ci, data = data, alpha = 0.3) + } + if (!is.null(kmdf)) { + p <- p + + ggplot2.utils::geom_km(aes_km, data = kmdf) + + ggplot2.utils::geom_km_ticks(aes_km, data = kmdf) + } + p +} diff --git a/R/generics.R b/R/generics.R index 26692305..5204f0d2 100755 --- a/R/generics.R +++ b/R/generics.R @@ -67,20 +67,7 @@ NULL #' @export NULL -# aggregate ---- -#' Aggregation Methods for Different Classes -#' -#' These aggregation methods allow to group samples of different objects. -#' -#' @name aggregate -#' @aliases aggregate -#' -#' @param x what to aggregate. -#' @param ... other arguments passed to aggregation methods. -#' -#' @export -NULL # autoplot ---- @@ -89,14 +76,44 @@ NULL #' These plot methods visualize various objects. #' #' @name autoplot -#' @aliases autoplot #' #' @param object what to plot. #' @param ... other arguments passed to plotting methods. #' -#' @export +#' @family autoplot +#' +#' @export autoplot NULL + + +# predict ---- + +#' Model Predictions +#' +#' NOTE: This man page is for the `predict` S4 generic function defined within +#' jmpost. See [stats::predict()] for the default method. +#' @name predict +#' @inheritParams stats::predict +#' @family predict +#' @export predict +setGeneric("predict", predict, signature = c("object")) + + +# subset ---- + +#' Subsetting Vectors, Matrices and Data Frames +#' +#' NOTE: This man page is for the `subset` S4 generic function defined within +#' jmpost. See [base::subset()] for the default method. +#' @name subset +#' @inheritParams base::subset +#' @family subset +#' @export subset +setGeneric("subset", subset, signature = c("x")) + + + # show ---- #' Printing of Different Classes @@ -271,21 +288,6 @@ setGeneric( def = function(object, ...) standardGeneric("longitudinal") ) -# survival ---- - -#' `survival` -#' -#' Obtain the survival function samples from [`JointModelSamples`]. -#' -#' @param object samples to extract the survival function values from. -#' @param ... additional options. -#' -#' @export -setGeneric( - name = "survival", - def = function(object, ...) standardGeneric("survival") -) - # generateQuantities ---- diff --git a/R/utilities.R b/R/utilities.R index 7e5d9495..9d9e5c0b 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -162,11 +162,7 @@ replace_with_lookup <- function(sizes, data) { #' @param level (`number`)\cr credibility level to use for the credible intervals. #' #' @returns A `data.frame` with columns `median`, `lower` and `upper`. -#' @export -#' -#' @examples -#' set.seed(123) -#' samples <- cbind(rnorm(100, 0, 1), rexp(100, 0.5), rpois(100, 5)) +#' @keywords internal #' samples_median_ci(samples) samples_median_ci <- function(samples, level = 0.95) { assert_that(is.matrix(samples)) @@ -282,3 +278,68 @@ expand_patients <- function(patients, all_pts) { ) return(patients) } + + + +#' Decompose Patients into Relevant Components +#' +#' This function takes in a character vector or list of patients and decomposes it into a +#' structured format. +#' +#' The primary use of this function is to correctly setup indexing variables for +#' predicting survival quantities (see [`predict(SurvivalSamples)`][SurvivalSamples-class]) +#' +#' @param patients (`character` or `list`)\cr patient identifiers. If `NULL` will be set to `all_pts`. +#' +#' @param all_pts (`character`)\cr the set of allowable patient identifiers. +#' Will cause an error if any value of `patients` is not in this vector. +#' +#' @return A list containing three components: +#' - `groups`: (`list`)\cr each element of the list is a character vector +#' specifying which patients belong to a given "group" where the "group" is the element name +#' - `unique_values`: (`character`)\cr vector of the unique patients within `patients` +#' - `indexes`: (`list`)\cr each element is a named and is a numeric index vector +#' that maps the values of `grouped` to `unique_values` +#' @examples +#' \dontrun{ +#' result <- decompose_patients(c("A", "B"), c("A", "B", "C", "D")) +#' result <- decompose_patients( +#' list("g1" = c("A", "B"), "g2" = c("B", "C")), +#' c("A", "B", "C", "D") +#' ) +#' } +#' @seealso [expand_patients()], [`predict(SurvivalSamples)`][SurvivalSamples-class] +#' @keywords internal +decompose_patients <- function(patients, all_pts) { + if (is.character(patients) || is.null(patients)) { + patients <- expand_patients(patients, all_pts) + names(patients) <- patients + patients <- as.list(patients) + } + patients <- lapply( + patients, + expand_patients, + all_pts = all_pts + ) + assert_that( + is.list(patients), + length(unique(names(patients))) == length(patients), + all(vapply(patients, is.character, logical(1))) + ) + patients_vec_unordered <- unique(unlist(patients)) + patients_vec <- patients_vec_unordered[order(patients_vec_unordered)] + patients_lookup <- stats::setNames(seq_along(patients_vec), patients_vec) + patients_index <- lapply( + patients, + \(x) { + z <- patients_lookup[x] + names(z) <- NULL + z + } + ) + list( + groups = patients, + unique_values = patients_vec, + indexes = patients_index + ) +} diff --git a/R/zzz.R b/R/zzz.R index fa803427..013e0cf6 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -16,4 +16,15 @@ options(jmpost_opts[opt]) } } + + +} + + +# This only exists to silence the false positive R CMD CHECK warning about +# importing but not using the posterior package. posterior is a dependency +# of rcmdstan that we use a lot implicitly. Also we link to their documentation +# pages in ours +.never_run <- function() { + posterior::as_draws() } diff --git a/_pkgdown.yml b/_pkgdown.yml index 18ec54d3..7c84ba54 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -63,10 +63,8 @@ reference: - LongitudinalSamples - SurvivalSamples - longitudinal - - survival - - aggregate - - autoplot - - samples_median_ci + - starts_with("autoplot") + - starts_with("predict") - generateQuantities - title: Stan Code contents: @@ -85,3 +83,4 @@ reference: - merge - names - show + - starts_with("subset") diff --git a/design/tests/generated_quantities.R b/design/tests/generated_quantities.R index c7a4ad2c..fbca22f8 100644 --- a/design/tests/generated_quantities.R +++ b/design/tests/generated_quantities.R @@ -76,7 +76,7 @@ jdat <- DataJoint( ) ) -mp <- sampleStanModel( +stan_samples <- sampleStanModel( jm, data = jdat, iter_sampling = 400, @@ -86,13 +86,74 @@ mp <- sampleStanModel( ) -mp@results$summary() +stan_samples@results$summary() -pts <- sample(dat_os$pt, 4) +class(stan_samples@results) -longitudinal(mp, pts, c(0, 10, 40, 100, 200, 300)) |> - autoplot() +survival_samples <- SurvivalSamples(stan_samples) -survival(mp, pts, c(0, 10, 40, 100, 200, 300)) |> +longitudinal(stan_samples, sample(dat_os$pt, 5), c(0, 10, 40, 100, 200, 300)) |> autoplot() + +pts <- list( + "g1" = sample(dat_os$pt, 100), + "g2" = sample(dat_os$pt, 100) +) + +predict( + survival_samples, + patients = pts, + type = "haz", + time_grid = c(0, 100, 200) +) + + +pts <- sample(dat_os$pt, 4) +predict( + survival_samples, + patients = pts +) + +jdat@survival@data + + + + + + +autoplot( + survival_samples, + pts, + add_km = TRUE +) + +autoplot( + survival_samples, + pts, + add_wrap = FALSE +) + +pts <- list( + "g1" = sample(dat_os$pt, 4), + "g2" = sample(dat_os$pt, 4) +) + +autoplot( + survival_samples, + pts, + type = "cumhaz", + add_wrap = FALSE +) + +autoplot( + survival_samples, + pts, + type = "haz" +) + +autoplot( + survival_samples, + pts, + type = "loghaz" +) diff --git a/design/tests/os-loglogistic.R b/design/tests/os-loglogistic.R index 8a1f3c44..4ddc4a28 100644 --- a/design/tests/os-loglogistic.R +++ b/design/tests/os-loglogistic.R @@ -80,5 +80,3 @@ mcmc_trace(mp$draws("sm_logl_p")) # Surv(time, event) ~ cov_cat + cov_cont, # data = dat_os # ) - - diff --git a/inst/WORDLIST b/inst/WORDLIST index a2945945..d78633fe 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -78,3 +78,9 @@ xl xshift jinjar tempdir +ci +autoplot +cmdstanr +Summarise +SurvivalSamples +DataJoint diff --git a/inst/stan/base/survival.stan b/inst/stan/base/survival.stan index 908e24f1..eea7ca3c 100755 --- a/inst/stan/base/survival.stan +++ b/inst/stan/base/survival.stan @@ -79,7 +79,30 @@ functions { } - matrix generate_survival_quantities( + matrix generate_log_hazard_estimates( + array[] int pt_select_index, + vector sm_time_grid, + vector pars_os, + matrix pars_lm, + vector cov_contribution + ) { + int n_pt_select_index = num_elements(pt_select_index); + int n_sm_time_grid = num_elements(sm_time_grid); + matrix[n_pt_select_index, n_sm_time_grid] result; + for (i in 1:n_pt_select_index) { + int current_pt_index = pt_select_index[i]; + result[i, ] = to_row_vector(log_hazard( + rep_matrix(sm_time_grid, 1), + pars_os, + rep_matrix(pars_lm[current_pt_index, ], n_sm_time_grid), + rep_vector(cov_contribution[current_pt_index], n_sm_time_grid) + )); + } + return result; + } + + + matrix generate_log_survival_estimates( array[] int pt_select_index, vector sm_time_grid, vector pars_os, @@ -208,8 +231,9 @@ generated quantities { // Source - base/survival.stan // matrix[n_pt_select_index, n_sm_time_grid] log_surv_fit_at_time_grid; + matrix[n_pt_select_index, n_sm_time_grid] log_haz_fit_at_time_grid; if (n_sm_time_grid > 0) { - log_surv_fit_at_time_grid = generate_survival_quantities( + log_surv_fit_at_time_grid = generate_log_survival_estimates( pt_select_index, sm_time_grid, pars_os, @@ -218,5 +242,12 @@ generated quantities { weights, os_cov_contribution ); + log_haz_fit_at_time_grid = generate_log_hazard_estimates( + pt_select_index, + sm_time_grid, + pars_os, + pars_lm, + os_cov_contribution + ); } } diff --git a/man/DataJoint-class.Rd b/man/DataJoint.Rd similarity index 91% rename from man/DataJoint-class.Rd rename to man/DataJoint.Rd index 7d905db4..60258531 100644 --- a/man/DataJoint-class.Rd +++ b/man/DataJoint.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/DataJoint.R \docType{class} -\name{DataJoint-class} -\alias{DataJoint-class} -\alias{.DataJoint} +\name{DataJoint} \alias{DataJoint} +\alias{.DataJoint} +\alias{DataJoint-class} \title{\code{DataJoint}} \usage{ DataJoint(survival, longitudinal) @@ -32,3 +32,8 @@ for fitting a \code{\link[=JointModel]{JointModel()}}. See the vignette (TODO) f \item{\code{longitudinal}}{(\code{DataLongitudinal})\cr object created by \code{\link[=DataLongitudinal]{DataLongitudinal()}}.} }} +\seealso{ +Other DataJoint: +\code{\link{subset,DataJoint-method}} +} +\concept{DataJoint} diff --git a/man/JointModelSamples-class.Rd b/man/JointModelSamples-class.Rd index 7ed0b26f..fef291d1 100644 --- a/man/JointModelSamples-class.Rd +++ b/man/JointModelSamples-class.Rd @@ -16,8 +16,6 @@ Contains samples from a \code{\link{JointModel}}. \item{\code{data}}{(\code{list})\cr data input.} -\item{\code{init}}{(\code{list})\cr initial values.} - \item{\code{results}}{(\code{CmdStanMCMC})\cr the results from \code{\link[=sampleStanModel]{sampleStanModel()}}.} }} diff --git a/man/SurvivalSamples-Joint.Rd b/man/SurvivalSamples-Joint.Rd new file mode 100644 index 00000000..c2cf99c1 --- /dev/null +++ b/man/SurvivalSamples-Joint.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{SurvivalSamples-Joint} +\alias{SurvivalSamples-Joint} +\title{NULL Documentation page to house re-usable elements across SurvivalSamples methods/objects} +\arguments{ +\item{object}{(\code{SurvivalSamples}) \cr A \link[=SurvivalSamples-class]{SurvivalSamples} +object created by \code{\link[=SurvivalSamples]{SurvivalSamples()}}} + +\item{patients}{(\code{character} or \code{list} or \code{NULL})\cr which patients to calculate the desired +quantities for. +See "Patient Specification" for more details.} + +\item{type}{(\code{character})\cr The quantity to be generated. +Must be one of \code{surv}, \code{haz}, \code{loghaz}, \code{cumhaz}.} + +\item{time_grid}{(\code{numeric})\cr a vector of time points to calculate the desired quantity at.} +} +\description{ +NULL Documentation page to house re-usable elements across SurvivalSamples methods/objects +} +\section{Patient Specification}{ + +If \code{patients} is a character vector then quantities / summary statistics +will only be calculated for those specific patients + +If \code{patients} is a list then any elements with more than 1 patient ID will be grouped together +and their quantities / summary statistics (as selected by \code{type}) +will be calculated by taking the point-wise average. For example: +\code{patients = list("g1" = c("pt1", "pt2"), "g2" = c("pt3", "pt4"))} would result +in 2 groups being created whose values are the pointwise average +of \code{c("pt1", "pt2")} and \code{c("pt3", "pt4")} respectively. + +If \code{patients=NULL} then all patients from original dataset will be selected +} + +\keyword{internal} diff --git a/man/SurvivalSamples-class.Rd b/man/SurvivalSamples-class.Rd index 20788efb..69a53cea 100644 --- a/man/SurvivalSamples-class.Rd +++ b/man/SurvivalSamples-class.Rd @@ -5,26 +5,24 @@ \alias{SurvivalSamples-class} \alias{.SurvivalSamples} \alias{SurvivalSamples} -\alias{[,SurvivalSamples,ANY,ANY,ANY-method} -\title{\code{SurvivalSamples}} +\title{\code{SurvivalSamples} Object and Constructor Function} \usage{ -\S4method{[}{SurvivalSamples,ANY,ANY,ANY}(x, i, j, ..., drop = TRUE) +SurvivalSamples(object) } \arguments{ -\item{x}{(\code{SurvivalSamples})\cr the samples object to subset.} - -\item{i}{(\code{vector})\cr the index vector.} - -\item{j}{not used.} - -\item{...}{not used.} - -\item{drop}{not used.} -} -\value{ -The subsetted \code{SurvivalSamples} object. +\item{object}{(\code{JointModelSamples}) \cr A \code{\link[=JointModelSamples-class]{JointModelSamples}} object} } \description{ -This class is an extension of \code{list} so that we -can define specific survival postprocessing methods for it. +\code{SurvivalSamples()} creates a \code{SurvivalSamples} object. The \code{SurvivalSamples} class +is an extension of the \link[=JointModelSamples-class]{JointModelSamples} class so that we +can define specific survival postprocessing methods for it (e.g. there are no new +additional slots defined). +} +\seealso{ +\link[=JointModelSamples-class]{JointModelSamples} + +Other SurvivalSamples: +\code{\link{autoplot,SurvivalSamples-method}}, +\code{\link{predict,SurvivalSamples-method}} } +\concept{SurvivalSamples} diff --git a/man/aggregate.Rd b/man/aggregate.Rd deleted file mode 100644 index 4914116f..00000000 --- a/man/aggregate.Rd +++ /dev/null @@ -1,21 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/generics.R, R/SurvivalSamples.R -\name{aggregate} -\alias{aggregate} -\alias{aggregate,SurvivalSamples-method} -\title{Aggregation Methods for Different Classes} -\usage{ -\S4method{aggregate}{SurvivalSamples}(x, groups, ...) -} -\arguments{ -\item{x}{what to aggregate.} - -\item{groups}{(\code{list})\cr defining into which groups to aggregate -individual samples, where the names are the new group labels and -the character vectors are the old individual sample labels.} - -\item{...}{other arguments passed to aggregation methods.} -} -\description{ -These aggregation methods allow to group samples of different objects. -} diff --git a/man/as.list.Rd b/man/as.list.Rd index 44d3f543..dc828a71 100644 --- a/man/as.list.Rd +++ b/man/as.list.Rd @@ -27,7 +27,7 @@ \S4method{as.list}{StanModel}(x) } \arguments{ -\item{x}{what to convert.} +\item{x}{(\code{DataJoint}) \cr A \link[=DataJoint-class]{DataJoint} object created by \code{\link[=DataJoint]{DataJoint()}}} \item{...}{not used.} } diff --git a/man/autoplot-SurvivalSamples-method.Rd b/man/autoplot-SurvivalSamples-method.Rd new file mode 100644 index 00000000..f4624cd1 --- /dev/null +++ b/man/autoplot-SurvivalSamples-method.Rd @@ -0,0 +1,69 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{autoplot,SurvivalSamples-method} +\alias{autoplot,SurvivalSamples-method} +\title{Automatic Plotting for SurvivalSamples} +\usage{ +\S4method{autoplot}{SurvivalSamples}( + object, + patients, + time_grid = NULL, + type = c("surv", "haz", "loghaz", "cumhaz"), + add_km = FALSE, + add_ci = TRUE, + add_wrap = TRUE, + ... +) +} +\arguments{ +\item{object}{(\code{SurvivalSamples}) \cr A \link[=SurvivalSamples-class]{SurvivalSamples} +object created by \code{\link[=SurvivalSamples]{SurvivalSamples()}}} + +\item{patients}{(\code{character} or \code{list} or \code{NULL})\cr which patients to calculate the desired +quantities for. +See "Patient Specification" for more details.} + +\item{time_grid}{(\code{numeric})\cr a vector of time points to calculate the desired quantity at.} + +\item{type}{(\code{character})\cr The quantity to be generated. +Must be one of \code{surv}, \code{haz}, \code{loghaz}, \code{cumhaz}.} + +\item{add_km}{(\code{logical}) \cr If \code{TRUE} Kaplan-Meier curves will be added to the plot for +each group/patient as defined by \code{patients}} + +\item{add_ci}{(\code{logical}) \cr If \code{TRUE} 95\% CI will be added to the plot for +each group/patient as defined by \code{patients}} + +\item{add_wrap}{(\code{logical}) \cr If \code{TRUE} will apply a \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}} to the plot +by each group/patient as defined by \code{patients}} + +\item{...}{other arguments passed to plotting methods.} +} +\description{ +Automatic Plotting for SurvivalSamples +} +\section{Patient Specification}{ + +If \code{patients} is a character vector then quantities / summary statistics +will only be calculated for those specific patients + +If \code{patients} is a list then any elements with more than 1 patient ID will be grouped together +and their quantities / summary statistics (as selected by \code{type}) +will be calculated by taking the point-wise average. For example: +\code{patients = list("g1" = c("pt1", "pt2"), "g2" = c("pt3", "pt4"))} would result +in 2 groups being created whose values are the pointwise average +of \code{c("pt1", "pt2")} and \code{c("pt3", "pt4")} respectively. + +If \code{patients=NULL} then all patients from original dataset will be selected +} + +\seealso{ +Other autoplot: +\code{\link{autoplot}()} + +Other SurvivalSamples: +\code{\link{SurvivalSamples-class}}, +\code{\link{predict,SurvivalSamples-method}} +} +\concept{SurvivalSamples} +\concept{autoplot} diff --git a/man/autoplot.Rd b/man/autoplot.Rd index d3f170aa..7ede6b66 100644 --- a/man/autoplot.Rd +++ b/man/autoplot.Rd @@ -1,24 +1,22 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/generics.R, R/LongitudinalSamples.R, -% R/SurvivalSamples.R +% Please edit documentation in R/generics.R, R/LongitudinalSamples.R \name{autoplot} \alias{autoplot} \alias{autoplot,LongitudinalSamples-method} -\alias{autoplot,SurvivalSamples-method} \title{Plotting Methods for Different Classes} \usage{ \S4method{autoplot}{LongitudinalSamples}(object, ...) - -\S4method{autoplot}{SurvivalSamples}(object, add_km = TRUE, ...) } \arguments{ \item{object}{what to plot.} \item{...}{other arguments passed to plotting methods.} - -\item{add_km}{(\code{flag})\cr whether to add the Kaplan-Meier plot of the -survival data to the plots.} } \description{ These plot methods visualize various objects. } +\seealso{ +Other autoplot: +\code{\link{autoplot,SurvivalSamples-method}} +} +\concept{autoplot} diff --git a/man/decompose_patients.Rd b/man/decompose_patients.Rd new file mode 100644 index 00000000..d79d2fc2 --- /dev/null +++ b/man/decompose_patients.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utilities.R +\name{decompose_patients} +\alias{decompose_patients} +\title{Decompose Patients into Relevant Components} +\usage{ +decompose_patients(patients, all_pts) +} +\arguments{ +\item{patients}{(\code{character} or \code{list})\cr patient identifiers. If \code{NULL} will be set to \code{all_pts}.} + +\item{all_pts}{(\code{character})\cr the set of allowable patient identifiers. +Will cause an error if any value of \code{patients} is not in this vector.} +} +\value{ +A list containing three components: +\itemize{ +\item \code{groups}: (\code{list})\cr each element of the list is a character vector +specifying which patients belong to a given "group" where the "group" is the element name +\item \code{unique_values}: (\code{character})\cr vector of the unique patients within \code{patients} +\item \code{indexes}: (\code{list})\cr each element is a named and is a numeric index vector +that maps the values of \code{grouped} to \code{unique_values} +} +} +\description{ +This function takes in a character vector or list of patients and decomposes it into a +structured format. +} +\details{ +The primary use of this function is to correctly setup indexing variables for +predicting survival quantities (see \code{\link[=SurvivalSamples-class]{predict(SurvivalSamples)}}) +} +\examples{ +\dontrun{ +result <- decompose_patients(c("A", "B"), c("A", "B", "C", "D")) +result <- decompose_patients( + list("g1" = c("A", "B"), "g2" = c("B", "C")), + c("A", "B", "C", "D") +) +} +} +\seealso{ +\code{\link[=expand_patients]{expand_patients()}}, \code{\link[=SurvivalSamples-class]{predict(SurvivalSamples)}} +} +\keyword{internal} diff --git a/man/extract_survival_quantities.Rd b/man/extract_survival_quantities.Rd new file mode 100644 index 00000000..315888cb --- /dev/null +++ b/man/extract_survival_quantities.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{extract_survival_quantities} +\alias{extract_survival_quantities} +\title{Extract Survival Quantities} +\usage{ +extract_survival_quantities(gq, type = c("surv", "haz", "loghaz", "cumhaz")) +} +\arguments{ +\item{gq}{(\code{CmdStanGQ}) \cr A \link[cmdstanr:CmdStanGQ]{cmdstanr::CmdStanGQ} object created by \link{generateQuantities}} + +\item{type}{(\code{character})\cr The quantity to be generated. +Must be one of \code{surv}, \code{haz}, \code{loghaz}, \code{cumhaz}.} +} +\description{ +Utility function to extract generated quantities from a \link[cmdstanr:CmdStanGQ]{cmdstanr::CmdStanGQ} object. +Multiple quantities are generated by default so this is a convenience function to extract +the desired ones and return them them as a user friendly \link[posterior:draws_matrix]{posterior::draws_matrix} object +} +\keyword{internal} diff --git a/man/predict-SurvivalSamples-method.Rd b/man/predict-SurvivalSamples-method.Rd new file mode 100644 index 00000000..c8e5d2c6 --- /dev/null +++ b/man/predict-SurvivalSamples-method.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{predict,SurvivalSamples-method} +\alias{predict,SurvivalSamples-method} +\title{Predict} +\usage{ +\S4method{predict}{SurvivalSamples}( + object, + patients = NULL, + time_grid = NULL, + type = c("surv", "haz", "loghaz", "cumhaz") +) +} +\arguments{ +\item{object}{(\code{SurvivalSamples}) \cr A \link[=SurvivalSamples-class]{SurvivalSamples} +object created by \code{\link[=SurvivalSamples]{SurvivalSamples()}}} + +\item{patients}{(\code{character} or \code{list} or \code{NULL})\cr which patients to calculate the desired +quantities for. +See "Patient Specification" for more details.} + +\item{time_grid}{(\code{numeric})\cr a vector of time points to calculate the desired quantity at.} + +\item{type}{(\code{character})\cr The quantity to be generated. +Must be one of \code{surv}, \code{haz}, \code{loghaz}, \code{cumhaz}.} +} +\description{ +This method returns a \code{data.frame} of key quantities (survival / log-hazard / etc) +for selected patients at a given set of time points. +} +\section{Patient Specification}{ + +If \code{patients} is a character vector then quantities / summary statistics +will only be calculated for those specific patients + +If \code{patients} is a list then any elements with more than 1 patient ID will be grouped together +and their quantities / summary statistics (as selected by \code{type}) +will be calculated by taking the point-wise average. For example: +\code{patients = list("g1" = c("pt1", "pt2"), "g2" = c("pt3", "pt4"))} would result +in 2 groups being created whose values are the pointwise average +of \code{c("pt1", "pt2")} and \code{c("pt3", "pt4")} respectively. + +If \code{patients=NULL} then all patients from original dataset will be selected +} + +\seealso{ +Other SurvivalSamples: +\code{\link{SurvivalSamples-class}}, +\code{\link{autoplot,SurvivalSamples-method}} + +Other predict: +\code{\link{predict}()} +} +\concept{SurvivalSamples} +\concept{predict} diff --git a/man/predict.Rd b/man/predict.Rd new file mode 100644 index 00000000..62c833bb --- /dev/null +++ b/man/predict.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R +\name{predict} +\alias{predict} +\title{Model Predictions} +\usage{ +predict(object, ...) +} +\arguments{ +\item{object}{a model object for which prediction is desired.} + +\item{...}{additional arguments affecting the predictions produced.} +} +\description{ +NOTE: This man page is for the \code{predict} S4 generic function defined within +jmpost. See \code{\link[stats:predict]{stats::predict()}} for the default method. +} +\seealso{ +Other predict: +\code{\link{predict,SurvivalSamples-method}} +} +\concept{predict} diff --git a/man/samples_median_ci.Rd b/man/samples_median_ci.Rd index b6e17a90..bea5a03d 100644 --- a/man/samples_median_ci.Rd +++ b/man/samples_median_ci.Rd @@ -17,8 +17,5 @@ A \code{data.frame} with columns \code{median}, \code{lower} and \code{upper}. \description{ Obtain Median and Credible Intervals from MCMC samples } -\examples{ -set.seed(123) -samples <- cbind(rnorm(100, 0, 1), rexp(100, 0.5), rpois(100, 5)) -samples_median_ci(samples) -} +\keyword{internal} +\keyword{samples_median_ci(samples)} diff --git a/man/subset-DataJoint-method.Rd b/man/subset-DataJoint-method.Rd new file mode 100644 index 00000000..30406de0 --- /dev/null +++ b/man/subset-DataJoint-method.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/DataJoint.R +\name{subset,DataJoint-method} +\alias{subset,DataJoint-method} +\title{Subsetting \code{DataJoint} as a \code{data.frame}} +\usage{ +\S4method{subset}{DataJoint}(x, patients) +} +\arguments{ +\item{x}{(\code{DataJoint}) \cr A \link[=DataJoint-class]{DataJoint} object created by \code{\link[=DataJoint]{DataJoint()}}} + +\item{patients}{(\code{character} or \code{list})\cr the patients that you wish to subset the \code{data.frame} +to contain. See details.} +} +\description{ +Coerces the object into a \code{data.frame} containing just event times and status +filtering for specific patients. If \code{patients} is a list then an additional variable \code{group} will be added +onto the dataset specifying which group the row belongs to. +} +\examples{ +\dontrun{ +pts <- c("PT1", "PT3", "PT4") +subset(x, pts) + +groups <- list( + "g1" = c("PT1", "PT3", "PT4"), + "g2" = c("PT2", "PT3") +) +subset(x, groups) +} +} +\seealso{ +Other DataJoint: +\code{\link{DataJoint}} + +Other subset: +\code{\link{subset}()} +} +\concept{DataJoint} +\concept{subset} diff --git a/man/subset.Rd b/man/subset.Rd new file mode 100644 index 00000000..a2061433 --- /dev/null +++ b/man/subset.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R +\name{subset} +\alias{subset} +\title{Subsetting Vectors, Matrices and Data Frames} +\usage{ +subset(x, ...) +} +\arguments{ +\item{x}{object to be subsetted.} + +\item{...}{further arguments to be passed to or from other methods.} +} +\description{ +NOTE: This man page is for the \code{subset} S4 generic function defined within +jmpost. See \code{\link[base:subset]{base::subset()}} for the default method. +} +\seealso{ +Other subset: +\code{\link{subset,DataJoint-method}} +} +\concept{subset} diff --git a/man/subset_and_add_grouping.Rd b/man/subset_and_add_grouping.Rd new file mode 100644 index 00000000..c08c5812 --- /dev/null +++ b/man/subset_and_add_grouping.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/DataJoint.R +\name{subset_and_add_grouping} +\alias{subset_and_add_grouping} +\title{\code{subset_and_add_grouping}} +\usage{ +subset_and_add_grouping(dat, groupings) +} +\arguments{ +\item{dat}{(\code{data.frame}) \cr Must have a column called \code{patient} which corresponds to the +values passed to \code{groupings}} + +\item{groupings}{(\code{character} or \code{list})\cr the patients that you wish to subset the dataset +to contain. If \code{groupings} is a list then an additional variable \code{group} will be added +onto the dataset specifying which group the row belongs to.} +} +\description{ +\code{subset_and_add_grouping} +} +\details{ +Example of usage + +\if{html}{\out{
}}\preformatted{pts <- c("PT1", "PT3", "PT4") +subset_and_add_grouping(dat, pts) + +groups <- list( + "g1" = c("PT1", "PT3", "PT4"), + "g2" = c("PT2", "PT3") +) +subset_and_add_grouping(dat, groups) +}\if{html}{\out{
}} +} +\keyword{internal} diff --git a/man/summarise_by_group.Rd b/man/summarise_by_group.Rd new file mode 100644 index 00000000..54cce9f9 --- /dev/null +++ b/man/summarise_by_group.Rd @@ -0,0 +1,54 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{summarise_by_group} +\alias{summarise_by_group} +\title{Summarise Quantities By Group} +\usage{ +summarise_by_group(subject_index, time_index, quantities) +} +\arguments{ +\item{subject_index}{(\code{numeric})\cr Which subject indices to extract from \code{quantities}. +See details.} + +\item{time_index}{(\code{numeric})\cr Which time point indices to extract from \code{quantities}. +See details.} + +\item{quantities}{(\code{\link[posterior:draws_matrix]{posterior::draws_matrix}})\cr A matrix of sample draws. +See details.} +} +\value{ +A data frame containing 1 row per \code{time_index} (in order) with the following columns: +\itemize{ +\item \code{median} - The median value of the samples in \code{quantities} +\item \code{lower} - The lower \verb{95\%} CI value of the samples in \code{quantities} +\item \code{upper} - The upper \verb{95\%} CI value of the samples in \code{quantities} +} +} +\description{ +This function takes a \code{\link[posterior:draws_matrix]{posterior::draws_matrix()}} (matrix of cmdstanr sample draws) and calculates +summary statistics (median / lower ci / upper ci) for selected columns. +A key feature is that it allows for columns to be aggregated together (see details). +} +\details{ +It is assumed that \code{quantities} consists of the cartesian product +of subject indices and time indices. That is, if the matrix contains 4 subjects and 3 time +points then it should have 12 columns. +It is also assumed that each column of \code{quantities} are named as: + +\if{html}{\out{
}}\preformatted{"quantity[x,y]" +}\if{html}{\out{
}} + +Where +\itemize{ +\item \code{x} is the subject index +\item \code{y} is the time point index +} + +The resulting \code{data.frame} that is created will have 1 row per value of \code{time_index} where +each row represents the summary statistics for that time point. + +Note that if multiple values are provided for \code{subject_index} then the pointwise average +will be calculated for each time point by taking the mean across the specified subjects +at that time point. +} +\keyword{internal} diff --git a/man/survival.Rd b/man/survival.Rd deleted file mode 100644 index fe98a819..00000000 --- a/man/survival.Rd +++ /dev/null @@ -1,27 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/generics.R, R/JointModelSamples.R -\name{survival} -\alias{survival} -\alias{survival,JointModelSamples-method} -\title{\code{survival}} -\usage{ -survival(object, ...) - -\S4method{survival}{JointModelSamples}(object, patients = NULL, time_grid = NULL, ...) -} -\arguments{ -\item{object}{samples to extract the survival function values from.} - -\item{...}{additional options.} - -\item{patients}{(\code{character} or \code{NULL})\cr optional subset of patients for -which the survival function samples should be extracted, the default \code{NULL} -meaning all patients.} - -\item{time_grid}{(\code{numeric})\cr grid of time points to use for providing samples -of the survival model fit functions. If \code{NULL}, will be taken as a sequence of -201 values from 0 to the maximum observed event time.} -} -\description{ -Obtain the survival function samples from \code{\link{JointModelSamples}}. -} diff --git a/man/survival_plot.Rd b/man/survival_plot.Rd new file mode 100644 index 00000000..fbc256b0 --- /dev/null +++ b/man/survival_plot.Rd @@ -0,0 +1,60 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/SurvivalSamples.R +\name{survival_plot} +\alias{survival_plot} +\title{Survival Plot} +\usage{ +survival_plot( + data, + add_ci = TRUE, + add_wrap = TRUE, + kmdf = NULL, + y_label = expression(S(t)), + x_label = expression(t) +) +} +\arguments{ +\item{data}{(\code{data.frame})\cr A \code{data.frame} of summary statistics for a survival +curve to be plotted. See details.} + +\item{add_ci}{(\code{logical})\cr Should confidence intervals be added? Default = \code{TRUE}.} + +\item{add_wrap}{(\code{logical})\cr Should the plots be wrapped by \code{data$group}? Default = \code{TRUE}.} + +\item{kmdf}{(\code{data.frame} or \code{NULL})\cr A \code{data.frame} of event times and status used to plot +overlaying KM curves. If \code{NULL} no KM curve will be plotted. See details.} + +\item{y_label}{(\code{character} or \code{expression}) \cr Label to display on the y-axis. +Default = \code{expression(S(t))}} + +\item{x_label}{(\code{character} or \code{expression}) \cr Label to display on the x-axis.} +} +\description{ +Internal plotting function to create survival plots with KM curve overlays +This function predominately exists to extract core logic into its own function +to enable easier unit testing. +} +\details{ +\subsection{\code{data}}{ + +Should contain the following columns: +\itemize{ +\item \code{time} - Time point +\item \code{group} - The group in which the observation belongs to +\item \code{median} - The median value for the summary statistic +\item \code{upper} - The upper 95\% CI for the summary statistic +\item \code{lower} - The lower 95\% CI for the summary statistic +} +} + +\subsection{\code{kmdf}}{ + +Should contain the following columns: +\itemize{ +\item \code{time} - The time at which an event occurred +\item \code{event} - 1/0 status indicator for the event +\item \code{group} - Which group the event belongs to, should correspond to values in \code{data$group} +} +} +} +\keyword{internal} diff --git a/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci-km-ggplot2-integration.svg b/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci-km-ggplot2-integration.svg new file mode 100644 index 00000000..42eb1e64 --- /dev/null +++ b/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci-km-ggplot2-integration.svg @@ -0,0 +1,314 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0.00 +0.25 +0.50 +0.75 +1.00 + + + + + + + + + + +0 +100 +200 +300 +400 +t +h +d +3 +1 +2 + ++ +S +( +t +2 +) + +group + + + + + + + + + + + + + + + +A +B +C +survival_plot with no wrap and no ci + km + ggplot2 integration + + diff --git a/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci.svg b/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci.svg new file mode 100644 index 00000000..d84e451c --- /dev/null +++ b/tests/testthat/_snaps/survival_plot/survival-plot-with-no-wrap-and-no-ci.svg @@ -0,0 +1,100 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0.00 +0.25 +0.50 +0.75 +1.00 + + + + + + + + + + +0 +100 +200 +300 +400 +t +h +d +3 +1 +2 + ++ +S +( +t +2 +) + +group + + + + + + +A +B +C +survival_plot with no wrap and no ci + + diff --git a/tests/testthat/_snaps/survival_plot/survival-plot-with-wrap-and-ci.svg b/tests/testthat/_snaps/survival_plot/survival-plot-with-wrap-and-ci.svg new file mode 100644 index 00000000..586f17b2 --- /dev/null +++ b/tests/testthat/_snaps/survival_plot/survival-plot-with-wrap-and-ci.svg @@ -0,0 +1,209 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +A + + + + + + + + + + +B + + + + + + + + + + +C + + + + + + + +0 +100 +200 +300 +400 + + + + + +0 +100 +200 +300 +400 + + + + + +0 +100 +200 +300 +400 +0.00 +0.25 +0.50 +0.75 +1.00 + + + + + +t +h +d +3 +1 +2 + ++ +S +( +t +2 +) +survival_plot with wrap and ci + + diff --git a/tests/testthat/test-DataJoint.R b/tests/testthat/test-DataJoint.R index 40118b8e..0951a94c 100644 --- a/tests/testthat/test-DataJoint.R +++ b/tests/testthat/test-DataJoint.R @@ -80,3 +80,134 @@ test_that("DataJoint errors if subjects don't allign after", { regexp = "subjects in the longitudinal" ) }) + + + + + +test_that("subset(DataJoint) works as expected", { + + dat <- dplyr::tribble( + ~patient, ~time, ~event, + "a", 1, 1, + "b", 1, 0, + "c", 2, 1, + "d", 2, 0, + "e", 3, 1, + "f", 3, 0, + "g", 4, 1 + ) + pts <- list( + "g1" = c("a", "e", "f"), + "g2" = c("b", "c"), + "g3" = "d" + ) + expected <- dplyr::tribble( + ~patient, ~time, ~event, ~group, + "a", 1, 1, "g1", + "e", 3, 1, "g1", + "f", 3, 0, "g1", + "b", 1, 0, "g2", + "c", 2, 1, "g2", + "d", 2, 0, "g3" + ) + expect_equal( + subset_and_add_grouping(dat, pts), + expected + ) + + pts <- c("b", "d", "a") + expected <- dplyr::tribble( + ~patient, ~time, ~event, ~group, + "b", 1, 0, "b", + "d", 2, 0, "d", + "a", 1, 1, "a" + ) + expect_equal( + subset_and_add_grouping(dat, pts), + expected + ) + + # Duplicate values between groups should be ok + pts <- list( + "g1" = c("a", "b", "c"), + "g2" = c("a", "b", "c") + ) + expected <- dplyr::tribble( + ~patient, ~time, ~event, ~group, + "a", 1, 1, "g1", + "b", 1, 0, "g1", + "c", 2, 1, "g1", + "a", 1, 1, "g2", + "b", 1, 0, "g2", + "c", 2, 1, "g2" + ) + expect_equal( + subset_and_add_grouping(dat, pts), + expected + ) + + + # Should error for patients that dosn't exist in vector mode + pts <- c("b", "d", "a", "z") + expect_error( + subset_and_add_grouping(dat, pts), + regexp = "`patients`" + ) + # Should error for patients that don't exist in list mode + pts <- list("g1" = c("a", "z", "b")) + expect_error( + subset_and_add_grouping(dat, pts), + regexp = "`patients`" + ) + # Should error if we ask for the same patient multiple times + pts <- c("b", "d", "a", "a") + expect_error( + subset_and_add_grouping(dat, pts), + regexp = "`patients`" + ) + + + df_surv <- data.frame( + vpt = c("A", "B", "C"), + vtime = c(100, 200, 150), + vevent = c(0, 1, 1), + vct = c(5, 2, 4), + varm = c("A1", "A1", "A2"), + vstudy = c("S1", "S1", "S2") + ) + + df_long <- data.frame( + vpt = c("A", "A", "B", "B", "C", "B"), + vtime = c(10, 10, 20, 30, 40, 50), + vout = c(1, 2, 3, 4, 5, 6) + ) + + d_joint <- DataJoint( + survival = DataSurvival( + data = df_surv, + formula = Surv(vtime, vevent) ~ vct, + subject = "vpt", + arm = "varm", + study = "vstudy" + ), + longitudinal = DataLongitudinal( + data = df_long, + formula = vout ~ vtime, + subject = "vpt" + ) + ) + + expected <- data.frame( + time = c(150, 100), + event = c(1, 0), + patient = c("C", "A"), + group = c("C", "A"), + row.names = NULL + ) + expect_equal( + subset(d_joint, c("C", "A")), + expected + ) + +}) diff --git a/tests/testthat/test-JointModelSamples.R b/tests/testthat/test-JointModelSamples.R index 06661d47..3832768b 100644 --- a/tests/testthat/test-JointModelSamples.R +++ b/tests/testthat/test-JointModelSamples.R @@ -7,7 +7,7 @@ test_that("longitudinal works as expected for JointModelSamples", { expect_s4_class(mcmc_results, "JointModelSamples") result <- longitudinal(mcmc_results) expect_s4_class(result, "LongitudinalSamples") - expect_length(result, mcmc_results@data$Nind) + expect_length(result, as.list(mcmc_results@data)$Nind) expect_true(is.list(result)) one_result <- result[[3]] expect_type(one_result, "list") @@ -33,37 +33,3 @@ test_that("longitudinal allows to subset patients and times", { expect_identical(dim(one_result$samples), c(100L, length(time_grid))) expect_identical(nrow(one_result$summary), length(time_grid)) }) - - -# survival ---- - -test_that("survival works as expected for JointModelSamples", { - expect_s4_class(mcmc_results, "JointModelSamples") - result <- survival(mcmc_results) - expect_s4_class(result, "SurvivalSamples") - expect_length(result, mcmc_results@data$Nind) - expect_true(is.list(result)) - one_result <- result[[3]] - expect_type(one_result, "list") - expect_type(one_result$samples, "double") - expect_identical(dim(one_result$samples), c(100L, 201L)) - expect_s3_class(one_result$observed, "data.frame") - expect_identical(colnames(one_result$observed), c("t", "death", "median", "lower", "upper")) - expect_s3_class(one_result$summary, "data.frame") - expect_identical(colnames(one_result$summary), c("time", "median", "lower", "upper")) - expect_identical(nrow(one_result$summary), 201L) -}) - - -test_that("survival allows to subset patients", { - patients <- c("pt_00001", "pt_00005", "pt_00010", "pt_00022") - time_grid <- c(1, 40, 100) - result <- survival(mcmc_results, patients = patients, time_grid = time_grid) - expect_s4_class(result, "SurvivalSamples") - expect_length(result, length(patients)) - expect_identical(names(result), patients) - - one_result <- result[[3]] - expect_identical(dim(one_result$samples), c(100L, length(time_grid))) - expect_identical(nrow(one_result$summary), length(time_grid)) -}) diff --git a/tests/testthat/test-SurvivalSamples.R b/tests/testthat/test-SurvivalSamples.R index 07c36f02..e2764f53 100644 --- a/tests/testthat/test-SurvivalSamples.R +++ b/tests/testthat/test-SurvivalSamples.R @@ -1,127 +1,294 @@ -mcmc_results <- get_mcmc_results() -# constructor ---- +# Note that these are more just "smoke tests" e.g. we are looking for obvious signs +# that something has gone wrong. Given the dependence on complex objects generated +# by MCMC sampling this is hard to test deterministically. +# That being said all the individual components that comprise the function have been +# individually tested so this is regarded as being sufficient +test_that("smoke test for predict(SurvivalSamples) and autoplot(SurvivalSamples)", { + set.seed(739) + jlist <- simulate_joint_data( + n = c(250, 150), + times = 1:2000, + lambda_cen = 1 / 9000, + lm_fun = sim_lm_random_slope( + intercept = 30, + sigma = 3, + slope_mu = c(1, 3), + slope_sigma = 0.2, + phi = 0 + ), + os_fun = sim_os_exponential(1 / 100), + .debug = TRUE, + .silent = TRUE + ) -test_that("SurvivalSamples can be initialized", { - x <- list(pt_00001 = 5, pt_00002 = 10) - result <- .SurvivalSamples(x) - expect_s4_class(result, "SurvivalSamples") - expect_identical(names(result), names(x)) -}) + dat_os <- jlist$os + dat_lm <- jlist$lm |> + dplyr::filter(time %in% c(0, 1, 100, 200, 250, 300, 350)) |> + dplyr::arrange(pt, time) -# subset ---- -test_that("subsetting works as expected for SurvivalSamples", { - object <- .SurvivalSamples( - list(pt_00001 = 5, pt_00002 = 10) + jm <- JointModel( + longitudinal = LongitudinalRandomSlope( + intercept = prior_normal(30, 2), + slope_sigma = prior_lognormal(log(0.2), sigma = 0.5), + sigma = prior_lognormal(log(3), sigma = 0.5) + ), + survival = SurvivalExponential( + lambda = prior_lognormal(log(1 / 100), 1 / 100) + ) ) - result <- object["pt_00001"] - expect_s4_class(result, "SurvivalSamples") - expect_length(result, 1L) - expect_identical(names(result), "pt_00001") -}) -# aggregate ---- - -test_that("aggregate works as expected for SurvivalSamples", { - x <- .SurvivalSamples( - list( - id1 = list( - samples = matrix(1:4, 2, 2), - summary = data.frame(time = 1:2, median = 0:1, lower = -1:0, upper = 1:2), - observed = data.frame(t = 5, death = TRUE, median = 0.8, lower = 0.5, upper = 1.2) - ), - id2 = list( - samples = matrix(2:5, 2, 2), - summary = data.frame(time = 1:2, median = 0:1, lower = -1:0, upper = 1:2), - observed = data.frame(t = 2, death = FALSE, median = 0.2, lower = 0.3, upper = 0.9) - ), - id3 = list( - samples = matrix(3:6, 2, 2), - summary = data.frame(time = 1:2, median = 0:1, lower = -1:0, upper = 1:2), - observed = data.frame(t = 1, death = TRUE, median = 0.1, lower = -0.1, upper = 2) - ) + jdat <- DataJoint( + survival = DataSurvival( + data = dat_os, + formula = Surv(time, event) ~ cov_cat + cov_cont, + subject = "pt", + arm = "arm", + study = "study" + ), + longitudinal = DataLongitudinal( + data = dat_lm, + formula = sld ~ time, + subject = "pt", + threshold = 5 ) ) - result <- aggregate(x, groups = list(a = c("id3", "id1"), b = c("id1", "id2"))) - expect_s4_class(result, "SurvivalSamples") - expect_identical(names(result), c("a", "b")) - expect_identical(rownames(result[["a"]]$observed), c("id3", "id1")) - expect_identical(rownames(result[["b"]]$observed), c("id1", "id2")) -}) + mp <- sampleStanModel( + jm, + data = jdat, + iter_sampling = 100, + iter_warmup = 150, + chains = 1, + refresh = 0, + parallel_chains = 1 + ) + + survsamps <- SurvivalSamples(mp) + + + ##### Section for predict,SurvivalSamples + + expected_column_names <- c("median", "lower", "upper", "time", "group", "type") + + preds <- predict(survsamps, list("a" = c("pt_00001", "pt_00002")), c(10, 20, 200, 300)) + expect_equal(nrow(preds), 4) + expect_equal(length(unique(preds$group)), 1) + expect_equal(names(preds), expected_column_names) + expect_equal(unique(preds$type), "surv") + + + preds <- predict(survsamps, time_grid = c(10, 20, 200, 300)) + expect_equal(nrow(preds), 4 * nrow(dat_os)) # 4 timepoints for each subject in the OS dataset + expect_equal(names(preds), expected_column_names) + expect_equal(unique(preds$group), dat_os$pt) + -# autoplot ---- + preds <- predict(survsamps, c("pt_00001", "pt_00003")) + expect_equal(nrow(preds), 2 * 201) # 201 default time points for 2 subjects + expect_equal(names(preds), expected_column_names) -test_that("autoplot works as expected for SurvivalSamples", { - object <- survival(mcmc_results, patients = c("pt_00001", "pt_00022")) - result <- expect_silent(autoplot(object)) - data_layer1 <- layer_data(result) - expect_s3_class(data_layer1, "data.frame") - expect_identical( - names(data_layer1), - c("x", "y", "PANEL", "group", "flipped_aes", "colour", "linewidth", "linetype", "alpha") + preds1 <- predict(survsamps, "pt_00001", c(200, 300)) + preds2 <- predict(survsamps, "pt_00001", c(200, 300), type = "cumhaz") + preds3 <- predict(survsamps, "pt_00001", c(200, 300), type = "haz") + preds4 <- predict(survsamps, "pt_00001", c(200, 300), type = "loghaz") + expect_equal(unique(preds1$type), "surv") + expect_equal(unique(preds2$type), "cumhaz") + expect_equal(unique(preds3$type), "haz") + expect_equal(unique(preds4$type), "loghaz") + expect_equal(round(preds1$median, 5), round(exp(-preds2$median), 5)) + expect_equal(round(preds3$median, 5), round(exp(preds4$median), 5)) + expect_true(all(preds1$median != preds3$median)) + + + ##### Section for autoplot,SurvivalSamples + p <- autoplot( + survsamps, + add_wrap = FALSE, + add_ci = FALSE, + c("pt_00011", "pt_00061") + ) + + dat <- predict(survsamps, c("pt_00011", "pt_00061")) + + expect_length(p$layers, 1) + expect_equal(p$labels$y, expression(S(t))) + expect_true(inherits(p$facet, "FacetNull")) + expect_equal(dat, p$layers[[1]]$data) + expect_true(inherits(p$layers[[1]]$geom, "GeomLine")) + + p <- autoplot( + survsamps, + add_wrap = TRUE, + patients = list("a" = c("pt_00011", "pt_00061"), "b" = c("pt_00001", "pt_00002")), + time_grid = c(10, 20, 50, 200), + type = "loghaz" + ) + + dat <- predict( + survsamps, + patients = list("a" = c("pt_00011", "pt_00061"), "b" = c("pt_00001", "pt_00002")), + time_grid = c(10, 20, 50, 200), + type = "loghaz" ) - expect_identical( - length(data_layer1$x), - 201L * 2L + + expect_length(p$layers, 2) + expect_equal(p$labels$y, expression(log(h(t)))) + expect_true(inherits(p$facet, "FacetWrap")) + expect_equal(dat, p$layers[[1]]$data) + expect_true(inherits(p$layers[[1]]$geom, "GeomLine")) + expect_true(inherits(p$layers[[2]]$geom, "GeomRibbon")) + + + set.seed(39130) + ptgroups <- list( + gtpt1 = sample(dat_os$pt, 20), + gtpt2 = sample(dat_os$pt, 20), + gtpt3 = sample(dat_os$pt, 20) ) - expect_identical( - length(unique(data_layer1$x)), - 201L + times <- seq(0, 100, by = 10) + + p <- autoplot( + survsamps, + add_wrap = FALSE, + add_ci = FALSE, + add_km = TRUE, + patients = ptgroups, + time_grid = times, + type = "surv" ) - expect_identical( - data_layer1$y, - c(object[[1]]$summary$median, object[[2]]$summary$median) + + dat <- predict( + survsamps, + patients = ptgroups, + time_grid = times, + type = "surv" ) + expect_length(p$layers, 3) + expect_equal(p$labels$y, expression(S(t))) + expect_true(inherits(p$facet, "FacetNull")) + expect_equal(dat, p$layers[[1]]$data) + expect_true(inherits(p$layers[[1]]$geom, "GeomLine")) + expect_true(inherits(p$layers[[2]]$geom, "GeomKm")) + expect_true(inherits(p$layers[[3]]$geom, "GeomKmTicks")) - data_layer2 <- layer_data(result, i = 2) - expect_s3_class(data_layer2, "data.frame") - expect_identical( - names(data_layer2), - c( - "x", "ymin", "ymax", "PANEL", "group", "flipped_aes", "y", - "colour", "fill", "linewidth", "linetype", "alpha" +}) + + + + +test_that("summarise_by_group() works as expected", { + get_ci_summary <- function(vec) { + if (inherits(vec, "matrix")) { + vec <- rowMeans(vec) + } + x <- data.frame( + median = median(vec), + lower = quantile(vec, 0.025), + upper = quantile(vec, 0.975) ) + rownames(x) <- NULL + x + } + + x <- matrix( + c( + 1, 10, 1, 3, 9, 23, + 2, 20, 1, 3, 10, 23, + 3, 30, 1, 3, 9, 23, + 4, 40, 2, 4, 12, 23, + 5, 50, 2, 4, 14, 23, + 6, 60, 2, 4, 10, 23 + ), + byrow = TRUE, + ncol = 6 ) - expect_identical( - length(data_layer2$x), - 201L * 2L + colnames(x) <- sprintf( + "quantity[%i,%i]", + # 1 2 3 4 5 6 # Column index + c(1, 1, 2, 2, 3, 3), # Subject IDs + c(4, 5, 4, 5, 4, 5) # Time point IDs ) - expect_identical( - length(unique(data_layer2$x)), - 201L + draws_x <- posterior::as_draws_matrix(x) + + + ## 1 subject 1 timepoint + actual <- summarise_by_group( + subject_index = 1, + time_index = 4, + quantities = draws_x ) - expect_identical( - data_layer2$ymin, - c(object[[1]]$summary$lower, object[[2]]$summary$lower) + expect_equal( + actual, + get_ci_summary(x[, 1]) + ) + + actual <- summarise_by_group( + subject_index = 1, + time_index = 5, + quantities = draws_x ) - expect_identical( - data_layer2$ymax, - c(object[[1]]$summary$upper, object[[2]]$summary$upper) + expect_equal( + actual, + get_ci_summary(x[, 2]) ) - data_layer3 <- layer_data(result, i = 3) - expect_s3_class(data_layer3, "data.frame") - expect_identical( - names(data_layer3), - c("x", "y", "time", "survival", "status", "PANEL", "group", "colour", - "fill", "linewidth", "linetype", "weight", "alpha") + + ## Select multiple subjects to collapse into a single "aggregate subject" + ## at a single timepoint + actual <- summarise_by_group( + subject_index = c(1, 2), + time_index = 5, + quantities = draws_x + ) + expect_equal( + actual, + get_ci_summary(rowMeans(x[, c(2, 4)])) ) -}) -test_that("autoplot does not show the Kaplan-Meier plot if disabled", { - object <- survival(mcmc_results, patients = c("pt_00001", "pt_00022")) - result <- expect_silent(autoplot(object, add_km = FALSE)) - # Only 2 layers here, i.e. no Kaplan-Meier plot. - expect_identical(length(result$layers), 2L) -}) -test_that("autoplot works end to end with Kaplan-Meier plot", { - object <- survival(mcmc_results, patients = c("pt_00001", "pt_00022")) - result <- expect_silent(autoplot(object, add_km = TRUE)) - # 4 layers here, i.e. including Kaplan-Meier plot line and ticks. - expect_identical(length(result$layers), 4L) - # TODO - Need to rework when updating plotting functions - ## vdiffr::expect_d oppelganger("SurvivalSamples autoplot with KM", result) + ## 1 subject at multiple time points + actual <- summarise_by_group( + subject_index = c(3), + time_index = c(4, 5), + quantities = draws_x + ) + expect_equal( + actual, + dplyr::bind_rows( + get_ci_summary(x[, 5]), + get_ci_summary(x[, 6]) + ) + ) + + + ## Selecting multiple subjects to collapse into a single "agregate subject" + ## at multiple timepoints + actual <- summarise_by_group( + subject_index = c(1, 3), + time_index = c(4, 5), + quantities = draws_x + ) + expect_equal( + actual, + dplyr::bind_rows( + get_ci_summary(x[, c(1, 5)]), + get_ci_summary(x[, c(2, 6)]) + ) + ) + + ## Can select the same subject multiple times + actual <- summarise_by_group( + subject_index = c(3, 3, 3, 2), + time_index = c(4, 5), + quantities = draws_x + ) + expect_equal( + actual, + dplyr::bind_rows( + get_ci_summary(x[, c(5, 5, 5, 3)]), + get_ci_summary(x[, c(6, 6, 6, 4)]) + ) + ) }) diff --git a/tests/testthat/test-extract_survival_quantities.R b/tests/testthat/test-extract_survival_quantities.R new file mode 100644 index 00000000..594df990 --- /dev/null +++ b/tests/testthat/test-extract_survival_quantities.R @@ -0,0 +1,69 @@ + + + + + + +test_that("extract_survival_quantities() works as expected", { + log_surv <- log(c(0.1, 0.3, 0.5, 0.7, 0.2)) + log_haz <- log(c(10, 11, 4, 2.1, 9, 3)) + + + nullmodel <- StanModule(test_path("models", "null_model_insert.stan")) + stan_code <- sprintf(" +generated quantities { + vector[%i] log_surv_fit_at_time_grid; + vector[%i] log_haz_fit_at_time_grid; + log_surv_fit_at_time_grid = to_vector({%s}); + log_haz_fit_at_time_grid = to_vector({%s}); +}", + length(log_surv), + length(log_haz), + paste0(log_surv, collapse = ","), + paste0(log_haz, collapse = ",") + ) + + gq_code <- StanModule(stan_code) + + model_code <- merge( + gq_code, + nullmodel + ) + + mod <- cmdstanr::cmdstan_model( + stan_file = cmdstanr::write_stan_file( + as.character(model_code), + dir = CACHE_DIR, + basename = "test_extract_gq.stan" + ), + exe_file = file.path( + CACHE_DIR, + paste0("test_extract_gq") + ) + ) + + ## Mock data just to get function to run + stan_data <- list(null_y_data = c(0, 1)) + stan_fitted <- posterior::as_draws_matrix(list(null_y_mean = c(1))) + + gq <- mod$generate_quantities( + data = stan_data, + fitted_params = stan_fitted + ) + + + run_test <- function(vals, keyword) { + expected <- posterior::as_draws_matrix(matrix(vals, nrow = 1)) + colnames(expected) <- sprintf("quantity[%i]", seq_along(vals)) + expect_equal( + round(extract_survival_quantities(gq, type = keyword), 3), + round(expected, 3) + ) + } + + run_test(exp(log_surv), "surv") + run_test(-log_surv, "cumhaz") + run_test(exp(log_haz), "haz") + run_test(log_haz, "loghaz") + +}) diff --git a/tests/testthat/test-survival_plot.R b/tests/testthat/test-survival_plot.R new file mode 100644 index 00000000..8b9c2882 --- /dev/null +++ b/tests/testthat/test-survival_plot.R @@ -0,0 +1,101 @@ + + +test_that("survival_plot works as expected", { + set.seed(38132) + define_data <- function(i, group) { + n <- 120 + dplyr::tibble( + e_time = rexp(n, 1 / i), + c_time = rexp(n, 1 / 200), + event = ifelse(e_time <= c_time, 1, 0), + time = ifelse(e_time <= c_time, e_time, c_time), + group = group + )|> dplyr::select(time, event, group) + } + + dat <- dplyr::bind_rows( + define_data(100, "A"), + define_data(75, "B"), + define_data(50, "C") + ) + + mod <- survival::survreg( + survival::Surv(time, event) ~ group, + dist = "exponential", + data = dat + ) + + preds <- predict( + mod, + type = "response", + se.fit = TRUE, + newdata = data.frame(group = c("A", "B", "C")) + ) + + med <- preds$fit + lci <- preds$fit - preds$se.fit * 1.96 + uci <- preds$fit + preds$se.fit * 1.96 + + times <- seq(0, 400, by = 30) + + get_data <- function(i, group) { + data.frame( + time = times, + median = pexp(times, 1 / med[i], lower.tail = FALSE), + lower = pexp(times, 1 / lci[i], lower.tail = FALSE), + upper = pexp(times, 1 / uci[i], lower.tail = FALSE), + group = group + ) + } + + res <- dplyr::bind_rows( + get_data(1, "A"), + get_data(2, "B"), + get_data(3, "C") + ) + + p1 <- survival_plot( + res, + add_ci = TRUE, + add_wrap = TRUE, + kmdf = NULL, + y_label = expression(frac(1, 2) + S(t^2)), + x_label = expression(thd[3]) + ) + + vdiffr::expect_doppelganger( + "survival_plot with wrap and ci", + p1 + ) + + p2 <- survival_plot( + res, + add_ci = FALSE, + add_wrap = FALSE, + kmdf = NULL, + y_label = expression(frac(1, 2) + S(t^2)), + x_label = expression(thd[3]) + ) + + vdiffr::expect_doppelganger( + "survival_plot with no wrap and no ci", + p2 + ) + + + p3 <- survival_plot( + res, + add_ci = FALSE, + add_wrap = FALSE, + kmdf = dat, + y_label = expression(frac(1, 2) + S(t^2)), + x_label = expression(thd[3]) + ) + + theme(legend.position = "bottom") + + scale_y_continuous(trans = "sqrt") + + vdiffr::expect_doppelganger( + "survival_plot with no wrap and no ci + km + ggplot2 integration", + p3 + ) +}) diff --git a/tests/testthat/test-utilities.R b/tests/testthat/test-utilities.R index 610dd238..525dc78d 100644 --- a/tests/testthat/test-utilities.R +++ b/tests/testthat/test-utilities.R @@ -217,3 +217,73 @@ test_that("expand_patients() works as expected", { regex = "`patients`" ) }) + + +test_that("decompose_patients() works as expected", { + + # Basic vector format + actual <- decompose_patients(c("a", "b", "d"), c("a", "b", "c", "d")) + expected <- list( + groups = list( + "a" = "a", + "b" = "b", + "d" = "d" + ), + unique_values = c("a", "b", "d"), + indexes = list( + "a" = 1, + "b" = 2, + "d" = 3 + ) + ) + expect_equal(actual, expected) + + + + # list format + actual <- decompose_patients( + list("g1" = c("b", "a"), "g2" = c("a", "d")), + c("a", "b", "c", "d") + ) + expected <- list( + groups = list( + "g1" = c("b", "a"), + "g2" = c("a", "d") + ), + unique_values = c("a", "b", "d"), + indexes = list( + "g1" = c(2, 1), + "g2" = c(1, 3) + ) + ) + expect_equal(actual, expected) + + + # NULL is correctly expanded + actual <- decompose_patients( + NULL, + c("a", "d", "c", "b", "b", "b", "a") + ) + expected <- list( + groups = list( + "a" = "a", "d" = "d", "c" = "c", "b" = "b" + ), + unique_values = c("a", "b", "c", "d"), + indexes = list( + "a" = 1, "d" = 4, "c" = 3, "b" = 2 + ) + ) + expect_equal(actual, expected) + + # errors if patient doesn't exist + expect_error( + decompose_patients("e", c("a", "d", "c", "b", "b")), + regexp = "`patients`" + ) + # errors if group has same patient twice + expect_error( + decompose_patients(list("g1" = c("a", "a")), c("a", "d", "c", "b", "b")), + regexp = "`patients`" + ) + +}) diff --git a/vignettes/model_fitting.Rmd b/vignettes/model_fitting.Rmd index db7f74cc..63b1843a 100644 --- a/vignettes/model_fitting.Rmd +++ b/vignettes/model_fitting.Rmd @@ -329,8 +329,8 @@ And using the `survival()` method we can do the same for the estimated survival functions. We can do this for single patients: ```{r fig.width=6} -surv_samples <- survival(mcmc_results, patients = selected_patients) -autoplot(surv_samples, add_km = FALSE) +surv_samples <- SurvivalSamples(mcmc_results) +autoplot(surv_samples, add_km = FALSE, patients = selected_patients) ``` We can also aggregate the estimated survival curves from groups of patients, @@ -338,8 +338,7 @@ using the corresponding method. ```{r, warning=FALSE, fig.width=6} groups <- split(os_data$pt, os_data$arm) -surv_grouped_samples <- aggregate(survival(mcmc_results), groups = groups) -autoplot(surv_grouped_samples, add_km = TRUE) +autoplot(surv_samples, add_km = TRUE, patients = groups) ```