Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Survival Plots #208

Merged
merged 20 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ jobs:
CMDSTAN=/root/.cmdstan
CMDSTAN_PATH=/root/.cmdstan
CMDSTANR_NO_VER_CHECK=true
linter:
if: github.event_name == 'pull_request'
name: SuperLinter 🦸‍♀️
uses: insightsengineering/r.pkg.template/.github/workflows/linter.yaml@main
### Disabled pending https://github.com/insightsengineering/idr-tasks/issues/667
# linter:
# if: github.event_name == 'pull_request'
# name: SuperLinter 🦸‍♀️
# uses: insightsengineering/r.pkg.template/.github/workflows/linter.yaml@main
roxygen:
name: Roxygen 🅾
uses: insightsengineering/r.pkg.template/.github/workflows/roxygen.yaml@main
Expand Down
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ linters: linters_with_defaults(
line_length_linter(120),
object_name_linter = NULL,
object_usage_linter = NULL,
cyclocomp_linter = NULL
cyclocomp_linter = NULL,
indentation_linter(4)
)
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Imports:
tibble,
methods,
digest,
posterior,
stats
Suggests:
bayesplot,
Expand Down
8 changes: 3 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ export(Surv)
export(SurvivalExponential)
export(SurvivalLogLogistic)
export(SurvivalModel)
export(SurvivalSamples)
export(SurvivalWeibullPH)
export(addLink)
export(aggregate)
export(as.data.frame)
export(as.list)
export(autoplot)
Expand All @@ -37,6 +37,7 @@ export(link_gsf_identity)
export(link_gsf_ttg)
export(longitudinal)
export(merge)
export(predict)
export(prior_beta)
export(prior_cauchy)
export(prior_gamma)
Expand All @@ -46,15 +47,14 @@ export(prior_normal)
export(prior_std_normal)
export(read_stan)
export(sampleStanModel)
export(samples_median_ci)
export(show)
export(sim_lm_gsf)
export(sim_lm_random_slope)
export(sim_os_exponential)
export(sim_os_loglogistic)
export(sim_os_weibull)
export(simulate_joint_data)
export(survival)
export(subset)
export(write_stan)
exportClasses(DataJoint)
exportClasses(DataLongitudinal)
Expand All @@ -76,7 +76,6 @@ exportClasses(StanModule)
exportClasses(SurvivalExponential)
exportClasses(SurvivalLogLogistic)
exportClasses(SurvivalModel)
exportClasses(SurvivalSamples)
exportClasses(SurvivalWeibullPH)
exportClasses(link_gsf_abstract)
exportClasses(link_gsf_dsld)
Expand All @@ -87,7 +86,6 @@ exportMethods(as.character)
exportMethods(generateQuantities)
exportMethods(longitudinal)
exportMethods(names)
exportMethods(survival)
import(assertthat)
import(ggplot2)
import(methods)
Expand Down
104 changes: 93 additions & 11 deletions R/DataJoint.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ NULL
#' @slot survival (`DataSurvival`)\cr object created by [DataSurvival()].
#' @slot longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()].
#'
#' @name DataJoint
#' @aliases DataJoint-class
#'
#' @param survival (`DataSurvival`)\cr object created by [DataSurvival()].
#' @param longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()].
#' @details
#'
#' - `as.list(x)`, `as(x, "list")`: Coerces x into a list of data components required
#' for fitting a [JointModel()]. See the vignette (TODO) for more details.
#'
#' @family DataJoint
#'
#' @export DataJoint
#' @exportClass DataJoint
.DataJoint <- setClass(
Class = "DataJoint",
Expand All @@ -23,17 +36,7 @@ NULL

# DataJoint-constructors ----

#' @rdname DataJoint-class
#'
#' @param survival (`DataSurvival`)\cr object created by [DataSurvival()].
#' @param longitudinal (`DataLongitudinal`)\cr object created by [DataLongitudinal()].
#'
#' @details
#'
#' - `as.list(x)`, `as(x, "list")`: Coerces x into a list of data components required
#' for fitting a [JointModel()]. See the vignette (TODO) for more details.
#'
#' @export
#' @rdname DataJoint
DataJoint <- function(survival, longitudinal) {
.DataJoint(
survival = survival,
Expand Down Expand Up @@ -76,6 +79,7 @@ setMethod(

# coerce-DataJoint,list ----

#' @param x (`DataJoint`) \cr A [DataJoint][DataJoint-class] object created by [DataJoint()]
#' @rdname as.list
#'
#' @name coerce-DataJoint-list-method
Expand All @@ -85,3 +89,81 @@ setAs(
to = "list",
def = function(from) as.list(from)
)


#' Subsetting `DataJoint` as a `data.frame`
#'
#' @param x (`DataJoint`) \cr A [DataJoint][DataJoint-class] object created by [DataJoint()]
#' @param patients (`character` or `list`)\cr the patients that you wish to subset the `data.frame`
#' to contain. See details.
#'
#' @description
#'
#' Coerces the object into a `data.frame` containing just event times and status
#' filtering for specific patients. If `patients` is a list then an additional variable `group` will be added
#' onto the dataset specifying which group the row belongs to.
#'
#' @examples
#' \dontrun{
#' pts <- c("PT1", "PT3", "PT4")
#' subset(x, pts)
#'
#' groups <- list(
#' "g1" = c("PT1", "PT3", "PT4"),
#' "g2" = c("PT2", "PT3")
#' )
#' subset(x, groups)
#' }
#' @family DataJoint
#' @family subset
setMethod(
f = "subset",
signature = "DataJoint",
definition = function(x, patients) {
data <- as.list(x)
dat <- data.frame(
time = data[["Times"]],
event = as.numeric(seq_along(data[["Times"]]) %in% data[["dead_ind_index"]]),
patient = names(data[["pt_to_ind"]])
)
subset_and_add_grouping(dat, patients)
}
)


#' `subset_and_add_grouping`
#'
#' @param dat (`data.frame`) \cr Must have a column called `patient` which corresponds to the
#' values passed to `groupings`
#' @param groupings (`character` or `list`)\cr the patients that you wish to subset the dataset
#' to contain. If `groupings` is a list then an additional variable `group` will be added
#' onto the dataset specifying which group the row belongs to.
#'
#' @details
#' Example of usage
#' ```
#' pts <- c("PT1", "PT3", "PT4")
#' subset_and_add_grouping(dat, pts)
#'
#' groups <- list(
#' "g1" = c("PT1", "PT3", "PT4"),
#' "g2" = c("PT2", "PT3")
#' )
#' subset_and_add_grouping(dat, groups)
#' ```
#'
#' @keywords internal
subset_and_add_grouping <- function(dat, groupings) {
groupings <- decompose_patients(groupings, dat$patient)$groups
dat_subset_list <- lapply(
seq_along(groupings),
\(i) {
dat_reduced <- dat[dat$patient %in% groupings[[i]], , drop = FALSE]
dat_reduced[["group"]] <- names(groupings)[[i]]
dat_reduced
}
)
x <- Reduce(rbind, dat_subset_list)
row.names(x) <- NULL
x
}
1 change: 0 additions & 1 deletion R/DataLongitudinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ setMethod(
) |>
t()

# TODO - Maybe reimplement this using a more robust approach than magic number
adj_threshold <- if (is.null(vars$threshold)) {
-999999
} else {
Expand Down
12 changes: 2 additions & 10 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,7 @@ setMethod(
definition = function(object, data, ...) {

args <- list(...)

if (is(data, "DataJoint")) {
args[["data"]] <- as.list(data)
} else if (is(data, "list")) {
args[["data"]] <- data
} else {
stop("`data` must either be a list or a DataJoint object")
}
args[["data"]] <- as.list(data)

if (!"init" %in% names(args)) {
values_initial <- initialValues(object)
Expand All @@ -135,8 +128,7 @@ setMethod(

.JointModelSamples(
model = object,
data = args$data,
init = values_initial_expanded,
data = data,
results = results
)
}
Expand Down
76 changes: 7 additions & 69 deletions R/JointModelSamples.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#' @include JointModel.R
NULL

setOldClass("CmdStanMCMC")

# JointModelSamples-class ----

#' `JointModelSamples`
Expand All @@ -9,7 +11,6 @@ NULL
#'
#' @slot model (`JointModel`)\cr the original model.
#' @slot data (`list`)\cr data input.
#' @slot init (`list`)\cr initial values.
#' @slot results (`CmdStanMCMC`)\cr the results from [sampleStanModel()].
#'
#' @aliases JointModelSamples
Expand All @@ -18,9 +19,8 @@ NULL
"JointModelSamples",
slots = c(
model = "JointModel",
data = "list",
init = "list",
results = "ANY"
data = "DataJoint",
results = "CmdStanMCMC"
)
)

Expand All @@ -39,7 +39,7 @@ setMethod(
f = "generateQuantities",
signature = c(object = "JointModelSamples"),
definition = function(object, patients, time_grid_lm, time_grid_sm, ...) {
data <- object@data
data <- as.list(object@data)
data[["n_lm_time_grid"]] <- length(time_grid_lm)
data[["lm_time_grid"]] <- time_grid_lm
data[["n_sm_time_grid"]] <- length(time_grid_sm)
Expand Down Expand Up @@ -83,9 +83,9 @@ setMethod(
signature = c(object = "JointModelSamples"),
definition = function(object, patients = NULL, time_grid = NULL, ...) {

data <- object@data
data <- as.list(object@data)
time_grid <- expand_time_grid(time_grid, max(data[["Tobs"]]))
patients <- expand_patients(patients, names(object@data$pt_to_ind))
patients <- expand_patients(patients, names(data$pt_to_ind))
gq <- generateQuantities(
object,
patients = patients,
Expand Down Expand Up @@ -122,65 +122,3 @@ setMethod(
.LongitudinalSamples(results)
}
)


# survival-JointModelSamples ----

#' @rdname survival
#'
#' @param patients (`character` or `NULL`)\cr optional subset of patients for
#' which the survival function samples should be extracted, the default `NULL`
#' meaning all patients.
#'
#' @param time_grid (`numeric`)\cr grid of time points to use for providing samples
#' of the survival model fit functions. If `NULL`, will be taken as a sequence of
#' 201 values from 0 to the maximum observed event time.
#'
#' @export
setMethod(
f = "survival",
signature = c(object = "JointModelSamples"),
definition = function(object, patients = NULL, time_grid = NULL, ...) {

data <- object@data
time_grid <- expand_time_grid(time_grid, max(data[["Times"]]))
patients <- expand_patients(patients, names(object@data$pt_to_ind))
gq <- generateQuantities(
object,
patients = patients,
time_grid_lm = numeric(0),
time_grid_sm = time_grid
)

log_surv_at_grid_samples <- gq$draws(format = "draws_matrix")
log_surv_at_obs_samples <- object@results$draws(
"log_surv_fit_at_obs_times",
format = "draws_matrix"
)

results <- list()
for (this_pt_ind in seq_along(patients)) {
this_pt <- patients[this_pt_ind]
this_result <- list()
patient_ind <- object@data$pt_to_ind[this_pt]
this_surv_fit_names <- sprintf(
"log_surv_fit_at_time_grid[%i,%i]",
this_pt_ind,
seq_along(time_grid)
)
this_result$samples <- exp(log_surv_at_grid_samples[, this_surv_fit_names, drop = FALSE])
this_result$summary <- data.frame(
time = time_grid,
samples_median_ci(this_result$samples)
)
this_result$observed <- data.frame(
t = data$Times[patient_ind],
death = (patient_ind %in% object@data$dead_ind_index),
samples_median_ci(exp(log_surv_at_obs_samples[, patient_ind, drop = FALSE]))
)
rownames(this_result$observed) <- this_pt
results[[this_pt]] <- this_result
}
.SurvivalSamples(results)
}
)
1 change: 0 additions & 1 deletion R/StanModule.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ STAN_BLOCKS <- list(
#' @keywords internal
add_missing_stan_blocks <- function(x) {
# STAN_BLOCKS is defined as a global variable in StanModule.R
# TODO - Make it an argument to the function
for (block in names(STAN_BLOCKS)) {
if (is.null(x[[block]])) {
x[[block]] <- ""
Expand Down
Loading