-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMainDriver.m
126 lines (101 loc) · 3.54 KB
/
MainDriver.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
%%%%%%%%%%% USER INPUTS %%%%%%%%%%%%
% Hyperparams
NLearn = 20;
LearnRate = 1;
KFold = 2;
% Weight for protein forest
proteinWeight = 0.1;
% Number of PR patients used in each balanced set
NUM_PR_USED = 50; % Max = 55, there are 55 PR patients in set
%%%% Feature Selection %%%
weightCell = {
'cyto_cat_score',1;
'PRIOR_MAL',1;
'ITD',1;
'D835',1;
'Age_at_Dx',1;
'WBC',1;
'BM_MONOCYTES',1;
'PB_BLAST',1;
'PB_MONO',1;
'BM_ABS_RATIO',1;
'SEX',1;
'ChemoAraC',1;
'ChemoAnthra',1;
'ChemoFlu',1;
'ChemoHDACPlus',1;
};
% Marginal patient tolerance
patientTol = 0.2;
%%%%%%%%%%%% ADD APPROPRIATE PATHS %%%%%%%%%
addpath('HelperFuns')
addpath('Input')
addpath('Outputs')
%%%%%%%%%%%%% READING DATA %%%%%%%%%%%%%%%%%
% Read training data
if exist('trainingMat1','var') == 0
trainingFile = '/Input/trainingData-release_cytocat.xlsx';
tbl = ReadExcel(trainingFile);
tbl = tableClean(tbl,'train');
tblBackup = tbl;
end
% Read test data
if exist('testMat1','var') == 0
testFile = '/Input/scoringData-release_cytocat.xlsx';
testTable = ReadExcel(testFile);
testTable = tableClean(testTable,'test');
end
%%%%%%%%%%%%%%%% SC1 %%%%%%%%%%%%%%%%%%%%%%
% Convert to ML friendly matrix
% Training
[trainingTableAll,responseVar] = tblPrep(tbl,'train');
% Apply weights to table to create categorical data table
trainingTableCat = weightTable(trainingTableAll,weightCell);
% get protein training table
trainingTableProtein = makeProteinTable(trainingTableAll);
% Y scramble training
%responseVar1 = responseVar1(randperm(length(responseVar1)));
% Testing
testTableAll = tblPrep(testTable,'test');
% Apply weights to table
testTableCat = weightTable(testTableAll,weightCell);
% Get protein test table
testTableProtein = makeProteinTable(testTableAll);
%%%%%% CATEGORICAL MODEL %%%%%%%
% train RF model
[prediction,BackLabel,importance,foldLossCat] = ModelBuild(trainingTableCat,responseVar,...
NLearn,LearnRate,KFold,testTableCat,NUM_PR_USED);
%%%%%%% PROTEIN MODEL %%%%%%%%
[predictionProt,BackLabelProt,importanceProt,foldLossProt] = ModelBuild(trainingTableProtein,responseVar,...
NLearn,LearnRate,KFold,testTableProtein,NUM_PR_USED);
%%%%%% COMBINE MODELS %%%%%%%
if proteinWeight > 1 || proteinWeight < 0
error('Invalid protein weight')
end
categoryWeight = 1 - proteinWeight;
% Combine predictions
predictionCombined = proteinWeight.*predictionProt + categoryWeight.*prediction;
% Combine backlabel for self-scoring
backLabelCombined = BackLabelProt.*proteinWeight + categoryWeight.*BackLabel;
% Convert -1 to 0 in the response var for scoring
respVarScore = responseVar;
respVarScore(respVarScore == -1) = 0;
% Score the back labelled
[BAC,AUROC] = score(backLabelCombined,respVarScore);
fprintf('=====================\n')
fprintf('BAC: %.3f\n',BAC)
fprintf('AUROC: %.3f\n',AUROC)
fprintf('Cat Fold Loss: %.3f\n',foldLossCat)
fprintf('Protein Fold Loss: %.3f\n',foldLossProt)
fprintf('Num of PR Predictions: %.0f\n',sum(prediction < 0.5))
fprintf('---------------------\n')
importance = mean(importance,2);
importance = importance./max(importance);
importanceProt = mean(importanceProt,2);
importanceProt = importanceProt./max(importanceProt);
importanceTableCat = table(testTableCat.Properties.VariableNames',importance);
importanceTableProt = table(testTableProtein.Properties.VariableNames',importanceProt);
disp(sortrows(importanceTableCat,2,{'descend'}))
disp(sortrows(importanceTableProt,2,{'descend'}))
%writeToOutput('Outputs/HewesTanZhuJahn_Week7_SC1.txt',prediction)
%%%%%%%%%%%%%%%%%%