Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support models without a key-value cache in rten-generate #305

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 108 additions & 35 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,34 @@ impl<'a> Default for ModelInputsConfig<'a> {
}
}

/// Generates a token ID sequence using an auto-regressive language model.
/// Generates a token ID sequence using a transformer decoder model.
///
/// This is an iterator that runs the model on each call to [`Iterator::next`]
/// and yields a result containing the next token ID or an error.
///
/// The token ID sequence can be converted to text using the
/// [`decode`](GeneratorUtils::decode) method of the [`GeneratorUtils`] trait.
///
/// This trait also provides useful wrappers for the output, such as stopping
/// generation when an end-of-text token is reached. You can also use all of
/// the standard iterator adapters. For example `generator.take(30)` will
/// return an iterator that stops generation after 30 tokens have been produced).
/// The `GeneratorUtils` trait also provides useful wrappers for the output,
/// such as stopping generation when an end-of-text token is reached. You can
/// also use all of the standard iterator adapters. For example
/// `generator.take(30)` will return an iterator that stops generation after 30
/// tokens have been produced).
///
/// ## Sampling
///
/// The token ID is sampled from the outputs of the model (the "logits") using
/// a [`Sampler`]. By default this is an [`ArgMaxSampler`] which simply chooses
/// the token with the highest probability. The sampler can be configured using
/// [`with_sampler`](Self::with_sampler).
///
/// ## Key-value caches and generation performance
///
/// To enable efficient decoding, the model should have inputs and outputs for
/// the [key-value
/// cache](https://peterchng.com/blog/2024/06/11/what-is-the-transformer-kv-cache/).
/// The generator will work with models that do not have cache inputs, but
/// decoding of long output sequences will be much slower.
pub struct Generator<'a> {
model: &'a dyn Model,

Expand Down Expand Up @@ -230,6 +239,9 @@ impl<'a> Generator<'a> {
/// - `past_key_values.N.value` - (batch, head, past_key_values, size) value vector cache,
/// where `N` is the layer index
///
/// **Warning:** Generation of long sequences will be much slower in models without
/// key-value caches.
///
/// The model must have the outputs:
///
/// - `logits` - output (batch, sequence, vocab) tensor of next token probabilities
Expand Down Expand Up @@ -540,9 +552,13 @@ impl<'a> Generator<'a> {
cache_entry.cache = Some(kv_cache);
}

// Update the token IDs for the next iteration.
self.seq_len += self.input_ids.len() as u32;
self.input_ids = vec![next_id];
// Update the token IDs and sequence offset for the next iteration.
if !self.kv_cache.is_empty() {
self.seq_len += self.input_ids.len() as u32;
self.input_ids = vec![next_id];
} else {
self.input_ids.push(next_id);
}

Ok(next_id)
}
Expand Down Expand Up @@ -773,6 +789,7 @@ mod tests {
/// outputs.
fn fake_transformer_model(
params: TransformerParams,
use_kv_cache: bool,
prompt_len: usize,
output_token_ids: &[u32],
) -> FakeModel {
Expand All @@ -793,25 +810,27 @@ mod tests {

// Add KV-cache inputs and outputs.
let mut kv_cache_output_names = Vec::new();
for layer in 0..n_layers {
let dims = [
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(n_heads as usize),
Dimension::Symbolic("seq".to_string()),
Dimension::Fixed(n_embed),
];
let past_key_name = format!("past_key_values.{}.key", layer);
let past_value_name = format!("past_key_values.{}.value", layer);
let present_key_name = format!("present.{}.key", layer);
let present_value_name = format!("present.{}.value", layer);

inputs.push(NodeInfo::from_name_shape(&past_key_name, &dims));
inputs.push(NodeInfo::from_name_shape(&past_value_name, &dims));

outputs.push(NodeInfo::from_name_shape(&present_key_name, &dims));
outputs.push(NodeInfo::from_name_shape(&present_value_name, &dims));
kv_cache_output_names.push(present_key_name);
kv_cache_output_names.push(present_value_name);
if use_kv_cache {
for layer in 0..n_layers {
let dims = [
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(n_heads as usize),
Dimension::Symbolic("seq".to_string()),
Dimension::Fixed(n_embed),
];
let past_key_name = format!("past_key_values.{}.key", layer);
let past_value_name = format!("past_key_values.{}.value", layer);
let present_key_name = format!("present.{}.key", layer);
let present_value_name = format!("present.{}.value", layer);

inputs.push(NodeInfo::from_name_shape(&past_key_name, &dims));
inputs.push(NodeInfo::from_name_shape(&past_value_name, &dims));

outputs.push(NodeInfo::from_name_shape(&present_key_name, &dims));
outputs.push(NodeInfo::from_name_shape(&present_value_name, &dims));
kv_cache_output_names.push(present_key_name);
kv_cache_output_names.push(present_value_name);
}
}

let mut model = FakeModel::with_inputs_and_outputs(&inputs, &outputs);
Expand All @@ -823,7 +842,12 @@ mod tests {
"token ID is invalid for vocab size"
);

let logits = generate_logits(n_vocab, &[output_token_id]);
let logits = if use_kv_cache {
generate_logits(n_vocab, &[output_token_id])
} else {
generate_logits(n_vocab, &output_token_ids[..=step])
};

let mut outputs = HashMap::new();
outputs.insert(logits_id, Output::FloatTensor(logits.into()));

Expand All @@ -847,12 +871,11 @@ mod tests {
model
}

#[test]
fn test_generator() -> Result<(), Box<dyn Error>> {
fn test_generator_impl(use_kv_cache: bool) -> Result<(), Box<dyn Error>> {
let params = TransformerParams::default();
let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
let prompt = [1, 2, 3, 1, 2, 3];
let model = fake_transformer_model(params, prompt.len(), &expected_token_ids);
let model = fake_transformer_model(params, use_kv_cache, prompt.len(), &expected_token_ids);

let generator = Generator::from_model(&model)?;
let generation_len = 10;
Expand Down Expand Up @@ -894,7 +917,7 @@ mod tests {

assert_eq!(step_pos_ids.size(1), prompt.len());
assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len()));
} else {
} else if use_kv_cache {
assert_eq!(step_inputs.size(1), 1);
assert_eq!(step_inputs[[0, 0]] as u32, expected_token_ids[step - 1]);

Expand All @@ -903,19 +926,59 @@ mod tests {

assert_eq!(step_pos_ids.size(1), 1);
assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32);
} else {
let expected_inputs: Vec<i32> = prompt
.iter()
.copied()
.chain(expected_token_ids)
.take(prompt.len() + step)
.map(|x| x as i32)
.collect();
assert_eq!(
step_inputs,
NdTensor::from_data([1, expected_inputs.len()], expected_inputs)
);

let expected_attn_mask = vec![1i32; prompt.len() + step];
assert_eq!(
step_attn_mask,
NdTensor::from_data([1, expected_attn_mask.len()], expected_attn_mask)
);

let expected_pos_ids: Vec<i32> =
(0..prompt.len() + step).map(|x| x as i32).collect();
assert_eq!(
step_pos_ids,
NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids)
);
}
}

Ok(())
}

#[test]
fn test_generator() -> Result<(), Box<dyn Error>> {
test_generator_impl(true /* use_kv_cache */)
}

#[test]
fn test_generator_without_kv_cache() -> Result<(), Box<dyn Error>> {
test_generator_impl(false /* use_kv_cache */)
}

#[test]
fn test_generator_append_prompt() -> Result<(), Box<dyn Error>> {
let mut params = TransformerParams::default();
params.n_vocab = 110;
let output_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8];
let prompt = [99];
let model = fake_transformer_model(params, prompt.len(), &output_token_ids);
let model = fake_transformer_model(
params,
true, /* use_kv_cache */
prompt.len(),
&output_token_ids,
);

let mut generator = Generator::from_model(&model)?.with_prompt(&prompt);

Expand Down Expand Up @@ -950,7 +1013,12 @@ mod tests {
let params = TransformerParams::default();
let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
let prompt = [1, 2, 3, 1, 2, 3];
let model = fake_transformer_model(params, prompt.len(), &expected_token_ids);
let model = fake_transformer_model(
params,
true, /* use_kv_cache */
prompt.len(),
&expected_token_ids,
);

let generator = Generator::from_model(&model)?;

Expand All @@ -970,7 +1038,12 @@ mod tests {
let params = TransformerParams::default();
let expected_token_ids = [0, 1, 2, 3, 4];
let prompt = [1, 2, 3, 1, 2, 3];
let model = fake_transformer_model(params, prompt.len(), &expected_token_ids);
let model = fake_transformer_model(
params,
true, /* use_kv_cache */
prompt.len(),
&expected_token_ids,
);

let generator = Generator::from_model(&model)?;
let mut metrics = Metrics::new();
Expand Down
Loading