Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Feb 16, 2022
2 parents 77a615a + 7c6b014 commit 6df0106
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 3 additions & 3 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def _fit_dcph(features, outcomes, random_seed, **hyperparams):
epochs = hyperparams.get('epochs', 50)
activation = hyperparams.get('activation', 'relu')

if activation == 'relu': activation = torch.nn.ReLU()
elif activation == 'relu6': activation = torch.nn.ReLU6()
elif activation == 'tanh': activation = torch.nn.Tanh()
if activation == 'relu': activation = torch.nn.ReLU
elif activation == 'relu6': activation = torch.nn.ReLU6
elif activation == 'tanh': activation = torch.nn.Tanh
else: raise NotImplementedError("Activation function not implemented")

x = features.values.astype('float32')
Expand Down
5 changes: 5 additions & 0 deletions auton_survival/models/dsm/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .dsm_torch import DeepSurvivalMachinesTorch
from .losses import unconditional_loss, conditional_loss

from sklearn.utils import shuffle

from tqdm import tqdm
from copy import deepcopy

Expand Down Expand Up @@ -149,6 +151,9 @@ def train_dsm(model,
costs = []
i = 0
for i in tqdm(range(n_iter)):

x_train, t_train, e_train = shuffle(x_train, t_train, e_train, random_state=i)

for j in range(nbatches):

xb = x_train[j*bs:(j+1)*bs]
Expand Down

0 comments on commit 6df0106

Please sign in to comment.