-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcpp_bind.js
645 lines (585 loc) · 20.6 KB
/
cpp_bind.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
//---------------------------
// Dependencies
//---------------------------
// Node deps
const path = require("path")
const util = require("util")
// Get the koffi
const koffi = require("koffi");
//---------------------------
// Lib selection
//---------------------------
// The lib path to use
let rwkvCppLibPath = null;
// Check which platform we're on
if( process.arch === 'arm64' ) {
if( process.platform === 'darwin' ) {
rwkvCppLibPath = './lib/librwkv-arm64.dylib';
} else if( process.platform === 'linux' ) {
rwkvCppLibPath = './lib/librwkv-arm64.so';
} else {
throw new Error('Unsupported RWKV.cpp platform / arch: ' + process.platform + ' / ' + process.arch);
}
} else if( process.arch === 'x64' ) {
if( process.platform === 'win32' ) {
// We only do CPU feature detection in windows
// due to the different libraries with varients in AVX support
//
// Note as this is an optional dependency,
// it can fail to load/compile for random reasons
let cpuFeatures = null;
try {
cpuFeatures = require('cpu-features')();
} catch( err ) {
// Silently ignore, we assume only avx is supported
}
// Load the highest AVX supported CPU when possible
if( cpuFeatures == null ) {
// console.warn("cpu-features failed to load, assuming AVX CPU is supported")
rwkvCppLibPath = './lib/rwkv-avx.dll';
} else if( cpuFeatures.flags.avx512 ) {
rwkvCppLibPath = './lib/rwkv-avx512.dll';
} else if( cpuFeatures.flags.avx2 ) {
rwkvCppLibPath = './lib/rwkv-avx2.dll';
} else {
// AVX detection is not reliable, so if we fail to detect, we downgrade to lowest avx version
rwkvCppLibPath = './lib/rwkv-avx.dll';
}
} else if( process.platform === 'darwin' ) {
rwkvCppLibPath = './lib/librwkv.dylib';
} else if( process.platform === 'linux' ) {
rwkvCppLibPath = './lib/librwkv.so';
} else {
throw new Error('Unsupported RWKV.cpp platform / arch: ' + process.platform + ' / ' + process.arch);
}
} else {
throw new Error("Unsupported RWKV.cpp arch: " + process.arch);
}
// The lib path to use
const rwkvCppFullLibPath = path.resolve(__dirname, "..", rwkvCppLibPath);
//---------------------------
// Lib binding loading
//---------------------------
const rwkvKoffiBind = koffi.load(rwkvCppFullLibPath);
// Custom pointers, to avoid copying data to JS land
const ctx_pointer = koffi.pointer('CTX_HANDLE', koffi.opaque());
// Initializing / cloning process
const rwkv_init_from_file = rwkvKoffiBind.func('CTX_HANDLE rwkv_init_from_file(const char * model_file_path, uint32_t n_threads)');
const rwkv_clone_context = rwkvKoffiBind.func('CTX_HANDLE rwkv_clone_context(CTX_HANDLE ctx, uint32_t n_threads)');
const rwkv_gpu_offload_layers = rwkvKoffiBind.func('bool rwkv_gpu_offload_layers(CTX_HANDLE ctx, uint32_t n_gpu_layers)');
// Model info extraction
const rwkv_get_n_vocab = rwkvKoffiBind.func('size_t rwkv_get_n_vocab(CTX_HANDLE ctx)');
const rwkv_get_n_embed = rwkvKoffiBind.func('size_t rwkv_get_n_embed(CTX_HANDLE ctx)');
const rwkv_get_n_layer = rwkvKoffiBind.func('size_t rwkv_get_n_layer(CTX_HANDLE ctx)');
const rwkv_get_state_len = rwkvKoffiBind.func('size_t rwkv_get_state_len(CTX_HANDLE ctx)');
const rwkv_get_logits_len = rwkvKoffiBind.func('size_t rwkv_get_logits_len(CTX_HANDLE ctx)');
// Eval sequence
const rwkv_eval = rwkvKoffiBind.func('bool rwkv_eval(CTX_HANDLE ctx, int32_t token, const float * state_in, _Out_ float * state_out, _Out_ float * logits_out)');
const rwkv_eval_sequence = rwkvKoffiBind.func('bool rwkv_eval_sequence(CTX_HANDLE ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, _Out_ float * state_out, _Out_ float * logits_out)');
// // Unsupported functions (due to API integration limitation)
// const rwkv_init_state = rwkvKoffiBind.func('void rwkv_init_state(CTX_HANDLE ctx, float * state)');
// const rwkv_set_print_errors = rwkvKoffiBind.func('void rwkv_set_print_errors(CTX_HANDLE ctx, bool print_errors)');
// const rwkv_get_print_errors = rwkvKoffiBind.func('bool rwkv_get_print_errors(CTX_HANDLE ctx)');
// const rwkv_get_last_error = rwkvKoffiBind.func('enum rwkv_error_flags rwkv_get_last_error(CTX_HANDLE ctx)');
// const rwkv_get_system_info_string = rwkvKoffiBind.func('const char * rwkv_get_system_info_string()');
// Quantizing models
const rwkv_quantize_model_file = rwkvKoffiBind.func('bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name)');
// Context destruction
const rwkv_free = rwkvKoffiBind.func('void rwkv_free(CTX_HANDLE ctx)');
//---------------------------
// Module export
//---------------------------
module.exports = {
// The path to the lib used
_libPath: rwkvCppFullLibPath,
/**
* Loads the model from a file and prepares it for inference.
* Returns NULL on any error. Error messages would be printed to stderr.
*
* @param {String} model_file_path path to model file in ggml format.
* @param {Number} n_threads number of threads to use for inference.
*
* @returns {ffi_pointer} Pointer to the RWKV context.
*/
async rwkv_init_from_file(model_file_path, n_threads) {
return new Promise((resolve, reject) => {
rwkv_init_from_file.async(
model_file_path,
n_threads,
(err, ctx) => {
if (err) {
reject(err);
} else {
resolve(ctx);
}
}
);
});
},
/**
* Offloads the specified layers to the GPU.
* Returns false on any error. Error messages would be printed to stderr.
*
*/
async rwkv_gpu_offload_layers(ctx, gpu_id) {
return new Promise((resolve, reject) => {
rwkv_gpu_offload_layers.async(ctx, gpu_id, (err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
});
});
},
/**
* Frees all allocated memory and the context.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
**/
async rwkv_free(ctx) {
return new Promise((resolve, reject) => {
rwkv_free.async(ctx, (err) => {
if (err) {
reject(err);
} else {
resolve();
}
});
});
},
/**
* Evaluates the model for a single token.
* Returns false on any error. Error messages would be printed to stderr.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} token - The token to evaluate.
* @param {ffi_pointer} state_in - The input state.
* @param {ffi_pointer} state_out - The output state.
* @param {ffi_pointer} logits_out - The output logits.
*
* @returns {Boolean} True if successful, false if not.
**/
async rwkv_eval(ctx, token, state_in, state_out, logits_out) {
return new Promise((resolve, reject) => {
rwkv_eval.async(
ctx,
token,
state_in,
state_out,
logits_out,
(err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
}
);
});
},
/** Evaluates the model for a sequence of tokens.
* Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
* Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
* Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
* Returns false on any error.
* @param tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
* @param sequence_len: number of tokens to read from the array.
* @param state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
* @param state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
* @param logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
**/
async rwkv_eval_sequence(
ctx,
tokens,
sequence_len,
state_in,
state_out,
logits_out
) {
return new Promise((resolve, reject) => {
rwkv_eval_sequence.async(
ctx,
tokens,
sequence_len,
state_in,
state_out,
logits_out,
(err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
}
);
});
},
/**
* Returns count of FP32 elements in state buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the state buffer.
**/
async rwkv_get_state_buffer_element_count(ctx) {
return new Promise((resolve, reject) => {
rwkv_get_state_buffer_element_count.async(ctx, (err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
});
});
},
/**
* Returns count of FP32 elements in logits buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the logits buffer.
**/
async rwkv_get_logits_buffer_element_count(ctx) {
return new Promise((resolve, reject) => {
rwkv_get_logits_buffer_element_count.async(ctx, (err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
});
});
},
/**
* Quantizes the model file.
* Returns false on any error. Error messages would be printed to stderr.
*
* Available format names:
* - Q4_0
* - Q4_1
* - Q4_2
* - Q5_0
* - Q5_1
* - Q8_0
*
* @param {String} model_file_path_in - Path to the input model file in ggml format.
* @param {String} model_file_path_out - Path to the output model file in ggml format.
* @param {String} format_name - The format to use for quantization.
*
* @returns {Boolean} True if successful, false if not.
**/
async rwkv_quantize_model_file(
model_file_path_in,
model_file_path_out,
format_name
) {
return new Promise((resolve, reject) => {
rwkv_quantize_model_file.async(
model_file_path_in,
model_file_path_out,
format_name,
(err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
}
);
});
},
// The path to the lib used
_libPath: rwkvCppFullLibPath,
// Initializing / cloning process
// ---
/**
* Loads the model from a file and prepares it for inference.
* Returns NULL on any error. Error messages would be printed to stderr.
*
* @param {String} model_file_path - path to model file in ggml format.
* @param {Number} n_threads - number of CPU threads to use for inference.
*
* @returns {ffi_pointer} Pointer to the RWKV context.
*/
rwkv_init_from_file: rwkv_init_from_file,
/**
* Creates a new context from an existing one.
* This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
* Each rwkv_context can have one eval running at a time.
* Every rwkv_context must be freed using rwkv_free.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} n_threads - number of CPU threads to use for inference.
*
* @returns {ffi_pointer} Pointer to the new RWKV context.
*/
rwkv_clone_context: rwkv_clone_context,
/**
* Offloads the specified layers to the GPU.
* Returns false on any error. Error messages would be printed to stderr.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} n_gpu_layers - number of GPU layers to offload
*/
rwkv_gpu_offload_layers: rwkv_gpu_offload_layers,
// Model info extraction
// ---
/**
* Returns count of FP32 elements in state buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the state buffer.
**/
rwkv_get_state_len: rwkv_get_state_len,
/**
* Returns count of FP32 elements in logits buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the logits buffer.
**/
rwkv_get_logits_len: rwkv_get_logits_len,
/**
* Returns count of FP32 number of layers
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of layers
**/
rwkv_get_n_layer: rwkv_get_n_layer,
/**
* Returns count of FP32 number of embed params
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of embed params
**/
rwkv_get_n_embed: rwkv_get_n_embed,
/**
* Returns count of FP32 number of vocab params
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of embed params
**/
rwkv_get_n_vocab: rwkv_get_n_vocab,
// Eval sequences
// ---
/**
* Evaluates the model for a single token.
* Returns false on any error. Error messages would be printed to stderr.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} token - The token to evaluate.
* @param {ffi_pointer} state_in - The input state.
* @param {ffi_pointer} state_out - The output state.
* @param {ffi_pointer} logits_out - The output logits.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_eval: rwkv_eval,
/**
* Evaluates the model for a sequence of tokens.
* Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
* Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
* Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
* Returns false on any error.
*
* @param tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
* @param sequence_len: number of tokens to read from the array.
* @param state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
* @param state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
* @param logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_eval_sequence : rwkv_eval_sequence,
// Quantizing models
// ---
/**
* Quantizes the model file.
* Returns false on any error. Error messages would be printed to stderr.
*
* Available format names:
* - Q4_0
* - Q4_1
* - Q4_2
* - Q5_0
* - Q5_1
* - Q8_0
*
* (For async op, just call the <function-name>.async varient)
*
* @param {String} model_file_path_in - Path to the input model file in ggml format.
* @param {String} model_file_path_out - Path to the output model file in ggml format.
* @param {String} format_name - The quantization format to use.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_quantize_model_file: rwkv_quantize_model_file,
// Context destruction
// ---
/**
* Frees all allocated memory and the context.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
**/
rwkv_free: rwkv_free,
// ====
// Promise Varient
// ====
promises: {
/**
* @async
*
* Loads the model from a file and prepares it for inference.
* Returns NULL on any error. Error messages would be printed to stderr.
*
* @param {String} model_file_path - path to model file in ggml format.
* @param {Number} n_threads - number of CPU threads to use for inference.
*
* @returns {ffi_pointer} Pointer to the RWKV context.
*/
rwkv_init_from_file: util.promisify(rwkv_init_from_file.async),
/**
* @async
*
* Creates a new context from an existing one.
* This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
* Each rwkv_context can have one eval running at a time.
* Every rwkv_context must be freed using rwkv_free.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} n_threads - number of CPU threads to use for inference.
*
* @returns {ffi_pointer} Pointer to the new RWKV context.
*/
rwkv_clone_context: util.promisify(rwkv_clone_context.async),
/**
* @async
*
* Offloads the specified layers to the GPU.
* Returns false on any error. Error messages would be printed to stderr.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} n_gpu_layers - number of GPU layers to offload
*/
rwkv_gpu_offload_layers: util.promisify(rwkv_gpu_offload_layers.async),
/**
* @async
*
* Returns count of FP32 elements in state buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the state buffer.
**/
rwkv_get_state_len: util.promisify(rwkv_get_state_len.async),
/**
* @async
*
* Returns count of FP32 elements in logits buffer.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of elements in the logits buffer.
**/
rwkv_get_logits_len: util.promisify(rwkv_get_logits_len.async),
/**
* @async
*
* Returns count of FP32 number of layers
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of layers
**/
rwkv_get_n_layer: util.promisify(rwkv_get_n_layer.async),
/**
* @async
*
* Returns count of FP32 number of embed params
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of embed params
**/
rwkv_get_n_embed: util.promisify(rwkv_get_n_embed.async),
/**
* @async
*
* Returns count of FP32 number of vocab params
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
*
* @returns {Number} The number of embed params
**/
rwkv_get_n_vocab: rwkv_get_n_vocab,
/**
* @async
*
* Evaluates the model for a single token.
* Returns false on any error. Error messages would be printed to stderr.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
* @param {Number} token - The token to evaluate.
* @param {ffi_pointer} state_in - The input state.
* @param {ffi_pointer} state_out - The output state.
* @param {ffi_pointer} logits_out - The output logits.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_eval: util.promisify(rwkv_eval.async),
/**
* @async
*
* Evaluates the model for a sequence of tokens.
* Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
* Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
* Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
* Returns false on any error.
*
* @param tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
* @param sequence_len: number of tokens to read from the array.
* @param state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
* @param state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
* @param logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_eval_sequence: util.promisify(rwkv_eval_sequence.async),
/**
* @async
*
* Quantizes the model file.
* Returns false on any error. Error messages would be printed to stderr.
*
* Available format names:
* - Q4_0
* - Q4_1
* - Q4_2
* - Q5_0
* - Q5_1
* - Q8_0
*
* (For async op, just call the <function-name>.async varient)
*
* @param {String} model_file_path_in - Path to the input model file in ggml format.
* @param {String} model_file_path_out - Path to the output model file in ggml format.
* @param {String} format_name - The quantization format to use.
*
* @returns {Boolean} True if successful, false if not.
**/
rwkv_quantize_model_file: util.promisify(rwkv_quantize_model_file.async),
/**
* @async
*
* Frees all allocated memory and the context.
*
* @param {ffi_pointer} ctx - Pointer to the RWKV context.
**/
rwkv_free: util.promisify(rwkv_free.async),
}
}