-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimu_1_run.py
125 lines (99 loc) · 4.48 KB
/
simu_1_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This module runs the models on each simulated dataset.
n_datasets : (int) Number of dataset simulations and trainings. This is used
for measuring the variability of the models' performance
n_patients : (int) Number of samples in the simulated datasets.
In the whole project, the pid column is 'pid', the duration column is 'X',
the event column is 'J'. It was much easier to run the pydts model by using
this convention so it was kept for the whole project.
"""
import datetime
from pathlib import Path
from data_simulation import load_simulated_datasets
from models.run_pydts import PyDTS_template
from models.run_transformer import Transformer_template
from models.run_coxph import CoxPH_template
from models.run_deephit import DeepHit_template
from utils.utils import Simu_rundir
n_patients_n_datasets = {
2_000: 10,
5_000: 10,
10_000: 10,
20_000: 10,
50_000: 10,
}
cols = {'pid_col': 'pid',
'duration_col': 'X',
'event_col': 'J'
}
deep_kwargs = {
'd_model': 64,
'dropout': 0.1,
'batch_size': 128,
'lr': 1e-4
}
class Run_simu:
def __init__(self, simulated_datasets, n_patients, simu_params, cols, rundirs):
self.simulated_datasets = simulated_datasets
self.n_patients = n_patients
self.simu_params = simu_params
self.cols = cols
self.rundirs = rundirs
self.weights_savedir = f'models/weights/{self._now_str()}/'
Path(self.weights_savedir).mkdir(exist_ok=True, parents=True)
def _now_str(self):
return (str(datetime.datetime.now()).split('.')[0]
.replace(':', '_')
.replace('-', '_')
.replace(' ', '_'))
def main(self,):
for i, rd in enumerate(rundirs):
print(f'run : {rd} - {i}')
kwargs = {'train_df': rd.train,
'test_df': rd.test,
'hazard_gt': rd.hazard_gt,
'weights_savedir': self.weights_savedir}
kwargs = kwargs | simu_params | cols
self.template_transformer = Transformer_template(
rd.savedir,
plotdir=f'{rd.plotdir}transformer/',
plot=False,
num_heads=1,
use_transformer_decoder=False,
**kwargs, **deep_kwargs)
self.template_transformer.run()
self.coxph_template = CoxPH_template(rd.savedir,
plotdir=f'{rd.plotdir}coxph_regularized/',
schoenfeld_savepath='outputs/plots/table_schoenfeld/simu_residuals.parquet',
compute_schoenfeld=False,
name='regularized_CoxPH',
l1_ratio=1,
penalizer=0.01,
**kwargs)
self.coxph_template.run()
pydts_template = PyDTS_template(rd.savedir,
plotdir=f'{rd.plotdir}pydts/',
**kwargs)
pydts_template.run()
deephit_template = DeepHit_template(rd.savedir,
plotdir=f'{rd.plotdir}deephit/',
plot=False,
**kwargs, **deep_kwargs)
deephit_template.run()
print(dirname, n_patients)
if __name__ == '__main__':
dirname = 'rundir'
for n_patients, n_datasets in n_patients_n_datasets.items():
simulated_datasets, simu_params = load_simulated_datasets(n_patients,
n_datasets,
n_times=30)
rundirs = [Simu_rundir(basedir=f'outputs/{dirname}_{n_patients}/',
name=i,
simu_dataset=dataset)
for i, dataset in enumerate(simulated_datasets)]
self = Run_simu(simulated_datasets,
n_patients,
simu_params,
cols,
rundirs)
self.main()