Skip to content

Commit 9aea96f

Browse files
committed
talk.wasm : polishing + adding many AI personalities
1 parent 385236d commit 9aea96f

File tree

4 files changed

+383
-48
lines changed

4 files changed

+383
-48
lines changed

bindings/javascript/whisper.js

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/talk.wasm/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ In order to run this demo efficiently, you need to have the following:
3131
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
3232
- The web-page uses about 1.4GB of RAM
3333

34+
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
35+
Also, the prompting strategy can likely be improved to achieve better results.
36+
37+
The demo is quite computationally heavy - it's not usual to run these transformer models in a browser. Typically, they
38+
run on powerful GPU hardware. So for better experience, you do need to have a powerful computer.
39+
40+
Probably in the near future, mobile browsers will start to support the WASM SIMD capabilities and this will allow
41+
to run the demo on your phone or tablet. But for now it seems to be not supported (at least on iPhone).
42+
3443
## Feedback
3544

3645
If you have any comments or ideas for improvement, please drop a comment in the following discussion:

examples/talk.wasm/emscripten.cpp

+29-13
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ std::atomic<bool> g_running(false);
988988

989989
bool g_force_speak = false;
990990
std::string g_text_to_speak = "";
991-
std::string g_status = "idle";
991+
std::string g_status = "";
992992
std::string g_status_forced = "";
993993

994994
std::string gpt2_gen_text(const std::string & prompt) {
@@ -997,7 +997,7 @@ std::string gpt2_gen_text(const std::string & prompt) {
997997
std::vector<float> embd_w;
998998

999999
// tokenize the prompt
1000-
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, g_gpt2.prompt_base + prompt);
1000+
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, prompt);
10011001

10021002
g_gpt2.n_predict = std::min(g_gpt2.n_predict, g_gpt2.model.hparams.n_ctx - (int) embd_inp.size());
10031003

@@ -1088,6 +1088,8 @@ void talk_main(size_t index) {
10881088
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
10891089
}
10901090

1091+
printf("talk: using %d threads\n", N_THREAD);
1092+
10911093
std::vector<float> pcmf32;
10921094

10931095
auto & ctx = g_contexts[index];
@@ -1214,53 +1216,60 @@ void talk_main(size_t index) {
12141216
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
12151217

12161218
std::string text_to_speak;
1219+
std::string prompt_base;
1220+
1221+
{
1222+
std::lock_guard<std::mutex> lock(g_mutex);
1223+
prompt_base = g_gpt2.prompt_base;
1224+
}
12171225

12181226
if (tokens.size() > 0) {
1219-
text_to_speak = gpt2_gen_text(text_heard + "\n");
1227+
text_to_speak = gpt2_gen_text(prompt_base + text_heard + "\n");
12201228
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
12211229
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
12221230

12231231
std::lock_guard<std::mutex> lock(g_mutex);
12241232

12251233
// remove first 2 lines of base prompt
12261234
{
1227-
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1235+
const size_t pos = prompt_base.find_first_of("\n");
12281236
if (pos != std::string::npos) {
1229-
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1237+
prompt_base = prompt_base.substr(pos + 1);
12301238
}
12311239
}
12321240
{
1233-
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1241+
const size_t pos = prompt_base.find_first_of("\n");
12341242
if (pos != std::string::npos) {
1235-
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1243+
prompt_base = prompt_base.substr(pos + 1);
12361244
}
12371245
}
1238-
g_gpt2.prompt_base += text_heard + "\n" + text_to_speak + "\n";
1246+
prompt_base += text_heard + "\n" + text_to_speak + "\n";
12391247
} else {
1240-
text_to_speak = gpt2_gen_text("");
1248+
text_to_speak = gpt2_gen_text(prompt_base);
12411249
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
12421250
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
12431251

12441252
std::lock_guard<std::mutex> lock(g_mutex);
12451253

1246-
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1254+
const size_t pos = prompt_base.find_first_of("\n");
12471255
if (pos != std::string::npos) {
1248-
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1256+
prompt_base = prompt_base.substr(pos + 1);
12491257
}
1250-
g_gpt2.prompt_base += text_to_speak + "\n";
1258+
prompt_base += text_to_speak + "\n";
12511259
}
12521260

12531261
printf("gpt-2: %s\n", text_to_speak.c_str());
12541262

12551263
//printf("========================\n");
1256-
//printf("gpt-2: prompt_base:\n'%s'\n", g_gpt2.prompt_base.c_str());
1264+
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
12571265
//printf("========================\n");
12581266

12591267
{
12601268
std::lock_guard<std::mutex> lock(g_mutex);
12611269
t_last = std::chrono::high_resolution_clock::now();
12621270
g_text_to_speak = text_to_speak;
12631271
g_pcmf32.clear();
1272+
g_gpt2.prompt_base = prompt_base;
12641273
}
12651274

12661275
talk_set_status("speaking ...");
@@ -1376,4 +1385,11 @@ EMSCRIPTEN_BINDINGS(talk) {
13761385
g_status_forced = status;
13771386
}
13781387
}));
1388+
1389+
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
1390+
{
1391+
std::lock_guard<std::mutex> lock(g_mutex);
1392+
g_gpt2.prompt_base = prompt;
1393+
}
1394+
}));
13791395
}

0 commit comments

Comments
 (0)