Skip to content

Commit

Permalink
reviewed updates
Browse files Browse the repository at this point in the history
  • Loading branch information
PotosnakW committed May 17, 2022
1 parent f142a03 commit 9513068
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 366 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ outcomes, features = load_dataset(dataset='SUPPORT')
cat_feats = ['sex', 'income', 'race']
num_feats = ['age', 'resp', 'glucose']

from auton_survival.experiments import SurvivalCVRegressionCV
from auton_survival.experiments import SurvivalRegressionCV
# Instantiate an auton_survival Experiment
experiment = SurvivalCVRegressionCV(model='cph', num_folds=5,
experiment = SurvivalRegressionCV(model='cph', num_folds=5,
hyperparam_grid=hyperparam_grid)

# Fit the `experiment` object with the specified Cox model.
Expand Down
90 changes: 46 additions & 44 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_valid_idx(n, size, random_seed):

return vidx

def _fit_dcm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
def _fit_dcm(features, outcomes, val_data, random_seed, **hyperparams):

r"""Fit the Deep Cox Mixtures (DCM) [1] model to a given dataset.
Expand All @@ -76,13 +76,9 @@ def _fit_dcm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
and columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event'.
vsize : float, default=0.15
Amount of data to set aside as the validation set.
Not applicable to 'rsf' and 'cph' models.
val_data : tuple
A tuple of the validation dataset features and outcomes of
'time' and 'event'.
If passed, vsize is ignored.
random_seed : int
Controls the rproduecibility of fitted estimators.
hyperparams : Optional arguments
Expand Down Expand Up @@ -120,13 +116,13 @@ def _fit_dcm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
gamma=gamma,
smoothing_factor=smoothing_factor,
random_seed=random_seed)
model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
model.fit(x=features, t=outcomes.time, e=outcomes.event,
val_data=val_data, iters=epochs, batch_size=bs,
learning_rate=lr)

return model

def _fit_dcph(features, outcomes, vsize, val_data, random_seed, **hyperparams):
def _fit_dcph(features, outcomes, val_data, random_seed, **hyperparams):

"""Fit a Deep Cox Proportional Hazards Model/Farragi Simon Network [1,2]
model to a given dataset.
Expand All @@ -146,13 +142,9 @@ def _fit_dcph(features, outcomes, vsize, val_data, random_seed, **hyperparams):
and columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event'.
vsize : float, default=0.15
Amount of data to set aside as the validation set.
Not applicable to 'rsf' and 'cph' models.
val_data : tuple
A tuple of the validation dataset features and outcomes of
'time' and 'event'.
If passed, vsize is ignored.
random_seed : int
Controls the reproducibility of called functions.
hyperparams : Optional arguments
Expand Down Expand Up @@ -181,7 +173,7 @@ def _fit_dcph(features, outcomes, vsize, val_data, random_seed, **hyperparams):

model = DeepCoxPH(layers=layers, random_seed=random_seed)

model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
model.fit(x=features, t=outcomes.time, e=outcomes.event,
val_data=val_data, iters=epochs, batch_size=bs,
learning_rate=lr)

Expand Down Expand Up @@ -238,7 +230,7 @@ def _predict_dcph(model, features, times):

return model.predict_survival(x=features.values, t=times)

def _fit_cph(features, outcomes, random_seed, **hyperparams):
def _fit_cph(features, outcomes, val_data, random_seed, **hyperparams):
"""Fit a linear Cox Proportional Hazards model to a given dataset.
Parameters
Expand All @@ -248,6 +240,9 @@ def _fit_cph(features, outcomes, random_seed, **hyperparams):
columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event'.
val_data : tuple
A tuple of the validation dataset features and outcomes of
'time' and 'event'.
random_seed : int
Controls the reproducibility of called functions.
hyperparams : Optional arguments
Expand All @@ -270,7 +265,7 @@ def _fit_cph(features, outcomes, random_seed, **hyperparams):
duration_col='time',
event_col='event')

def _fit_rsf(features, outcomes, random_seed, **hyperparams):
def _fit_rsf(features, outcomes, val_data, random_seed, **hyperparams):

"""Fit the Random Survival Forests (RSF) [1] model to a given dataset.
RSF is an extension of Random Forests to the survival settings where
Expand All @@ -289,6 +284,9 @@ def _fit_rsf(features, outcomes, random_seed, **hyperparams):
columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event'.
val_data : tuple
A tuple of the validation dataset features and outcomes of
'time' and 'event'.
random_seed : int
Controls the reproducibility of called functions.
hyperparams : Optional arguments
Expand Down Expand Up @@ -326,7 +324,7 @@ def _fit_rsf(features, outcomes, random_seed, **hyperparams):
return rsf


def _fit_dsm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
def _fit_dsm(features, outcomes, val_data, random_seed, **hyperparams):

"""Fit the Deep Survival Machines (DSM) [1] model to a given dataset.
Expand All @@ -348,13 +346,9 @@ def _fit_dsm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event'.
vsize : float, default=0.15
Amount of data to set aside as the validation set.
Not applicable to 'rsf' and 'cph' models.
val_data : tuple
A tuple of the validation dataset features and outcomes of
'time' and 'event'.
If passed, vsize is ignored.
random_seed : int
Controls the reproducibility of called functions.
hyperparams : Optional arguments
Expand Down Expand Up @@ -395,8 +389,8 @@ def _fit_dsm(features, outcomes, vsize, val_data, random_seed, **hyperparams):
temp=temperature,
random_seed=random_seed)

model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
val_data=val_data, iters=epochs, learning_rate=lr, batch_size=bs)
model.fit(x=features, t=outcomes.time, e=outcomes.event, val_data=val_data,
iters=epochs, learning_rate=lr, batch_size=bs)

return model

Expand Down Expand Up @@ -595,71 +589,79 @@ def fit(self, features, outcomes, vsize=0.15, val_data=None,
"""

data = features.join(outcomes)

if val_data is None:
assert weights_val is None, "Weights for validation data \
must be None if validation data is not specified."

train_data = data.sample(frac=1-vsize, random_state=self.random_seed)
val_data = data[~data.index.isin(train_data.index)]
val_data = (val_data[features.columns], val_data[outcomes.columns])

else:
train_data = data

if weights is not None:
assert len(weights) == features.shape[0], "Size of passed weights \
must match size of training data."
assert (weights>0.).any(), "All weights must be positive."

data = features.join(outcomes)
weights = pd.Series(weights, index=data.index)
val_data = val_data[0].join(val_data[1])

if val_data is not None:
if weights_val is None:
weights_train = weights[train_data.index]
weights_val = weights[val_data.index]

else:
assert weights_val is not None, "Validation set weights must be \
specified."
assert len(weights_val) == val_data[0].shape[0], "Size of passed \
weights_val must match size of validation data."
assert (weights_val>0.).any(), "All weights_val must be positive."

data_train = data
data_val = val_data[0].join(val_data[1])
weights_train = weights

else:
data_train = data.sample(frac=1-vsize, random_state=self.random_seed)
data_val = data[~data.index.isin(data_train.index)]
weights_train = weights[data_train.index]
weights_val = weights[data_val.index]

data_train_resampled = data_train.sample(weights = weights_train,
train_data_resampled = train_data.sample(weights = weights_train,
frac = resample_size,
replace = True,
random_state = self.random_seed)

data_val_resampled = data_val.sample(weights = weights_val,
val_data_resampled = val_data.sample(weights = weights_val,
frac = resample_size,
replace = True,
random_state = self.random_seed)

features = data_train_resampled[features.columns]
outcomes = data_train_resampled[outcomes.columns]
features = train_data_resampled[features.columns]
outcomes = train_data_resampled[outcomes.columns]

val_data = (data_val_resampled[features.columns],
data_val_resampled[outcomes.columns])
val_data = (val_data_resampled[features.columns],
val_data_resampled[outcomes.columns])

if val_data is not None:
val_data = (val_data[0], val_data[1].time, val_data[1].event)
val_data = (val_data[0], val_data[1].time, val_data[1].event)

if self.model == 'cph':
self._model = _fit_cph(features, outcomes,
self.random_seed,
val_data, self.random_seed,
**self.hyperparams)
elif self.model == 'rsf':
self._model = _fit_rsf(features, outcomes,
self.random_seed,
val_data, self.random_seed,
**self.hyperparams)
elif self.model == 'dsm':
self._model = _fit_dsm(features, outcomes,
vsize, val_data,
val_data,
self.random_seed,
**self.hyperparams)
elif self.model == 'dcph':
self._model = _fit_dcph(features, outcomes,
vsize, val_data,
val_data,
self.random_seed,
**self.hyperparams)
elif self.model == 'dcm':
self._model = _fit_dcm(features, outcomes,
vsize, val_data,
val_data,
self.random_seed,
**self.hyperparams)

Expand Down
Loading

0 comments on commit 9513068

Please sign in to comment.