diff --git a/README.md b/README.md index ec45025..eb07e2f 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,8 @@ Phenotyping and Knowledge Discovery `auton_survival.phenotyping` allows extraction of latent clusters or subgroups of patients that demonstrate similar outcomes. In the context of this package, -we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows: +we refer to this task as **phenotyping**. `auton_survival.phenotyping` provides +the following phenotyping utilities: - **Intersectional Phenotyping**: Recovers groups, or phenotypes, of individuals over exhaustive combinations of user-specified categorical and numerical features. @@ -226,6 +227,8 @@ response to a specific intervention. Relies on the specially designed `auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model. ```python +from auton_survival.models.cmhe DeepCoxMixturesHeterogenousEffects + # Instantiate the CMHE model model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers) @@ -248,6 +251,13 @@ model = SurvivalVirtualTwins(horizon=365) phenotypes = model.fit_predict(features, outcomes.time, outcomes.event, interventions) ``` +DAG representations of the unsupervised, supervised, and counterfactual probabilitic +phenotypers in auton-survival are shown in the below figure. *X* represents the +covariates, *T* the time-to-event and *Z* is the phenotype to be inferred. + +

+ + Evaluation and Reporting @@ -277,9 +287,14 @@ score = survival_regression_metric(metric='brs', outcomes_train, ``` - **Treatment Effect**: Used to compare treatment arms by computing the difference in the following metrics for treatment and control groups: - - **Time at Risk** (TaR) - - **Risk at Time** - - **Restricted Mean Survival Time** (RMST) + - **Time at Risk (TaR)** (left) + - **Risk at Time** (center) + - **Restricted Mean Survival Time (RMST)** (right) + +

+ + +

```python from auton_survival.metrics import survival_diff_metric diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py index d17b17b..d73d757 100644 --- a/auton_survival/experiments.py +++ b/auton_survival/experiments.py @@ -39,11 +39,11 @@ class SurvivalRegressionCV: """Universal interface to train Survival Analysis models in a cross- - validation or nested cross-validation fashion. + validation fashion. - Each of the model is trained in a CV fashion over the user specified - hyperparameter grid. The best model(s) in terms of user-specified metric - is selected. + The model is trained in a CV fashion over the user-specified + hyperparameter grid. Model hyperparameters are selected based on the + user-specified metric. Parameters ----------- @@ -65,9 +65,6 @@ class SurvivalRegressionCV: num_folds : int, default=5 The number of folds. Ignored if folds is specified. - num_nested_folds : int, default=None - The number of folds to use for nested cross-validation. - If None, then regular (unnested) CV is performed. random_seed : int, default=0 Controls reproducibility of results. hyperparam_grid : dict @@ -92,12 +89,11 @@ class SurvivalRegressionCV: """ def __init__(self, model='dcph', folds=None, num_folds=5, - num_nested_folds=None, random_seed=0, hyperparam_grid={}): + random_seed=0, hyperparam_grid={}): self.model = model self.folds = folds self.num_folds = num_folds - self.num_nested_folds = num_nested_folds self.random_seed = random_seed self.hyperparam_grid = list(ParameterGrid(hyperparam_grid)) @@ -116,7 +112,7 @@ def fit(self, features, outcomes, horizons, metric='ibs'): outcomes : pd.DataFrame A pandas dataframe with columns 'time' and 'event' that contain the survival time and censoring status \( \delta_i = 1 \), respectively. - horizon : int or float or list + horizons : int or float or list Event-horizons at which to evaluate model performance. metric : str, default='ibs' Metric used to evaluate model performance and tune hyperparameters. @@ -125,12 +121,12 @@ def fit(self, features, outcomes, horizons, metric='ibs'): - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index + Returns ----------- Trained survival regression model(s). """ - assert horizons is not None, "Horizons must be specified." if isinstance(horizons, (int, float)): @@ -156,10 +152,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'): assert max(horizons) < time_max, "Horizons exceeds max time range." assert min(horizons) > time_min, "Horizons exceeds min time range." - # if self.horizon is None: - # assert (self.metric == 'ibs'), "Horizon must be specified for the selected metric" - # self.horizon = time_max - hyper_param_scores = [] for i, hyper_param in enumerate(self.hyperparam_grid): print("At hyper-param", hyper_param) @@ -189,7 +181,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'): **best_hyper_param).fit(features, outcomes) return model - def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed): """Get cross-validation fold value for each sample. @@ -288,7 +279,6 @@ class CounterfactualSurvivalRegressionCV: model : str A string that determines the choice of the surival analysis model. Survival model choices include: - - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model @@ -341,10 +331,10 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}): random_seed=random_seed, hyperparam_grid=hyperparam_grid) - def fit(self, features, outcomes, interventions, metric): + def fit(self, features, outcomes, interventions, horizons, metric): - r"""Fits the Survival Regression Model to the data in a Cross - Validation fashion. + r"""Fits the Survival Regression Model to the data in a cross- + validation fashion. Parameters ----------- @@ -359,6 +349,15 @@ def fit(self, features, outcomes, interventions, metric): interventions: pandas.Series A pandas series containing the treatment status of each subject. \( a_i = 1 \) if the subject is `treated`, else is considered control. + horizons : int or float or list + Event-horizons at which to evaluate model performance. + metric : str, default='ibs' + Metric used to evaluate model performance and tune hyperparameters. + Options include: + - 'auc': Dynamic area under the ROC curve + - 'brs' : Brier Score + - 'ibs' : Integrated Brier Score + - 'ctd' : Concordance Index Returns ----------- @@ -369,9 +368,11 @@ def fit(self, features, outcomes, interventions, metric): treated_model = self.treated_experiment.fit(features.loc[interventions==1], outcomes.loc[interventions==1], + horizons=horizons, metric=metric) control_model = self.control_experiment.fit(features.loc[interventions!=1], outcomes.loc[interventions!=1], + horizons=horizons, metric=metric) return CounterfactualSurvivalModel(treated_model, control_model) diff --git a/auton_survival/phenotyping.py b/auton_survival/phenotyping.py index faf8ca5..f5fbaf1 100644 --- a/auton_survival/phenotyping.py +++ b/auton_survival/phenotyping.py @@ -478,8 +478,7 @@ def __init__(self, self.random_seed = random_seed - def fit(self, features, outcomes, interventions, metric, - horizon): + def fit(self, features, outcomes, interventions, horizons, metric): """Fit a counterfactual model and regress the difference of the estimated counterfactual Restricted Mean Survival Time using a Random Forest regressor. @@ -495,6 +494,8 @@ def fit(self, features, outcomes, interventions, metric, interventions : np.array Boolean numpy array of treatment indicators. True means individual was assigned a specific treatment. + horizons : int or float or list + Event-horizons at which to evaluate model performance. metric : str, default='ibs' Metric used to evaluate model performance and tune hyperparameters. Options include: @@ -502,9 +503,6 @@ def fit(self, features, outcomes, interventions, metric, - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index - horizon : np.float - The event horizon at which to compute the counterfacutal RMST for - regression. Returns ----------- @@ -515,12 +513,13 @@ def fit(self, features, outcomes, interventions, metric, cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method, hyperparam_grid=self.cf_hyperparams) - self.cf_model = cf_model.fit(features, outcomes, interventions, metric) + self.cf_model = cf_model.fit(features, outcomes, interventions, + horizons, metric) times = np.unique(outcomes.time.values) cf_predictions = self.cf_model.predict_counterfactual_survival(features, times.tolist()) - + horizon = max(horizons) ite_estimates = cf_predictions[1] - cf_predictions[0] ite_estimates = [estimate[times < horizon] for estimate in ite_estimates] times = times[times < horizon] @@ -558,7 +557,7 @@ def predict_proba(self, features): """ - phenotype_preds= self.pheno_model.predict(features) + phenotype_preds = self.pheno_model.predict(features) preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min()) preds_surv_less = 1 - preds_surv_greater preds = np.array([[preds_surv_less[i], preds_surv_greater[i]] diff --git a/examples/CV Survival Regression on SUPPORT Dataset.ipynb b/examples/CV Survival Regression on SUPPORT Dataset.ipynb index e1410c9..f6564d5 100644 --- a/examples/CV Survival Regression on SUPPORT Dataset.ipynb +++ b/examples/CV Survival Regression on SUPPORT Dataset.ipynb @@ -49,8 +49,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "horizons = [0.25, 0.5, 0.75]\n", - "times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()" + "times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()" ] }, { @@ -67,7 +66,7 @@ " 'layers' : [[100]]}\n", "\n", "experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n", - "model = experiment.fit(x, outcomes, metric='ctd')" + "model = experiment.fit(x, outcomes, times, metric='brs')" ] }, { @@ -80,13 +79,6 @@ "model" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -122,7 +114,7 @@ "for fold in set(experiment.folds):\n", " print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n", " out_survival[experiment.folds==fold], \n", - " times=times))\n" + " times=times))" ] }, { @@ -136,13 +128,6 @@ " print(time)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null,