Skip to content

Commit

Permalink
modified: dcm_api.py
Browse files Browse the repository at this point in the history
	modified:   dcm_torch.py
	modified:   dcm_utilities.py
  • Loading branch information
chiragnagpal committed Dec 24, 2021
1 parent 2fc30e6 commit 9b088d0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
23 changes: 21 additions & 2 deletions dsm/contrib/dcm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from dsm.contrib.dcm_torch import DeepCoxMixturesTorch
from dsm.contrib.dcm_utilities import train_dcm, predict_survival
from dsm.contrib.dcm_utilities import train_dcm, predict_survival, predict_latent_z


class DeepCoxMixtures:
Expand Down Expand Up @@ -35,10 +35,13 @@ class DeepCoxMixtures:
>>> model.fit(x, t, e)
"""
def __init__(self, k=3, layers=None):
def __init__(self, k=3, layers=None, gamma=0.95, use_activation=False):

self.k = k
self.layers = layers
self.fitted = False
self.gamma = gamma
self.use_activation = use_activation

def __call__(self):
if self.fitted:
Expand Down Expand Up @@ -86,6 +89,8 @@ def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
return DeepCoxMixturesTorch(inputdim,
k=self.k,
gamma=self.gamma,
use_activation=self.use_activation,
layers=self.layers,
optimizer=optimizer)

Expand Down Expand Up @@ -175,3 +180,17 @@ def predict_survival(self, x, t):
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")

def predict_latent_z(self, x):

x = self._preprocess_test_data(x)

if self.fitted:
scores = predict_latent_z(self.torch_model, x)
return scores
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_latent_z`.")


14 changes: 9 additions & 5 deletions dsm/contrib/dcm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _init_dcm_layers(self, lastdim):
self.gate = torch.nn.Linear(lastdim, self.k, bias=False)
self.expert = torch.nn.Linear(lastdim, self.k, bias=False)

def __init__(self, inputdim, k, layers=None, optimizer='Adam'):
def __init__(self, inputdim, k, gamma=0.95, use_activation=False, layers=None, optimizer='Adam'):

super(DeepCoxMixturesTorch, self).__init__()

Expand All @@ -44,14 +44,18 @@ def __init__(self, inputdim, k, layers=None, optimizer='Adam'):

self._init_dcm_layers(lastdim)
self.embedding = create_representation(inputdim, layers, 'ReLU6')
self.gamma = np.log(gamma)
self.use_activation = use_activation

def forward(self, x):

x = self.embedding(x)
gamma = self.gamma

log_hazard_ratios = torch.clamp(self.expert(x), min=-7e-1, max=7e-1)
#log_hazard_ratios = self.expert(x)
#log_hazard_ratios = torch.nn.Tanh()(self.expert(x))
x = self.embedding(x)
if self.use_activation:
log_hazard_ratios = gamma*torch.nn.Tanh()(self.expert(x))
else:
log_hazard_ratios = torch.clamp(self.expert(x), min=-gamma, max=gamma)
log_gate_prob = torch.nn.LogSoftmax(dim=1)(self.gate(x))

return log_gate_prob, log_hazard_ratios
32 changes: 21 additions & 11 deletions dsm/contrib/dcm_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fit_spline(t, surv, s=1e-4):
return UnivariateSpline(t, surv, s=s, ext=3, k=1)

def smooth_bl_survival(breslow, smoothing_factor):

blsurvival = breslow.baseline_survival_
x, y = blsurvival.x, blsurvival.y
return fit_spline(x, y, s=smoothing_factor)
Expand Down Expand Up @@ -84,12 +84,12 @@ def sample_hard_z(gates_prob):
return torch.multinomial(gates_prob.exp(), num_samples=1)[:, 0]

def repair_probs(probs):
probs[torch.isnan(probs)] = -10
probs[torch.isnan(probs)] = -10
probs[probs<-10] = -10
return probs

def get_likelihood(model, breslow_splines, x, t, e, log=False):

# Function requires numpy/torch

gates, lrisks = model(x)
Expand Down Expand Up @@ -143,9 +143,9 @@ def e_step(model, breslow_splines, x, t, e, log=False):
posteriors = get_posteriors(repair_probs(probs))

return posteriors

def m_step(model, optimizer, x, t, e, posteriors, typ='soft'):

optimizer.zero_grad()
loss = q_function(model, x, t, e, posteriors, typ)
loss.backward()
Expand Down Expand Up @@ -210,7 +210,7 @@ def train_step(model, x, t, e, breslow_splines, optimizer,
if use_posteriors:
posteriors = e_step(model, breslow_splines, x, t, e)
breslow_splines = fit_breslow(model, x, t, e, posteriors=posteriors, typ='soft')
else:
else:
breslow_splines = fit_breslow(model, x, t, e, posteriors=None, typ='soft')
# print(f'Duration of Breslow spline estimation: {time.time() - estimate_breslow_start}')
except Exception as exce:
Expand Down Expand Up @@ -262,10 +262,10 @@ def train_dcm(model, train_data, val_data, epochs=50,
for epoch in tqdm(range(epochs)):

# train_step_start = time.time()
breslow_splines = train_step(model, xt, tt, et, breslow_splines,
breslow_splines = train_step(model, xt, tt, et, breslow_splines,
optimizer, bs=bs, seed=epoch, typ=typ,
use_posteriors=use_posteriors,
update_splines_after=update_splines_after,
use_posteriors=use_posteriors,
update_splines_after=update_splines_after,
smoothing_factor=smoothing_factor)
# print(f'Duration of train-step: {time.time() - train_step_start}')
# test_step_start = time.time()
Expand All @@ -281,8 +281,8 @@ def train_dcm(model, train_data, val_data, epochs=50,
else: patience_ = 0

if patience_ == patience:
if return_losses: return (model, breslow_splines), losses
else: return (model, breslow_splines)
if return_losses: return (model, breslow_splines), losses
else: return (model, breslow_splines)

valc = valcn

Expand All @@ -307,3 +307,13 @@ def predict_survival(model, x, t):
predictions.append((gate_probs*expert_output).sum(axis=1))

return np.array(predictions).T

def predict_latent_z(model, x):

model, _ = model
gates, _ = model(x)

gate_probs = torch.exp(gates).detach().numpy()

return gate_probs

0 comments on commit 9b088d0

Please sign in to comment.