Skip to content

Commit

Permalink
new file: dsm/contrib/__init__.py
Browse files Browse the repository at this point in the history
	new file:   dsm/contrib/dcm_api.py
	new file:   dsm/contrib/dcm_torch.py
	new file:   dsm/contrib/dcm_utilities.py
	modified:   dsm/dsm_api.py
	modified:   dsm/dsm_torch.py
	modified:   dsm/utilities.py
  • Loading branch information
chiragnagpal committed Oct 25, 2021
1 parent 76120c6 commit ddb01c5
Show file tree
Hide file tree
Showing 7 changed files with 628 additions and 27 deletions.
61 changes: 61 additions & 0 deletions dsm/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# coding=utf-8
# MIT License

# Copyright (c) 2020 Carnegie Mellon University, Auton Lab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


r"""
`dsm` includes extended functionality for survival analysis as part
of `dsm.contrib`.
Contributed Modules
--------------------
This submodule incorporates contributed survival analysis methods.
Deep Cox Mixtures
------------------
The Cox Mixture involves the assumption that the survival function
of the individual to be a mixture of K Cox Models. Conditioned on each
subgroup Z=k; the PH assumptions are assumed to hold and the baseline
hazard rates is determined non-parametrically using an spline-interpolated
Breslow's estimator.
For full details on Deep Cox Mixture, refer to the paper [1].
References
----------
[1] <a href="https://arxiv.org/abs/2101.06536">Deep Cox Mixtures
for Survival Regression. Machine Learning in Health Conference (2021)</a>
```
@article{nagpal2021dcm,
title={Deep Cox mixtures for survival regression},
author={Nagpal, Chirag and Yadlowsky, Steve and Rostamzadeh, Negar and Heller, Katherine},
journal={arXiv preprint arXiv:2101.06536},
year={2021}
}
```
"""

from dsm.contrib.dcm_api import DeepCoxMixtures
181 changes: 181 additions & 0 deletions dsm/contrib/dcm_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@

import torch
import numpy as np

from dsm.contrib.dcm_torch import DeepCoxMixturesTorch
from dsm.contrib.dcm_utilities import train_dcm, predict_survival


class DeepCoxMixtures():
"""A Deep Cox Mixture model.
This is the main interface to a Deep Cox Mixture model.
A model is instantiated with approporiate set of hyperparameters and
fit on numpy arrays consisting of the features, event/censoring times
and the event/censoring indicators.
For full details on Deep Cox Mixture, refer to the paper [1].
References
----------
[1] <a href="https://arxiv.org/abs/2101.06536">Deep Cox Mixtures
for Survival Regression. Machine Learning in Health Conference (2021)</a>
Parameters
----------
k: int
The number of underlying Cox distributions.
layers: list
A list of integers consisting of the number of neurons in each
hidden layer.
Example
-------
>>> from dsm.contrib import DeepCoxMixtures
>>> model = DeepCoxMixtures()
>>> model.fit(x, t, e)
"""
def __init__(self, k=3, layers=None, distribution="Weibull",
temp=1000., discount=1.0):
self.k = k
self.layers = layers
self.dist = distribution
self.temp = temp
self.discount = discount
self.fitted = False

def __call__(self):
if self.fitted:
print("A fitted instance of the Deep Cox Mixtures model")
else:
print("An unfitted instance of the Deep Cox Mixtures model")

print("Number of underlying cox distributions (k):", self.k)
print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x):
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state):

idx = list(range(x.shape[0]))
np.random.seed(random_state)
np.random.shuffle(idx)
x_train, t_train, e_train = x[idx], t[idx], e[idx]

x_train = torch.from_numpy(x_train).float()
t_train = torch.from_numpy(t_train).float()
e_train = torch.from_numpy(e_train).float()

if val_data is None:

vsize = int(vsize*x_train.shape[0])
x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]

x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]

else:

x_val, t_val, e_val = val_data

x_val = torch.from_numpy(x_val).float()
t_val = torch.from_numpy(t_val).float()
e_val = torch.from_numpy(e_val).float()

return (x_train, t_train, e_train, x_val, t_val, e_val)

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
return DeepCoxMixturesTorch(inputdim,
k=self.k,
layers=self.layers,
optimizer=optimizer)

def fit(self, x, t, e, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
optimizer="Adam", random_state=100):

r"""This method is used to train an instance of the DSM model.
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
t: np.ndarray
A numpy array of the event/censoring times, \( t \).
e: np.ndarray
A numpy array of the event/censoring indicators, \( \delta \).
\( \delta = 1 \) means the event took place.
vsize: float
Amount of data to set aside as the validation set.
val_data: tuple
A tuple of the validation dataset. If passed vsize is ignored.
iters: int
The maximum number of training iterations on the training dataset.
learning_rate: float
The learning rate for the `Adam` optimizer.
batch_size: int
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
optimizer: str
The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random seed that determines how the validation set is chosen.
"""

processed_data = self._preprocess_training_data(x, t, e,
vsize, val_data,
random_state)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

#Todo: Change this somehow. The base design shouldn't depend on child

inputdim = x_train.shape[-1]

model = self._gen_torch_model(inputdim, optimizer)

model, _ = train_dcm(model,
(x_train, t_train, e_train),
(x_val, t_val, e_val),
epochs=iters,
lr=learning_rate,
bs=batch_size,
return_losses=True,
smoothing_factor=None,
use_posteriors=True,)

self.torch_model = (model[0].eval(), model[1])
self.fitted = True

return self


def predict_survival(self, x, t):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
t: list or float
a list or float of the times at which survival probability is
to be computed
Returns:
np.array: numpy array of the survival probabilites at each time in t.
"""
x = self._preprocess_test_data(x)
if not isinstance(t, list):
t = [t]
if self.fitted:
scores = predict_survival(self.torch_model, x, t)
return scores
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")
57 changes: 57 additions & 0 deletions dsm/contrib/dcm_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn

import numpy as np

from scipy.interpolate import UnivariateSpline
from sksurv.linear_model.coxph import BreslowEstimator

import time
from tqdm import tqdm

from dsm.dsm_torch import create_representation

class DeepCoxMixturesTorch(nn.Module):
"""PyTorch model definition of the Deep Cox Mixture Survival Model.
The Cox Mixture involves the assumption that the survival function
of the individual to be a mixture of K Cox Models. Conditioned on each
subgroup Z=k; the PH assumptions are assumed to hold and the baseline
hazard rates is determined non-parametrically using an spline-interpolated
Breslow's estimator.
"""

def _init_dcm_layers(self, lastdim):

self.gate = torch.nn.Linear(lastdim, self.k, bias=False)
self.expert = torch.nn.Linear(lastdim, self.k, bias=False)

def __init__(self, inputdim, k, layers=None, optimizer='Adam'):

super(DeepCoxMixturesTorch, self).__init__()

if not isinstance(k, int):
raise ValueError(f'k must be int, but supplied k is {type(k)}')

self.k = k
self.optimizer = optimizer

if layers is None: layers = []
self.layers = layers

if len(layers) == 0: lastdim = inputdim
else: lastdim = layers[-1]

self._init_dcm_layers(lastdim)
self.embedding = create_representation(inputdim, layers, 'ReLU6')

def forward(self, x):

x = self.embedding(x)

log_hazard_ratios = torch.clamp(self.expert(x), min=-7e-1, max=7e-1)
#log_hazard_ratios = self.expert(x)
#log_hazard_ratios = torch.nn.Tanh()(self.expert(x))
log_gate_prob = torch.nn.LogSoftmax(dim=1)(self.gate(x))

return log_gate_prob, log_hazard_ratios
Loading

0 comments on commit ddb01c5

Please sign in to comment.