forked from karpathy/llama2.c
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,616 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
## llama2.c | ||
|
||
 | ||
|
||
Have you ever wanted to inference a baby [Llama 2](https://ai.meta.com/llama/) model in pure C? No? Well, now you can! | ||
|
||
Code in this repo first lets you train the Llama 2 architecture from scratch in PyTorch, then save the weights to a raw binary file, then load that into one ~simple 500-line C file that inferences the model, simply in fp32 for now. | ||
|
||
Of course, this is not super fast, but it's not too bad either. E.g. on my cloud Linux devbox a dim 288 6-layer 6-head model (~15M params) inferences at ~18 tok/s in fp32, and about the same on my M1 MacBook Air. | ||
|
||
Please note that this is just a weekend project where I took nanoGPT, gutted it to implement the Llama-2 architecture (instead of GPT-2), and then wrote the C inference engine for it in `run.c`. So this is not really meant to be a production-grade library right now. | ||
|
||
Hat tip to [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. I wanted something super minimal so I chose to hard-code the llama-2 architecture, stick to fp32, and just roll one inference file of pure C with no dependencies. | ||
|
||
## howto | ||
|
||
It should be possible to load the weights released by Meta but I haven't tried because the inference speed, even of the 7B model, would probably be not great with this baby single-threaded C program. So in this repo we focus on more narrow applications, and train the same architecture but from scratch, in this case on the TinyStories dataset for fun. | ||
|
||
First let's download and pretokenize the TinyStories dataset: | ||
|
||
```bash | ||
python tinystories.py download | ||
python tinystories.py pretokenize | ||
``` | ||
|
||
Then train our model: | ||
|
||
```bash | ||
python train.py | ||
``` | ||
|
||
See the train.py script for more exotic launches and hyperparameter overrides. I didn't tune the hyperparameters, I expect simple hyperparameter exploration should give better models. Totally understand if you want to skip model training, for simple demo just download my pretrained model: | ||
|
||
```bash | ||
wget TODOhoweasiesthmm | ||
``` | ||
|
||
Once we have the model.bin file, we can inference in C. Compile the C code first: | ||
|
||
```bash | ||
gcc -o run run.c -lm | ||
``` | ||
|
||
You can now run it simply as | ||
|
||
```bash | ||
./run | ||
``` | ||
|
||
But note that this only emits the SentencePiece tokens. To decode the tokens into text too, run this script through a simple wrapper: | ||
|
||
```bash | ||
python run_wrap.py | ||
``` | ||
|
||
I hope to delete this script soon though. Anyway, watch the tokens stream by, fun! | ||
|
||
To verify correctness, we can also run the PyTorch inference script: | ||
|
||
```bash | ||
python sample.py | ||
``` | ||
|
||
Which gives the same results. I'd love to find some time to create actual tests, one day maybe. For now I just manually inspected activations and verified that they match, and that the samples are identical at temperature 0. If someone wishes to help me with tests I welcome PRs. | ||
|
||
## unsorted todos | ||
|
||
- why SentencePiece can't iteratively decode properly? | ||
- would love to delete run_wrap.py and just directly use C code to string, help welcome | ||
- todo multiquery support? doesn't seem as useful for smaller models that run on CPU | ||
- todo support inferencing beyond max_seq_len steps, have to think through the kv cache | ||
- why is MFU so low (~20%) on my A100 40GB for training? | ||
- weird errors with torch.compile and wandb when using DDP | ||
- make tests to decrease yolo | ||
|
||
## License | ||
MIT |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Poor Man's Configurator. Probably a terrible idea. Example usage: | ||
$ python train.py config/override_file.py --batch_size=32 | ||
this will first run config/override_file.py, then override batch_size to 32 | ||
The code in this file will be run as follows from e.g. train.py: | ||
>>> exec(open('configurator.py').read()) | ||
So it's not a Python module, it's just shuttling this code away from train.py | ||
The code in this script then overrides the globals() | ||
I know people are not going to love this, I just really dislike configuration | ||
complexity and having to prepend config. to every single variable. If someone | ||
comes up with a better simple Python solution I am all ears. | ||
""" | ||
|
||
import sys | ||
from ast import literal_eval | ||
|
||
for arg in sys.argv[1:]: | ||
if '=' not in arg: | ||
# assume it's the name of a config file | ||
assert not arg.startswith('--') | ||
config_file = arg | ||
print(f"Overriding config with {config_file}:") | ||
with open(config_file) as f: | ||
print(f.read()) | ||
exec(open(config_file).read()) | ||
else: | ||
# assume it's a --key=value argument | ||
assert arg.startswith('--') | ||
key, val = arg.split('=') | ||
key = key[2:] | ||
if key in globals(): | ||
try: | ||
# attempt to eval it it (e.g. if bool, number, or etc) | ||
attempt = literal_eval(val) | ||
except (SyntaxError, ValueError): | ||
# if that goes wrong, just use the string | ||
attempt = val | ||
# ensure the types match ok | ||
assert type(attempt) == type(globals()[key]) | ||
# cross fingers | ||
print(f"Overriding: {key} = {attempt}") | ||
globals()[key] = attempt | ||
else: | ||
raise ValueError(f"Unknown config key: {key}") |
Oops, something went wrong.