From bbe71c5fd88406c3bbaf0bd798c91f1960e4e073 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 2 Jun 2024 17:53:25 +0100 Subject: [PATCH 1/2] Add `Generator` struct to handle generation loop for transformer models --- rten-examples/src/generator.rs | 258 +++++++++++++++++++++++++++++++++ rten-examples/src/lib.rs | 1 + 2 files changed, 259 insertions(+) create mode 100644 rten-examples/src/generator.rs create mode 100644 rten-examples/src/lib.rs diff --git a/rten-examples/src/generator.rs b/rten-examples/src/generator.rs new file mode 100644 index 00000000..c99d812f --- /dev/null +++ b/rten-examples/src/generator.rs @@ -0,0 +1,258 @@ +use std::error::Error; +use std::fmt; + +use rten::{Dimension, Input, Model, NodeId, Operators}; +use rten_tensor::prelude::*; +use rten_tensor::{NdTensor, Tensor}; + +/// Errors that occur when creating or running a [`Generator`]. +#[derive(Debug)] +pub enum GeneratorError { + /// An expected model input was not found. + InputNotFound(String), + + /// An expected model output was not found. + OutputNotFound(String), + + /// An input or output did not have the expected shape. + ShapeMismatch(String), + + /// An error occurred while generating the next token. + GenerateError(Box), +} + +impl fmt::Display for GeneratorError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + GeneratorError::InputNotFound(name) => write!(f, "model input not found: {}", name), + GeneratorError::OutputNotFound(name) => write!(f, "model output not found: {}", name), + GeneratorError::ShapeMismatch(err) => write!(f, "shape mismatch: {}", err), + GeneratorError::GenerateError(err) => write!(f, "generation error: {}", err), + } + } +} + +impl Error for GeneratorError {} + +/// Key-value cache for a single layer of a transformer model. +struct KvCache { + /// Input ID for this cache entry. + input_id: NodeId, + + /// Output ID for this cache entry. + output_id: NodeId, + + /// The cached keys and values, with shape [batch, heads, seq_len, size]. + cache: NdTensor, +} + +/// Generates a token sequence using an auto-regressive language model. +/// +/// This is an iterator that runs the model on each call to [`Iterator::next`] +/// and yields a result containing the next token ID or an error. +pub struct Generator<'a> { + model: &'a Model, + + /// Input token IDs for the next run of the model. + input_ids: Vec, + + // Input node IDs + input_ids_input: NodeId, + attention_mask_input: NodeId, + position_ids_input: NodeId, + + // Output node IDs + logits_output: NodeId, + + /// Length of the sequence generated so far. + seq_len: u32, + + /// Key-value cache. + kv_cache: Vec, +} + +impl<'a> Generator<'a> { + /// Create a generator that runs a given model. + /// + /// The model is expected to have the following inputs: + /// + /// - `input_ids` - (batch, sequence) tensor of token IDs + /// - `attention_mask` - (batch, sequence) tensor of booleans + /// - `position_ids` - (batch, sequence) tensor of position indices + /// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache + /// where `N` is the layer index + /// - `past_key_values.N.value` - (batch, head, past_key_values, size) value vector cache, + /// where `N` is the layer index + /// + /// The model is expected to have the following outputs: + /// + /// - `logits` - output (batch, sequence, vocab) tensor of next token probabilities + /// - `present.N.key` - (batch, head, past_seq_len + 1, size) updated key vector cache + /// - `present.N.value` - (batch, head, past_seq_len + 1, size) updated value vector cache + pub fn from_model(model: &'a Model) -> Result, GeneratorError> { + let input_ids_input = model + .find_node("input_ids") + .ok_or(GeneratorError::InputNotFound("input_ids".to_string()))?; + let attention_mask_input = model + .find_node("attention_mask") + .ok_or(GeneratorError::InputNotFound("attention_mask".to_string()))?; + let position_ids_input = model + .find_node("position_ids") + .ok_or(GeneratorError::InputNotFound("position_ids".to_string()))?; + + let logits_output = model + .find_node("logits") + .ok_or(GeneratorError::OutputNotFound("logits".to_string()))?; + + // Find inputs and corresponding outputs for key-value cache. + let batch_size = 1; + let mut kv_cache = Vec::new(); + for &input_id in model.input_ids() { + let input_info = model + .node_info(input_id) + .ok_or(GeneratorError::InputNotFound(format!( + "input ID {}", + input_id + )))?; + let Some(name) = input_info.name() else { + continue; + }; + + if !name.starts_with("past_key_values.") { + continue; + } + + if !name.ends_with(".key") && !name.ends_with(".value") { + continue; + } + + let [n_heads, size] = match input_info.shape().as_deref() { + Some(&[_, Dimension::Fixed(n_heads), _, Dimension::Fixed(size)]) => [n_heads, size], + _ => { + return Err(GeneratorError::ShapeMismatch(format!("input \"{}\" has unexpected shape. expected (batch, heads, past_seq_len, size) where `heads` and `size` are fixed", name))); + } + }; + + let cache_type = if name.ends_with(".key") { + "key" + } else { + "value" + }; + + let layer_index_start = "past_key_values.".len(); + let layer_index_str: String = name[layer_index_start..] + .chars() + .take_while(|ch| ch.is_ascii_digit()) + .collect(); + let Ok(layer_index) = layer_index_str.parse::() else { + continue; + }; + + let output_name = format!("present.{}.{}", layer_index, cache_type); + let output_id = model + .find_node(&output_name) + .ok_or(GeneratorError::OutputNotFound(output_name))?; + + kv_cache.push(KvCache { + input_id, + output_id, + cache: NdTensor::zeros([batch_size, n_heads, 0 /* seq len */, size]), + }); + } + + Ok(Generator { + model, + input_ids: vec![], + input_ids_input, + attention_mask_input, + position_ids_input, + logits_output, + kv_cache, + seq_len: 0, + }) + } + + /// Set the initial sequence of tokens (aka. the prompt) passed to the model + /// when it is first run. + pub fn with_prompt(mut self, prompt: &'a [u32]) -> Self { + self.input_ids = prompt.to_vec(); + self + } + + /// Run the model and generate the next token. + fn generate_next_token(&mut self) -> Result { + fn wrap_error(e: E) -> GeneratorError + where + E: Into>, + { + GeneratorError::GenerateError(e.into()) + } + + let batch_size = 1; + let input_ids: NdTensor = self + .input_ids + .iter() + .map(|id| *id as i32) + .collect::>() + .into_shape([batch_size, self.input_ids.len()]); + let attention_mask = NdTensor::full([batch_size, self.input_ids.len()], 1i32); + + let position_ids = NdTensor::from_fn([batch_size, input_ids.len()], |[_batch, pos]| { + self.seq_len as i32 + pos as i32 + }); + + let model_inputs: Vec<(NodeId, Input)> = [ + (self.input_ids_input, input_ids.view().into()), + (self.attention_mask_input, attention_mask.view().into()), + (self.position_ids_input, position_ids.view().into()), + ] + .into_iter() + .chain( + self.kv_cache + .iter() + .map(|entry| (entry.input_id, entry.cache.view().into())), + ) + .collect(); + + let model_outputs: Vec = [self.logits_output] + .into_iter() + .chain(self.kv_cache.iter().map(|entry| entry.output_id)) + .collect(); + + let mut outputs = self + .model + .run(model_inputs.as_slice(), &model_outputs, None) + .map_err(wrap_error)?; + + // Sample output token. + let logits: NdTensor = outputs.remove(0).try_into().map_err(wrap_error)?; + let next_ids = logits + .arg_max(-1, false /* keep_dims */) + .map_err(wrap_error)?; + let next_id = next_ids + .slice::<0, _>((0, -1)) + .item() + .map(|it| *it as u32) + .expect("expected scalar"); + + // Update the key-value cache. + for cache_entry in self.kv_cache.iter_mut() { + cache_entry.cache = outputs.remove(0).try_into().map_err(wrap_error)?; + } + + // Update the token IDs for the next iteration. + self.seq_len += self.input_ids.len() as u32; + self.input_ids = vec![next_id]; + + Ok(next_id) + } +} + +impl<'a> Iterator for Generator<'a> { + type Item = Result; + + /// Run the model and generate the next output token. + fn next(&mut self) -> Option { + Some(self.generate_next_token()) + } +} diff --git a/rten-examples/src/lib.rs b/rten-examples/src/lib.rs new file mode 100644 index 00000000..2a22e3d0 --- /dev/null +++ b/rten-examples/src/lib.rs @@ -0,0 +1 @@ +pub mod generator; From 2a8b295c97eb6d60815f489d38db31dedc628e7a Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 14 Dec 2023 06:48:12 +0000 Subject: [PATCH 2/2] Add GPT 2 text generation example and reference implementation --- rten-examples/Cargo.toml | 4 + rten-examples/README.md | 2 + rten-examples/src/generator.rs | 2 +- rten-examples/src/gpt2.rs | 126 ++++++++++++++++++++++++++++ rten-examples/src/gpt2_reference.py | 24 ++++++ 5 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 rten-examples/src/gpt2.rs create mode 100644 rten-examples/src/gpt2_reference.py diff --git a/rten-examples/Cargo.toml b/rten-examples/Cargo.toml index 3f0b31ac..5fc81c2f 100644 --- a/rten-examples/Cargo.toml +++ b/rten-examples/Cargo.toml @@ -60,6 +60,10 @@ path = "src/depth_anything.rs" name = "bert_qa" path = "src/bert_qa.rs" +[[bin]] +name = "gpt2" +path = "src/gpt2.rs" + [[bin]] name = "jina_similarity" path = "src/jina_similarity.rs" diff --git a/rten-examples/README.md b/rten-examples/README.md index 36c3a69e..998e5e60 100644 --- a/rten-examples/README.md +++ b/rten-examples/README.md @@ -62,6 +62,8 @@ The examples have been chosen to cover common tasks and popular models. - **bert_qa** - Extractive question answering using [BERT](https://arxiv.org/abs/1810.04805)-based models which have been fine-tuned on the [SQuAD](https://paperswithcode.com/dataset/squad) dataset +- **gpt2** - Text generation using the [GPT-2](https://openai.com/index/better-language-models/) +language model. - **jina_similarity** - Sentence similarity using vector embeddings of sentences ### Audio diff --git a/rten-examples/src/generator.rs b/rten-examples/src/generator.rs index c99d812f..fe44df7c 100644 --- a/rten-examples/src/generator.rs +++ b/rten-examples/src/generator.rs @@ -72,7 +72,7 @@ pub struct Generator<'a> { } impl<'a> Generator<'a> { - /// Create a generator that runs a given model. + /// Create a generator that iteratively produces tokens using a model. /// /// The model is expected to have the following inputs: /// diff --git a/rten-examples/src/gpt2.rs b/rten-examples/src/gpt2.rs new file mode 100644 index 00000000..fb510dcd --- /dev/null +++ b/rten-examples/src/gpt2.rs @@ -0,0 +1,126 @@ +use std::collections::VecDeque; +use std::error::Error; +use std::fs; +use std::io::prelude::*; + +use rten::Model; +use rten_examples::generator::Generator; +use rten_text::tokenizers::Tokenizer; + +struct Args { + model: String, + tokenizer_config: String, + prompt: String, + output_length: usize, +} + +fn parse_args() -> Result { + use lexopt::prelude::*; + + let mut values = VecDeque::new(); + let mut parser = lexopt::Parser::from_env(); + let mut output_length = 30; + + while let Some(arg) = parser.next()? { + match arg { + Short('l') | Long("length") => { + output_length = parser.value()?.parse()?; + } + Value(val) => values.push_back(val.string()?), + Long("help") => { + println!( + "Generate text using a prompt. + +Usage: {bin_name} [options] + +Args: + + - Input GPT-2 model + - `tokenizer.json` file + - Text generation prompt + +Options: + + -l, --length - Set max output length (in tokens) +", + bin_name = parser.bin_name().unwrap_or("gpt2") + ); + std::process::exit(0); + } + _ => return Err(arg.unexpected()), + } + } + + let model = values.pop_front().ok_or("missing `model` arg")?; + let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; + let prompt = values.make_contiguous().join(" "); + + let args = Args { + model, + tokenizer_config, + prompt, + output_length, + }; + + Ok(args) +} + +/// Generates text using GPT-2 [1] and a prompt. +/// +/// To obtain the model from Hugging Face, use Optimum [2], then convert it: +/// +/// ```sh +/// optimum-cli export onnx --model gpt2 gpt2_onnx/ +/// rten-convert gpt2_onnx/model.onnx +/// ``` +/// +/// Run the converted model with a prompt: +/// +/// ```sh +/// cargo run --release --bin gpt2 gpt2_onnx/model.rten gp2_onnx/tokenizer.json +/// ``` +/// +/// Where `` is the start of a sentence that the model should complete. +/// +/// [1] https://openai.com/research/better-language-models +/// [2] https://huggingface.co/docs/optimum/index +fn main() -> Result<(), Box> { + let args = parse_args()?; + let model_bytes = fs::read(args.model)?; + let model = Model::load(model_bytes)?; + + let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?; + let tokenizer = Tokenizer::from_json(&tokenizer_config)?; + + let prompt = args.prompt.as_str(); + let encoded_prompt = tokenizer.encode(prompt.into(), Default::default())?; + let token_ids: Vec = encoded_prompt + .token_ids() + .iter() + .map(|id| *id as u32) + .collect(); + + // The output starts with the user's prompt. + print!("{}", prompt); + + // Buffer that holds model output tokens until it forms a valid UTF-8 + // sequence. + let mut token_buf = Vec::new(); + + let generator = Generator::from_model(&model)?.with_prompt(&token_ids); + for token in generator.take(args.output_length) { + let token = token?; + token_buf.push(token as usize); + + let token_strings = tokenizer.encoder().get_tokens(&token_buf); + if let Ok(strings) = token_strings { + for s in strings { + print!("{}", s); + } + let _ = std::io::stdout().flush(); + token_buf.clear(); + } + } + + Ok(()) +} diff --git a/rten-examples/src/gpt2_reference.py b/rten-examples/src/gpt2_reference.py new file mode 100644 index 00000000..60517bf8 --- /dev/null +++ b/rten-examples/src/gpt2_reference.py @@ -0,0 +1,24 @@ +from argparse import ArgumentParser +from transformers import pipeline, set_seed + + +def main(): + parser = ArgumentParser(description="Generate text using GPT-2 and a prompt") + parser.add_argument("prompt", nargs="*") + parser.add_argument("--seed", type=int, help="Random seed") + args = parser.parse_args() + + prompt = " ".join(args.prompt) + if args.seed is not None: + set_seed(args.seed) + + print(f'prompt: "{prompt}"') + generator = pipeline("text-generation", model="gpt2") + + sequences = generator(prompt, max_length=30, num_return_sequences=1, do_sample=False) + for seq in sequences: + print(seq) + + +if __name__ == "__main__": + main()