-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerateMessage.py
61 lines (33 loc) · 1.21 KB
/
generateMessage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
print("Starting message generation...")
import os
scriptDir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(scriptDir + '/checkpoint/run1'):
print("No model found. Please train a model first.")
exit()
import gpt_2_simple as gpt2
import sys
print("Starting TensorFlow session...")
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)
# To pick a trained model, do checkpoint="model number"
# Leave blank for most recent
prefix = sys.argv[1]
sampleNum = int(sys.argv[2])
sampleLen = int(sys.argv[3])
batchSize = int(sys.argv[4])
temperature = float(sys.argv[5])
topK = int(sys.argv[6])
topP = float(sys.argv[7])
gpt2.generate(sess,
prefix=prefix,
model_dir=scriptDir+"/models",
checkpoint_dir=scriptDir+"/checkpoint",
nsamples=sampleNum,
length=sampleLen,
batch_size=batchSize,
temperature=temperature,
top_k=topK,
top_p=topP)
print("\nSample generation complete.")
# Generation: Use short length and more nsamples for more speed
# temperature: 0.0 - 1.0. Higher value = More random. 0.7 is somewhat normal