diff --git a/README.md b/README.md
index ec45025..eb07e2f 100644
--- a/README.md
+++ b/README.md
@@ -166,7 +166,8 @@ Phenotyping and Knowledge Discovery
`auton_survival.phenotyping` allows extraction of latent clusters or subgroups
of patients that demonstrate similar outcomes. In the context of this package,
-we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows:
+we refer to this task as **phenotyping**. `auton_survival.phenotyping` provides
+the following phenotyping utilities:
- **Intersectional Phenotyping**: Recovers groups, or phenotypes, of individuals
over exhaustive combinations of user-specified categorical and numerical features.
@@ -226,6 +227,8 @@ response to a specific intervention. Relies on the specially designed
`auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model.
```python
+from auton_survival.models.cmhe DeepCoxMixturesHeterogenousEffects
+
# Instantiate the CMHE model
model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers)
@@ -248,6 +251,13 @@ model = SurvivalVirtualTwins(horizon=365)
phenotypes = model.fit_predict(features, outcomes.time, outcomes.event, interventions)
```
+DAG representations of the unsupervised, supervised, and counterfactual probabilitic
+phenotypers in auton-survival are shown in the below figure. *X* represents the
+covariates, *T* the time-to-event and *Z* is the phenotype to be inferred.
+
+

+
+
Evaluation and Reporting
@@ -277,9 +287,14 @@ score = survival_regression_metric(metric='brs', outcomes_train,
```
- **Treatment Effect**: Used to compare treatment arms by computing the difference in the following metrics for treatment and control groups:
- - **Time at Risk** (TaR)
- - **Risk at Time**
- - **Restricted Mean Survival Time** (RMST)
+ - **Time at Risk (TaR)** (left)
+ - **Risk at Time** (center)
+ - **Restricted Mean Survival Time (RMST)** (right)
+
+
+
+
+
```python
from auton_survival.metrics import survival_diff_metric
diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py
index d17b17b..d73d757 100644
--- a/auton_survival/experiments.py
+++ b/auton_survival/experiments.py
@@ -39,11 +39,11 @@
class SurvivalRegressionCV:
"""Universal interface to train Survival Analysis models in a cross-
- validation or nested cross-validation fashion.
+ validation fashion.
- Each of the model is trained in a CV fashion over the user specified
- hyperparameter grid. The best model(s) in terms of user-specified metric
- is selected.
+ The model is trained in a CV fashion over the user-specified
+ hyperparameter grid. Model hyperparameters are selected based on the
+ user-specified metric.
Parameters
-----------
@@ -65,9 +65,6 @@ class SurvivalRegressionCV:
num_folds : int, default=5
The number of folds.
Ignored if folds is specified.
- num_nested_folds : int, default=None
- The number of folds to use for nested cross-validation.
- If None, then regular (unnested) CV is performed.
random_seed : int, default=0
Controls reproducibility of results.
hyperparam_grid : dict
@@ -92,12 +89,11 @@ class SurvivalRegressionCV:
"""
def __init__(self, model='dcph', folds=None, num_folds=5,
- num_nested_folds=None, random_seed=0, hyperparam_grid={}):
+ random_seed=0, hyperparam_grid={}):
self.model = model
self.folds = folds
self.num_folds = num_folds
- self.num_nested_folds = num_nested_folds
self.random_seed = random_seed
self.hyperparam_grid = list(ParameterGrid(hyperparam_grid))
@@ -116,7 +112,7 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
- horizon : int or float or list
+ horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
@@ -125,12 +121,12 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
+
Returns
-----------
Trained survival regression model(s).
"""
-
assert horizons is not None, "Horizons must be specified."
if isinstance(horizons, (int, float)):
@@ -156,10 +152,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
assert max(horizons) < time_max, "Horizons exceeds max time range."
assert min(horizons) > time_min, "Horizons exceeds min time range."
- # if self.horizon is None:
- # assert (self.metric == 'ibs'), "Horizon must be specified for the selected metric"
- # self.horizon = time_max
-
hyper_param_scores = []
for i, hyper_param in enumerate(self.hyperparam_grid):
print("At hyper-param", hyper_param)
@@ -189,7 +181,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
**best_hyper_param).fit(features, outcomes)
return model
-
def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed):
"""Get cross-validation fold value for each sample.
@@ -288,7 +279,6 @@ class CounterfactualSurvivalRegressionCV:
model : str
A string that determines the choice of the surival analysis model.
Survival model choices include:
-
- 'dsm' : Deep Survival Machines [3] model
- 'dcph' : Deep Cox Proportional Hazards [2] model
- 'dcm' : Deep Cox Mixtures [4] model
@@ -341,10 +331,10 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)
- def fit(self, features, outcomes, interventions, metric):
+ def fit(self, features, outcomes, interventions, horizons, metric):
- r"""Fits the Survival Regression Model to the data in a Cross
- Validation fashion.
+ r"""Fits the Survival Regression Model to the data in a cross-
+ validation fashion.
Parameters
-----------
@@ -359,6 +349,15 @@ def fit(self, features, outcomes, interventions, metric):
interventions: pandas.Series
A pandas series containing the treatment status of each subject.
\( a_i = 1 \) if the subject is `treated`, else is considered control.
+ horizons : int or float or list
+ Event-horizons at which to evaluate model performance.
+ metric : str, default='ibs'
+ Metric used to evaluate model performance and tune hyperparameters.
+ Options include:
+ - 'auc': Dynamic area under the ROC curve
+ - 'brs' : Brier Score
+ - 'ibs' : Integrated Brier Score
+ - 'ctd' : Concordance Index
Returns
-----------
@@ -369,9 +368,11 @@ def fit(self, features, outcomes, interventions, metric):
treated_model = self.treated_experiment.fit(features.loc[interventions==1],
outcomes.loc[interventions==1],
+ horizons=horizons,
metric=metric)
control_model = self.control_experiment.fit(features.loc[interventions!=1],
outcomes.loc[interventions!=1],
+ horizons=horizons,
metric=metric)
return CounterfactualSurvivalModel(treated_model, control_model)
diff --git a/auton_survival/phenotyping.py b/auton_survival/phenotyping.py
index faf8ca5..f5fbaf1 100644
--- a/auton_survival/phenotyping.py
+++ b/auton_survival/phenotyping.py
@@ -478,8 +478,7 @@ def __init__(self,
self.random_seed = random_seed
- def fit(self, features, outcomes, interventions, metric,
- horizon):
+ def fit(self, features, outcomes, interventions, horizons, metric):
"""Fit a counterfactual model and regress the difference of the estimated
counterfactual Restricted Mean Survival Time using a Random Forest regressor.
@@ -495,6 +494,8 @@ def fit(self, features, outcomes, interventions, metric,
interventions : np.array
Boolean numpy array of treatment indicators. True means individual
was assigned a specific treatment.
+ horizons : int or float or list
+ Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
@@ -502,9 +503,6 @@ def fit(self, features, outcomes, interventions, metric,
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
- horizon : np.float
- The event horizon at which to compute the counterfacutal RMST for
- regression.
Returns
-----------
@@ -515,12 +513,13 @@ def fit(self, features, outcomes, interventions, metric,
cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method,
hyperparam_grid=self.cf_hyperparams)
- self.cf_model = cf_model.fit(features, outcomes, interventions, metric)
+ self.cf_model = cf_model.fit(features, outcomes, interventions,
+ horizons, metric)
times = np.unique(outcomes.time.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features,
times.tolist())
-
+ horizon = max(horizons)
ite_estimates = cf_predictions[1] - cf_predictions[0]
ite_estimates = [estimate[times < horizon] for estimate in ite_estimates]
times = times[times < horizon]
@@ -558,7 +557,7 @@ def predict_proba(self, features):
"""
- phenotype_preds= self.pheno_model.predict(features)
+ phenotype_preds = self.pheno_model.predict(features)
preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min())
preds_surv_less = 1 - preds_surv_greater
preds = np.array([[preds_surv_less[i], preds_surv_greater[i]]
diff --git a/examples/CV Survival Regression on SUPPORT Dataset.ipynb b/examples/CV Survival Regression on SUPPORT Dataset.ipynb
index e1410c9..f6564d5 100644
--- a/examples/CV Survival Regression on SUPPORT Dataset.ipynb
+++ b/examples/CV Survival Regression on SUPPORT Dataset.ipynb
@@ -49,8 +49,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
- "horizons = [0.25, 0.5, 0.75]\n",
- "times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()"
+ "times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()"
]
},
{
@@ -67,7 +66,7 @@
" 'layers' : [[100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n",
- "model = experiment.fit(x, outcomes, metric='ctd')"
+ "model = experiment.fit(x, outcomes, times, metric='brs')"
]
},
{
@@ -80,13 +79,6 @@
"model"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "code",
"execution_count": null,
@@ -122,7 +114,7 @@
"for fold in set(experiment.folds):\n",
" print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n",
" out_survival[experiment.folds==fold], \n",
- " times=times))\n"
+ " times=times))"
]
},
{
@@ -136,13 +128,6 @@
" print(time)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "code",
"execution_count": null,