diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index 0a5a309..39b4758 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -118,62 +118,6 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams): return model - # if len(layers): model = DeepCoxMixture(k=k, inputdim=features.shape[1], hidden=layers[0]) - # else: model = CoxMixture(k=k, inputdim=features.shape[1]) - - # x = torch.from_numpy(features.values.astype('float32')) - # t = torch.from_numpy(outcomes['time'].values.astype('float32')) - # e = torch.from_numpy(outcomes['event'].values.astype('float32')) - - # vidx = _get_valid_idx(x.shape[0], 0.15, random_seed) - - # train_data = (x[~vidx], t[~vidx], e[~vidx]) - # val_data = (x[vidx], t[vidx], e[vidx]) - - # (model, breslow_splines, unique_times) = train(model, - # train_data, - # val_data, - # epochs=epochs, - # lr=lr, bs=bs, - # use_posteriors=True, - # patience=5, - # return_losses=False, - # smoothing_factor=smoothing_factor) - - #return (model, breslow_splines, unique_times) - -# THIS IS 1 OF 2 _PREDICT_DCM FUNCTIONS HERE BUT THIS ONE THROWS A BUG SO I USE _PREDICT_DCM FUNCTION BELOW -# def _predict_dcm(model, features, times): - -# """Predict survival probabilities at specified time(s) using the -# Deep Cox Mixtures model. - -# Parameters -# ----------- -# model : Trained instance of the Deep Cox Mixtures model. -# features : pd.DataFrame -# A pandas dataframe with rows corresponding to individual -# samples and columns as covariates. -# times: float or list -# A float or list of the times at which to compute -# the survival probability. - -# Returns -# ----------- -# np.array : An array of the survival probabilites at each -# time point in times. - -# """ - -# #raise NotImplementedError() - -# survival_predictions = model.predict_survival(features, times) -# if len(times)>1: -# survival_predictions = pd.DataFrame(survival_predictions, columns=times).T -# return __interpolate_missing_times(survival_predictions, times) -# else: -# return survival_predictions - def _fit_dcph(features, outcomes, random_seed, **hyperparams): """Fit a Deep Cox Proportional Hazards Model/Farragi Simon Network [1,2] @@ -228,55 +172,6 @@ def _fit_dcph(features, outcomes, random_seed, **hyperparams): return model - #raise NotImplementedError() - # import torch - # import torchtuples as ttup - - # from pycox.models import CoxPH - - # torch.manual_seed(random_seed) - # np.random.seed(random_seed) - - # layers = hyperparams.get('layers', [100]) - # lr = hyperparams.get('lr', 1e-3) - # bs = hyperparams.get('bs', 100) - # epochs = hyperparams.get('epochs', 50) - # activation = hyperparams.get('activation', 'relu') - - # if activation == 'relu': activation = torch.nn.ReLU - # elif activation == 'relu6': activation = torch.nn.ReLU6 - # elif activation == 'tanh': activation = torch.nn.Tanh - # else: raise NotImplementedError("Activation function not implemented") - - # x = features.values.astype('float32') - # t = outcomes['time'].values.astype('float32') - # e = outcomes['event'].values.astype('bool') - - # in_features = x.shape[1] - # out_features = 1 - # batch_norm = False - # dropout = 0.0 - - # net = ttup.practical.MLPVanilla(in_features, layers, - # out_features, batch_norm, dropout, - # activation=activation, - # output_bias=False) - - # model = CoxPH(net, torch.optim.Adam) - - # vidx = _get_valid_idx(x.shape[0], 0.15, random_seed) - - # y_train, y_val = (t[~vidx], e[~vidx]), (t[vidx], e[vidx]) - # val_data = x[vidx], y_val - - # callbacks = [ttup.callbacks.EarlyStopping()] - # model.fit(x[~vidx], y_train, bs, epochs, callbacks, True, - # val_data=val_data, - # val_batch_size=bs) - # model.compute_baseline_hazards() - - # return model - def __interpolate_missing_times(survival_predictions, times): """Interpolate survival probabilities at missing time points. @@ -771,14 +666,14 @@ def __init__(self, treated_model, control_model): def predict_counterfactual_survival(self, features, times): - control_outcomes = self.control_model.predict_survival(features, times) treated_outcomes = self.treated_model.predict_survival(features, times) + control_outcomes = self.control_model.predict_survival(features, times) return treated_outcomes, control_outcomes def predict_counterfactual_risk(self, features, times): - control_outcomes = self.control_model.predict_risk(features, times) treated_outcomes = self.treated_model.predict_risk(features, times) + control_outcomes = self.control_model.predict_risk(features, times) return treated_outcomes, control_outcomes \ No newline at end of file diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py index 50b1cc3..0676fcd 100644 --- a/auton_survival/experiments.py +++ b/auton_survival/experiments.py @@ -1,3 +1,29 @@ +# coding=utf-8 +# MIT License + +# Copyright (c) 2022 Carnegie Mellon University, Auton Lab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities to perform cross-validation.""" + +from copy import deepcopy import numpy as np from auton_survival.estimators import SurvivalModel, CounterfactualSurvivalModel @@ -94,17 +120,23 @@ def fit(self, features, outcomes, ret_trained_model=True): self.folds = folds - unique_times = np.unique(outcomes['time'].values) - - time_min, time_max = unique_times.min(), unique_times.max() - + unique_times = np.unique(outcomes.time.values) + time_max, time_min = unique_times.max(), unique_times.min() + for fold in range(self.cv_folds): - - fold_outcomes = outcomes.loc[folds==fold, 'time'] - - if fold_outcomes.min() > time_min: time_min = fold_outcomes.min() - if fold_outcomes.max() < time_max: time_max = fold_outcomes.max() + + time_test = outcomes.loc[folds==fold, 'time'] + time_train = outcomes.loc[folds!=fold, 'time'] + if time_test.min() > time_min: + time_min = time_test.min() + + if (time_test.max() < time_max)|(time_train.max() < time_max): + if time_test.max() > time_train.max(): + time_max = max(time_test[time_test < time_train.max()]) + else: + time_max = max(time_test[time_test < time_test.max()]) + unique_times = unique_times[unique_times>=time_min] unique_times = unique_times[unique_times follow-up time + max_follow_up = outcomes_train.time.max() + predictions_test = predictions_test[outcomes_test.time.values < max_follow_up] + outcomes_test = outcomes_test.loc[outcomes_test.time.values < max_follow_up] # Compute IBS - score = survival_regression_metric('ibs', outcomes_train, outcomes_test, - predictions_test, unique_times) + score = survival_regression_metric('ibs', outcomes_train, predictions_test, + unique_times, outcomes_test) score_per_fold.append(score) current_score = np.mean(score_per_fold) @@ -235,6 +273,8 @@ class CounterfactualSurvivalRegressionCV: """ + _VALID_CF_METHODS = ['dsm', 'dcph', 'dcm', 'rsf', 'cph'] + def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}): self.model = model @@ -277,11 +317,9 @@ def fit(self, features, outcomes, interventions): """ - - treated, control = interventions==1, interventions!=1 - treated_model = self.treated_experiment.fit(features.loc[treated], - outcomes.loc[treated]) - control_model = self.control_experiment.fit(features.loc[control], - outcomes.loc[control]) + treated_model = self.treated_experiment.fit(features.loc[interventions==1], + outcomes.loc[interventions==1]) + control_model = self.control_experiment.fit(features.loc[interventions!=1], + outcomes.loc[interventions!=1]) return CounterfactualSurvivalModel(treated_model, control_model) diff --git a/auton_survival/metrics.py b/auton_survival/metrics.py index a88e828..b22b341 100644 --- a/auton_survival/metrics.py +++ b/auton_survival/metrics.py @@ -24,22 +24,20 @@ """Tools to compute metrics used to assess survival outcomes and survival model performance.""" -from sksurv import metrics, util from lifelines import KaplanMeierFitter, CoxPHFitter - -from sklearn.metrics import auc - import pandas as pd import numpy as np - +from sksurv import metrics, util +from scipy.optimize import fsolve +from sklearn.metrics import auc from tqdm import tqdm - import warnings -def survival_diff_metric(metric, outcomes, treatment_indicator, - weights=None, horizon=None, interpolate=True, - weights_clip=1e-2, n_bootstrap=None, - size_bootstrap=1.0, random_seed=0): +def treatment_effect(metric, outcomes, treatment_indicator, + weights=None, horizons=None, risks=None, + interpolate=True, weights_clip=1e-2, + n_bootstrap=None, size_bootstrap=1.0, + random_seed=0): """Compute metrics for comparing population level survival outcomes across treatment arms. @@ -50,7 +48,7 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, The metric to evalute for comparing survival outcomes. Options include: - `median` - - `time_to` + - `tar` - `hazard_ratio` - `restricted_mean` - `survival_at` @@ -63,10 +61,14 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, weights : pd.Series, default=None Treatment assignment propensity scores, \( \widehat{\mathbb{P}}(A|X=x) \). If `None`, all weights are set to \( 0.5 \). Default is `None`. - horizon : float - The time horizon at which to compare the survival curves. + horizons : float or int or array of floats or ints, default=None + Event horizon(s) at which to compute the metric. Must be specified for metric 'restricted_mean' and 'survival_at'. For 'hazard_ratio' this is ignored. + risks : float or array of floats + The risk level (0-1) at which to compare times between treatment arms. + Must be specified for metric 'tar'. + Ignored for other metrics. interpolate : bool, default=True Whether to interpolate the survival curves. weights_clip : float @@ -88,10 +90,17 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, """ assert metric in ['median', 'hazard_ratio', 'restricted_mean', - 'survival_at', 'time_to'] - - if metric in ['restricted_mean', 'survival_at', 'time_to']: - assert horizon is not None, "Please specify Event Horizon" + 'survival_at', 'tar'] + + if metric in ['restricted_mean', 'survival_at']: + assert horizons is not None, "Please specify Event Horizon" + assert risks is None, "Risks must be non for 'restricted_mean' and \ +'survival_at' metrics" + + if metric in ['tar']: + assert risks is not None, "Please specify risk level(s) at \ +which to compare time-to-event." + assert horizons is None, "Horizons must be none for 'tar' metric." if metric == 'hazard_ratio': warnings.warn("WARNING: You are computing Hazard Ratios.\n Make sure you have tested the PH Assumptions.") @@ -101,6 +110,12 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, # Bootstrapping ... if n_bootstrap is not None: assert isinstance(n_bootstrap, int), '`bootstrap` must be None or int' + + if isinstance(horizons, (int, float)): + horizons = [horizons] + + if isinstance(risks, (int, float)): + risks = [risks] if isinstance(n_bootstrap, int): print('Bootstrapping... ', n_bootstrap, @@ -120,12 +135,12 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, if metric == 'survival_at': _metric = _survival_at_diff - elif metric == 'time_to': - _metric = _time_to_diff + elif metric == 'tar': + _metric = _tar elif metric == 'restricted_mean': _metric = _restricted_mean_diff elif metric == 'median': - _metric = _time_to_diff + _metric = _median # Lifelines .median_survival_time_? elif metric == 'hazard_ratio': _metric = _hazard_ratio else: raise NotImplementedError() @@ -133,22 +148,24 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, if n_bootstrap is None: return _metric(treated_outcomes, control_outcomes, - horizon=horizon, + horizons=horizons, + risks=risks, interpolate=interpolate, treated_weights=iptw_weights[treatment_indicator], control_weights=iptw_weights[~treatment_indicator]) else: return [_metric(treated_outcomes, control_outcomes, - horizon=horizon, + horizons=horizons, + risks=risks, interpolate=interpolate, treated_weights=iptw_weights[treatment_indicator], control_weights=iptw_weights[~treatment_indicator], size_bootstrap=size_bootstrap, random_seed=i) for i in range(n_bootstrap)] -def survival_regression_metric(metric, outcomes_train, outcomes_test, - predictions, times): +def survival_regression_metric(metric, outcomes_train, predictions, + times, outcomes_test=None): """Compute metrics to assess survival model performance. Parameters @@ -178,6 +195,19 @@ def survival_regression_metric(metric, outcomes_train, outcomes_test, """ + if isinstance(times, (float,int)): + times = [times] + + if outcomes_test is None: + outcomes_test = outcomes_train + warnings.warn("You are are evaluating model performance on the \ +same data used to train the model.") + + assert max(times) < outcomes_train.time.max(), "Times should \ +be within the range of training set times to avoid exterpolation." + assert max(times) < outcomes_test.time.max(), "Times \ +must be within the range of test set times." + survival_train = util.Surv.from_dataframe('event', 'time', outcomes_train) survival_test = util.Surv.from_dataframe('event', 'time', outcomes_test) predictions_test = predictions @@ -201,10 +231,10 @@ def survival_regression_metric(metric, outcomes_train, outcomes_test, else: raise NotImplementedError() - + def phenotype_purity(phenotypes_train, outcomes_train, phenotypes_test=None, outcomes_test=None, - strategy='instantaneous', horizon=None, + strategy='instantaneous', horizons=None, bootstrap=None): """Compute the brier score to assess survival model performance for phenotypes. @@ -227,7 +257,7 @@ def phenotype_purity(phenotypes_train, outcomes_train, Options include: - `instantaneous` : Compute the brier score. - `integrated` : Compute the integrated brier score. - horizon : float or int or np.array of floats or ints, default=None + horizons : float or int or an array of floats or ints, default=None Event horizon(s) at which to compute the metric bootstrap : integer, default=None The number of bootstrap iterations. @@ -240,8 +270,6 @@ def phenotype_purity(phenotypes_train, outcomes_train, """ - # CODE UPDATE: enable phenotype purity to be computed for the test set... - # without specifying folds to determine the train/test sets when folds are inapplicable (no CV) np.random.seed(0) if (outcomes_test is None) & (phenotypes_test is not None): @@ -249,17 +277,19 @@ def phenotype_purity(phenotypes_train, outcomes_train, if (outcomes_test is not None) & (phenotypes_test is None): raise Exception("Specify phenotypes for test set.") - assert horizon is not None, "Please specify Event Horizon" + assert horizons is not None, "Please specify Event Horizon" - if isinstance(horizon, float) | isinstance(horizon, int): - horizon = [horizon] + if isinstance(horizons, (float,int)): + horizons = [horizons] if outcomes_test is None: phenotypes_test = phenotypes_train outcomes_test = outcomes_train warnings.warn("You are are estimating survival probabilities for \ - the same dataset used to estimate the censoring \ - distribution.") +the same dataset used to estimate the censoring distribution.") + + assert outcomes_test.time.max() >= outcomes_train.time.max(), "Test \ +set times must be within the range of training set follow-up times." survival_curves = {} for phenotype in np.unique(phenotypes_train): @@ -272,28 +302,28 @@ def phenotype_purity(phenotypes_train, outcomes_train, if strategy == 'instantaneous': - predictions = np.zeros((len(survival_test), len(horizon))) + predictions = np.zeros((len(survival_test), len(horizons))) for phenotype in set(phenotypes_test): - predictions[phenotypes_test==phenotype, :] = survival_curves[phenotype].predict(times=horizon, + predictions[phenotypes_test==phenotype, :] = survival_curves[phenotype].predict(times=horizons, interpolate=True) if bootstrap is None: return metrics.brier_score(survival_train, survival_test, - predictions, horizon)[1] + predictions, horizons)[1] else: scores = [] for i in tqdm(range(bootstrap)): idx = np.random.choice(n, size=n, replace=True) score = metrics.brier_score(survival_train, survival_test[idx], - predictions[idx], horizon)[1] + predictions[idx], horizons)[1] scores.append(score) return scores elif strategy == 'integrated': horizon_scores = [] - for time in horizon: + for horizon in horizons: times = np.unique(outcomes_test['time']) - times = times[timesthres][0], y[y>thres][0] + root = fsolve(func, x0=x1, args=(thres, x1, y1, x2, y2))[0] + return root + def func(x, y, x1, y1, x2, y2): + return y1 + (x-x1)*((y2-y1)/(x2-x1)) - y + + if random_seed is not None: + treated_outcomes = treated_outcomes.sample(n=int(size_bootstrap*len(treated_outcomes)), + weights=treated_weights, + random_state=random_seed, replace=True) + control_outcomes = control_outcomes.sample(n=int(size_bootstrap*len(control_outcomes)), + weights=control_weights, + random_state=random_seed, replace=True) - treatment_survival = KaplanMeierFitter().fit(treated_outcomes['time'], - treated_outcomes['event']) + treated_survival = KaplanMeierFitter().fit(treated_outcomes['time'], + treated_outcomes['event']) control_survival = KaplanMeierFitter().fit(control_outcomes['time'], control_outcomes['event']) + treated_horizons = np.linspace(treated_outcomes.time.min(), + treated_outcomes.time.max(), + round((treated_outcomes.time.max()-treated_outcomes.time.min())*20)) + control_horizons = np.linspace(control_outcomes.time.min(), + control_outcomes.time.max(), + round((control_outcomes.time.max()-control_outcomes.time.min())*20)) + + treated_risk = 1-treated_survival.predict(treated_horizons, interpolate).values + control_risk = 1-control_survival.predict(control_horizons, interpolate).values + + tar_diff = [] + for risk in risks: + treated_tar = interp_x(treated_risk, treated_horizons, risk) + control_tar = interp_x(control_risk, control_horizons, risk) + tar_diff.append(treated_tar - control_tar) + + return np.array(tar_diff) + def _hazard_ratio(treated_outcomes, control_outcomes, treated_weights, control_weights, size_bootstrap=1.0, random_seed=None, **kwargs): diff --git a/auton_survival/phenotyping.py b/auton_survival/phenotyping.py index dd8f4ee..08b9513 100644 --- a/auton_survival/phenotyping.py +++ b/auton_survival/phenotyping.py @@ -31,6 +31,7 @@ from copy import deepcopy from sklearn import cluster, decomposition, mixture +from sklearn.metrics import auc from auton_survival.utils import _get_method_kwargs from auton_survival.experiments import CounterfactualSurvivalRegressionCV @@ -419,72 +420,171 @@ def fit_phenotype(self, features): return self.fit(features).phenotype(features) -class SurvivalVirtualTwinsPhenotyper(object): +class SurvivalVirtualTwinsPhenotyper(Phenotyper): - """"Not Yet Implemented""" + """Phenotyper that estimates the potential outcomes under treatment and + control using a counterfactual Deep Cox Proportional Hazards model, + followed by regressing the difference of the estimated counterfactual + Restricted Mean Survival Times using a Random Forest regressor.""" - - _VALID_PHENO_METHODS = ['rsf'] + _VALID_PHENO_METHODS = ['rfr'] _DEFAULT_PHENO_HYPERPARAMS = {} - _DEFAULT_PHENO_HYPERPARAMS['rsf'] = {'n_estimators': 50, + _DEFAULT_PHENO_HYPERPARAMS['rfr'] = {'n_estimators': 50, 'max_depth': 5} def __init__(self, cf_method='dcph', - phenotyping_method='rsf', + phenotyping_method='rfr', cf_hyperparams=None, phenotyper_hyperparams=None, random_seed=0): - - raise NotImplementedError() - - assert cf_method in CounterfactualSurvivalRegressionCV._VALID_CF_METHODS, "Invalid Counterfactual Method: "+cf_method - assert phenotyping_method in self._VALID_PHENO_METHODS, "Invalid Phenotyping Method: "+phenotyping_method + + assert cf_method in CounterfactualSurvivalRegressionCV._VALID_CF_METHODS, "\ + Invalid Counterfactual Method: "+cf_method + assert phenotyping_method in self._VALID_PHENO_METHODS, "Invalid Phenotyping Method:\ + "+phenotyping_method self.cf_method = cf_method self.phenotyping_method = phenotyping_method - if cf_method_hyperparams is None: - cf_method_hyperparams = {} + if cf_hyperparams is None: + cf_hyperparams = {} if phenotyper_hyperparams is None: phenotyper_hyperparams = {} - - phenotyper_hyperparams = deepcopy(SurvivalVirtualTwinsPhenotyper._DEFAULT_PHENO_HYPERPARAMS[phenotyping_method]).update(phenotyper_hyperparams) + self.phenotyper_hyperparams = phenotyper_hyperparams - - cf_hyperparams = deepcopy(SurvivalVirtualTwinsPhenotyper._DEFAULT_PHENO_HYPERPARAMS[cf_method]).update(cf_hyperparams) self.cf_hyperparams = cf_hyperparams self.random_seed = random_seed def fit(self, features, outcomes, interventions, horizon): + + """Fit a counterfactual model and regress the difference of the estimated + counterfactual RMST using a Random Forest regressor. + + Parameters + ----------- + features: pd.DataFrame + A pandas dataframe with rows corresponding to individual samples + and columns as covariates. + outcomes : pd.DataFrame + A pandas dataframe with rows corresponding to individual samples + and columns 'time' and 'event'. + treatment_indicator : np.array + Boolean numpy array of treatment indicators. True means individual + was assigned a specific treatment. + horizon : np.float + The event horizon at which to compute the counterfacutal RMST for + regression. - raise NotImplementedError() + Returns + ----------- + Trained instance of Survival Virtual Twins Phenotyer. + + """ - cf_model = CounterfactualSurvivalRegressionCV(**self.cf_method_hyperparams) + cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method, + hyperparam_grid=self.cf_hyperparams) self.cf_model = cf_model.fit(features, outcomes, interventions) - times = np.unique(outcomes.times.values) + times = np.unique(outcomes.time.values) cf_predictions = self.cf_model.predict_counterfactual_survival(features, - interventions, - times) + times.tolist()) ite_estimates = cf_predictions[1] - cf_predictions[0] + ite_estimates = [estimate[times < horizon] for estimate in ite_estimates] + times = times[times < horizon] + # Compute rmst for each sample based on user-specified event-horizon + rmst = np.array([auc(times, i) for i in ite_estimates]) - if self.phenotyping_method == 'rsf': + if self.phenotyping_method == 'rfr': from sklearn.ensemble import RandomForestRegressor - pheno_model = RandomForestRegressor(**self.phenotyping_method_hyperparams) - pheno_model.fit(features.values, ite_estimates) + pheno_model = RandomForestRegressor(**self.phenotyper_hyperparams) + pheno_model.fit(features.values, rmst) self.pheno_model = pheno_model + self.fitted = True + + return self + + def predict_proba(self, features): + + """Estimate the probability that the treatment group RMST is greater than + that of the control group. + + Parameters + ----------- + features: pd.DataFrame + a pandas dataframe with rows corresponding to individual samples + and columns as covariates. + + Returns + ----------- + np.array + a numpy array of the phenogroup probabilties. + + """ + + phenotype_preds= self.pheno_model.predict(features) + preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min()) + preds_surv_less = 1 - preds_surv_greater + preds = np.array([[preds_surv_less[i], preds_surv_greater[i]] + for i in range(len(features))]) + + return preds def predict(self, features): - raise NotImplementedError() + """Predict phenogroups. + + Parameters + ----------- + features: pd.DataFrame + a pandas dataframe with rows corresponding to individual samples + and columns as covariates. + Returns + ----------- + np.array + a numpy array of the phenogroup labels + + """ + phenotype_preds= self.pheno_model.predict(features) - phenotype_preds = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min()) - return phenotype_preds + preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min()) + preds_surv_less = 1 - preds_surv_greater + preds = np.array([[preds_surv_less[i], preds_surv_greater[i]] + for i in range(len(features))]) + + return np.argmax(preds, axis=1) + + def fit_predict(self, features, outcomes, interventions, horizon): + + """Fit and perform phenotyping on a given dataset. + + Parameters + ----------- + features: pd.DataFrame + A pandas dataframe with rows corresponding to individual samples + and columns as covariates. + outcomes : pd.DataFrame + A pandas dataframe with rows corresponding to individual samples + and columns 'time' and 'event'. + treatment_indicator : np.array + Boolean numpy array of treatment indicators. True means individual + was assigned a specific treatment. + horizon : np.float + The event horizon at which to compute the counterfacutal RMST for + regression. + + Returns + ----------- + np.array + a numpy array of the phenogroup labels. + + """ + + return self.fit(features, outcomes, interventions, horizon).predict(features) \ No newline at end of file diff --git a/examples/Demo of CMHE on Synthetic Data.ipynb b/examples/Demo of CMHE on Synthetic Data.ipynb index a4160b2..0b284db 100644 --- a/examples/Demo of CMHE on Synthetic Data.ipynb +++ b/examples/Demo of CMHE on Synthetic Data.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "7e3e61af", + "id": "5ded7b8d", "metadata": {}, "source": [ "# Deep Cox Mixtures with Heterogenous Effects (CMHE) Demo\n", @@ -25,30 +25,26 @@ "### 2. [Synthetic Data](#syndata) \n", "####               2.1 [Generative Process for the Synthetic Dataset.](#gensyndata)\n", "####               2.2 [Loading and Visualizing the Dataset.](#vissyndata)\n", - "####               2.2 [Split Dataset into Train and Test.](#splitdata)\n", + "####               2.3 [Split Dataset into Train and Test.](#splitdata)\n", "\n", " \n", "### 3. [Counterfactual Phenotyping](#phenotyping)\n", "\n", "####               3.1 [Phenotyping with CMHE](#phenocmhe)\n", - "\n", - "####               3.1 [Comparison with Clustering](#clustering)\n", - "\n", + "####               3.2 [Phenotyping with Virtual Twins Survival Regression](#vtsp)\n", + "####               3.3 [Comparison with Clustering](#clustering)\n", "\n", "\n", "### 4. [Factual Regression](#regression)\n", - "\n", "####               4.1 [Factual Regression with CMHE](#regcmhe)\n", - "\n", - "\n", - "####               4.1 [Comparison with a Deep Cox Proportional Hazards Model](#deepcph)\n", + "####               4.2 [Comparison with a Deep Cox Proportional Hazards Model](#deepcph)\n", "\n", "
\n" ] }, { "cell_type": "markdown", - "id": "02e1d71c", + "id": "929af912", "metadata": {}, "source": [ "\n", @@ -78,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "20d71b4e", + "id": "ec3e5eed", "metadata": {}, "source": [ "\n", @@ -88,8 +84,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "bbf84a5a", + "execution_count": 8, + "id": "4522fb58", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +101,7 @@ }, { "cell_type": "markdown", - "id": "d7ba3d8a", + "id": "676d208e", "metadata": {}, "source": [ "\n", @@ -114,7 +110,7 @@ }, { "cell_type": "markdown", - "id": "bac7bba7", + "id": "6b126859", "metadata": {}, "source": [ "1. Features $x_1$, $x_2$ and the base survival phenotypes $Z$ are sampled from $\\texttt{scikit-learn's make_blobs(...)}$ function which generates isotropic Gaussian blobs:\n", @@ -133,8 +129,8 @@ }, { "cell_type": "code", - "execution_count": 19, - "id": "826af396", + "execution_count": 9, + "id": "45d328ce", "metadata": {}, "outputs": [ { @@ -244,7 +240,7 @@ "4 0.748930 " ] }, - "execution_count": 19, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -259,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "83ca0fea", + "id": "72b3fb0d", "metadata": {}, "source": [ "\n", @@ -268,8 +264,8 @@ }, { "cell_type": "code", - "execution_count": 20, - "id": "cd1754ef", + "execution_count": 10, + "id": "2838615d", "metadata": {}, "outputs": [ { @@ -289,7 +285,7 @@ }, { "cell_type": "markdown", - "id": "7156670f", + "id": "b7bb0f4a", "metadata": {}, "source": [ "\n", @@ -298,8 +294,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "c65b81bb", + "execution_count": 11, + "id": "e58d6786", "metadata": {}, "outputs": [ { @@ -353,8 +349,8 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "0c2628c6", + "execution_count": 25, + "id": "2e419169", "metadata": {}, "outputs": [], "source": [ @@ -380,7 +376,7 @@ }, { "cell_type": "markdown", - "id": "8e54dfd5", + "id": "b1d39c9c", "metadata": {}, "source": [ "\n", @@ -389,7 +385,7 @@ }, { "cell_type": "markdown", - "id": "2c21b829", + "id": "8f947f28", "metadata": {}, "source": [ "\n", @@ -398,8 +394,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "8e2f5ff3", + "execution_count": 19, + "id": "755c82c0", "metadata": {}, "outputs": [], "source": [ @@ -408,9 +404,9 @@ "g = 2 # number of underlying treatment effect phenotypes.\n", "layers = [50, 50] # number of neurons in each hidden layer.\n", "\n", - "random_seed = 3\n", + "random_seed =10\n", "iters = 100 # number of training epochs\n", - "learning_rate = 0.001\n", + "learning_rate = 0.01\n", "batch_size = 128 \n", "vsize = 0.15 # size of the validation split\n", "patience = 3\n", @@ -419,8 +415,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "5fe68dfb", + "execution_count": 20, + "id": "267164ba", "metadata": { "scrolled": true }, @@ -429,7 +425,13 @@ "name": "stderr", "output_type": "stream", "text": [ - " 28%|██████████████████████████████▊ | 28/100 [00:15<00:41, 1.75it/s]\n" + " 0%| | 0/100 [00:00" ] @@ -534,9 +536,352 @@ "plot_phenotypes_roc(outcomes_te, zeta_probs_test_CMHE[:, max_treat_idx_CMHE])" ] }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ed3809ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[2.1216183e-05, 9.9997878e-01],\n", + " [9.9602604e-01, 3.9739967e-03],\n", + " [9.9793184e-01, 2.0681797e-03],\n", + " ...,\n", + " [3.6357585e-04, 9.9963641e-01],\n", + " [1.0352042e-04, 9.9989653e-01],\n", + " [5.2588820e-01, 4.7411180e-01]], dtype=float32)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "zeta_probs_test_CMHE" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "9100d5fb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.9999788 , 0.003974 , 0.00206818, ..., 0.9996364 , 0.9998965 ,\n", + " 0.4741118 ], dtype=float32)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "zeta_probs_test_CMHE[:, max_treat_idx_CMHE]" + ] + }, + { + "cell_type": "markdown", + "id": "87c7bc1c", + "metadata": {}, + "source": [ + "\n", + "### 3.2 Phenotyping with Virtual Twins Survival Regression" + ] + }, { "cell_type": "markdown", - "id": "90471bba", + "id": "b4a0e802", + "metadata": {}, + "source": [ + "A Virtual Twins model as first proposed in [1] predicts response probabilities for each sample using models trained separately on treatment and control groups. `auton-survival` fits a counterfactual model and regresses the difference of the estimated RMST using a Random Forest regressor.\n", + "\n", + "*For more information on Virtual Twins models [1], please refer to the following paper*:\n", + "\n", + "[1] [Subgroup identification from randomized clinical trial data](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3880775/)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a37bbb0c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Now let us evaluate our performance\n", + "plot_phenotypes_roc(outcomes_tr, phenogroups)" + ] + }, + { + "cell_type": "markdown", + "id": "e5e36371", "metadata": {}, "source": [ "\n", @@ -545,7 +890,7 @@ }, { "cell_type": "markdown", - "id": "5e4b3b79", + "id": "555a7366", "metadata": {}, "source": [ "We compare the ability of CMHE against dimensionality reduction followed by clustering for counterfactual phenotyping. Specifically, we first perform dimensionality reduction of the input confounders, $\\mathbf{x}$, followed by clustering. Due to a small number of confounders in the synthetic data, in the following experiment, we directly perform clustering using a Gaussian Mixture Model (GMM) with 2 components and diagonal covariance matrices." @@ -553,8 +898,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "id": "08816e3e", + "execution_count": null, + "id": "de6e8dd9", "metadata": {}, "outputs": [], "source": [ @@ -570,41 +915,16 @@ "phenotyper = ClusteringPhenotyper(clustering_method=clustering_method, \n", " dim_red_method=dim_red_method, \n", " n_components=n_components, \n", - " n_clusters=n_clusters)" + " n_clusters=n_clusters,\n", + " random_seed=36) " ] }, { "cell_type": "code", - "execution_count": 30, - "id": "2e3e5bec", + "execution_count": null, + "id": "eaa74cbe", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No Dimensionaity reduction specified...\n", - " Proceeding to learn clusters with the raw features...\n", - "Fitting the following Clustering Model:\n", - " GaussianMixture(covariance_type='diag', n_components=3, random_state=0)\n", - "Distribution of individuals in each treatement phenotype in the training data: [1306 1162 1431]\n", - "\n", - "Group 2 has the maximum restricted mean survival time on the training data!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\Willa Potosnak\\miniconda3\\envs\\localenv\\lib\\site-packages\\lifelines\\fitters\\__init__.py:204: ApproximationWarning: Approximating using linear interpolation`.\n", - "\n", - " warnings.warn(\"Approximating using linear interpolation`.\\n\", exceptions.ApproximationWarning)\n", - "C:\\Users\\Willa Potosnak\\miniconda3\\envs\\localenv\\lib\\site-packages\\lifelines\\fitters\\__init__.py:204: ApproximationWarning: Approximating using linear interpolation`.\n", - "\n", - " warnings.warn(\"Approximating using linear interpolation`.\\n\", exceptions.ApproximationWarning)\n" - ] - } - ], + "outputs": [], "source": [ "zeta_probs_train = phenotyper.fit_phenotype(features_tr.values)\n", "zeta_train = np.argmax(zeta_probs_train, axis=1)\n", @@ -618,7 +938,7 @@ }, { "cell_type": "markdown", - "id": "f5dd6035", + "id": "782f9ed5", "metadata": {}, "source": [ "### Evaluate Clustering Phenotyper on Test Data" @@ -626,28 +946,10 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "028bef76", + "execution_count": null, + "id": "be4e1558", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Distribution of individuals in each treatement phenotype in the test data: [380 330 391]\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Now for each individual in the test data, let's find the probability that \n", "# they belong to the max treatment effect group\n", @@ -664,7 +966,7 @@ }, { "cell_type": "markdown", - "id": "3dd82f69", + "id": "8d8062d9", "metadata": {}, "source": [ "\n", @@ -673,7 +975,7 @@ }, { "cell_type": "markdown", - "id": "1ae0a9f1", + "id": "25baff78", "metadata": {}, "source": [ "For completeness, we further evaluate the performance of CMHE in estimating factual risk over multiple time horizons using the standard survival analysis metrics, including: \n", @@ -694,7 +996,7 @@ }, { "cell_type": "markdown", - "id": "9c009a2c", + "id": "81aaa75b", "metadata": {}, "source": [ "\n", @@ -704,19 +1006,10 @@ }, { "cell_type": "code", - "execution_count": 32, - "id": "80b4a724", + "execution_count": null, + "id": "70503c63", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Concordance Index (1 Year): 0.6548 (3 Year) 0.6656: (5 Year): 0.6665\n", - "Integrated Brier Score: 0.1597\n" - ] - } - ], + "outputs": [], "source": [ "horizons = [1, 3, 5]\n", "\n", @@ -729,42 +1022,9 @@ "print(f'Integrated Brier Score: {np.around(IBS, 4)}')" ] }, - { - "cell_type": "code", - "execution_count": 33, - "id": "63c26319", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0.1487453 , 1.8924836 , 0.19525401, ..., 0.4836966 ,\n", - " 0.33955073, 0.37479353],\n", - " [ 2.2802675 , -0.73033774, -1.7158557 , ..., 0.47515342,\n", - " 0.8169348 , 0.59493804],\n", - " [ 1.4902165 , -0.91186345, 0.7905248 , ..., 0.08489922,\n", - " 0.37587634, 0.62569857],\n", - " ...,\n", - " [ 2.0787652 , -2.0501418 , -0.36273366, ..., 0.63403505,\n", - " 0.97710544, 0.81890947],\n", - " [-1.5852283 , -0.79666543, 0.9420089 , ..., 0.6343607 ,\n", - " 0.5449456 , 0.01124444],\n", - " [ 0.25414062, 1.5835027 , -0.41139466, ..., 0.369376 ,\n", - " 0.05497523, 0.5677395 ]], dtype=float32)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x_te" - ] - }, { "cell_type": "markdown", - "id": "a247dd6d", + "id": "0f284076", "metadata": {}, "source": [ "\n", @@ -773,20 +1033,12 @@ }, { "cell_type": "code", - "execution_count": 34, - "id": "4b944e83", + "execution_count": null, + "id": "3aaf3504", "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 47/50 [00:03<00:00, 14.81it/s]\n" - ] - } - ], + "outputs": [], "source": [ "from auton_survival.estimators import SurvivalModel\n", "\n", @@ -804,7 +1056,7 @@ }, { "cell_type": "markdown", - "id": "bf5ee5fe", + "id": "b4b93e25", "metadata": {}, "source": [ "### Evaluate DCPH on Test Data" @@ -812,16 +1064,19 @@ }, { "cell_type": "code", - "execution_count": 35, - "id": "0d3cf5e9", + "execution_count": 1, + "id": "487dcfc1", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Concordance Index (1 Year): 0.6919 (3 Year) 0.6947: (5 Year): 0.6981\n", - "Integrated Brier Score: 0.153\n" + "ename": "NameError", + "evalue": "name 'dcph_model' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m# Find suvival scores in the test data\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mpredictions_test_DCPH\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdcph_model\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict_survival\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeatures_te_dcph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhorizons\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m CI1, CI3, CI5, IBS = factual_evaluate((x_tr, t_tr, e_tr, a_tr), (x_te, t_te, e_te, a_te), \n\u001b[0;32m 5\u001b[0m horizons, predictions_test_DCPH)\n", + "\u001b[1;31mNameError\u001b[0m: name 'dcph_model' is not defined" ] } ], @@ -838,7 +1093,117 @@ { "cell_type": "code", "execution_count": null, - "id": "a792ce60", + "id": "8eb2a718", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "49d0f1f7", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import torch\n", + "from tqdm import tqdm \n", + "import sys\n", + "sys.path.append('../')\n", + "from auton_survival.datasets import load_dataset\n", + "from cmhe_demo_utils import * \n", + "\n", + "# Load the synthetic dataset\n", + "outcomes, features, interventions = load_dataset(dataset='SYNTHETIC')\n", + "\n", + "x = features.iloc[:100]\n", + "y = pd.DataFrame(outcomes, columns=['event', 'time']).iloc[:100]\n", + "#y = outcomes.iloc[:100]\n", + "i = interventions.astype('float64').iloc[:100]\n", + "#i = interventions.iloc[:100]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2fda4247", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fa511f30", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00\n" ] }, { "cell_type": "markdown", - "id": "d3949e98", + "id": "e984fe01", "metadata": {}, "source": [ "\n", @@ -52,7 +53,7 @@ }, { "cell_type": "markdown", - "id": "6ec156cf", + "id": "636416cc", "metadata": {}, "source": [ "`auton-survival` offers utilities to phenotype, or group, samples for use in assessing differential survival probabilities across groups. Phenotyping can aid clinical decision makers by offering insight into groups of patients for which differential survival probabilities exist. This insight can influence clinical practices applied to these groups.\n", @@ -68,7 +69,7 @@ }, { "cell_type": "markdown", - "id": "eda5bb46", + "id": "24d8e2be", "metadata": {}, "source": [ "\n", @@ -78,7 +79,7 @@ }, { "cell_type": "markdown", - "id": "6bf472e1", + "id": "fc152936", "metadata": {}, "source": [ "*For the original datasource, please refer to the following [website](https://biostat.app.vumc.org/wiki/Main/SupportDesc).*\n", @@ -88,8 +89,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "50b7cf7f", + "execution_count": 2, + "id": "4a479b1e", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +103,7 @@ }, { "cell_type": "markdown", - "id": "21c52d23", + "id": "6ea602d5", "metadata": {}, "source": [ "\n", @@ -111,8 +112,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "7bc0c498", + "execution_count": 3, + "id": "edbd1c2c", "metadata": {}, "outputs": [ { @@ -400,7 +401,7 @@ }, { "cell_type": "markdown", - "id": "598804cf", + "id": "0ff79b8f", "metadata": {}, "source": [ "Here we perform imputation and scaling on the entire dataset but in practice we recommend that preprocessing tools be fitted solely to training data." @@ -408,8 +409,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "61e2f725", + "execution_count": 4, + "id": "ae30d5f4", "metadata": {}, "outputs": [], "source": [ @@ -422,8 +423,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "1445e99a", + "execution_count": 5, + "id": "823bab64", "metadata": {}, "outputs": [ { @@ -448,7 +449,7 @@ }, { "cell_type": "markdown", - "id": "bb8354d1", + "id": "3d407452", "metadata": {}, "source": [ "\n", @@ -457,7 +458,7 @@ }, { "cell_type": "markdown", - "id": "bb77eff8", + "id": "03430868", "metadata": {}, "source": [ "The intersectional Phenotyper performs an exhaustive cartesian product on the user-specified set of categorical and numerical variables to obtain the phenotypes. Numeric variables are binned based on user-specified quantiles." @@ -465,7 +466,7 @@ }, { "cell_type": "markdown", - "id": "2298969c", + "id": "449b1e4f", "metadata": {}, "source": [ "\n", @@ -474,7 +475,7 @@ }, { "cell_type": "markdown", - "id": "5f117fdd", + "id": "a9c7397a", "metadata": {}, "source": [ "Here we fit the phenotyper on the entire dataset but in practice we recommend that the phenotyper be fitted solely to training data." @@ -482,8 +483,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "83e21fcb", + "execution_count": 6, + "id": "2db693a2", "metadata": {}, "outputs": [ { @@ -496,7 +497,7 @@ " dtype='\n", @@ -528,7 +529,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "33ffa96e", + "id": "b6685ba0", "metadata": {}, "outputs": [ { @@ -557,7 +558,7 @@ }, { "cell_type": "markdown", - "id": "68ff0142", + "id": "a9a5b347", "metadata": {}, "source": [ "As you can see, patients ages 18 to 64 without cancer have the highest survival rates. Alternatively, patients ages 64 to 101 with metastatic cancer have the lowest survival rates." @@ -565,7 +566,7 @@ }, { "cell_type": "markdown", - "id": "7061f353", + "id": "0639e04a", "metadata": {}, "source": [ "\n", @@ -574,7 +575,7 @@ }, { "cell_type": "markdown", - "id": "ff594be0", + "id": "645e23c4", "metadata": {}, "source": [ "Dimensionality reduction of the input covariates, $\\mathbf{x}$, is performed followed by clustering. Learned clusters are considered phenotypes and used to group samples based on similarity in the covariate space. The estimated probability of sample cluster association is computed as the sample distance to a cluster center normalized by the sum of distances to other clusters.\n", @@ -588,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "bd475419", + "id": "73e8d1f8", "metadata": {}, "source": [ "\n", @@ -597,7 +598,7 @@ }, { "cell_type": "markdown", - "id": "bcf510b1", + "id": "d7bace22", "metadata": {}, "source": [ " " @@ -606,7 +607,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "2c9b5abd", + "id": "003e3213", "metadata": {}, "outputs": [ { @@ -654,7 +655,7 @@ }, { "cell_type": "markdown", - "id": "7701e2ba", + "id": "9f706f8e", "metadata": {}, "source": [ "\n", @@ -663,8 +664,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "fb8d3fb3", + "execution_count": 6, + "id": "f6557355", "metadata": {}, "outputs": [ { @@ -694,7 +695,7 @@ }, { "cell_type": "markdown", - "id": "3d7ed290", + "id": "161418c6", "metadata": {}, "source": [ "Intersecting survival rates indicate that the SUPPORT dataset follows non-proportional hazards which violates assumptions of the Cox Model." @@ -702,7 +703,7 @@ }, { "cell_type": "markdown", - "id": "1a19711f", + "id": "b36d1092", "metadata": {}, "source": [ "\n", @@ -711,7 +712,7 @@ }, { "cell_type": "markdown", - "id": "b9c39451", + "id": "740ca8ad", "metadata": {}, "source": [ "To measure a phenotyper's ability to extract subgroups with differential survival rates, we estimate the (Integrated) Brier Score by fitting a Kaplan-Meier estimator within each phenogroup and employing it to estimate the survival rate within each phenogroup. We refer to this as the *phenotyping purity.*" @@ -719,8 +720,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "9d26d606", + "execution_count": 12, + "id": "6a6fbe40", "metadata": {}, "outputs": [ { @@ -734,7 +735,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "C:\\Users\\Willa Potosnak\\OneDrive\\Documents\\CMU Research\\CMU_Projects\\auton-survival\\examples\\..\\auton_survival\\metrics.py:260: UserWarning: You are are estimating survival probabilities for the same dataset used to estimate the censoring distribution.\n", + "C:\\Users\\Willa Potosnak\\OneDrive\\Documents\\CMU Research\\CMU_Projects\\auton-survival\\examples\\..\\auton_survival\\metrics.py:261: UserWarning: You are are estimating survival probabilities for the same dataset used to estimate the censoring distribution.\n", " warnings.warn(\"You are are estimating survival probabilities for \\\n", "C:\\Users\\Willa Potosnak\\miniconda3\\envs\\localenv\\lib\\site-packages\\lifelines\\fitters\\__init__.py:204: ApproximationWarning: Approximating using linear interpolation`.\n", "\n", @@ -748,7 +749,7 @@ "# Estimate the Integrated Brier Score at event horizons of 1, 2 and 5 years.\n", "metric = phenotype_purity(phenotypes_train=phenotypes, outcomes_train=y_tr, \n", " phenotypes_test=None, outcomes_test=None,\n", - " strategy='instantaneous', horizon=[365, 730, 1825], \n", + " strategy='instantaneous', horizons=[365, 730, 1825], \n", " bootstrap=None)\n", "\n", "print(f'Phenotyping purity for event horizon of 1 year: {metric[0]} | 2 years: {metric[1]} | 5 years: {metric[2]}')" @@ -756,7 +757,7 @@ }, { "cell_type": "markdown", - "id": "978c303b", + "id": "1c2e576c", "metadata": {}, "source": [ "\n", @@ -770,7 +771,7 @@ } }, "cell_type": "markdown", - "id": "296b1097", + "id": "5cd0b913", "metadata": {}, "source": [ "\n", @@ -790,7 +791,7 @@ }, { "cell_type": "markdown", - "id": "744aa357", + "id": "5d0ce8ae", "metadata": {}, "source": [ "\n", @@ -799,7 +800,7 @@ }, { "cell_type": "markdown", - "id": "b093362e", + "id": "506ebbed", "metadata": {}, "source": [ "Fit DCM model to training data. Perform hyperparameter tuning by selecting model parameters that minimize the brier score computed for the validation set.\n", @@ -813,7 +814,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "cec70532", + "id": "51bb7267", "metadata": {}, "outputs": [ { @@ -857,7 +858,7 @@ }, { "cell_type": "markdown", - "id": "00068611", + "id": "5253ad49", "metadata": {}, "source": [ "\n", @@ -870,7 +871,7 @@ { "cell_type": "code", "execution_count": 15, - "id": "1c37f5ac", + "id": "b3783f2d", "metadata": {}, "outputs": [ { @@ -899,7 +900,7 @@ }, { "cell_type": "markdown", - "id": "3e213acf", + "id": "f8a1067f", "metadata": {}, "source": [ "\n", @@ -909,7 +910,7 @@ { "cell_type": "code", "execution_count": 16, - "id": "724422d6", + "id": "d91f5ab9", "metadata": {}, "outputs": [ { @@ -938,7 +939,7 @@ }, { "cell_type": "markdown", - "id": "b8bb50c2", + "id": "ef7be826", "metadata": {}, "source": [ "Intersecting survival rates indicate that the SUPPORT dataset follows non-proportional hazards which violates assumptions of the Cox Model." @@ -946,7 +947,7 @@ }, { "cell_type": "markdown", - "id": "9d460add", + "id": "aa407976", "metadata": {}, "source": [ "\n", @@ -955,7 +956,7 @@ }, { "cell_type": "markdown", - "id": "efaf0ea0", + "id": "b933a72b", "metadata": {}, "source": [ "To measure a phenotyper's ability to extract subgroups with differential survival rates, we estimate the (Integrated) Brier Score by fitting a Kaplan-Meier estimator within each phenogroup and employing it to estimate the survival rate within each phenogroup. We refer to this as the *phenotyping purity.*" @@ -964,7 +965,7 @@ { "cell_type": "code", "execution_count": 17, - "id": "adfaac68", + "id": "0a4ed63e", "metadata": {}, "outputs": [ { @@ -992,7 +993,7 @@ "# Estimate the Integrated Brier Score at event horizons of 1, 2 and 5 years\n", "metric = phenotype_purity(phenotypes_train=phenotypes, outcomes_train=y_tr, \n", " phenotypes_test=None, outcomes_test=None,\n", - " strategy='instantaneous', horizon=[365, 730, 1825], \n", + " strategy='instantaneous', horizons=[365, 730, 1825], \n", " bootstrap=None)\n", "\n", "print(f'Phenotyping purity for event horizon of 1 year: {metric[0]} | 2 years: {metric[1]} | 5 years: {metric[2]}')" @@ -1000,7 +1001,7 @@ }, { "cell_type": "markdown", - "id": "850e812b", + "id": "ead8730c", "metadata": {}, "source": [ "It can be observed the phenotyping purity is lower for supervised phenotyping compared to unsupervised phenotyping. This indicates that the supervised phenotyper is able extract phenogroups with higher discriminative power in terms of the observed survival rates.\n", @@ -1010,16 +1011,16 @@ }, { "cell_type": "markdown", - "id": "69bcb6e4", + "id": "ea3046ef", "metadata": {}, "source": [ - "\n", + "\n", "## 5. Counterfactual Phenotyping" ] }, { "cell_type": "markdown", - "id": "ba772dea", + "id": "ee6bdffd", "metadata": {}, "source": [ "*For examples of counterfactual phenotyping with Deep Cox Mixtures with Heterogeneous Effects (CMHE) [1], please refer to the following paper and example jupyter notebook*:\n", @@ -1032,7 +1033,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4756c39", + "id": "81df3294", "metadata": {}, "outputs": [], "source": [] diff --git a/examples/Survival Regression with Auton-Survival.ipynb b/examples/Survival Regression with Auton-Survival.ipynb index b436c85..cdd2897 100644 --- a/examples/Survival Regression with Auton-Survival.ipynb +++ b/examples/Survival Regression with Auton-Survival.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "1fa2753c", + "id": "d7efe2f9", "metadata": {}, "source": [ "# Survival Regression with `estimators.SurvivalModel`\n", @@ -47,7 +47,7 @@ }, { "cell_type": "markdown", - "id": "b8b5cbe0", + "id": "85369414", "metadata": {}, "source": [ "\n", @@ -57,7 +57,7 @@ }, { "cell_type": "markdown", - "id": "dfa5ae91", + "id": "cd41f8b9", "metadata": {}, "source": [ "The `SurvivalModels` class offers a steamlined approach to train two `auton-survival` models and three baseline survival models for right-censored time-to-event data. The fit method requires the same inputs across all five models, however, model parameter types vary and must be defined and tuned for the specified model.\n", @@ -103,7 +103,7 @@ }, { "cell_type": "markdown", - "id": "7762a8bf", + "id": "c2099ee9", "metadata": {}, "source": [ "\n", @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "453efa17", + "id": "87369d9d", "metadata": {}, "source": [ "*For the original datasource, please refer to the following [website](https://biostat.app.vumc.org/wiki/Main/SupportDesc).*\n", @@ -123,8 +123,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "78ceb1eb", + "execution_count": 2, + "id": "26e6b532", "metadata": {}, "outputs": [], "source": [ @@ -137,8 +137,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "edfdc7cb", + "execution_count": 3, + "id": "532b3549", "metadata": {}, "outputs": [ { @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "4eb21858", + "id": "f2e9776e", "metadata": {}, "source": [ "\n", @@ -435,8 +435,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "0d108642", + "execution_count": 4, + "id": "de6456e1", "metadata": {}, "outputs": [ { @@ -464,8 +464,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "b69dc7da", + "execution_count": 5, + "id": "38684f59", "metadata": {}, "outputs": [], "source": [ @@ -482,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "2e937840", + "id": "00683f5e", "metadata": {}, "source": [ "\n", @@ -491,7 +491,7 @@ }, { "cell_type": "markdown", - "id": "d0713974", + "id": "34a6c3f7", "metadata": {}, "source": [ "CPH [2] model assumes that individuals across the population have constant proportional hazards overtime. In this model, the estimator of the survival function conditional on $X, S(·|X) , P(T > t|X)$, is assumed to have constant proportional hazard. Thus, the relative proportional hazard between individuals is constant across time.\n", @@ -503,7 +503,7 @@ }, { "cell_type": "markdown", - "id": "44ff4aad", + "id": "be77b767", "metadata": {}, "source": [ "\n", @@ -512,8 +512,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "32552d7d", + "execution_count": 7, + "id": "41c56557", "metadata": {}, "outputs": [], "source": [ @@ -525,7 +525,7 @@ "param_grid = {'l2' : [1e-3, 1e-4]}\n", "params = ParameterGrid(param_grid)\n", "\n", - "# Define the times for tuning the model hyperparameters and for evaluating the model\n", + "# Define the times for model evaluation\n", "times = np.quantile(y_tr['time'][y_tr['event']==1], np.linspace(0.1, 1, 10)).tolist()\n", "\n", "# Perform hyperparameter tuning \n", @@ -538,7 +538,7 @@ "\n", " # Obtain survival probabilities for validation set and compute the Integrated Brier Score \n", " predictions_val = model.predict_survival(x_val, times)\n", - " metric_val = survival_regression_metric('ibs', y_tr, y_val, predictions_val, times)\n", + " metric_val = survival_regression_metric('ibs', y_tr, predictions_val, times, y_val)\n", " models.append([metric_val, model])\n", " \n", "# Select the best model based on the mean metric value computed for the validation set\n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "6b697d6c", + "id": "630de5a1", "metadata": {}, "source": [ "\n", @@ -558,15 +558,15 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "cfde11cc", + "execution_count": 21, + "id": "62c21074", "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, "metadata": {}, @@ -581,16 +581,16 @@ "\n", "# Compute the Brier Score and time-dependent concordance index for the test set to assess model performance\n", "results = dict()\n", - "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", - "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", + "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", + "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", "plot_performance_metrics(results, times)" ] }, { "cell_type": "markdown", - "id": "8dc206f6", + "id": "09ad9584", "metadata": {}, "source": [ "\n", @@ -599,7 +599,7 @@ }, { "cell_type": "markdown", - "id": "a05dab92", + "id": "83d13bf2", "metadata": {}, "source": [ "DCPH [2], [3] is an extension to the CPH model. DCPH involves modeling the proportional hazard ratios over the individuals with Deep Neural Networks allowing the ability to learn non linear hazard ratios.\n", @@ -613,7 +613,7 @@ }, { "cell_type": "markdown", - "id": "1f5123e4", + "id": "ee882560", "metadata": {}, "source": [ "\n", @@ -623,7 +623,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "8db562b5", + "id": "0eb815b5", "metadata": {}, "outputs": [ { @@ -667,7 +667,7 @@ "\n", " # Obtain survival probabilities for validation set and compute the Integrated Brier Score \n", " predictions_val = model.predict_survival(x_val, times)\n", - " metric_val = survival_regression_metric('ibs', y_tr, y_val, predictions_val, times)\n", + " metric_val = survival_regression_metric('ibs', y_tr, predictions_val, times, y_val)\n", " models.append([metric_val, model])\n", " \n", "# Select the best model based on the mean metric value computed for the validation set\n", @@ -678,7 +678,7 @@ }, { "cell_type": "markdown", - "id": "65df1701", + "id": "8983b804", "metadata": {}, "source": [ "\n", @@ -687,7 +687,7 @@ }, { "cell_type": "markdown", - "id": "7332567b", + "id": "28aad20c", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -696,7 +696,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "940a4b53", + "id": "dd8caa11", "metadata": {}, "outputs": [ { @@ -718,10 +718,10 @@ "\n", "# Compute the Brier Score and time-dependent concordance index for the test set to assess model performance\n", "results = dict()\n", - "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", - "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", + "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", + "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", "plot_performance_metrics(results, times)" ] }, @@ -732,7 +732,7 @@ } }, "cell_type": "markdown", - "id": "0adde949", + "id": "b875cc6c", "metadata": {}, "source": [ "\n", @@ -753,7 +753,7 @@ }, { "cell_type": "markdown", - "id": "00577b19", + "id": "68ee298a", "metadata": {}, "source": [ "\n", @@ -763,7 +763,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "c3f56926", + "id": "2f1e347f", "metadata": {}, "outputs": [ { @@ -823,7 +823,7 @@ "\n", " # Obtain survival probabilities for validation set and compute the Integrated Brier Score \n", " predictions_val = model.predict_survival(x_val, times)\n", - " metric_val = survival_regression_metric('ibs', y_tr, y_val, predictions_val, times)\n", + " metric_val = survival_regression_metric('ibs', y_tr, predictions_val, times, y_val)\n", " models.append([metric_val, model])\n", " \n", "# Select the best model based on the mean metric value computed for the validation set\n", @@ -834,7 +834,7 @@ }, { "cell_type": "markdown", - "id": "d07820c7", + "id": "81c45ead", "metadata": {}, "source": [ "\n", @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "6a1a8b6d", + "id": "ad13ad21", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -852,7 +852,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "53a5177f", + "id": "d0d5b400", "metadata": {}, "outputs": [ { @@ -874,16 +874,16 @@ "\n", "# Compute the Brier Score and time-dependent concordance index for the test set to assess model performance\n", "results = dict()\n", - "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", - "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", + "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", + "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", "plot_performance_metrics(results, times)" ] }, { "cell_type": "markdown", - "id": "6ca3f1aa", + "id": "8185b74e", "metadata": {}, "source": [ "\n", @@ -897,7 +897,7 @@ } }, "cell_type": "markdown", - "id": "10a8218a", + "id": "82daa2f4", "metadata": {}, "source": [ "DCM [2] generalizes the proportional hazards assumption via a mixture model, by assuming that there are latent groups and within each, the proportional hazards assumption holds. DCM allows the hazard ratio in each latent group, as well as the latent group membership, to be flexibly modeled by a deep neural network.\n", @@ -915,7 +915,7 @@ }, { "cell_type": "markdown", - "id": "c499b948", + "id": "03951373", "metadata": {}, "source": [ "\n", @@ -925,7 +925,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "841bf58c", + "id": "065c68c6", "metadata": {}, "outputs": [ { @@ -1013,7 +1013,7 @@ "\n", " # Obtain survival probabilities for validation set and compute the Integrated Brier Score \n", " predictions_val = model.predict_survival(x_val, times)\n", - " metric_val = survival_regression_metric('ibs', y_tr, y_val, predictions_val, times)\n", + " metric_val = survival_regression_metric('ibs', y_tr, predictions_val, times, y_val)\n", " models.append([metric_val, model])\n", " \n", "# Select the best model based on the mean metric value computed for the validation set\n", @@ -1024,7 +1024,7 @@ }, { "cell_type": "markdown", - "id": "9bf28113", + "id": "4794d4f9", "metadata": {}, "source": [ "\n", @@ -1033,7 +1033,7 @@ }, { "cell_type": "markdown", - "id": "0ab4b8cf", + "id": "b6248233", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -1042,7 +1042,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "a2acb333", + "id": "2ee57c19", "metadata": {}, "outputs": [ { @@ -1064,16 +1064,16 @@ "\n", "# Compute the Brier Score and time-dependent concordance index for the test set to assess model performance\n", "results = dict()\n", - "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", - "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", + "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", + "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", "plot_performance_metrics(results, times)" ] }, { "cell_type": "markdown", - "id": "bac6b75f", + "id": "006724a1", "metadata": {}, "source": [ "\n", @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "7f054819", + "id": "c26416ad", "metadata": {}, "source": [ "\n", @@ -1100,7 +1100,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "7f9358ae", + "id": "8e039745", "metadata": {}, "outputs": [], "source": [ @@ -1129,7 +1129,7 @@ "\n", " # Obtain survival probabilities for validation set and compute the Integrated Brier Score \n", " predictions_val = model.predict_survival(x_val, times)\n", - " metric_val = survival_regression_metric('ibs', y_tr, y_val, predictions_val, times)\n", + " metric_val = survival_regression_metric('ibs', y_tr, predictions_val, times, y_vals)\n", " models.append([metric_val, model])\n", " \n", "# Select the best model based on the mean metric value computed for the validation set\n", @@ -1140,7 +1140,7 @@ }, { "cell_type": "markdown", - "id": "1bd45b34", + "id": "b0680104", "metadata": {}, "source": [ "\n", @@ -1149,7 +1149,7 @@ }, { "cell_type": "markdown", - "id": "4b44cd85", + "id": "5135b083", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -1158,7 +1158,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "e7710fb4", + "id": "27fa6905", "metadata": {}, "outputs": [ { @@ -1180,17 +1180,17 @@ "\n", "# Compute the Brier Score and time-dependent concordance index for the test set to assess model performance\n", "results = dict()\n", - "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", - "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, outcomes_test=y_te, \n", - " predictions=predictions_te, times=times)\n", + "results['Brier Score'] = survival_regression_metric('brs', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", + "results['Concordance Index'] = survival_regression_metric('ctd', outcomes_train=y_tr, predictions=predictions_te, \n", + " times=times, outcomes_test=y_te)\n", "plot_performance_metrics(results, times)" ] }, { "cell_type": "code", "execution_count": null, - "id": "e71903bf", + "id": "7cfaac62", "metadata": {}, "outputs": [], "source": []