-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathalphasearch.py
72 lines (56 loc) · 2.03 KB
/
alphasearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from models.model_utils import makeANNModel
from utils.keras_utils import set_keras_growth
from utils.storage_utils import createdir
from utils.plotting import plotNResults
from datetime import datetime
from keras.wrappers.scikit_learn import KerasClassifier
from utils.runutils import runalldatasets,getArgs
import sys
import numpy as np
import pandas as pd
import json
import random
def main():
args = getArgs()
if args.seed is None:
seed = random.randrange(sys.maxsize)
args.seed = seed
print(f"generating new random seed: {seed}")
else:
print(f"setting random seed to: {args.seed}")
random.seed(args.seed)
datasetlist = args.datasets
print(f"doing experiment with {datasetlist} in that order")
k = args.kfold
results = {}
runlist = args.methods
set_keras_growth(args.gpu)
prefix = "runner"
if args.prefix is not None:
prefix = args.prefix
rootpath = prefix+datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
createdir(rootpath)
min_retain_losses = list()
min_losses = list()
writejson(f"{rootpath}/settings.json", sys.argv[1:])
if args.callbacks is not None:
callbacks = args.callbacks
else:
callbacks = list()
alpharange = np.arange(0.10,1.0,0.05)
nresults = list()
for i in range(0,args.n):
nresults.append(runalldatasets(args, callbacks,
datasetlist, rootpath,
runlist, alphalist=alpharange,
n=i, printcvresults=args.cvsummary,
printcv=args.printcv))
writejson(f"{rootpath}/data.json", nresults)
plotNResults(datasetlist, nresults, rootpath, args.kfold, args.n)
resdf = pd.DataFrame(nresults)
resdf.to_csv(f"{rootpath}/results_{args.kfold}kfold_{args.epochs}epochs_{args.onehot}onehot.csv")
def writejson(filename, data):
with open(filename, 'w') as outfile:
json.dump(data, outfile)
if __name__ == "__main__":
main()