diff --git a/rten-generate/src/generator.rs b/rten-generate/src/generator.rs index da79bbe2..679fff59 100644 --- a/rten-generate/src/generator.rs +++ b/rten-generate/src/generator.rs @@ -96,6 +96,23 @@ impl<'a> From<(&'a str, &'a str)> for KVCachePattern<'a> { } } +/// Specifies a pair of patterns for corresponding input and output key-value +/// cache entries. +pub struct KVCachePair<'a> { + /// The pattern for the model input name. + pub input: KVCachePattern<'a>, + + /// The pattern for the model output name. + pub output: KVCachePattern<'a>, + + /// Specifies whether this cache is used for a cross-attention ("encoder") + /// KV cache. + /// + /// Encoder KV-cache entries are computed only on the first run of the + /// model and reused in subsequent runs. + pub encoder: bool, +} + /// Specifies the names of model inputs and outputs. /// /// The [`Default`] impl for this struct returns an instance whose names @@ -116,17 +133,12 @@ pub struct ModelInputsConfig<'a> { /// Model input that contains position IDs for each position. pub position_ids: &'a str, - /// Pattern for key cache inputs. - pub key_cache: KVCachePattern<'a>, + /// Patterns for inputs and outputs used for key-value caches. + pub kv_caches: Vec>, - /// Pattern for key cache outputs. - pub key_cache_output: KVCachePattern<'a>, - - /// Pattern for value cache inputs. - pub value_cache: KVCachePattern<'a>, - - /// Pattern for value cache outputs. - pub value_cache_output: KVCachePattern<'a>, + /// Boolean input that is set to false on the first run and true on + /// subsequent runs. + pub use_cache_flag: &'a str, } /// Contains essential configuration needed for a `Generator` to execute a @@ -147,10 +159,47 @@ impl<'a> Default for ModelInputsConfig<'a> { logits: "logits", attention_mask: "attention_mask", position_ids: "position_ids", - key_cache: ("past_key_values.", ".key").into(), - key_cache_output: ("present.", ".key").into(), - value_cache: ("past_key_values.", ".value").into(), - value_cache_output: ("present.", ".value").into(), + use_cache_flag: "use_cache_branch", + + // Patterns are matched in order, so patterns with longer prefixes/ + // suffixes are listed first to ensure we match them. + kv_caches: [ + // "Merged" decoders exported by Optimum for encoder-decoder + // models. These have KV caches for both the self-attention and + // cross-attention modules. + KVCachePair { + input: ("past_key_values.", ".decoder.key").into(), + output: ("present.", ".decoder.key").into(), + encoder: false, + }, + KVCachePair { + input: ("past_key_values.", ".decoder.value").into(), + output: ("present.", ".decoder.value").into(), + encoder: false, + }, + KVCachePair { + input: ("past_key_values.", ".encoder.key").into(), + output: ("present.", ".encoder.key").into(), + encoder: true, + }, + KVCachePair { + input: ("past_key_values.", ".encoder.value").into(), + output: ("present.", ".encoder.value").into(), + encoder: true, + }, + // Decoder-only models exported by Optimum. + KVCachePair { + input: ("past_key_values.", ".key").into(), + output: ("present.", ".key").into(), + encoder: false, + }, + KVCachePair { + input: ("past_key_values.", ".value").into(), + output: ("present.", ".value").into(), + encoder: false, + }, + ] + .into(), } } } @@ -215,8 +264,14 @@ pub struct Generator<'a> { /// Length of the sequence generated so far. seq_len: u32, - /// Key-value cache. + /// Self-attention key-value cache. This is extended on each iteration. kv_cache: Vec, + + /// Cross-attention key-value cache. + /// + /// This is used by encoder-decoder models. The cross-attention values + /// are computed on the first run and reused in subsequent runs. + encoder_kv_cache: Vec, } impl<'a> Generator<'a> { @@ -284,6 +339,7 @@ impl<'a> Generator<'a> { // Find inputs and corresponding outputs for key-value cache. let batch_size = 1; let mut kv_cache = Vec::new(); + let mut encoder_kv_cache = Vec::new(); for &input_id in model.input_ids() { let input_info = model .node_info(input_id) @@ -293,14 +349,14 @@ impl<'a> Generator<'a> { )))?; let name = input_info.name(); - let is_key_cache = name.starts_with(model_inputs.key_cache.prefix) - && name.ends_with(model_inputs.key_cache.suffix); - let is_value_cache = name.starts_with(model_inputs.value_cache.prefix) - && name.ends_with(model_inputs.value_cache.suffix); - if !is_key_cache && !is_value_cache { + let Some(kv_pattern) = model_inputs + .kv_caches + .iter() + .find(|pat| name.starts_with(pat.input.prefix) && name.ends_with(pat.input.suffix)) + else { continue; - } + }; let (n_heads, size) = match *input_info.shape() { [_, Dimension::Fixed(n_heads), _, Dimension::Fixed(size)] => (Some(n_heads), size), @@ -310,32 +366,17 @@ impl<'a> Generator<'a> { } }; - let prefix = if is_key_cache { - model_inputs.key_cache.prefix - } else { - model_inputs.value_cache.prefix - }; + let prefix = kv_pattern.input.prefix; let layer_index_start = prefix.len(); - let layer_index_str: String = name[layer_index_start..] - .chars() - .take_while(|ch| ch.is_ascii_digit()) - .collect(); + let layer_index_end = name.len() - kv_pattern.input.suffix.len(); + let layer_index_str = &name[layer_index_start..layer_index_end]; let Ok(layer_index) = layer_index_str.parse::() else { continue; }; - let (output_prefix, output_suffix) = if is_key_cache { - ( - model_inputs.key_cache_output.prefix, - model_inputs.key_cache_output.suffix, - ) - } else { - ( - model_inputs.value_cache_output.prefix, - model_inputs.value_cache_output.suffix, - ) - }; + let output_prefix = kv_pattern.output.prefix; + let output_suffix = kv_pattern.output.suffix; let output_name = format!("{}{}{}", output_prefix, layer_index, output_suffix); let output_id = model @@ -345,7 +386,7 @@ impl<'a> Generator<'a> { // This value should be configurable. let max_seq_len = 512; - kv_cache.push(KvCache { + let kv_cache_entry = KvCache { input_id, output_id, cache: if let Some(n_heads) = n_heads { @@ -359,7 +400,13 @@ impl<'a> Generator<'a> { 1, /* seq dim */ ))) }, - }); + }; + + if kv_pattern.encoder { + encoder_kv_cache.push(kv_cache_entry); + } else { + kv_cache.push(kv_cache_entry); + } } let mut generator = Generator { @@ -376,6 +423,7 @@ impl<'a> Generator<'a> { input_ids_input, logits_output, kv_cache, + encoder_kv_cache, seq_len: 0, sampler: Box::new(ArgMaxSampler {}), }; @@ -399,6 +447,13 @@ impl<'a> Generator<'a> { }); } + let use_cache_input = model.find_node(model_inputs.use_cache_flag); + if let Some(use_cache_input) = use_cache_input { + generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| { + Tensor::from(if positions.start == 0 { 0i32 } else { 1 }).into() + }); + } + Ok(generator) } @@ -520,10 +575,24 @@ impl<'a> Generator<'a> { } } + // Add cross-attention key-value cache. + for entry in self.encoder_kv_cache.iter() { + match &entry.cache { + Some(KvCacheData::BatchSeqChans(cache)) => { + model_inputs.push((entry.input_id, cache.into())); + } + Some(KvCacheData::BatchHeadSeqChans(cache)) => { + model_inputs.push((entry.input_id, cache.into())); + } + None => {} + } + } + // Run the model and collect outputs and updated KV cache. let model_outputs: Vec = [self.logits_output] .into_iter() .chain(self.kv_cache.iter().map(|entry| entry.output_id)) + .chain(self.encoder_kv_cache.iter().map(|entry| entry.output_id)) .collect(); let mut outputs = self @@ -535,13 +604,34 @@ impl<'a> Generator<'a> { let logits: NdTensor = outputs.remove(0).try_into().map_err(wrap_error)?; let next_id = self.sampler.sample(logits.slice::<1, _>((0, -1))); - // Update the key-value cache. + // Update the self-attention key-value cache. // // The KV cache tensors returned from the model should be the same as // the passed in tensors, but extended by one element along the sequence // axis. for cache_entry in self.kv_cache.iter_mut() { let output = outputs.remove(0); + + let kv_cache = match output.ndim() { + 3 => KvCacheData::BatchSeqChans(output.try_into().map_err(wrap_error)?), + 4 => KvCacheData::BatchHeadSeqChans(output.try_into().map_err(wrap_error)?), + _ => { + return Err(wrap_error("expected KV cache output to have 3 or 4 dims")); + } + }; + cache_entry.cache = Some(kv_cache); + } + + // Update the cross-attention key-value cache. + for cache_entry in self.encoder_kv_cache.iter_mut() { + let output = outputs.remove(0); + if output.is_empty() { + // Optimum-exported models only return encoder KV-cache tensors + // on the first run and dummy empty tensors on subsequent runs. + // Ignore these and continue to use the value from the first run. + continue; + } + let kv_cache = match output.ndim() { 3 => KvCacheData::BatchSeqChans(output.try_into().map_err(wrap_error)?), 4 => KvCacheData::BatchHeadSeqChans(output.try_into().map_err(wrap_error)?), @@ -711,6 +801,11 @@ mod tests { { return Err(format!("invalid input ID {}", input_id).into()); } + for &expected_input in self.input_ids.iter() { + if !inputs.iter().any(|&(id, _)| id == expected_input) { + return Err(format!("missing input ID {}", expected_input).into()); + } + } if let Some(output_id) = outputs.iter().find(|id| !self.output_ids.contains(id)) { return Err(format!("invalid output ID {}", output_id).into()); @@ -785,11 +880,20 @@ mod tests { } } + #[derive(Copy, Clone, PartialEq)] + enum KvCacheType { + /// Add KV-cache inputs and outputs for self-attention. + Decoder, + /// Add KV-cache inputs and outputs for self-attention and cross- + /// attention. + EncoderDecoder, + } + /// Create a fake transformer model using the default names for inputs and /// outputs. fn fake_transformer_model( params: TransformerParams, - use_kv_cache: bool, + kv_cache: Option, prompt_len: usize, output_token_ids: &[u32], ) -> FakeModel { @@ -810,26 +914,58 @@ mod tests { // Add KV-cache inputs and outputs. let mut kv_cache_output_names = Vec::new(); - if use_kv_cache { + if let Some(kv_cache_type) = kv_cache { + let dims = [ + Dimension::Symbolic("batch".to_string()), + Dimension::Fixed(n_heads as usize), + Dimension::Symbolic("seq".to_string()), + Dimension::Fixed(n_embed), + ]; + let make_name_info = |name: &str| NodeInfo::from_name_shape(name, &dims); + for layer in 0..n_layers { - let dims = [ - Dimension::Symbolic("batch".to_string()), - Dimension::Fixed(n_heads as usize), - Dimension::Symbolic("seq".to_string()), - Dimension::Fixed(n_embed), - ]; - let past_key_name = format!("past_key_values.{}.key", layer); - let past_value_name = format!("past_key_values.{}.value", layer); - let present_key_name = format!("present.{}.key", layer); - let present_value_name = format!("present.{}.value", layer); - - inputs.push(NodeInfo::from_name_shape(&past_key_name, &dims)); - inputs.push(NodeInfo::from_name_shape(&past_value_name, &dims)); - - outputs.push(NodeInfo::from_name_shape(&present_key_name, &dims)); - outputs.push(NodeInfo::from_name_shape(&present_value_name, &dims)); - kv_cache_output_names.push(present_key_name); - kv_cache_output_names.push(present_value_name); + let past_names: Vec; + let present_names: Vec; + + match kv_cache_type { + KvCacheType::Decoder => { + past_names = [ + format!("past_key_values.{}.key", layer), + format!("past_key_values.{}.value", layer), + ] + .into(); + present_names = [ + format!("present.{}.key", layer), + format!("present.{}.value", layer), + ] + .into(); + } + KvCacheType::EncoderDecoder => { + past_names = [ + format!("past_key_values.{}.decoder.key", layer), + format!("past_key_values.{}.decoder.value", layer), + format!("past_key_values.{}.encoder.key", layer), + format!("past_key_values.{}.encoder.value", layer), + ] + .into(); + + present_names = [ + format!("present.{}.decoder.key", layer), + format!("present.{}.decoder.value", layer), + format!("present.{}.encoder.key", layer), + format!("present.{}.encoder.value", layer), + ] + .into(); + } + } + + inputs.extend(past_names.iter().map(|name| make_name_info(&name))); + outputs.extend(present_names.iter().map(|name| make_name_info(&name))); + kv_cache_output_names.extend(present_names); + } + + if kv_cache_type == KvCacheType::EncoderDecoder { + inputs.push(NodeInfo::from_name_shape("use_cache_branch", &[])); } } @@ -842,7 +978,7 @@ mod tests { "token ID is invalid for vocab size" ); - let logits = if use_kv_cache { + let logits = if kv_cache.is_some() { generate_logits(n_vocab, &[output_token_id]) } else { generate_logits(n_vocab, &output_token_ids[..=step]) @@ -859,9 +995,28 @@ mod tests { } else { prompt_len + step - 1 }; + + let is_encoder = model + .node_info(kv_output_id) + .as_ref() + .map(|ni| ni.name()) + .unwrap_or("") + .contains("encoder"); + + let output_n_embed = if is_encoder && step > 0 { + // Encoder KV cache outputs are only used on the first run. + // On subsequent runs return a dummy output, which should + // be ignored. + 0 + } else { + n_embed + }; + outputs.insert( kv_output_id, - Output::FloatTensor(NdTensor::zeros([1, n_heads, context_len, n_embed]).into()), + Output::FloatTensor( + NdTensor::zeros([1, n_heads, context_len, output_n_embed]).into(), + ), ); } @@ -871,11 +1026,12 @@ mod tests { model } - fn test_generator_impl(use_kv_cache: bool) -> Result<(), Box> { + fn test_generator_impl(kv_cache_type: Option) -> Result<(), Box> { let params = TransformerParams::default(); let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0]; let prompt = [1, 2, 3, 1, 2, 3]; - let model = fake_transformer_model(params, use_kv_cache, prompt.len(), &expected_token_ids); + let model = + fake_transformer_model(params, kv_cache_type, prompt.len(), &expected_token_ids); let generator = Generator::from_model(&model)?; let generation_len = 10; @@ -894,6 +1050,7 @@ mod tests { let input_id = model.find_node("input_ids").unwrap(); let position_ids = model.find_node("position_ids").unwrap(); let attention_mask = model.find_node("attention_mask").unwrap(); + let cache_branch = model.find_node("use_cache_branch"); for step in 0..generation_len { let step_inputs = model.get_inputs(step, input_id).unwrap(); @@ -905,6 +1062,12 @@ mod tests { let step_attn_mask = model.get_inputs(step, attention_mask).unwrap(); let step_attn_mask: NdTensor = step_attn_mask.try_into().unwrap(); + let cache_branch = cache_branch.map(|cb_id| { + let cb = model.get_inputs(step, cb_id).unwrap(); + let cb: NdTensor = cb.try_into().unwrap(); + cb + }); + if step == 0 { assert_eq!(step_inputs.size(1), prompt.len()); assert!(step_inputs @@ -917,7 +1080,11 @@ mod tests { assert_eq!(step_pos_ids.size(1), prompt.len()); assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len())); - } else if use_kv_cache { + + if let Some(cache_branch) = cache_branch { + assert_eq!(cache_branch.item(), Some(&0)); + } + } else if kv_cache_type.is_some() { assert_eq!(step_inputs.size(1), 1); assert_eq!(step_inputs[[0, 0]] as u32, expected_token_ids[step - 1]); @@ -926,6 +1093,10 @@ mod tests { assert_eq!(step_pos_ids.size(1), 1); assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32); + + if let Some(cache_branch) = cache_branch { + assert_eq!(cache_branch.item(), Some(&1)); + } } else { let expected_inputs: Vec = prompt .iter() @@ -958,13 +1129,18 @@ mod tests { } #[test] - fn test_generator() -> Result<(), Box> { - test_generator_impl(true /* use_kv_cache */) + fn test_generator_with_decoder_kv_cache() -> Result<(), Box> { + test_generator_impl(Some(KvCacheType::Decoder)) + } + + #[test] + fn test_generator_with_encoder_decoder_kv_cache() -> Result<(), Box> { + test_generator_impl(Some(KvCacheType::EncoderDecoder)) } #[test] fn test_generator_without_kv_cache() -> Result<(), Box> { - test_generator_impl(false /* use_kv_cache */) + test_generator_impl(None) } #[test] @@ -975,7 +1151,7 @@ mod tests { let prompt = [99]; let model = fake_transformer_model( params, - true, /* use_kv_cache */ + Some(KvCacheType::Decoder), prompt.len(), &output_token_ids, ); @@ -1015,7 +1191,7 @@ mod tests { let prompt = [1, 2, 3, 1, 2, 3]; let model = fake_transformer_model( params, - true, /* use_kv_cache */ + Some(KvCacheType::Decoder), prompt.len(), &expected_token_ids, ); @@ -1040,7 +1216,7 @@ mod tests { let prompt = [1, 2, 3, 1, 2, 3]; let model = fake_transformer_model( params, - true, /* use_kv_cache */ + Some(KvCacheType::Decoder), prompt.len(), &expected_token_ids, );