Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wp2 #88

Merged
merged 4 commits into from
Jun 28, 2022
Merged

Wp2 #88

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ Phenotyping and Knowledge Discovery

`auton_survival.phenotyping` allows extraction of latent clusters or subgroups
of patients that demonstrate similar outcomes. In the context of this package,
we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows:
we refer to this task as **phenotyping**. `auton_survival.phenotyping` provides
the following phenotyping utilities:

- **Intersectional Phenotyping**: Recovers groups, or phenotypes, of individuals
over exhaustive combinations of user-specified categorical and numerical features.
Expand Down Expand Up @@ -226,6 +227,8 @@ response to a specific intervention. Relies on the specially designed
`auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model.

```python
from auton_survival.models.cmhe DeepCoxMixturesHeterogenousEffects

# Instantiate the CMHE model
model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers)

Expand All @@ -248,6 +251,13 @@ model = SurvivalVirtualTwins(horizon=365)
phenotypes = model.fit_predict(features, outcomes.time, outcomes.event, interventions)
```

DAG representations of the unsupervised, supervised, and counterfactual probabilitic
phenotypers in auton-survival are shown in the below figure. *X* represents the
covariates, *T* the time-to-event and *Z* is the phenotype to be inferred.

<p align="center"><img src="https://ndownloader.figshare.com/files/36056648" width=60%></p>


<a id="evaluation"></a>

Evaluation and Reporting
Expand Down Expand Up @@ -277,9 +287,14 @@ score = survival_regression_metric(metric='brs', outcomes_train,
```

- **Treatment Effect**: Used to compare treatment arms by computing the difference in the following metrics for treatment and control groups:
- **Time at Risk** (TaR)
- **Risk at Time**
- **Restricted Mean Survival Time** (RMST)
- **Time at Risk (TaR)** (left)
- **Risk at Time** (center)
- **Restricted Mean Survival Time (RMST)** (right)

<p align="center">
<img src="https://ndownloader.figshare.com/files/36056507" width=30%>
<img src="https://ndownloader.figshare.com/files/36056534" width=30%>
<img src="https://ndownloader.figshare.com/files/36056546" width=30%></p>

```python
from auton_survival.metrics import survival_diff_metric
Expand Down
41 changes: 21 additions & 20 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@

class SurvivalRegressionCV:
"""Universal interface to train Survival Analysis models in a cross-
validation or nested cross-validation fashion.
validation fashion.

Each of the model is trained in a CV fashion over the user specified
hyperparameter grid. The best model(s) in terms of user-specified metric
is selected.
The model is trained in a CV fashion over the user-specified
hyperparameter grid. Model hyperparameters are selected based on the
user-specified metric.

Parameters
-----------
Expand All @@ -65,9 +65,6 @@ class SurvivalRegressionCV:
num_folds : int, default=5
The number of folds.
Ignored if folds is specified.
num_nested_folds : int, default=None
The number of folds to use for nested cross-validation.
If None, then regular (unnested) CV is performed.
random_seed : int, default=0
Controls reproducibility of results.
hyperparam_grid : dict
Expand All @@ -92,12 +89,11 @@ class SurvivalRegressionCV:
"""

def __init__(self, model='dcph', folds=None, num_folds=5,
num_nested_folds=None, random_seed=0, hyperparam_grid={}):
random_seed=0, hyperparam_grid={}):

self.model = model
self.folds = folds
self.num_folds = num_folds
self.num_nested_folds = num_nested_folds
self.random_seed = random_seed
self.hyperparam_grid = list(ParameterGrid(hyperparam_grid))

Expand All @@ -116,7 +112,7 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
horizon : int or float or list
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Expand All @@ -125,12 +121,12 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index

Returns
-----------
Trained survival regression model(s).

"""


assert horizons is not None, "Horizons must be specified."
if isinstance(horizons, (int, float)):
Expand All @@ -156,10 +152,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
assert max(horizons) < time_max, "Horizons exceeds max time range."
assert min(horizons) > time_min, "Horizons exceeds min time range."

# if self.horizon is None:
# assert (self.metric == 'ibs'), "Horizon must be specified for the selected metric"
# self.horizon = time_max

hyper_param_scores = []
for i, hyper_param in enumerate(self.hyperparam_grid):
print("At hyper-param", hyper_param)
Expand Down Expand Up @@ -189,7 +181,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
**best_hyper_param).fit(features, outcomes)
return model


def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed):

"""Get cross-validation fold value for each sample.
Expand Down Expand Up @@ -288,7 +279,6 @@ class CounterfactualSurvivalRegressionCV:
model : str
A string that determines the choice of the surival analysis model.
Survival model choices include:

- 'dsm' : Deep Survival Machines [3] model
- 'dcph' : Deep Cox Proportional Hazards [2] model
- 'dcm' : Deep Cox Mixtures [4] model
Expand Down Expand Up @@ -341,10 +331,10 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)

def fit(self, features, outcomes, interventions, metric):
def fit(self, features, outcomes, interventions, horizons, metric):

r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
r"""Fits the Survival Regression Model to the data in a cross-
validation fashion.

Parameters
-----------
Expand All @@ -359,6 +349,15 @@ def fit(self, features, outcomes, interventions, metric):
interventions: pandas.Series
A pandas series containing the treatment status of each subject.
\( a_i = 1 \) if the subject is `treated`, else is considered control.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index

Returns
-----------
Expand All @@ -369,9 +368,11 @@ def fit(self, features, outcomes, interventions, metric):

treated_model = self.treated_experiment.fit(features.loc[interventions==1],
outcomes.loc[interventions==1],
horizons=horizons,
metric=metric)
control_model = self.control_experiment.fit(features.loc[interventions!=1],
outcomes.loc[interventions!=1],
horizons=horizons,
metric=metric)

return CounterfactualSurvivalModel(treated_model, control_model)
15 changes: 7 additions & 8 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,7 @@ def __init__(self,

self.random_seed = random_seed

def fit(self, features, outcomes, interventions, metric,
horizon):
def fit(self, features, outcomes, interventions, horizons, metric):

"""Fit a counterfactual model and regress the difference of the estimated
counterfactual Restricted Mean Survival Time using a Random Forest regressor.
Expand All @@ -495,16 +494,15 @@ def fit(self, features, outcomes, interventions, metric,
interventions : np.array
Boolean numpy array of treatment indicators. True means individual
was assigned a specific treatment.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
horizon : np.float
The event horizon at which to compute the counterfacutal RMST for
regression.

Returns
-----------
Expand All @@ -515,12 +513,13 @@ def fit(self, features, outcomes, interventions, metric,
cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method,
hyperparam_grid=self.cf_hyperparams)

self.cf_model = cf_model.fit(features, outcomes, interventions, metric)
self.cf_model = cf_model.fit(features, outcomes, interventions,
horizons, metric)

times = np.unique(outcomes.time.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features,
times.tolist())

horizon = max(horizons)
ite_estimates = cf_predictions[1] - cf_predictions[0]
ite_estimates = [estimate[times < horizon] for estimate in ite_estimates]
times = times[times < horizon]
Expand Down Expand Up @@ -558,7 +557,7 @@ def predict_proba(self, features):

"""

phenotype_preds= self.pheno_model.predict(features)
phenotype_preds = self.pheno_model.predict(features)
preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min())
preds_surv_less = 1 - preds_surv_greater
preds = np.array([[preds_surv_less[i], preds_surv_greater[i]]
Expand Down
21 changes: 3 additions & 18 deletions examples/CV Survival Regression on SUPPORT Dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"horizons = [0.25, 0.5, 0.75]\n",
"times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()"
"times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()"
]
},
{
Expand All @@ -67,7 +66,7 @@
" 'layers' : [[100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n",
"model = experiment.fit(x, outcomes, metric='ctd')"
"model = experiment.fit(x, outcomes, times, metric='brs')"
]
},
{
Expand All @@ -80,13 +79,6 @@
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -122,7 +114,7 @@
"for fold in set(experiment.folds):\n",
" print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n",
" out_survival[experiment.folds==fold], \n",
" times=times))\n"
" times=times))"
]
},
{
Expand All @@ -136,13 +128,6 @@
" print(time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down