@@ -988,7 +988,7 @@ std::atomic<bool> g_running(false);
988
988
989
989
bool g_force_speak = false ;
990
990
std::string g_text_to_speak = " " ;
991
- std::string g_status = " idle " ;
991
+ std::string g_status = " " ;
992
992
std::string g_status_forced = " " ;
993
993
994
994
std::string gpt2_gen_text (const std::string & prompt) {
@@ -997,7 +997,7 @@ std::string gpt2_gen_text(const std::string & prompt) {
997
997
std::vector<float > embd_w;
998
998
999
999
// tokenize the prompt
1000
- std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize (g_gpt2.vocab , g_gpt2. prompt_base + prompt);
1000
+ std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize (g_gpt2.vocab , prompt);
1001
1001
1002
1002
g_gpt2.n_predict = std::min (g_gpt2.n_predict , g_gpt2.model .hparams .n_ctx - (int ) embd_inp.size ());
1003
1003
@@ -1088,6 +1088,8 @@ void talk_main(size_t index) {
1088
1088
printf (" gpt-2: model loaded in %d ms\n " , (int ) (t_load_us/1000 ));
1089
1089
}
1090
1090
1091
+ printf (" talk: using %d threads\n " , N_THREAD);
1092
+
1091
1093
std::vector<float > pcmf32;
1092
1094
1093
1095
auto & ctx = g_contexts[index ];
@@ -1214,53 +1216,60 @@ void talk_main(size_t index) {
1214
1216
printf (" whisper: number of tokens: %d, '%s'\n " , (int ) tokens.size (), text_heard.c_str ());
1215
1217
1216
1218
std::string text_to_speak;
1219
+ std::string prompt_base;
1220
+
1221
+ {
1222
+ std::lock_guard<std::mutex> lock (g_mutex);
1223
+ prompt_base = g_gpt2.prompt_base ;
1224
+ }
1217
1225
1218
1226
if (tokens.size () > 0 ) {
1219
- text_to_speak = gpt2_gen_text (text_heard + " \n " );
1227
+ text_to_speak = gpt2_gen_text (prompt_base + text_heard + " \n " );
1220
1228
text_to_speak = std::regex_replace (text_to_speak, std::regex (" [^a-zA-Z0-9\\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
1221
1229
text_to_speak = text_to_speak.substr (0 , text_to_speak.find_first_of (" \n " ));
1222
1230
1223
1231
std::lock_guard<std::mutex> lock (g_mutex);
1224
1232
1225
1233
// remove first 2 lines of base prompt
1226
1234
{
1227
- const size_t pos = g_gpt2. prompt_base .find_first_of (" \n " );
1235
+ const size_t pos = prompt_base.find_first_of (" \n " );
1228
1236
if (pos != std::string::npos) {
1229
- g_gpt2. prompt_base = g_gpt2. prompt_base .substr (pos + 1 );
1237
+ prompt_base = prompt_base.substr (pos + 1 );
1230
1238
}
1231
1239
}
1232
1240
{
1233
- const size_t pos = g_gpt2. prompt_base .find_first_of (" \n " );
1241
+ const size_t pos = prompt_base.find_first_of (" \n " );
1234
1242
if (pos != std::string::npos) {
1235
- g_gpt2. prompt_base = g_gpt2. prompt_base .substr (pos + 1 );
1243
+ prompt_base = prompt_base.substr (pos + 1 );
1236
1244
}
1237
1245
}
1238
- g_gpt2. prompt_base += text_heard + " \n " + text_to_speak + " \n " ;
1246
+ prompt_base += text_heard + " \n " + text_to_speak + " \n " ;
1239
1247
} else {
1240
- text_to_speak = gpt2_gen_text (" " );
1248
+ text_to_speak = gpt2_gen_text (prompt_base );
1241
1249
text_to_speak = std::regex_replace (text_to_speak, std::regex (" [^a-zA-Z0-9\\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
1242
1250
text_to_speak = text_to_speak.substr (0 , text_to_speak.find_first_of (" \n " ));
1243
1251
1244
1252
std::lock_guard<std::mutex> lock (g_mutex);
1245
1253
1246
- const size_t pos = g_gpt2. prompt_base .find_first_of (" \n " );
1254
+ const size_t pos = prompt_base.find_first_of (" \n " );
1247
1255
if (pos != std::string::npos) {
1248
- g_gpt2. prompt_base = g_gpt2. prompt_base .substr (pos + 1 );
1256
+ prompt_base = prompt_base.substr (pos + 1 );
1249
1257
}
1250
- g_gpt2. prompt_base += text_to_speak + " \n " ;
1258
+ prompt_base += text_to_speak + " \n " ;
1251
1259
}
1252
1260
1253
1261
printf (" gpt-2: %s\n " , text_to_speak.c_str ());
1254
1262
1255
1263
// printf("========================\n");
1256
- // printf("gpt-2: prompt_base:\n'%s'\n", g_gpt2. prompt_base.c_str());
1264
+ // printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
1257
1265
// printf("========================\n");
1258
1266
1259
1267
{
1260
1268
std::lock_guard<std::mutex> lock (g_mutex);
1261
1269
t_last = std::chrono::high_resolution_clock::now ();
1262
1270
g_text_to_speak = text_to_speak;
1263
1271
g_pcmf32.clear ();
1272
+ g_gpt2.prompt_base = prompt_base;
1264
1273
}
1265
1274
1266
1275
talk_set_status (" speaking ..." );
@@ -1376,4 +1385,11 @@ EMSCRIPTEN_BINDINGS(talk) {
1376
1385
g_status_forced = status;
1377
1386
}
1378
1387
}));
1388
+
1389
+ emscripten::function (" set_prompt" , emscripten::optional_override ([](const std::string & prompt) {
1390
+ {
1391
+ std::lock_guard<std::mutex> lock (g_mutex);
1392
+ g_gpt2.prompt_base = prompt;
1393
+ }
1394
+ }));
1379
1395
}
0 commit comments