-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_training.py
executable file
·146 lines (114 loc) · 6.11 KB
/
model_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import uuid
import skorch.callbacks as scb
import neptune
from skorch.callbacks.logging import NeptuneLogger
import pandas as pd
from data_loading import HistopathDataset
import torch
from skorch.helper import predefined_split
def train_model(classifier, train_labels, test_labels, file_dir, train_transform, test_transform, in_memory, output_path, logger = None):
######################################
# NEPTUNE SETUP #
#####################################
if logger:
params = {}
if classifier and classifier.callbacks and classifier.callbacks[0]:
if hasattr(classifier.callbacks[0], "policy"): params["scheduler_policy"] = classifier.callbacks[0].policy
if hasattr(classifier.callbacks[0], "step_size"): params["scheduler_step_size"] = classifier.callbacks[0].step_size
if hasattr(classifier.callbacks[0], "gamma"): params["scheduler_gamma"] = classifier.callbacks[0].gamma
neptune.init(api_token=logger["api_token"],
project_qualified_name=logger["project_qualified_name"])
experiment = neptune.create_experiment(name=logger["experiment_name"],
params={**classifier.get_params(), **params})
logger = NeptuneLogger(experiment, close_after_train=False)
######################################
# DATA LOADING #
#####################################
dataset_train = HistopathDataset(
label_file = os.path.abspath(train_labels),
root_dir = os.path.abspath(file_dir),
transform = train_transform,
in_memory = in_memory)
dataset_test = HistopathDataset(
label_file = os.path.abspath(test_labels),
root_dir = os.path.abspath(file_dir),
transform = test_transform,
in_memory = in_memory)
######################################
# METRIC CALLBACKS #
#####################################
classifier.callbacks.extend([
('train_acc', scb.EpochScoring('accuracy',
name='train_acc',
lower_is_better = False,
on_train = True)),
('train_f1', scb.EpochScoring('f1',
name='train_f1',
lower_is_better = False,
on_train = True)),
('train_roc_auc', scb.EpochScoring('roc_auc',
name='train_roc_auc',
lower_is_better = False,
on_train = True)),
('train_precision', scb.EpochScoring('precision',
name='train_precision',
lower_is_better = False,
on_train = True)),
('train_recall', scb.EpochScoring('recall',
name='train_recall',
lower_is_better = False,
on_train = True)),
('valid_f1', scb.EpochScoring('f1',
name='valid_f1',
lower_is_better = False)),
('valid_roc_auc', scb.EpochScoring('roc_auc',
name='valid_roc_auc',
lower_is_better = False)),
('valid_precision', scb.EpochScoring('precision',
name='valid_precision',
lower_is_better = False)),
('valid_recall', scb.EpochScoring('recall',
name='valid_recall',
lower_is_better = False)),
scb.ProgressBar()])
classifier.train_split = predefined_split(dataset_test)
if logger:
classifier.callbacks.append(logger)
######################################
# MODEL TRAINING #
#####################################
print('''Starting Training for {}
\033[1mModel-Params:\033[0m
\033[1mCriterion:\033[0m {}
\033[1mOptimizer:\033[0m {}
\033[1mLearning Rate:\033[0m {}
\033[1mEpochs:\033[0m {}
\033[1mBatch size:\033[0m {}
'''
.format(classifier.module,
classifier.criterion,
classifier.optimizer,
classifier.lr,
classifier.max_epochs,
classifier.batch_size))
df = pd.read_csv(train_labels)
target = df["label"]
classifier.fit(X = dataset_train, y = torch.Tensor(target))
######################################
# MODEL SAVING #
#####################################
# Saving the model and its history locally and then upload the artifacts to neptune
print("Saving model...")
uid = uuid.uuid4()
classifier.save_params(f_params = '{}/{}-model.pkl'.format(output_path, uid),
f_optimizer='{}/{}-opt.pkl'.format(output_path, uid),
f_history='{}/{}-history.json'.format(output_path, uid))
if logger:
logger.experiment.log_text('uid', str(uid))
logger.experiment.log_text('test_name', train_labels)
logger.experiment.log_artifact('{}/{}-model.pkl'.format(output_path, uid))
logger.experiment.log_artifact('{}/{}-opt.pkl'.format(output_path, uid))
logger.experiment.log_artifact('{}/{}-history.json'.format(output_path, uid))
logger.experiment.stop()
print("Saving completed...")