-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_gpt2.zig
2019 lines (1915 loc) · 95.2 KB
/
train_gpt2.zig
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//! GPT2 in Zig, Inspired by Karpathy
//! An implementation of the LLMs in Zig based on Andrej Karpathy's minGPT
//!
const std = @import("std");
const math = @import("std").math;
const builtin = @import("builtin");
const LlmFloat = f32; // To easily change the precision of the model
const FloatList = std.ArrayList(LlmFloat);
pub const UIntList = std.ArrayList(u32);
extern fn exit() noreturn;
const VectorSize: usize = 8; // This seems the best choice for the vectorization on M2 macbook pro
const RndGen = std.rand.DefaultPrng;
pub const VecType = @Vector(VectorSize, LlmFloat);
const Errors = error{InvalidModelHeader, InvalidModelVersion, InvalidTokensFile};
const program_name = "LLM GPT-2";
const NumParameterTensors = 16;
const NumActivationTensors = 23;
const FILE_HEADER = 20240326;
const FILE_VERSION = 1;
const FILE_HEADER_SIZE = 256;
const cfuncs = @cImport({
@cInclude("math.h");
});
const time = std.time;
const Instant = time.Instant;
const Timer = time.Timer;
const ParameterTensors = struct {
wte: []LlmFloat, // (V, C)
wpe: []LlmFloat, // (maxT, C)
ln1w: []LlmFloat, // (L, C)
ln1b: []LlmFloat, // (L, C)
qkvw: []LlmFloat, // (L, 3*C, C)
qkvb: []LlmFloat, // (L, 3*C)
attprojw: []LlmFloat, // (L, C, C)
attprojb: []LlmFloat, // (L, C)
ln2w: []LlmFloat, // (L, C)
ln2b: []LlmFloat, // (L, C)
fcw: []LlmFloat, // (L, 4*C, C)
fcb: []LlmFloat, // (L, 4*C)
fcprojw: []LlmFloat, // (L, C, 4*C)
fcprojb: []LlmFloat, // (L, C)
lnfw: []LlmFloat, // (L, C)
lnfb: []LlmFloat, // (L, C)
};
const ActivationTensors = struct {
encoded: []LlmFloat, // (B, T, C)
ln1: []LlmFloat, // (L, B, T, C)
ln1_mean: []LlmFloat, // (L, B, T)
ln1_rstd: []LlmFloat, // (L, B, T)
qkv: []LlmFloat, // (L, B, T, 3*C)
atty: []LlmFloat, // (L, B, T, C)
preatt: []LlmFloat, // (L, B, NH, T, T)
att: []LlmFloat, // (L, B, NH, T, T)
attproj: []LlmFloat, // (L, B, T, C)
residual2: []LlmFloat, // (L, B, T, C)
ln2: []LlmFloat, // (L, B, T, C)
ln2_mean: []LlmFloat, // (L, B, T)
ln2_rstd: []LlmFloat, // (L, B, T)
fch: []LlmFloat, // (L, B, T, 4*C)
fch_gelu: []LlmFloat, // (L, B, T, 4*C)
fcproj: []LlmFloat, // (L, B, T, C)
residual3: []LlmFloat, // (L, B, T, C)
lnf: []LlmFloat, // (B, T, C)
lnf_mean: []LlmFloat, // (B, T)
lnf_rstd: []LlmFloat, // (B, T)
logits: []LlmFloat, // (B, T, V)
probs: []LlmFloat, // (B, T, V)
losses: []LlmFloat, // (B, T)
};
const GPT2Config = struct {
model_header: u32, // header
model_version: u32, // version
max_seq_len: u32, // max sequence length, e.g. 1024
vocab_size: u32, // vocab size, e.g. 50257
num_layers: u32, // number of layers, e.g. 12
num_heads: u32, // number of heads in attention, e.g. 12
channels: u32, // number of channels, e.g. 768
};
const GPT2 = struct {
init_params: bool,
init_grads: bool,
init_grads_acts: bool,
config: GPT2Config,
params: ParameterTensors,
params_sizes: [NumParameterTensors]u32, // Change to zigtype isize
params_memory: []LlmFloat,
num_parameters: u32,
// gradients of the weights
grads: ParameterTensors,
grads_memory: []LlmFloat,
// buffers for the AdamW optimizer
m_memory: []LlmFloat,
v_memory: []LlmFloat,
init_adam: bool,
// the activations of the model, and their sizes
acts: ActivationTensors,
act_sizes: [NumActivationTensors]u32,
acts_memory: []LlmFloat,
num_activations: u32,
// gradients of the activations
grads_acts: ActivationTensors,
grads_acts_memory: []LlmFloat,
// other run state configuration
batch_size: u32, // the batch size (B) of current forward pass
seq_len: u32, // the sequence length (T) of current forward pass
inputs: []u32, // the input tokens for the current forward pass
targets: []u32, // the target tokens for the current forward pass
mean_loss: LlmFloat, // after a forward pass with targets, will be populated with the mean loss
};
const DataLoader = struct {
B: u32, // batch size
T: u32, // sequence length
tokens_file: std.fs.File,
file_size: u64,
current_position: u64,
batch: []u32,
inputs: []u32,
targets: []u32,
num_batches: u32,
raw_data: []u8,
};
pub const Token = []u8;
pub const Tokenizer = struct {
vocab_size: u32,
vocab_map: [*][]u8,
init_ok: bool,
};
pub fn read_n_parameters_from_file(comptime T:type, file_name: [] const u8, N :usize, offset:usize) !std.ArrayList(T) {
var file = try std.fs.cwd().openFile(file_name, .{ .mode = .read_only, });
defer file.close();
const file_size = try file.getEndPos();
if (file_size < N * @sizeOf(T)) {
return error.NotEnoughParameters;
}
if (file_size == 0) {
return error.FileEmpty;
}
try file.seekTo(offset * @sizeOf(T));
// Create an ArrayList to hold the data read.
var data = try std.ArrayList(T).initCapacity(std.heap.page_allocator, N);
try std.ArrayList(T).resize(&data, N);
const bytes = std.mem.sliceAsBytes(data.items);
_ = try file.read(bytes);
return data;
}
pub fn tokenizer_free(allocator: std.mem.Allocator, tokenizer:*Tokenizer) void{
if(tokenizer.init_ok){
for(0..tokenizer.vocab_size) |i| {
allocator.free(tokenizer.vocab_map[i]);
}
}
}
pub fn tokenizer_init(allocator: std.mem.Allocator, tokenizer:*Tokenizer, filename:[] const u8) !void{
const model_header: UIntList = try read_n_parameters_from_file(u32,filename, FILE_HEADER_SIZE, 0);
if (model_header.items[0] != 20240328){
return Errors.InvalidModelHeader;
}
if (model_header.items[1] != 1){
return Errors.InvalidModelVersion;
}
var file = try std.fs.cwd().openFile(filename, .{});
defer file.close();
try file.seekTo(FILE_HEADER_SIZE * @sizeOf(u32));
tokenizer.vocab_size = model_header.items[2];
const val = try allocator.alloc([]u8, tokenizer.vocab_size);
tokenizer.vocab_map = val.ptr;
for(0..tokenizer.vocab_size) |i| {
var token_length: [1]u8 = undefined;
//var token_data: [64]u8 = undefined;
const read_token_length = try file.read(&token_length);
if (@as(u32, @intCast(read_token_length)) == 0){
return Errors.InvalidTokensFile;
}
const dyn_buffer = try allocator.alloc(u8, @as(u32, @intCast(token_length[0])));
const read_token_bytes = try file.read(dyn_buffer);
if (@as(u32, @intCast(read_token_bytes)) == 0){
return Errors.InvalidTokensFile;
}
tokenizer.vocab_map[i] = dyn_buffer;
}
const sample_token = tokenizer.vocab_map[50001];
std.debug.print("sample token: {s}\n", .{sample_token});
tokenizer.init_ok = true;
}
pub fn tokenizer_decode(tokenizer:Tokenizer, input:u32) []u8{
if(tokenizer.init_ok){
if (input < tokenizer.vocab_size){
return tokenizer.vocab_map[input];
}
}
return undefined;
}
/// Print some the first and last elements of the parameters and the size of the parameters mostly for debugging
pub fn printParams(model: GPT2) void {
std.debug.print("params.wte size {} first {d:.3} last {d:.3}\n", .{model.params.wte.len, model.params.wte[0], model.params.wte[model.params.wte.len - 1]});
std.debug.print("params.wpe size {} first {d:.3} last {d:.3}\n", .{model.params.wpe.len, model.params.wpe[0], model.params.wpe[model.params.wpe.len - 1]});
std.debug.print("params.ln1w size {} first {d:.3} last {d:.3}\n", .{model.params.ln1w.len, model.params.ln1w[0], model.params.ln1w[model.params.ln1w.len - 1]});
std.debug.print("params.ln1b size {} first {d:.3} last {d:.3}\n", .{model.params.ln1b.len, model.params.ln1b[0], model.params.ln1b[model.params.ln1b.len - 1]});
std.debug.print("params.qkvw size {} first {d:.3} last {d:.3}\n", .{model.params.qkvw.len, model.params.qkvw[0], model.params.qkvw[model.params.qkvw.len - 1]});
std.debug.print("params.qkvb size {} first {d:.3} last {d:.3}\n", .{model.params.qkvb.len, model.params.qkvb[0], model.params.qkvb[model.params.qkvb.len - 1]});
std.debug.print("params.attprojw size {} first {d:.3} last {d:.3}\n", .{model.params.attprojw.len, model.params.attprojw[0], model.params.attprojw[model.params.attprojw.len - 1]});
std.debug.print("params.attprojb size {} first {d:.3} last {d:.3}\n", .{model.params.attprojb.len, model.params.attprojb[0], model.params.attprojb[model.params.attprojb.len - 1]});
std.debug.print("params.ln2w size {} first {d:.3} last {d:.3}\n", .{model.params.ln2w.len, model.params.ln2w[0], model.params.ln2w[model.params.ln2w.len - 1]});
std.debug.print("params.ln2b size {} first {d:.3} last {d:.3}\n", .{model.params.ln2b.len, model.params.ln2b[0], model.params.ln2b[model.params.ln2b.len - 1]});
std.debug.print("params.fcw size {} first {d:.3} last {d:.3}\n", .{model.params.fcw.len, model.params.fcw[0], model.params.fcw[model.params.fcw.len - 1]});
std.debug.print("params.fcb size {} first {d:.3} last {d:.3}\n", .{model.params.fcb.len, model.params.fcb[0], model.params.fcb[model.params.fcb.len - 1]});
std.debug.print("params.fcprojw size {} first {d:.3} last {d:.3}\n", .{model.params.fcprojw.len, model.params.fcprojw[0], model.params.fcprojw[model.params.fcprojw.len - 1]});
std.debug.print("params.fcprojb size {} first {d:.3} last {d:.3}\n", .{model.params.fcprojb.len, model.params.fcprojb[0], model.params.fcprojb[model.params.fcprojb.len - 1]});
std.debug.print("params.lnfw size {} first {d:.3} last {d:.3}\n", .{model.params.lnfw.len, model.params.lnfw[0], model.params.lnfw[model.params.lnfw.len - 1]});
std.debug.print("params.lnfb size {} first {d:.3} last {d:.3}\n", .{model.params.lnfb.len, model.params.lnfb[0], model.params.lnfb[model.params.lnfb.len - 1]});
}
/// Encodes the input tokens into the model's input tensor by combining token embeddings and position embeddings.
/// This is often the first step in transformer models like GPT, where the input sequence is converted into
/// a more richly represented format for further processing.
/// @param output: The output buffer where the encoded tensor will be stored.
/// @param input: An array of token indices representing the input sequence.
/// @param wte: The token embedding weights (Vocabulary size x Embedding dimension).
/// @param wpe: The positional embedding weights (Max sequence length x Embedding dimension).
/// @param B: The batch size, indicating the number of sequences being processed simultaneously.
/// @param T: The sequence length of the input.
/// @param C: The embedding dimension or channel size of the embeddings.
pub fn encoder_forward( output : []LlmFloat, input: []u32, wte: []LlmFloat, wpe: []LlmFloat, B:u32, T:u32, C:u32) void
{
for(0..B) |b| {
for(0..T) |t| {
var out_bt = output[b * T * C + t * C..];
// Get the index of the token at inp[b, t]
const ix: u32 = input[b * T + t];
// Seek to the position in wte corresponding to the token
const wte_ix = wte[ix * C..];
// Seek to the position in wpe corresponding to the position
const wpe_t = wpe[t * C..];
for(0..C) |c| {
out_bt[c] = wte_ix[c] + wpe_t[c];
}
}
}
}
/// Vectorized
/// Encodes the input tokens into the model's input tensor by combining token embeddings and position embeddings.
/// This is often the first step in transformer models like GPT, where the input sequence is converted into
/// a more richly represented format for further processing.
/// @param output: The output buffer where the encoded tensor will be stored.
/// @param input: An array of token indices representing the input sequence.
/// @param wte: The token embedding weights (Vocabulary size x Embedding dimension).
/// @param wpe: The positional embedding weights (Max sequence length x Embedding dimension).
/// @param B: The batch size, indicating the number of sequences being processed simultaneously.
/// @param T: The sequence length of the input.
/// @param C: The embedding dimension or channel size of the embeddings.
pub fn encoder_forward_vec( comptime N: usize, output : []LlmFloat, input: []u32, wte: []LlmFloat, wpe: []LlmFloat,
B:u32, T:u32, C:u32) void
{
for(0..B) |b| {
for(0..T) |t| {
for(0..C/N) |i| {
const ix: u32 = input[b * T + t];
const wte_ix : @Vector(N , LlmFloat) = wte[ix * C + i*N..][0..N].*;
const wpe_t : @Vector(N , LlmFloat) = wpe[t * C + i*N..][0..N].*;
const res: [N]LlmFloat = wte_ix + wpe_t;
const start = b * T * C + t * C + i * N;
const end = start + N;
@memcpy(output[start..end], &res);
}
}
}
}
/// Computes gradients for the encoder's forward pass. This is crucial for training, as it helps in optimizing
/// the token and position embeddings by backpropagating errors from the output towards the inputs.
/// @param dwte: Gradient with respect to the token embeddings.
/// @param dwpe: Gradient with respect to the positional embeddings.
/// @param dout: Gradient coming from the next layer (upstream gradients).
/// @param inp: Array of token indices (input to the forward function).
/// @param B: Batch size, as defined in the forward function.
/// @param T: Sequence length, as defined in the forward function.
/// @param C: Embedding dimension, as defined in the forward function.
pub fn encoder_backward( dwte: []LlmFloat, dwpe: []LlmFloat, dout: []LlmFloat, inp: []u32, B: u32, T: u32, C: u32) void
{
for (0..B) |b| {
for (0..T) |t| {
const dout_bt = dout[b * T * C + t * C ..];
const ix: u32 = inp[b * T + t];
var dwte_ix = dwte[ix * C ..];
var dwpe_t = dwpe[t * C ..];
for (0..C) |c| {
const d = dout_bt[c];
dwte_ix[c] += d;
dwpe_t[c] += d;
}
}
}
}
/// Vectorized:
/// Computes gradients for the encoder's forward pass. This is crucial for training, as it helps in optimizing
/// the token and position embeddings by backpropagating errors from the output towards the inputs.
/// @param dwte: Gradient with respect to the token embeddings.
/// @param dwpe: Gradient with respect to the positional embeddings.
/// @param dout: Gradient coming from the next layer (upstream gradients).
/// @param inp: Array of token indices (input to the forward function).
/// @param B: Batch size, as defined in the forward function.
/// @param T: Sequence length, as defined in the forward function.
/// @param C: Embedding dimension, as defined in the forward function.
pub fn encoder_backward_vec( comptime N: u32, dwte: []LlmFloat, dwpe: []LlmFloat, dout: []LlmFloat, inp: []u32, B: u32,
T: u32, C: u32) void {
for (0..B) |b| {
for (0..T) |t| {
for(0..C/N) |i| {
const idx = b * T * C + t * C + i * N;
const dout_bt = dout[idx..][0..N].*;
const ix: usize = inp[b * T + t];
const dwte_ix: @Vector(N , LlmFloat) = dwte[ix * C + i * N..][0..N].*;
const dwpe_t: @Vector(N , LlmFloat) = dwpe[t * C + i * N..][0..N].*;
const res_wte: [N]LlmFloat = dout_bt + dwte_ix;
const res_wpe: [N]LlmFloat = dout_bt + dwpe_t;
@memcpy(dwte[ix * C + i * N..][0..N], &res_wte);
@memcpy(dwpe[t * C + i * N..][0..N], &res_wpe);
}
}
}
}
/// Applies layer normalization to the input tensor. Layer normalization is a standard technique in neural networks
/// to stabilize and accelerate training. It normalizes inputs across the features instead of the batch dimension.
/// @param output: The buffer where the normalized output will be stored.
/// @param mean: Buffer to store the computed mean of the input tensor for each sequence.
/// @param rstd: Buffer to store the reciprocal of the standard deviation (inverse standard deviation).
/// @param input: The input tensor to be normalized.
/// @param weight: The gamma parameter for scaling in the normalization.
/// @param bias: The beta parameter for shifting in the normalization.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Number of features (embedding dimension in the context of transformers).
pub fn layernorm_forward(output: []LlmFloat, mean: []LlmFloat, rstd: []LlmFloat, input: []LlmFloat, weight: []LlmFloat,
bias: []LlmFloat,B: u32,T: u32,C: u32) void {
const eps: LlmFloat = 1e-5;
for (0..B) |b| {
for (0..T) |t| {
const x = input[b * T * C + t * C ..];
var m: LlmFloat = 0.0;
for (0..C) |i| {
m += x[i];
}
m /= @floatFromInt( C);
var v: LlmFloat = 0.0;
for (0..C) |i| {
const diff = x[i] - m;
v += diff * diff;
}
v /= @floatFromInt( C);
const s = 1.0 / math.sqrt(v + eps);
var out = output[b * T * C + t * C ..];
for (0..C) |c| {
out[c] = (x[c] - m) * s * weight[c] + bias[c];
}
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
}
}
/// Vectorized:
/// Applies layer normalization to the input tensor. Layer normalization is a standard technique in neural networks
/// to stabilize and accelerate training. It normalizes inputs across the features instead of the batch dimension.
/// @param output: The buffer where the normalized output will be stored.
/// @param mean: Buffer to store the computed mean of the input tensor for each sequence.
/// @param rstd: Buffer to store the reciprocal of the standard deviation (inverse standard deviation).
/// @param input: The input tensor to be normalized.
/// @param weight: The gamma parameter for scaling in the normalization.
/// @param bias: The beta parameter for shifting in the normalization.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Number of features (embedding dimension in the context of transformers).
pub fn layernorm_forward_vec(comptime N: usize, output: []LlmFloat, mean: []LlmFloat, rstd: []LlmFloat, input: []LlmFloat,
weight: []LlmFloat, bias: []LlmFloat, B: u32, T: u32, C: u32) void {
const eps: LlmFloat = 1e-5;
const op :std.builtin.ReduceOp = .Add;
for (0..B) |b| {
for (0..T) |t| {
const start:usize = b * T * C + t * C;
var m: LlmFloat = 0.0; // mean
var v: LlmFloat = 0.0; // variance
for (0..C/N)|i|{
const x:@Vector(N , LlmFloat) = input[start + i*N ..][0..N].*;
m += @reduce(op, x);
}
m /= @floatFromInt( C);
const v_m : @Vector(N, LlmFloat) = @splat(m);
for (0..C/N)|i|{
var x:@Vector(N , LlmFloat) = input[start + i*N ..][0..N].*;
x -= v_m;
v += @reduce(op, x * x);
}
v /= @floatFromInt( C);
const s = 1.0 / math.sqrt(v + eps);
const s_vector : @Vector(N, LlmFloat) = @splat(s);
for (0..C/N)|i|{
var diff:@Vector(N , LlmFloat) = input[start + i*N ..][0..N].*;
diff -= v_m;
var w_vector : @Vector(N, LlmFloat) = weight[i*N..][0..N].*;
const bias_vector : @Vector(N, LlmFloat) = bias[i*N..][0..N].*;
w_vector *= s_vector;
const res: [N]LlmFloat = @mulAdd(@Vector(N, LlmFloat),diff, w_vector, bias_vector);
@memcpy(output[start + i*N ..start + i*N + N], &res);
}
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
}
}
/// Computes the gradients for the layer normalization layer during the backward pass of training.
/// Layer normalization is used to stabilize the learning process by normalizing the inputs across the features.
/// @param dinp: Gradient buffer for the input activations where the computed gradients will be accumulated.
/// @param dweight: Gradient buffer for the scale parameters (gamma in some contexts), where gradients will be accumulated.
/// @param dbias: Gradient buffer for the shift parameters (beta in some contexts), where gradients will be accumulated.
/// @param doutput: The gradient received from the downstream layer, which needs to be back-propagated.
/// @param inp: The input tensor to the layer normalization function; required for computing gradients.
/// @param weight: The scale tensor applied during the forward pass.
/// @param mean: Computed mean of the inputs during the forward pass, required for backpropagation.
/// @param rstd: Computed reciprocal of the standard deviation (1/std) of the inputs during the forward pass.
/// @param B: Batch size, indicating the number of separate input sets.
/// @param T: Sequence or temporal length in the input.
/// @param C: Number of channels or features in the input, which corresponds to the dimensionality of the mean and std dev calculations.
pub fn layernorm_backward(dinp: []LlmFloat, dweight: []LlmFloat, dbias: []LlmFloat,doutput: []LlmFloat,inp:
[]LlmFloat,weight: []LlmFloat,mean: []LlmFloat,rstd: []LlmFloat,B: u32,T: u32,C: u32) void {
for (0..B) |b| {
for (0..T) |t| {
const dout = doutput[b * T * C + t * C ..];
const inp_bt = inp[b * T * C + t * C..];
var dinp_bt = dinp[b * T * C + t * C..];
const mean_bt = mean[b * T + t];
const rstd_bt = rstd[b * T + t];
var dnorm_mean:LlmFloat = 0.0;
var dnorm_norm_mean:LlmFloat = 0.0;
for (0..C) |c| {
const norm_bti = (inp_bt[c] - mean_bt) * rstd_bt;
const dnorm_i = weight[c] * dout[c];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean /= @floatFromInt(C);
dnorm_norm_mean /= @floatFromInt(C);
for (0..C) |c| {
const norm_bti = (inp_bt[c] - mean_bt) * rstd_bt;
const dnorm_i = weight[c] * dout[c];
dbias[c] += dout[c];
dweight[c] += norm_bti * dout[c];
var dval = dnorm_i - dnorm_mean - norm_bti * dnorm_norm_mean;
dval *= rstd_bt;
dinp_bt[c] += dval;
}
}
}
}
/// Computes the forward pass of a matrix multiplication for a mini-batch. This operation is central in
/// neural networks, especially in fully connected and convolutional layers (transformed into matrix multiplications).
/// @param output: The output buffer where the results will be stored.
/// @param input: The input tensor that contains data from the previous layer.
/// @param weight: The weight matrix to be multiplied with the input.
/// @param bias: The bias vector to be added after matrix multiplication, can be optional.
/// @param B: Batch size, number of separate data entries processed.
/// @param T: Sequence length, typically the number of time steps in sequence data.
/// @param C: Number of input channels or features per time step.
/// @param OC: Number of output channels, corresponding to the number of neurons in a fully connected layer.
pub fn matmul_forward( output: []LlmFloat, input: []LlmFloat, weight: []LlmFloat, bias: []LlmFloat, B: u32, T: u32,
C: u32, OC: u32,
) void{
for(0..B) |b| {
for(0..T) |t| {
var out_bt:[]LlmFloat = output[b * T * OC + t * OC..];
const inp_bt:[]LlmFloat = input[b * T * C + t * C..];
for(0..OC) |o| {
var val:LlmFloat = 0.0;
if(bias.len != 0) {
val = bias[o];
}
const wrow:[]LlmFloat = weight[o * C..];
for(0..C) |i| {
val += inp_bt[i] * wrow[i];
}
out_bt[o] = val;
}
}
}
}
/// Vectorized
/// Computes the forward pass of a matrix multiplication for a mini-batch. This operation is central in
/// neural networks, especially in fully connected and convolutional layers (transformed into matrix multiplications).
/// @param output: The output buffer where the results will be stored.
/// @param input: The input tensor that contains data from the previous layer.
/// @param weight: The weight matrix to be multiplied with the input.
/// @param bias: The bias vector to be added after matrix multiplication, can be optional.
/// @param B: Batch size, number of separate data entries processed.
/// @param T: Sequence length, typically the number of time steps in sequence data.
/// @param C: Number of input channels or features per time step.
/// @param OC: Number of output channels, corresponding to the number of neurons in a fully connected layer.
pub fn matmul_forward_vec( comptime N : usize,output: []LlmFloat,input: []LlmFloat,weight: []LlmFloat,bias: []LlmFloat,
B: usize,T: usize,C: usize,OC: usize
) void {
for (0..B) |b| {
for (0..T) |t| {
var out_bt = output[b * T * OC + t * OC ..];
for (0..OC) |o| {
for (0..C/N) |i| {
const inp_bt_v : @Vector(N , LlmFloat) = input[b * T * C + t * C + i*N ..][0..N].*;
const wrow_v : @Vector(N , LlmFloat) = weight[o * C + i*N ..][0..N].*;
const res = @mulAdd(@Vector(N, LlmFloat), wrow_v, inp_bt_v, @splat(0.0));
const final_res = @reduce(.Add, res) ;
out_bt[o] += final_res;
}
if(bias.len != 0) {
out_bt[o] += bias[o];
}
}
}
}
}
/// Computes the gradients for the matrix multiplication operation in the backward pass of training.
/// It back-propagates gradients through the network, adjusting weights and biases based on the error
/// relative to the expected output.
/// @param dinp: Gradient buffer for the inputs, where gradients will be accumulated.
/// @param dweight: Gradient buffer for the weights, where gradients will be accumulated.
/// @param dbias: Gradient buffer for the biases, where gradients will be accumulated (if bias is used).
/// @param dout: The gradient received from the downstream layer (back-propagated error).
/// @param inp: The input tensor to the forward function, required for gradient computation.
/// @param weight: The weight matrix used in the forward pass, required for gradient computation.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Number of input channels or features.
/// @param OC: Number of output channels.
pub fn matmul_backward( dinp: []LlmFloat, dweight: []LlmFloat, dbias: []LlmFloat, dout: []LlmFloat, inp: []LlmFloat,
weight: []LlmFloat, B: u32, T: u32, C: u32, OC: u32) void{
for(0..B) |b| {
for(0..T) |t| {
const dout_bt = dout[b * T * OC + t * OC..];
var dinp_bt = dinp[b * T * C + t * C..];
for(0..OC) |o| {
const wrow = weight[o * C..];
const d:LlmFloat = dout_bt[o];
for(0..C) |i| {
dinp_bt[i] += wrow[i] * d;
}
}
}
}
for(0..OC) |o| {
for(0..B) |b| {
for(0..T) |t| {
const dout_bt = dout[b * T * OC + t * OC..];
const inp_bt = inp[b * T * C + t * C..];
var dwrow = dweight[o * C..];
const d:LlmFloat = dout_bt[o];
if(dbias.len != 0) {
dbias[o] += d;
}
for(0..C) |i| {
dwrow[i] += inp_bt[i] * d;
}
}
}
}
}
/// Vectorized
/// Computes the gradients for the matrix multiplication operation in the backward pass of training.
/// It back-propagates gradients through the network, adjusting weights and biases based on the error
/// relative to the expected output.
/// @param dinp: Gradient buffer for the inputs, where gradients will be accumulated.
/// @param dweight: Gradient buffer for the weights, where gradients will be accumulated.
/// @param dbias: Gradient buffer for the biases, where gradients will be accumulated (if bias is used).
/// @param dout: The gradient received from the downstream layer (back-propagated error).
/// @param inp: The input tensor to the forward function, required for gradient computation.
/// @param weight: The weight matrix used in the forward pass, required for gradient computation.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Number of input channels or features.
/// @param OC: Number of output channels.
pub fn matmul_backward_vec( comptime N: usize, dinp: []LlmFloat, dweight: []LlmFloat, dbias: []LlmFloat, dout: []LlmFloat,
inp: []LlmFloat, weight: []LlmFloat, B: u32, T: u32, C: u32, OC: usize) void{
for(0..B) |b| {
for(0..T) |t| {
const dout_bt = dout[b * T * OC + t * OC..];
var dinp_bt_v : @Vector(N , LlmFloat) = undefined;
for(0..OC) |o| {
for(0..C/N) |i| {
dinp_bt_v = dinp[b * T * C + t * C + i * N..][0..N].*;
const v_wrow : @Vector(N, LlmFloat) = weight[o * C + i * N..][0..N].*;
const d_t : @Vector(N, LlmFloat) = @splat(dout_bt[o]);
const zero : @Vector(N, LlmFloat) = @splat(0);
dinp_bt_v += @mulAdd(@Vector(N, LlmFloat),v_wrow, d_t,zero);
const tmp:[N]LlmFloat = dinp_bt_v;
@memcpy(dinp[b * T * C + t * C + i * N..b * T * C + t * C + (i+1) * N], &tmp);
}
}
}
}
for(0..OC) |o| {
for(0..B) |b| {
for(0..T) |t| {
const dout_bt = dout[b * T * OC + t * OC..];
for (0..C/N) |i| {
var dv_wrow : @Vector(N, LlmFloat) = dweight[o * C + i * N..][0..N].*;
const inp_bt_v : @Vector(N , LlmFloat) = inp[b * T * C + t * C + i * N ..][0..N].*;
const d_t : @Vector(N, LlmFloat) = @splat(dout_bt[o]);
dv_wrow += inp_bt_v * d_t;
var tmp:[N]LlmFloat = dv_wrow;
@memcpy(dweight[o * C + i * N ..o * C + (i+1)*N], &tmp);
}
if(dbias.len != 0) {
dbias[o] += dout_bt[o];
}
}
}
}
}
/// Computes the attention mechanism for a transformer model.
/// The function handles multi-head self-attention which is a core component of the transformer architecture.
/// It calculates the weighted sum of values based on the softmax of the dot products of queries and keys.
///
/// @param output: The resulting output after applying attention and weighted sum.
/// @param preatt: The pre-softmax attention scores which are computed as dot products of queries and keys.
/// @param att: The post-softmax attention scores.
/// @param inp: Concatenated query, key, and value vectors for all heads.
/// @param B: Batch size, denoting the number of sequences.
/// @param T: Sequence length.
/// @param C: Dimension of each input token, which is split into parts for query, key, and value.
/// @param NH: Number of attention heads.
pub fn attention_forward( output : []LlmFloat, preatt: []LlmFloat, att: []LlmFloat, inp: []LlmFloat, B: u32, T: u32,
C: u32, NH: u32
) void{
const C3:u32 = C*3;
const hs:u32 = C / NH; // head size
const hs_float:LlmFloat = @floatFromInt(hs);
const scale:LlmFloat = 1.0 / math.sqrt(hs_float);
for(0..B) |b| {
for(0..T) |t| {
for(0..NH) |h| {
const query_t = inp[b * T * C3 + t * C3 + h * hs..];
var preatt_bth = preatt[b * NH * T * T + h * T * T + t * T..];
var att_bth = att[b * NH * T * T + h * T * T + t * T..];
var maxval = -math.floatMin(LlmFloat);
for(0..t+1) |t2| {
const key_t2 = inp[b * T * C3 + t2 * C3 + h * hs + C..];
var val:LlmFloat = 0.0;
for(0..hs) |i| {
const q:LlmFloat = query_t[i];
const k:LlmFloat = key_t2[i];
val += k * q;
}
val *= scale;
if(val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
var expsum:LlmFloat = 0.0;
for(0..t+1) |t2| {
const expv:LlmFloat = math.exp(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
var expsum_inv:LlmFloat = 0.0;
if(expsum != 0.0) {
expsum_inv = 1.0 / expsum;
}
for(0..T) |t2| {
if(t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
att_bth[t2] = 0.0;
}
}
var out_bth = output[b * T * C + t * C + h * hs..];
for(0..hs) |i| {
out_bth[i] = 0.0;
}
for(0..t+1) |t2| {
const value_t2 = inp[b * T * C3 + t2 * C3 + h * hs + C * 2..];
const att_btht2:LlmFloat = att_bth[t2];
for(0..hs) |i| {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}
}
}
/// Vectorized
/// Computes the attention mechanism for a transformer model.
/// The function handles multi-head self-attention which is a core component of the transformer architecture.
/// It calculates the weighted sum of values based on the softmax of the dot products of queries and keys.
///
/// @param output: The resulting output after applying attention and weighted sum.
/// @param preatt: The pre-softmax attention scores which are computed as dot products of queries and keys.
/// @param att: The post-softmax attention scores.
/// @param inp: Concatenated query, key, and value vectors for all heads.
/// @param B: Batch size, denoting the number of sequences.
/// @param T: Sequence length.
/// @param C: Dimension of each input token, which is split into parts for query, key, and value.
/// @param NH: Number of attention heads.
pub fn attention_forward_vec( comptime N: usize, output : []LlmFloat, preatt: []LlmFloat, att: []LlmFloat, inp: []LlmFloat,
B: usize, T: usize, C: usize, NH: usize
) void{
const C3:usize = C*3;
const hs:usize = C / NH; // head size
const hs_float:LlmFloat = @floatFromInt(hs);
const scale:LlmFloat = 1.0 / math.sqrt(hs_float);
const add_op : std.builtin.ReduceOp = std.builtin.ReduceOp.Add;
for(0..B) |b| {
for(0..T) |t| {
for(0..NH) |h| {
const query_t = inp[b * T * C3 + t * C3 + h * hs..];
var preatt_bth = preatt[b * NH * T * T + h * T * T + t * T..];
var att_bth = att[b * NH * T * T + h * T * T + t * T..];
var maxval = -math.floatMin(LlmFloat);
for(0..t+1) |t2| {
const key_t2 = inp[b * T * C3 + t2 * C3 + h * hs + C..];
var val:LlmFloat = 0.0;
for(0..hs/N) |i| {
const q:@Vector(N,LlmFloat) = query_t[i*N..][0..N].*;
const k:@Vector(N,LlmFloat) = key_t2[i*N..][0..N].*;
val +=@reduce(add_op,q*k);
}
val *= scale;
if(val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
var expsum:LlmFloat = 0.0;
for(0..t+1) |t2| {
const expv:LlmFloat = math.exp(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
var expsum_inv:LlmFloat = 0.0;
if(expsum != 0.0) {
expsum_inv = 1.0 / expsum;
}
for(0..T) |t2| {
if(t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
att_bth[t2] = 0.0;
}
}
var out_bth = output[b * T * C + t * C + h * hs..];
for(0..hs) |i| {
out_bth[i] = 0.0;
}
for(0..t+1) |t2| {
const att_btht2:@Vector(N,LlmFloat) = @splat(att_bth[t2]);
for(0..hs/N) |i| {
const value_t2:@Vector(N,LlmFloat) = inp[b * T * C3 + t2 * C3 + h * hs + C * 2 + i*N..][0..N].*;
const out_bth_v:@Vector(N,LlmFloat) = out_bth[i*N..][0..N].*;
const res:[N]LlmFloat = @mulAdd(@Vector(N,LlmFloat), att_btht2, value_t2, out_bth_v);
@memcpy(out_bth[i*N..][0..N], &res);
}
}
}
}
}
}
/// Computes gradients for the attention mechanism during the backward pass.
/// This function backpropagates errors from the output of the attention layer to the inputs,
/// which include queries, keys, and values, and computes gradients for these components.
///
/// @param dinp: Gradient buffer for the input activations.
/// @param dpreatt: Gradient buffer for pre-softmax attention scores.
/// @param datt: Gradient buffer for post-softmax attention scores.
/// @param dout: Gradient buffer for the output activations.
/// @param inp: Input activations to the attention layer.
/// @param att: Post-softmax attention scores computed during the forward pass.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Channel size of the input.
/// @param NH: Number of attention heads.
pub fn attention_backward( dinp: []LlmFloat, dpreatt: []LlmFloat, datt: []LlmFloat, dout: []LlmFloat, inp: []LlmFloat,
att: []LlmFloat, B: u32, T: u32, C: u32, NH: u32) void{
const C3:u32 = C * 3;
const hs:u32 = @intCast(C / NH);
const hs_float:LlmFloat = @floatFromInt(hs);
const scale:LlmFloat = 1.0 / math.sqrt(hs_float);
for(0..B)|b|{
for(0..T)|t|{
for(0..NH)|h|{
const att_bth = att[b*NH*T*T + h*T*T + t*T..];
var datt_bth = datt[b*NH*T*T + h*T*T + t*T..];
var dpreatt_bth = dpreatt[b*NH*T*T + h*T*T + t*T..];
var dquery_t = dinp[b*T*C3 + t*C3 + h*hs..];
const query_t = inp[b*T*C3 + t*C3 + h*hs..];
// backward pass 4, through the value accumulation
const dout_bth = dout[b * T * C + t * C + h * hs..];
for(0..t+1)|t2|{
const value_t2 = inp[b*T*C3 + t2*C3 + h*hs + C*2..];
var dvalue_t2 = dinp[b*T*C3 + t2*C3 + h*hs + C*2..];
for(0..hs)|i|{
datt_bth[t2] += value_t2[i] * dout_bth[i];
dvalue_t2[i] += att_bth[t2] * dout_bth[i];
}
}
for (0..t+1) |t2| {
for (0..t+1) |t3|{
var indicator:LlmFloat = 0.0;
if(t3 == t2){
indicator = 1.0;
}
const local_derivative:LlmFloat = att_bth[t2] * (indicator - att_bth[t3]);
dpreatt_bth[t3] += local_derivative * datt_bth[t2];
}
}
for(0..t+1)|t2|{
const key_t2 = inp[b * T * C3 + t2 * C3 + h * hs + C..];
var dkey_t2 = dinp[b * T * C3 + t2 * C3 + h * hs + C..];
for(0..hs)|i|{
dquery_t[i] += key_t2[i] * dpreatt_bth[t2]*scale;
dkey_t2[i] += query_t[i] * dpreatt_bth[t2]*scale;
}
}
}
}
}
}
/// Vectorized
/// Computes gradients for the attention mechanism during the backward pass.
/// This function backpropagates errors from the output of the attention layer to the inputs,
/// which include queries, keys, and values, and computes gradients for these components.
///
/// @param dinp: Gradient buffer for the input activations.
/// @param dpreatt: Gradient buffer for pre-softmax attention scores.
/// @param datt: Gradient buffer for post-softmax attention scores.
/// @param dout: Gradient buffer for the output activations.
/// @param inp: Input activations to the attention layer.
/// @param att: Post-softmax attention scores computed during the forward pass.
/// @param B: Batch size.
/// @param T: Sequence length.
/// @param C: Channel size of the input.
/// @param NH: Number of attention heads.
pub fn attention_backward_vec( comptime NChannelsPerHead: usize, dinp: []LlmFloat, dpreatt: []LlmFloat, datt: []LlmFloat,
dout: []LlmFloat, inp: []LlmFloat, att: []LlmFloat, B: u32, T: u32, C: u32, NH: u32) void{
const C3:u32 = C * 3;
const hs:u32 = @intCast(C / NH);
const hs_float:LlmFloat = @floatFromInt(hs);
const scale:LlmFloat = 1.0 / math.sqrt(hs_float);
for(0..B)|b|{
for(0..T)|t|{
for(0..NH)|h|{
const att_bth = att[b*NH*T*T + h*T*T + t*T..];
var datt_bth = datt[b*NH*T*T + h*T*T + t*T..];
var dpreatt_bth = dpreatt[b*NH*T*T + h*T*T + t*T..];
var dquery_t = dinp[b*T*C3 + t*C3 + h*hs..];
const query_t = inp[b*T*C3 + t*C3 + h*hs..];
// backward pass 4, through the value accumulation
const dout_bth = dout[b * T * C + t * C + h * hs..];
for(0..t+1)|t2|{
var value_t2 = inp[b*T*C3 + t2*C3 + h*hs + C*2..];
var dvalue_t2 = dinp[b*T*C3 + t2*C3 + h*hs + C*2..];
const op_add :std.builtin.ReduceOp = .Add;
for (0..hs/NChannelsPerHead)|i|{
const v_dout_bth : VecType = dout_bth[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_value_t2 : VecType = value_t2[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_dvalue_t2 : VecType = dvalue_t2[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const res_value_outatt = @reduce(op_add,v_value_t2 * v_dout_bth);
datt_bth[t2] += res_value_outatt;
const v_att_bth : VecType = @splat(att_bth[t2]);
const res:[NChannelsPerHead]LlmFloat = v_dvalue_t2 + v_dout_bth * v_att_bth;
const start = b*T*C3 + t2*C3 + h*hs + C*2 + i*NChannelsPerHead;
const end = b*T*C3 + t2*C3 + h*hs + C*2 + (i+1)*NChannelsPerHead;
@memcpy(dinp[start..end], &res);
}
}
for (0..t+1) |t2| {
for (0..t+1) |t3|{
var indicator:LlmFloat = 0.0;
if(t3 == t2){
indicator = 1.0;
}
const local_derivative:LlmFloat = att_bth[t2] * (indicator - att_bth[t3]);
dpreatt_bth[t3] += local_derivative * datt_bth[t2];
}
}
for(0..t+1)|t2|{
const key_t2 = inp[b * T * C3 + t2 * C3 + h * hs + C..];
var dkey_t2 = dinp[b * T * C3 + t2 * C3 + h * hs + C..];
for (0..hs/NChannelsPerHead)|i|{
const v_query_t : VecType = query_t[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_key_t2 : VecType = key_t2[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_dkey_t2 : VecType = dkey_t2[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_dquery_t : VecType = dquery_t[i*NChannelsPerHead..][0..NChannelsPerHead].*;
const v_scale_dpreatt : VecType = @splat(dpreatt_bth[t2]*scale);
var res_dquery:[NChannelsPerHead]LlmFloat = v_dquery_t + v_key_t2 * v_scale_dpreatt;
@memcpy(dquery_t[i*NChannelsPerHead..][0..NChannelsPerHead], &res_dquery);
var res_dkey_t2:[NChannelsPerHead]LlmFloat = v_dkey_t2 + v_query_t * v_scale_dpreatt;
@memcpy(dkey_t2[i*NChannelsPerHead..][0..NChannelsPerHead], &res_dkey_t2);
}
}
}
}
}
}
/// Applies the Gaussian Error Linear Unit (GELU) activation function on the input data.
/// GELU is used in transformer models and is defined as x * Φ(x), where Φ(x) is the cumulative distribution function
/// of the standard normal distribution.
///
/// @param output: The array where the output will be stored.
/// @param input: The input array on which the GELU function is applied.
/// @param N: The number of elements in the input and output arrays.
pub fn gelu_forward( output: []LlmFloat, input: []LlmFloat, N: u32) void {
const s = math.sqrt(2.0 / math.pi);
for (0..N) |i| {
const x = input[i];
const cdf = 0.5 * (1.0 + math.tanh(s * (x + 0.044715 * math.pow(LlmFloat,x, 3))));
output[i] = x * cdf;
}
}
/// Vectorized:
/// Applies the Gaussian Error Linear Unit (GELU) activation function on the input data.
/// GELU is used in transformer models and is defined as x * Φ(x), where Φ(x) is the cumulative distribution function
/// of the standard normal distribution.
///
/// @param output: The array where the output will be stored.
/// @param input: The input array on which the GELU function is applied.
/// @param N: The number of elements in the input and output arrays.
//ToDo Fix issue with tanh producing NaN in the following implementation. Zig does not have an operator for tanh operating in vector mode.
pub fn gelu_forward_vec(comptime N:usize, output: []LlmFloat, input: []LlmFloat, BT4C: usize) void {
const s:@Vector(N , LlmFloat) = @splat(math.sqrt(2.0 / math.pi));
const half:@Vector(N , LlmFloat) = @splat(0.5);
const one: @Vector(N , LlmFloat) = @splat(1.0);
const g_coeff: @Vector(N , LlmFloat) = @splat(0.044715);
for(0..BT4C/N) |i| {
const x: @Vector(N , LlmFloat) = input[i*N..][0..N].*;
const x_cube: @Vector(N , LlmFloat) = x * x * x; // x^3
const x_par: @Vector(N , LlmFloat) = s * (x + g_coeff * x_cube); // s*(x+0.044715*x^3)
const tanh_x: @Vector(N , LlmFloat) =(@exp(x_par) - @exp(-x_par))/(@exp(x_par) + @exp(-x_par)) ;
const cdf = half * (one + tanh_x);
const res: [N]LlmFloat = x * cdf;
@memcpy(output[i*N..(i+1)*N], &res);
}
}
/// Computes the gradient of the GELU activation function with respect to the input tensor, using the gradients
/// from the next layer. This function is used in the backpropagation process to propagate gradients back through
/// the network for the GELU activation nodes.
///
/// @param dinput: The gradient with respect to the input of the GELU function, which this function computes.
/// @param input: The input tensor to the GELU function from the forward pass.
/// @param doutput: The gradient with respect to the output of the GELU function, received from the next layer.
/// @param N: The number of elements in the input and output gradient arrays.
pub fn gelu_backward( dinput: []LlmFloat, input: []LlmFloat, doutput: []LlmFloat, N: u32)
void {
const s = math.sqrt(2.0 / math.pi);
for (0..N) |i| {
const x = input[i];
const square = x * x * 0.044715;
const cube = square * x;
const tanh_arg = s * (x + cube);
const tanh_out = math.tanh(tanh_arg);
const coshf_out = math.cosh(tanh_arg);
const sech2 = 1.0 / (coshf_out * coshf_out);