From 5252d6f28c482197c3665c573665d30773285c5a Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Wed, 27 Jan 2021 00:32:12 +0530 Subject: [PATCH] modified: dsm/dsm_api.py --- dsm/dsm_api.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index b96734e..bbd3c4b 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -174,15 +174,22 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state): t_train = torch.from_numpy(t_train).double() e_train = torch.from_numpy(e_train).double() - vsize = int(vsize*x_train.shape[0]) + if vsize is not None: - x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] - x_train = x_train[:-vsize] - t_train = t_train[:-vsize] - e_train = e_train[:-vsize] + vsize = int(vsize*x_train.shape[0]) + + x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] + x_train = x_train[:-vsize] + t_train = t_train[:-vsize] + e_train = e_train[:-vsize] + + return (x_train, t_train, e_train, + x_val, t_val, e_val) + + else: + return (x_train, t_train, e_train, + x_train, t_train, e_train) - return (x_train, t_train, e_train, - x_val, t_val, e_val) def predict_mean(self, x, risk=1): r"""Returns the mean Time-to-Event \( t \) @@ -359,15 +366,20 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state): t_train = torch.from_numpy(t_train).double() e_train = torch.from_numpy(e_train).double() - vsize = int(vsize*x_train.shape[0]) - x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] + if vsize is not None: + + vsize = int(vsize*x_train.shape[0]) + x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] - x_train = x_train[:-vsize] - t_train = t_train[:-vsize] - e_train = e_train[:-vsize] + x_train = x_train[:-vsize] + t_train = t_train[:-vsize] + e_train = e_train[:-vsize] - return (x_train, t_train, e_train, - x_val, t_val, e_val) + return (x_train, t_train, e_train, + x_val, t_val, e_val) + else: + return (x_train, t_train, e_train, + x_train, t_train, e_train) class DeepConvolutionalSurvivalMachines(DSMBase):