Skip to content

Commit

Permalink
Add GPT 2 text generation example and reference implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Jun 4, 2024
1 parent bbe71c5 commit 2a8b295
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 1 deletion.
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ path = "src/depth_anything.rs"
name = "bert_qa"
path = "src/bert_qa.rs"

[[bin]]
name = "gpt2"
path = "src/gpt2.rs"

[[bin]]
name = "jina_similarity"
path = "src/jina_similarity.rs"
Expand Down
2 changes: 2 additions & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ The examples have been chosen to cover common tasks and popular models.
- **bert_qa** - Extractive question answering using
[BERT](https://arxiv.org/abs/1810.04805)-based models which have been fine-tuned
on the [SQuAD](https://paperswithcode.com/dataset/squad) dataset
- **gpt2** - Text generation using the [GPT-2](https://openai.com/index/better-language-models/)
language model.
- **jina_similarity** - Sentence similarity using vector embeddings of sentences

### Audio
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub struct Generator<'a> {
}

impl<'a> Generator<'a> {
/// Create a generator that runs a given model.
/// Create a generator that iteratively produces tokens using a model.
///
/// The model is expected to have the following inputs:
///
Expand Down
126 changes: 126 additions & 0 deletions rten-examples/src/gpt2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::io::prelude::*;

use rten::Model;
use rten_examples::generator::Generator;
use rten_text::tokenizers::Tokenizer;

struct Args {
model: String,
tokenizer_config: String,
prompt: String,
output_length: usize,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();
let mut output_length = 30;

while let Some(arg) = parser.next()? {
match arg {
Short('l') | Long("length") => {
output_length = parser.value()?.parse()?;
}
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Generate text using a prompt.
Usage: {bin_name} [options] <model> <tokenizer> <prompt>
Args:
<model> - Input GPT-2 model
<tokenizer> - `tokenizer.json` file
<prompt> - Text generation prompt
Options:
-l, --length - Set max output length (in tokens)
",
bin_name = parser.bin_name().unwrap_or("gpt2")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let model = values.pop_front().ok_or("missing `model` arg")?;
let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?;
let prompt = values.make_contiguous().join(" ");

let args = Args {
model,
tokenizer_config,
prompt,
output_length,
};

Ok(args)
}

/// Generates text using GPT-2 [1] and a prompt.
///
/// To obtain the model from Hugging Face, use Optimum [2], then convert it:
///
/// ```sh
/// optimum-cli export onnx --model gpt2 gpt2_onnx/
/// rten-convert gpt2_onnx/model.onnx
/// ```
///
/// Run the converted model with a prompt:
///
/// ```sh
/// cargo run --release --bin gpt2 gpt2_onnx/model.rten gp2_onnx/tokenizer.json <prompt>
/// ```
///
/// Where `<prompt>` is the start of a sentence that the model should complete.
///
/// [1] https://openai.com/research/better-language-models
/// [2] https://huggingface.co/docs/optimum/index
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(model_bytes)?;

let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?;
let tokenizer = Tokenizer::from_json(&tokenizer_config)?;

let prompt = args.prompt.as_str();
let encoded_prompt = tokenizer.encode(prompt.into(), Default::default())?;
let token_ids: Vec<u32> = encoded_prompt
.token_ids()
.iter()
.map(|id| *id as u32)
.collect();

// The output starts with the user's prompt.
print!("{}", prompt);

// Buffer that holds model output tokens until it forms a valid UTF-8
// sequence.
let mut token_buf = Vec::new();

let generator = Generator::from_model(&model)?.with_prompt(&token_ids);
for token in generator.take(args.output_length) {
let token = token?;
token_buf.push(token as usize);

let token_strings = tokenizer.encoder().get_tokens(&token_buf);
if let Ok(strings) = token_strings {
for s in strings {
print!("{}", s);
}
let _ = std::io::stdout().flush();
token_buf.clear();
}
}

Ok(())
}
24 changes: 24 additions & 0 deletions rten-examples/src/gpt2_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from argparse import ArgumentParser
from transformers import pipeline, set_seed


def main():
parser = ArgumentParser(description="Generate text using GPT-2 and a prompt")
parser.add_argument("prompt", nargs="*")
parser.add_argument("--seed", type=int, help="Random seed")
args = parser.parse_args()

prompt = " ".join(args.prompt)
if args.seed is not None:
set_seed(args.seed)

print(f'prompt: "{prompt}"')
generator = pipeline("text-generation", model="gpt2")

sequences = generator(prompt, max_length=30, num_return_sequences=1, do_sample=False)
for seq in sequences:
print(seq)


if __name__ == "__main__":
main()

0 comments on commit 2a8b295

Please sign in to comment.