Skip to content

Commit

Permalink
fixed_syntax_small_bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Chufan Gao committed Dec 16, 2020
1 parent cb20d8d commit 36ec567
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitgnore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dsm/__pycache*
4 changes: 3 additions & 1 deletion dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,6 @@
"""

from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines, DeepConvolutionalSurvivalMachines
from dsm.dsm_api import DeepSurvivalMachines
from dsm.dsm_api import DeepConvolutionalSurvivalMachines
from dsm.dsm_api import DeepRecurrentSurvivalMachines
12 changes: 5 additions & 7 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"""

from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch, DeepConvolutionalSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch
from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
from dsm.losses import predict_cdf
import dsm.losses as losses
from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets
Expand Down Expand Up @@ -66,8 +67,7 @@ def _gen_torch_model(self, inputdim, optimizer, risks):

def fit(self, x, t, e, vsize=0.15,
iters=1, learning_rate=1e-3, batch_size=100,
elbo=True, optimizer="Adam", random_state=100,
cuda=False):
elbo=True, optimizer="Adam", random_state=100):

r"""This method is used to train an instance of the DSM model.
Expand Down Expand Up @@ -186,7 +186,7 @@ def predict_risk(self, x, t, risk=1):
"before calling `predict_risk`.")


def predict_survival(self, x, t, risk=1, cuda=False):
def predict_survival(self, x, t, risk=1):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Expand Down Expand Up @@ -336,8 +336,7 @@ class DeepConvolutionalSurvivalMachines(DSMBase):

def __init__(self, k=3, layers=None, hidden=None,
distribution='Weibull', temp=1000., discount=1.0, typ='ConvNet'):
super(DeepConvolutionalSurvivalMachines, self).__init__(k=k,
layers=layers,
super(DeepConvolutionalSurvivalMachines, self).__init__(k=k,
distribution=distribution,
temp=temp,
discount=discount)
Expand All @@ -347,7 +346,6 @@ def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepConvolutionalSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
Expand Down
75 changes: 49 additions & 26 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch.nn as nn
import torch
import torchvision
import torch.nn.functional as F
import numpy as np

__pdoc__ = {}

Expand Down Expand Up @@ -339,7 +339,50 @@ def forward(self, x, risk='1'):
def get_shape_scale(self, risk='1'):
return(self.shape[risk],
self.scale[risk])



def create_conv_representation(inputdim, hidden, typ='ConvNet'):
r"""Helper function to generate the representation function for DSM.
Deep Survival Machines learns a representation (\ Phi(X) \) for the input
data. This representation is parameterized using a Convolutional Neural
Network (`torch.nn.Module`). This is a helper function designed to
instantiate the representation for Deep Survival Machines.
.. warning::
Not designed to be used directly.
Parameters
----------
inputdim: int
Dimensionality of the input features.
hidden: int
The number of neurons in each hidden layer.
typ: str
Choice of convolutional neural network: One of 'ConvNet'
Returns
----------
an ConvNet with torch.nn.Module with the specfied structure.
"""

if typ == 'ConvNet':
inputdim = np.squeeze(inputdim)
linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2
linear_dim *= 16
embedding = nn.Sequential(
nn.Conv2d(1, 6, 3),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 3),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(linear_dim, 120),
nn.Linear(120, 84),
nn.Linear(84, hidden)
)
return embedding

class DeepConvolutionalSurvivalMachinesTorch(nn.Module):
"""A Torch implementation of Deep Convolutional Survival Machines model.
Expand All @@ -357,11 +400,9 @@ class DeepConvolutionalSurvivalMachinesTorch(nn.Module):
Parameters
----------
inputdim: int
Dimensionality of the input features.
Dimensionality of the input features. A tuple (height, width).
k: int
The number of underlying parametric distributions.
layers: int
The number of hidden layers in the LSTM or RNN cell.
hidden: int
The number of neurons in each hidden layer.
init: tuple
Expand All @@ -381,7 +422,7 @@ class DeepConvolutionalSurvivalMachinesTorch(nn.Module):
"""

def __init__(self, inputdim, k, typ='ResNet', layers=1,
def __init__(self, inputdim, k, typ='ConvNet',
hidden=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam', risks=1):
super(DeepConvolutionalSurvivalMachinesTorch, self).__init__()
Expand All @@ -392,7 +433,6 @@ def __init__(self, inputdim, k, typ='ResNet', layers=1,
self.discount = float(discount)
self.optimizer = optimizer
self.hidden = hidden
self.layers = layers
self.typ = typ
self.risks = risks

Expand Down Expand Up @@ -430,17 +470,7 @@ def __init__(self, inputdim, k, typ='ResNet', layers=1,
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

if self.typ == 'ConvNet':
# self.cnn = torchvision.models.resnet18(pretrained=True).float()
# self.cnn.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
# self.linear = torch.nn.Linear(1000, hidden)
self.conv1 = nn.Conv2d(1, 6, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, hidden)

self.embedding = create_conv_representation(inputdim=inputdim, hidden=hidden, typ='ConvNet')

def forward(self, x, risk='1'):
"""The forward function that is called when data is passed through DSM.
Expand All @@ -450,14 +480,7 @@ def forward(self, x, risk='1'):
a torch.tensor of the input features.
"""
# xrep = self.linear(self.cnn(x))
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
xrep = self.fc3(x)

xrep = self.embedding(x)
dim = x.shape[0]
return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1),
self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1),
Expand Down
13 changes: 12 additions & 1 deletion examples/conv_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
"from sksurv.metrics import concordance_index_ipcw, brier_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# self.cnn = torchvision.models.resnet18(pretrained=True).float()\n",
"# self.cnn.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)\n",
"# self.linear = torch.nn.Linear(1000, hidden_dim)"
]
},
{
"cell_type": "code",
"execution_count": 13,
Expand Down Expand Up @@ -618,4 +629,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ torch>=1.0.0
numpy>=0.14
pandas>=1.0.0
tqdm>=4.0.0
scikit-learn>=0.18
scikit-learn>=0.18
torchvision>=0.7.0

0 comments on commit 36ec567

Please sign in to comment.