-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_madlibs.py
194 lines (161 loc) · 6.24 KB
/
generate_madlibs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np
import pandas as pd
import argparse
import openai
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import random
parser = argparse.ArgumentParser(description='Args')
parser.add_argument('--save_here',
default=f'{__file__}/data/madlibs/madlibs1.csv',
help='Where save?')
parser.add_argument('--num_people',
default=500,
type=int,
help='Where results to save?')
parser.add_argument('--openai_key',
help='Your OpenAI key')
args = parser.parse_args()
openai.api_key = args.openai_key
def askGPT(prompt, model='gpt-4'):
## JUST KEEP TRYING!
retry_limit = 10
retry_count = 0
while retry_count < retry_limit:
try:
response = openai.ChatCompletion.create(
messages=[
{'role': 'system', 'content': 'You are a helpful assistant bot.'},
{'role': 'user', 'content': prompt},
],
model=model
)
break
except Exception as e:
print(f"Error occurred: {e}. Retrying...")
retry_count += 1
# Probably hit rate limit, so let's just wait.
time.sleep(60)
if retry_count == retry_limit:
print("Reached maximum retry limit.")
reply = response['choices'][0]['message']['content']
return reply
print("GETTING ORIGINS")
origins_prompt = "Give me a list of 50 origins for people (example: Canadian, Texan, European, Monagasque), output it as a comma separated list. Output the list only, nothing else."
duped_origins = []
for k in range(15):
print(k)
tmp = askGPT(origins_prompt, model='gpt-3.5-turbo')
print(tmp)
duped_origins = duped_origins + tmp.split(", ")
origins = np.unique(np.array(duped_origins).flatten())
# Hack to make sure it's just the origin
origins = [s for s in origins if len(s) < 30]
print(len(origins))
print("GETTING JOBS")
jobs_prompt = "Give me a list of 50 jobs that people can become famous for (example: programmer, entrepreneur, taxi driver, basketball player), output it as a comma separated list. Output the list only, nothing else."
duped_jobs = []
for k in range(15):
print(k)
tmp = askGPT(jobs_prompt, model='gpt-3.5-turbo')
print(tmp)
duped_jobs = duped_jobs + tmp.split(", ")
jobs = np.unique(np.array(duped_jobs).flatten())
jobs = [s for s in jobs if len(s) < 40]
print(len(jobs))
print("Geting names")
non_duped_example_jobs = []
non_duped_example_orgins = []
non_duped_example_names = []
for j in range(args.num_people):
an_origin = random.choice(origins)
a_job = random.choice(jobs)
name_prompt = f"I am writing a novel, help me come up with a name for a famous {an_origin} {a_job}. Do not output the name of someone who already exists. Output only the name, no explanation."
name = askGPT(name_prompt, model='gpt-3.5-turbo')
# Hack to make sure it's just the name.
if len(name) < 30:
print(name)
if name not in non_duped_example_names:
non_duped_example_names.append(name)
non_duped_example_orgins.append(an_origin)
non_duped_example_jobs.append(a_job)
questions = [
"In which city was PERSON born?",
"What year was PERSON born?",
"Where did PERSON go to college?",
"What is the name of PERSON's spouse?",
"What is the name of the first company PERSON worked at?",
"What is the company PERSON founded called?",
"What is the title of the film PERSON directed?",
"Who is PERSON's idol?",
"What is the name of PERSON's pet?",
"What is PERSON's favorite color?",
"Where did PERSON go to high school?",
"What is the name of PERSON's best friend?",
"What is the title of PERSON's favorite movie?",
"In what year did PERSON get married?",
"What is the title of PERSON's favorite book?",
"What is the name of PERSON's first child?",
"What is the name of PERSON's favorite sports team?",
"In which country was PERSON born?",
"What was the title of PERSON's PhD thesis?",
"What sport does PERSON play?"
]
madlibs = pd.DataFrame({})
def doOne(name, job, origin):
q = random.sample(questions, 2)
question1 = q[0].replace("PERSON", name)
question2 = q[1].replace("PERSON", name)
wiki_prompt = f"""Please write a one paragraph wikipedia article for a famous {origin} {job} named {name}.
Make sure the article contains information that can answer the following questions:
{question1}
{question2}
Output the article only, no extraneous explanation.
"""
wiki = askGPT(wiki_prompt)
print("WIKI EXAMPLE")
print(wiki)
q1_prompt = f"""Here is a short wikipedia article.
##ARTICLE
{wiki}
Can you answer the following question?
##QUESTION
{question1}
Keep the answer as short as possible. If you can answer in one or two words, do that.
"""
a1 = askGPT(q1_prompt)
q2_prompt = f"""Here is a short wikipedia article.
##ARTICLE
{wiki}
Can you answer the following question?
##QUESTION
{question2}
Keep the answer as short as possible. If you can answer in one or two words, do that.
"""
a2 = askGPT(q2_prompt)
row1 = pd.DataFrame({'context': wiki,
'question': question1,
'answer': a1
}, index=[0])
row2 = pd.DataFrame({'context': wiki,
'question': question2,
'answer': a2
}, index=[0])
row = pd.concat([row1, row2], ignore_index=True)
return row
def worker(n, j, o):
return doOne(name=n, job=j, origin=o)
data = list(zip(non_duped_example_names, non_duped_example_jobs, non_duped_example_orgins))
pbar = tqdm(total=len(data), desc='Processing data')
results = []
print("RUNNING WHOLE THING")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(worker, n, j, o): (n, j, o) for (n, j, o) in data}
for future in as_completed(futures):
result = future.result()
results.append(result)
pbar.update(1)
pbar.close()
madlibs = pd.concat(results, ignore_index=True)
madlibs.to_csv(args.save_here, index=False)