From f566977457f7e0bc29e764fa917c41716dc6985f Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Thu, 2 Apr 2020 10:44:06 -0400 Subject: [PATCH] new file: dsm.py new file: dsm_loss.py new file: dsm_utilites.py --- dsm.py | 128 +++++++++++++++++ dsm_loss.py | 363 ++++++++++++++++++++++++++++++++++++++++++++++++ dsm_utilites.py | 211 ++++++++++++++++++++++++++++ 3 files changed, 702 insertions(+) create mode 100644 dsm.py create mode 100644 dsm_loss.py create mode 100644 dsm_utilites.py diff --git a/dsm.py b/dsm.py new file mode 100644 index 0000000..b150f30 --- /dev/null +++ b/dsm.py @@ -0,0 +1,128 @@ +import torch.nn as nn +import torch + +class DeepSurvivalMachines(nn.Module): + + def __init__(self, inputdim, k, mlptyp=1, HIDDEN=False, init=False, dist='Weibull'): + + super(DeepSurvivalMachines, self).__init__() + + shape = 1. + scale = 1. + + self.k = k + + self.mlptype = mlptyp + self.scale = nn.Parameter(-torch.ones(k)) + self.shape = nn.Parameter(-torch.ones(k)) + + self.HIDDEN = HIDDEN + + + if mlptyp == 1: + + self.gate = nn.Sequential(nn.Linear(inputdim, k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(inputdim, k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(inputdim, k, bias=True)) + + if mlptyp == 2: + + self.gate = nn.Sequential(nn.Linear(HIDDEN[0], k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(HIDDEN[0], k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(HIDDEN[0], k, bias=True)) + + self.embedding = nn.Sequential(nn.Linear(inputdim, HIDDEN[0], bias=False), + nn.ReLU6()) + + if mlptyp == 3: + + self.gate = nn.Sequential(nn.Linear(HIDDEN[1], k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(HIDDEN[1], k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(HIDDEN[1], k, bias=True)) + + self.embedding = nn.Sequential(nn.Linear(inputdim, HIDDEN[0], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[0], HIDDEN[1], bias=False), + nn.ReLU6()) + + if mlptyp == 4: + + self.gate = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + + self.embedding = nn.Sequential(nn.Linear(inputdim, HIDDEN[0], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[0], HIDDEN[1], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[1], HIDDEN[2], bias=False), + nn.ReLU6()) + + if mlptyp == 5: + + self.gate = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + + self.embedding = nn.Sequential(nn.Linear(inputdim, HIDDEN[0], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[0], HIDDEN[1], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[1], HIDDEN[2], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[2], HIDDEN[3], bias=False), + nn.ReLU6()) + + if mlptyp == 6: + + self.gate = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(HIDDEN[2], k, bias=True)) + + self.embedding = nn.Sequential(nn.Linear(inputdim, HIDDEN[0], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[0], HIDDEN[1], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[1], HIDDEN[2], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[2], HIDDEN[3], bias=False), + nn.ReLU6(), + nn.Linear(HIDDEN[3], HIDDEN[4], bias=False), + nn.ReLU6()) + + if init is not False: + + self.shape.data.fill_(init[0]) + self.scale.data.fill_(init[1]) + + + self.dist = dist + + if self.dist == 'Weibull': + + self.act = nn.SELU() + + elif self.dist == 'LogNormal': + + self.act = nn.Tanh() + + def forward(self, x, adj=True): + + if self.mlptype == 1: + + embed = x + + else: + + embed = self.embedding(x) + + if adj: + + return self.act(self.shapeg(embed))+self.shape.expand(x.shape[0],-1), self.act(self.scaleg(embed))+self.scale.expand(x.shape[0],-1), self.gate(embed)/1000 + + else: + + return self.shape, self.scale, self.gate(embed)/1000 + + + diff --git a/dsm_loss.py b/dsm_loss.py new file mode 100644 index 0000000..e90f384 --- /dev/null +++ b/dsm_loss.py @@ -0,0 +1,363 @@ +import torch +def _logNormalLoss(model, x, t, e): + + import numpy as np + + shape, scale, logits = model.forward(x) + + k_ = shape.expand(x.shape[0], -1) + b_ = scale.expand(x.shape[0], -1) + + ll = 0. + + G = model.k + + for g in range(G): + + mu = k_[:, g] + sigma = b_[:, g] + + f = - sigma - 0.5*np.log(2*np.pi) - torch.div( (torch.log(t) - mu)**2, 2.*torch.exp(2*sigma) ) + s = torch.div (torch.log(t) - mu, torch.exp(sigma)*np.sqrt(2) ) + s = 0.5 - 0.5*torch.erf(s) + s = torch.log(s) + + + uncens = np.where(e==1)[0] + cens = np.where(e==0)[0] + + ll += f[uncens].sum() + s[cens].sum() + + return -ll.mean() + + +# fitting weibull using only t and e +def _weibullLoss(model, x, t, e): + + import numpy as np + + torch.manual_seed(0) + + shape, scale, logits = model.forward(x, adj=False) + + G = model.k + + k_ = shape.expand(x.shape[0], -1) + b_ = scale.expand(x.shape[0], -1) + + ll = 0. + for g in range(G): + + k = k_[:, g] + b = b_[:, g] + + + f = k + b + ((torch.exp(k)-1)*(b+torch.log(t))) - (torch.pow(torch.exp(b)*t , torch.exp(k))) + + s = - (torch.pow(torch.exp(b)*t , torch.exp(k))) + + uncens = np.where(e.cpu().data.numpy() == 1)[0] + cens = np.where(e.cpu().data.numpy() == 0)[0] + + ll += f[uncens].sum() + s[cens].sum() + + return -ll.mean() + + +def unConditionalLoss(model, x, t, e): + + if model.dist == 'Weibull': + + return _weibullLoss(model, x, t, e) + + else if model.dist == 'LogNormal': + + return _logNormalLoss(model, x, t, e) + + +def _conditionalLogNormalLoss(model, x, t, e, ELBO=True, mean=True, lambd=1e-2, alpha=1., ): + + # k = log(shape), b = -log(scale) + + import numpy as np + from torch.nn import LogSoftmax, Softmax + from torch import lgamma + + torch.manual_seed(0) + + G = model.k + + shape, scale, logits = model.forward(x, adj=True) + + lossf = [] # pdf + losss = [] # survival + lossm = [] # sum of squared error of mean / median survival time + + k_ = shape + b_ = scale + + for g in range(G): + + mu = k_[:, g] + sigma = b_[:, g] + + + f = - sigma - 0.5*np.log(2*np.pi) - torch.div( (torch.log(t) - mu)**2, 2.*torch.exp(2*sigma) ) + s = torch.div (torch.log(t) - mu, torch.exp(sigma)*np.sqrt(2) ) + s = 0.5 - 0.5*torch.erf(s) + s = torch.log(s) + + + lossf.append(f) + losss.append(s) + + + losss = torch.stack(losss, dim=1) + lossf = torch.stack(lossf, dim=1) + + + if ELBO: + + lossg = Softmax(dim=1)(logits) + losss = lossg*losss + lossf = lossg*lossf + + losss = losss.sum(dim=1) + lossf = lossf.sum(dim=1) + + else: + + lossg = LogSoftmax(dim=1)(logits) + losss = lossg + losss + lossf = lossg + lossf + + losss = torch.logsumexp(losss, dim=1) + lossf = torch.logsumexp(lossf, dim=1) + + lossg = Softmax(dim=1)(logits) + + + + uncens = np.where(e.cpu().data.numpy() == 1)[0] + cens = np.where(e.cpu().data.numpy() == 0)[0] + + + + ll = lossf[uncens].sum() + alpha*losss[cens].sum() + return -ll/x.shape[0] + + +def _conditionalWeibullLoss(model, x, t, e, ELBO=True, mean=True, lambd=1e-2, alpha=1.): + + # k = log(shape), b = -log(scale) + + import numpy as np + from torch.nn import LogSoftmax, Softmax + from torch import lgamma + + torch.manual_seed(0) + + G = model.k + + shape, scale, logits = model.forward(x, adj=True) + + + # print shape, scale, logits + + lossf = [] # pdf + losss = [] # survival + + + k_ = shape + b_ = scale + + + for g in range(G): + + k = k_[:, g] + b = b_[:, g] + + f = k + b + ((torch.exp(k)-1)*(b+torch.log(t))) - (torch.pow(torch.exp(b)*t, torch.exp(k))) + + s = - (torch.pow(torch.exp(b)*t , torch.exp(k))) + + b_exp = torch.exp(-b) # b_exp = scale + k_exp = torch.exp(-k) # k_exp = 1/shape + + + lossf.append(f) + losss.append(s) + + + losss = torch.stack(losss, dim=1) + lossf = torch.stack(lossf, dim=1) + + + if ELBO: + + lossg = Softmax(dim=1)(logits) + losss = lossg*losss + lossf = lossg*lossf + + losss = losss.sum(dim=1) + lossf = lossf.sum(dim=1) + + else: + + lossg = LogSoftmax(dim=1)(logits) + losss = lossg + losss + lossf = lossg + lossf + + losss = torch.logsumexp(losss, dim=1) + lossf = torch.logsumexp(lossf, dim=1) + + lossg = Softmax(dim=1)(logits) + + + + uncens = np.where(e.cpu().data.numpy() == 1)[0] + cens = np.where(e.cpu().data.numpy() == 0)[0] + + reg = 0 + + ll = lossf[uncens].sum() + alpha*losss[cens].sum() + + return -ll/x.shape[0] + +def conditionalLoss(model, x, t, e, ELBO=True, mean=True, lambd=1e-2, alpha=1.): + + if model.dist == 'Weibull': + + return _conditionalWeibullLoss(model, x, t, e, G, ELBO, mean, lambd, alpha) + + else if model.dist == 'LogNormal': + + return _conditionalLogNormalLoss(model, x, t, e, G, ELBO, mean, lambd, alpha) + +def _weibull_cdf(model, x, t_horizon): + + import numpy as np + from torch.nn import Softmax, LogSoftmax + from scipy.special import gamma + + squish = LogSoftmax(dim=1) + + G = model.k + + shape, scale, logits = model.forward(x, adj=True) + logits = squish(logits) + + k_ = shape + b_ = scale + + t_horz = torch.tensor(t_horizon).double() + t_horz = t_horz.repeat(x.shape[0],1) + + + cdfs = [] + pdfs = [] + hazards = [] + + for j in range(len(t_horizon)): + + t = t_horz[:, j] + + lcdfs = [] + + lpdfs = [] + + for g in range(G): + + k = k_[:, g] + b = b_[:, g] + + s = - (torch.pow(torch.exp(b)*t , torch.exp(k))) # log survival + f = k + b + ((torch.exp(k)-1)*(b+torch.log(t))) - (torch.pow(torch.exp(b)*t, torch.exp(k))) + + lpdfs.append(f) + lcdfs.append(s) + + lcdfs = torch.stack(lcdfs, dim=1) + lpdfs = torch.stack(lpdfs, dim=1) + + lcdfs = lcdfs+logits + lpdfs = lpdfs+logits + + lcdfs = torch.logsumexp(lcdfs, dim=1) + lpdfs = torch.logsumexp(lpdfs, dim=1) + + cdfs.append(lcdfs) + pdfs.append(lpdfs) + hazards.append(lpdfs-lcdfs) + + return cdfs + +def _lognormal_cdf(model, x, t_horizon): + + import numpy as np + from torch.nn import Softmax, LogSoftmax + from scipy.special import gamma + + squish = LogSoftmax(dim=1) + + G = model.k + + shape, scale, logits = model.forward(x, adj=True) + + logits = squish(logits) + + k_ = shape + b_ = scale + + t_horz = torch.tensor(t_horizon).double() + t_horz = t_horz.repeat(x.shape[0],1) + + cdfs = [] + pdfs = [] + hazards = [] + + for j in range(len(t_horizon)): + + t = t_horz[:, j] + lcdfs = [] + lpdfs = [] + + for g in range(G): + + mu = k_[:, g] + sigma = b_[:, g] + + f = - sigma - 0.5*np.log(2*np.pi) - torch.div( (torch.log(t) - mu)**2, 2.*torch.exp(2*sigma) ) + s = torch.div (torch.log(t) - mu, torch.exp(sigma)*np.sqrt(2) ) + s = 0.5 - 0.5*torch.erf(s) + s = torch.log(s) + + lpdfs.append(f) + lcdfs.append(s) + + lcdfs = torch.stack(lcdfs, dim=1) + lpdfs = torch.stack(lpdfs, dim=1) + + lcdfs = lcdfs+logits + lpdfs = lpdfs+logits + + lcdfs = torch.logsumexp(lcdfs, dim=1) + lpdfs = torch.logsumexp(lpdfs, dim=1) + + cdfs.append(lcdfs) + pdfs.append(lpdfs) + hazards.append(lpdfs-lcdfs) + + return cdfs + +def predict_cdf(model, x, t_horizon): + + if model.dist == 'Weibull': + + return _weibull_cdf(model, x, t_horizon) + + if model.dist == 'LogNormal': + + return _lognormal_cdf(model, x, t_horizon) + + \ No newline at end of file diff --git a/dsm_utilites.py b/dsm_utilites.py new file mode 100644 index 0000000..2957354 --- /dev/null +++ b/dsm_utilites.py @@ -0,0 +1,211 @@ + +import numpy as np + + +def computeCIScores(model,quantiles, G, x_valid, t_valid, e_valid, t_train, e_train, risk=0): + + from sklearn.metrics import mean_squared_error, mean_absolute_error + from sksurv.metrics import concordance_index_ipcw + + #quantiles = [ 43.68333435, 86.8666687 , 146.33333588, 283.54268066] + + + cdf_preds = predict_cdf(model, x_valid,quantiles, G) + cdf_preds = [cdf.data.numpy() for cdf in cdf_preds] + + t_valid = t_valid.cpu().data.numpy() + e_valid = e_valid.cpu().data.numpy() + + t_train = t_train.cpu().data.numpy() + e_train = e_train.cpu().data.numpy() + + uncensored = np.where(e_valid == 1)[0] + + et1 = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))], dtype=[('e', bool), ('t', int)]) + et2 = np.array([(e_valid[i], t_valid[i]) for i in range(len(e_valid))],dtype=[('e', bool), ('t', int)]) + + cdf_ci_25 = concordance_index_ipcw( et1, et2, -cdf_preds[0], tau=quantiles[0] ) + cdf_ci_50 = concordance_index_ipcw( et1, et2, -cdf_preds[1],tau= quantiles[1] ) + cdf_ci_75 = concordance_index_ipcw( et1, et2, -cdf_preds[2],tau= quantiles[2] ) + cdf_ci_m = concordance_index_ipcw( et1, et2, -cdf_preds[3],tau= quantiles[3] ) + + return None,None, cdf_ci_25[0], cdf_ci_50[0], cdf_ci_75[0], cdf_ci_m[0] + +def increaseCensoring(e, t, p): + + np.random.seed(0) + + uncens = np.where(e==1)[0] + + mask = np.random.choice([False, True], len(uncens), p=[1-p, p]) + + toswitch = uncens[mask] + + e[toswitch] = 0 + t_ = t[toswitch] + + newt = [] + for t__ in t_: + + newt.append(np.random.uniform(1,t__)) + + t[toswitch] = newt + + return e, t + + +def pretrainDSM(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, \ + n_iter=10000, lr=1e-3, thres=1e-4): + + from tqdm import tqdm + from dsm_loss import unconditionalLoss + + dist = model.dist + + premodel = DeepSurvivalMachines(x_train.shape[1], 1, init=False, dist=model.dist) + + model.double() + + torch.manual_seed(0) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + oldcost = -float('inf') + + patience = 0 + + costs = [] + + for i in tqdm(range(n_iter)): + + optimizer.zero_grad() + + loss = unconditionalLoss(premodel, x_train, t_train, e_train) # not conditioned on X + + loss.backward() + + optimizer.step() + + valid_loss = unconditionalLoss(premodel, x_valid, t_valid, e_valid) + + valid_loss = valid_loss.detach().cpu().numpy() + + costs.append(valid_loss) + + + if np.abs(costs[-1] - oldcost) < thres: + + patience += 1 + + if patience == 3: + + break + + oldcost = costs[-1] + + return model + + +def trainDSM(model, x_train, t_train, e_train, premodel, x_valid, t_valid, e_valid, \ + n_iter=10000, lr=1e-3, \ + ELBO=True, mean=True, lambd=1e-2, alpha=1., thres=1e-4, bs=100): + + import numpy as np + from tqdm import tqdm_notebook as tqdm + + from copy import deepcopy + import gc + + G = model.k + mlptyp = model.mlptype + HIDDEN = model.HIDDEN + + + print ("Pretraining the Underlying Distributions...") + + premodel = pretrainDSM(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, \ + n_iter=10000, lr=1e-3, thres=1e-4) + + + model = WeibullMixture(x_train.shape[1], G, mlptyp=mlptyp, HIDDEN=HIDDEN, \ + init=(float(premodel.shape[0]), float(premodel.scale[0]) )) + + model.double() + + torch.manual_seed(0) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + oldcost = -float('inf') + + patience = 0 + + nbatches = int(x_train.shape[0]/bs)+1 + + dics = [] + + costs = [] + for i in tqdm(range(n_iter)): + + + for j in range(nbatches): + + optimizer.zero_grad() + + loss = conditionalWeibullLoss(model, x_train[j*bs:(j+1)*bs], t_train[j*bs:(j+1)*bs], e_train[j*bs:(j+1)*bs], \ + G, ELBO=ELBO, mean=mean, lambd=lambd, alpha=alpha) + + loss.backward() + + optimizer.step() + + + valid_loss = conditionalWeibullLoss(model, x_valid, t_valid, e_valid, \ + G, ELBO=True, mean=mean, lambd=lambd, alpha=alpha) + valid_loss = valid_loss.detach().cpu().numpy() + + out = predict_valid(model, G, x_valid, t_valid, e_valid, t_train, e_train) + + valid_loss = np.mean(out[2:]) + + costs.append(valid_loss) + + dics.append(deepcopy(model.state_dict())) + + + if (costs[-1] < oldcost) == True: + + + print (valid_loss, out) + + if patience == 2: + + maxm= np.argmax(costs) + + print ("max:", maxm) + + model.load_state_dict(dics[maxm]) + + del dics + + gc.collect() + + return model, i + + else: + + patience+=1 + + else: + + patience =0 + + if i%10==0: + + print (valid_loss, out) + + oldcost = costs[-1] + + return model, i + +