Skip to content

Commit

Permalink
Merge 18aaae5 into e8a56d4
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Feb 6, 2024
2 parents e8a56d4 + 18aaae5 commit abe9d91
Show file tree
Hide file tree
Showing 78 changed files with 914 additions and 359 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Collate:
'DataLongitudinal.R'
'DataSurvival.R'
'DataJoint.R'
'constants.R'
'StanModule.R'
'Prior.R'
'Parameter.R'
Expand All @@ -88,6 +89,7 @@ Collate:
'defaults.R'
'external-exports.R'
'jmpost-package.R'
'settings.R'
'simulations.R'
'simulations_gsf.R'
'simulations_os.R'
Expand Down
12 changes: 11 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ S3method(extractVariableNames,DataSurvival)
S3method(generateQuantities,JointModelSamples)
S3method(getParameters,default)
S3method(initialValues,JointModel)
S3method(initialValues,Link)
S3method(initialValues,Parameter)
S3method(initialValues,ParameterList)
S3method(initialValues,Prior)
S3method(initialValues,StanModel)
S3method(names,Parameter)
S3method(names,ParameterList)
S3method(sampleStanModel,JointModel)
Expand Down Expand Up @@ -100,6 +102,7 @@ export(generateQuantities)
export(gsf_dsld)
export(gsf_sld)
export(gsf_ttg)
export(initialValues)
export(link_gsf_abstract)
export(link_gsf_dsld)
export(link_gsf_identity)
Expand All @@ -112,7 +115,6 @@ export(prior_invgamma)
export(prior_logistic)
export(prior_loglogistic)
export(prior_lognormal)
export(prior_none)
export(prior_normal)
export(prior_std_normal)
export(prior_student_t)
Expand Down Expand Up @@ -162,5 +164,13 @@ importFrom(ggplot2,autoplot)
importFrom(ggplot2.utils,geom_km)
importFrom(glue,as_glue)
importFrom(stats,acf)
importFrom(stats,rbeta)
importFrom(stats,rcauchy)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,rlogis)
importFrom(stats,rnorm)
importFrom(stats,rt)
importFrom(stats,runif)
importFrom(survival,Surv)
importFrom(tibble,add_case)
74 changes: 64 additions & 10 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @include LongitudinalModel.R
#' @include SurvivalModel.R
#' @include Link.R
#' @include constants.R
NULL


Expand Down Expand Up @@ -136,6 +137,8 @@ compileStanModel.JointModel <- function(object) {
#' @export
sampleStanModel.JointModel <- function(object, data, ...) {

assert_class(data, "DataJoint")

if (!is.null(object@survival)) {
assert_that(
!is.null(data@survival),
Expand All @@ -150,21 +153,37 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}

args <- list(...)

args[["data"]] <- append(
as_stan_list(data),
as_stan_list(object@parameters)
)

if (!"init" %in% names(args)) {
values_initial <- initialValues(object)
values_sizes <- size(object@parameters)
values_sizes_complete <- replace_with_lookup(values_sizes, args[["data"]])
values_initial_expanded <- expand_initial_values(values_initial, values_sizes_complete)
args[["init"]] <- function() values_initial_expanded
args[["chains"]] <- if ("chains" %in% names(args)) {
args[["chains"]]
} else {
# Magic constant from R/constants.R
CMDSTAN_DEFAULT_CHAINS
}

initial_values <- if ("init" %in% names(args)) {
args[["init"]]
} else {
initialValues(object, n_chains = args[["chains"]])
}

args[["init"]] <- ensure_initial_values(
initial_values,
args[["data"]],
object@parameters
)

model <- compileStanModel(object)
results <- do.call(model$sample, args)

results <- do.call(
model$sample,
args
)

.JointModelSamples(
model = object,
Expand All @@ -174,12 +193,47 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}


# initialValues-JointModel ----
#' Ensure that initial values are correctly specified
#'
#' @param initial_values (`list`)\cr A list of lists containing the initial values
#' must be 1 list per desired chain. All elements should have identical names
#' @param data (`list`)\cr specifies the size to expand each of our initial values to be.
#' That is elements of size 1 in `initial_values` will be expanded to be the same
#' size as the corresponding element in `data` by broadcasting the value.
#' @param parameters ([`ParameterList`])\cr the parameters object
#'
#' @details
#' This function is mostly a thin wrapper around `expand_initial_values` to
#' enable easier unit testing.
#'
#' @keywords internal
ensure_initial_values <- function(initial_values, data, parameters) {
if (is.function(initial_values)) {
return(initial_values)
}

assert_class(data, "list")
assert_class(parameters, "ParameterList")
assert_class(initial_values, "list")

values_sizes <- size(parameters)
values_sizes_complete <- replace_with_lookup(
values_sizes,
data
)
lapply(
initial_values,
expand_initial_values,
sizes = values_sizes_complete
)
}



#' @rdname initialValues
#' @export
initialValues.JointModel <- function(object) {
initialValues(object@parameters)
initialValues.JointModel <- function(object, n_chains, ...) {
initialValues(object@parameters, n_chains)
}


Expand Down
5 changes: 3 additions & 2 deletions R/Link.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ setMethod(
# initialValues-Link ----

#' @rdname initialValues
initialValues.Link <- function(object) {
initialValues(object@parameters)
#' @export
initialValues.Link <- function(object, n_chains, ...) {
initialValues(object@parameters, n_chains)
}


Expand Down
6 changes: 3 additions & 3 deletions R/LinkGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ link_gsf_abstract <- function(
#'
#' @export
link_gsf_ttg <- function(
gamma = prior_normal(0, 5, init = 0)
gamma = prior_normal(0, 5)
) {
.link_gsf_ttg(
name = "TTG",
Expand Down Expand Up @@ -182,7 +182,7 @@ link_gsf_ttg <- function(
#'
#' @export
link_gsf_dsld <- function(
beta = prior_normal(0, 5, init = 0)
beta = prior_normal(0, 5)
) {
.link_gsf_dsld(
name = "dSLD",
Expand Down Expand Up @@ -215,7 +215,7 @@ link_gsf_dsld <- function(
#' @param tau (`Prior`)\cr prior for the link coefficient `tau`.
#'
#' @export
link_gsf_identity <- function(tau = prior_normal(0, 5, init = 0)) {
link_gsf_identity <- function(tau = prior_normal(0, 5)) {
.link_gsf_identity(
name = "Identity",
stan = StanModule("lm-gsf/link_identity.stan"),
Expand Down
2 changes: 1 addition & 1 deletion R/LinkRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ NULL
#'
#' @export
LinkRandomSlope <- function(
link_lm_phi = prior_normal(0.2, 0.5, init = 0.02)
link_lm_phi = prior_normal(0.2, 0.5)
) {
.LinkRandomSlope(
Link(
Expand Down
56 changes: 29 additions & 27 deletions R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,23 @@ NULL
#' @param a_phi (`Prior`)\cr for the alpha parameter for the fraction of cells that respond to treatment.
#' @param b_phi (`Prior`)\cr for the beta parameter for the fraction of cells that respond to treatment.
#'
#' @param psi_bsld (`Prior`)\cr for the baseline value random effect `psi_bsld`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_ks (`Prior`)\cr for the shrinkage rate random effect `psi_ks`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_kg (`Prior`)\cr for the growth rate random effect `psi_kg`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_phi (`Prior`)\cr for the shrinkage proportion random effect `psi_phi`. Only used in the
#' centered parameterization to set the initial value.
#'
#' @param centered (`logical`)\cr whether to use the centered parameterization.
#'
#' @export
LongitudinalGSF <- function(

mu_bsld = prior_normal(log(60), 1, init = 60),
mu_ks = prior_normal(log(0.5), 1, init = 0.5),
mu_kg = prior_normal(log(0.3), 1, init = 0.3),

omega_bsld = prior_lognormal(log(0.2), 1, init = 0.2),
omega_ks = prior_lognormal(log(0.2), 1, init = 0.2),
omega_kg = prior_lognormal(log(0.2), 1, init = 0.2),
mu_bsld = prior_normal(log(60), 1),
mu_ks = prior_normal(log(0.5), 1),
mu_kg = prior_normal(log(0.3), 1),

a_phi = prior_lognormal(log(5), 1, init = 5),
b_phi = prior_lognormal(log(5), 1, init = 5),
omega_bsld = prior_lognormal(log(0.2), 1),
omega_ks = prior_lognormal(log(0.2), 1),
omega_kg = prior_lognormal(log(0.2), 1),

sigma = prior_lognormal(log(0.1), 1, init = 0.1),
a_phi = prior_lognormal(log(5), 1),
b_phi = prior_lognormal(log(5), 1),

psi_bsld = prior_none(init = 60),
psi_ks = prior_none(init = 0.5),
psi_kg = prior_none(init = 0.5),
psi_phi = prior_none(init = 0.5),
sigma = prior_lognormal(log(0.1), 1),

centered = FALSE
) {
Expand All @@ -87,17 +73,33 @@ LongitudinalGSF <- function(

Parameter(name = "lm_gsf_a_phi", prior = a_phi, size = "n_arms"),
Parameter(name = "lm_gsf_b_phi", prior = b_phi, size = "n_arms"),
Parameter(name = "lm_gsf_psi_phi", prior = psi_phi, size = "Nind"),
Parameter(
name = "lm_gsf_psi_phi",
prior = prior_init_only(prior_beta(a_phi@init, b_phi@init)),
size = "Nind"
),

Parameter(name = "lm_gsf_sigma", prior = sigma, size = 1)
)

assert_flag(centered)
parameters_extra <- if (centered) {
list(
Parameter(name = "lm_gsf_psi_bsld", prior = psi_bsld, size = "Nind"),
Parameter(name = "lm_gsf_psi_ks", prior = psi_ks, size = "Nind"),
Parameter(name = "lm_gsf_psi_kg", prior = psi_kg, size = "Nind")
Parameter(
name = "lm_gsf_psi_bsld",
prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)),
size = "Nind"
),
Parameter(
name = "lm_gsf_psi_ks",
prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)),
size = "Nind"
),
Parameter(
name = "lm_gsf_psi_kg",
prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)),
size = "Nind"
)
)
} else {
list(
Expand Down
23 changes: 9 additions & 14 deletions R/LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,19 @@ NULL
#' @param slope_mu (`Prior`)\cr for the population slope `slope_mu`.
#' @param slope_sigma (`Prior`)\cr for the random slope standard deviation `slope_sigma`.
#' @param sigma (`Prior`)\cr for the variance of the longitudinal values `sigma`.
#' @param random_slope (`Prior`)\cr must be `prior_none()`, just used to specify initial values.
#'
#' @export
LongitudinalRandomSlope <- function(
intercept = prior_normal(30, 10, init = 30),
slope_mu = prior_normal(0, 15, init = 0.001),
slope_sigma = prior_lognormal(1, 5, init = 1),
sigma = prior_lognormal(1, 5, init = 1),
random_slope = prior_none(init = 0)
intercept = prior_normal(30, 10),
slope_mu = prior_normal(0, 15),
slope_sigma = prior_lognormal(0, 1.5),
sigma = prior_lognormal(0, 1.5)
) {

stan <- StanModule(
x = "lm-random-slope/model.stan"
)

assert_that(
inherits(random_slope, "Prior"),
random_slope@repr_data == "",
random_slope@repr_model == "",
msg = "`random_slope` must be a `prior_none()`"
)

.LongitudinalRandomSlope(
LongitudinalModel(
name = "Random Slope",
Expand All @@ -55,7 +46,11 @@ LongitudinalRandomSlope <- function(
Parameter(name = "lm_rs_slope_mu", prior = slope_mu, size = "n_arms"),
Parameter(name = "lm_rs_slope_sigma", prior = slope_sigma, size = 1),
Parameter(name = "lm_rs_sigma", prior = sigma, size = 1),
Parameter(name = "lm_rs_ind_rnd_slope", prior = random_slope, size = "Nind")
Parameter(
name = "lm_rs_ind_rnd_slope",
prior = prior_init_only(prior_normal(slope_mu@init, slope_sigma@init)),
size = "Nind"
)
)
)
)
Expand Down
3 changes: 2 additions & 1 deletion R/Parameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ as_stan_list.Parameter <- function(object, ...) {
#'
#' @param x (`Paramater`) \cr A model parameter
#' @param object (`Paramater`) \cr A model parameter
#' @param ... Not used.
#'
#' @description
#' Getter functions for the slots of a [`Parameter`] object
Expand All @@ -115,7 +116,7 @@ names.Parameter <- function(x) x@name

#' @describeIn Parameter-Getter-Methods The parameter's initial values
#' @export
initialValues.Parameter <- function(object) initialValues(object@prior)
initialValues.Parameter <- function(object, ...) initialValues(object@prior)

#' @describeIn Parameter-Getter-Methods The parameter's dimensionality
#' @export
Expand Down
19 changes: 14 additions & 5 deletions R/ParameterList.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ as.list.ParameterList <- function(x, ...) {
#' Getter functions for the slots of a [`ParameterList`] object
#' @inheritParams ParameterList-Shared
#' @family ParameterList
#' @param n_chains (`integer`) \cr the number of chains.
#' @name ParameterList-Getter-Methods
NULL

Expand All @@ -145,11 +146,19 @@ names.ParameterList <- function(x) {

#' @describeIn ParameterList-Getter-Methods The parameter-list's parameter initial values
#' @export
initialValues.ParameterList <- function(object) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
return(vals)
initialValues.ParameterList <- function(object, n_chains, ...) {
# Generate initial values as a list of lists. This is to ensure it is in the required
# format as specified by cmdstanr see the `init` argument of
# `help("model-method-sample", "cmdstanr")` for more details
lapply(
seq_len(n_chains),
\(i) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
vals
}
)
}


Expand Down
Loading

0 comments on commit abe9d91

Please sign in to comment.