From 2a8b295c97eb6d60815f489d38db31dedc628e7a Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 14 Dec 2023 06:48:12 +0000 Subject: [PATCH] Add GPT 2 text generation example and reference implementation --- rten-examples/Cargo.toml | 4 + rten-examples/README.md | 2 + rten-examples/src/generator.rs | 2 +- rten-examples/src/gpt2.rs | 126 ++++++++++++++++++++++++++++ rten-examples/src/gpt2_reference.py | 24 ++++++ 5 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 rten-examples/src/gpt2.rs create mode 100644 rten-examples/src/gpt2_reference.py diff --git a/rten-examples/Cargo.toml b/rten-examples/Cargo.toml index 3f0b31ac..5fc81c2f 100644 --- a/rten-examples/Cargo.toml +++ b/rten-examples/Cargo.toml @@ -60,6 +60,10 @@ path = "src/depth_anything.rs" name = "bert_qa" path = "src/bert_qa.rs" +[[bin]] +name = "gpt2" +path = "src/gpt2.rs" + [[bin]] name = "jina_similarity" path = "src/jina_similarity.rs" diff --git a/rten-examples/README.md b/rten-examples/README.md index 36c3a69e..998e5e60 100644 --- a/rten-examples/README.md +++ b/rten-examples/README.md @@ -62,6 +62,8 @@ The examples have been chosen to cover common tasks and popular models. - **bert_qa** - Extractive question answering using [BERT](https://arxiv.org/abs/1810.04805)-based models which have been fine-tuned on the [SQuAD](https://paperswithcode.com/dataset/squad) dataset +- **gpt2** - Text generation using the [GPT-2](https://openai.com/index/better-language-models/) +language model. - **jina_similarity** - Sentence similarity using vector embeddings of sentences ### Audio diff --git a/rten-examples/src/generator.rs b/rten-examples/src/generator.rs index c99d812f..fe44df7c 100644 --- a/rten-examples/src/generator.rs +++ b/rten-examples/src/generator.rs @@ -72,7 +72,7 @@ pub struct Generator<'a> { } impl<'a> Generator<'a> { - /// Create a generator that runs a given model. + /// Create a generator that iteratively produces tokens using a model. /// /// The model is expected to have the following inputs: /// diff --git a/rten-examples/src/gpt2.rs b/rten-examples/src/gpt2.rs new file mode 100644 index 00000000..fb510dcd --- /dev/null +++ b/rten-examples/src/gpt2.rs @@ -0,0 +1,126 @@ +use std::collections::VecDeque; +use std::error::Error; +use std::fs; +use std::io::prelude::*; + +use rten::Model; +use rten_examples::generator::Generator; +use rten_text::tokenizers::Tokenizer; + +struct Args { + model: String, + tokenizer_config: String, + prompt: String, + output_length: usize, +} + +fn parse_args() -> Result { + use lexopt::prelude::*; + + let mut values = VecDeque::new(); + let mut parser = lexopt::Parser::from_env(); + let mut output_length = 30; + + while let Some(arg) = parser.next()? { + match arg { + Short('l') | Long("length") => { + output_length = parser.value()?.parse()?; + } + Value(val) => values.push_back(val.string()?), + Long("help") => { + println!( + "Generate text using a prompt. + +Usage: {bin_name} [options] + +Args: + + - Input GPT-2 model + - `tokenizer.json` file + - Text generation prompt + +Options: + + -l, --length - Set max output length (in tokens) +", + bin_name = parser.bin_name().unwrap_or("gpt2") + ); + std::process::exit(0); + } + _ => return Err(arg.unexpected()), + } + } + + let model = values.pop_front().ok_or("missing `model` arg")?; + let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; + let prompt = values.make_contiguous().join(" "); + + let args = Args { + model, + tokenizer_config, + prompt, + output_length, + }; + + Ok(args) +} + +/// Generates text using GPT-2 [1] and a prompt. +/// +/// To obtain the model from Hugging Face, use Optimum [2], then convert it: +/// +/// ```sh +/// optimum-cli export onnx --model gpt2 gpt2_onnx/ +/// rten-convert gpt2_onnx/model.onnx +/// ``` +/// +/// Run the converted model with a prompt: +/// +/// ```sh +/// cargo run --release --bin gpt2 gpt2_onnx/model.rten gp2_onnx/tokenizer.json +/// ``` +/// +/// Where `` is the start of a sentence that the model should complete. +/// +/// [1] https://openai.com/research/better-language-models +/// [2] https://huggingface.co/docs/optimum/index +fn main() -> Result<(), Box> { + let args = parse_args()?; + let model_bytes = fs::read(args.model)?; + let model = Model::load(model_bytes)?; + + let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?; + let tokenizer = Tokenizer::from_json(&tokenizer_config)?; + + let prompt = args.prompt.as_str(); + let encoded_prompt = tokenizer.encode(prompt.into(), Default::default())?; + let token_ids: Vec = encoded_prompt + .token_ids() + .iter() + .map(|id| *id as u32) + .collect(); + + // The output starts with the user's prompt. + print!("{}", prompt); + + // Buffer that holds model output tokens until it forms a valid UTF-8 + // sequence. + let mut token_buf = Vec::new(); + + let generator = Generator::from_model(&model)?.with_prompt(&token_ids); + for token in generator.take(args.output_length) { + let token = token?; + token_buf.push(token as usize); + + let token_strings = tokenizer.encoder().get_tokens(&token_buf); + if let Ok(strings) = token_strings { + for s in strings { + print!("{}", s); + } + let _ = std::io::stdout().flush(); + token_buf.clear(); + } + } + + Ok(()) +} diff --git a/rten-examples/src/gpt2_reference.py b/rten-examples/src/gpt2_reference.py new file mode 100644 index 00000000..60517bf8 --- /dev/null +++ b/rten-examples/src/gpt2_reference.py @@ -0,0 +1,24 @@ +from argparse import ArgumentParser +from transformers import pipeline, set_seed + + +def main(): + parser = ArgumentParser(description="Generate text using GPT-2 and a prompt") + parser.add_argument("prompt", nargs="*") + parser.add_argument("--seed", type=int, help="Random seed") + args = parser.parse_args() + + prompt = " ".join(args.prompt) + if args.seed is not None: + set_seed(args.seed) + + print(f'prompt: "{prompt}"') + generator = pipeline("text-generation", model="gpt2") + + sequences = generator(prompt, max_length=30, num_return_sequences=1, do_sample=False) + for seq in sequences: + print(seq) + + +if __name__ == "__main__": + main()