Skip to content

Commit 00d2b19

Browse files
committed
Support cache_position inputs in Hugging Face models
Support the `cache_position` input that was added to Hugging Face Whisper models as part of a revision of how it handles KV-caching. This is like `position_ids`, but there is no batch dimension. See huggingface/optimum#1971 and huggingface/transformers#31166.
1 parent 63634b8 commit 00d2b19

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

rten-generate/src/generator.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,18 @@ pub struct ModelInputsConfig<'a> {
132132
pub attention_mask: &'a str,
133133

134134
/// Model input that contains position IDs for each position.
135-
pub position_ids: &'a str,
135+
///
136+
/// This input does not have a batch dimension.
137+
pub cache_position: &'a str,
136138

137139
/// Patterns for inputs and outputs used for key-value caches.
138140
pub kv_caches: Vec<KVCachePair<'a>>,
139141

142+
/// Model input that contains position IDs for each position.
143+
///
144+
/// This input has a batch dimension.
145+
pub position_ids: &'a str,
146+
140147
/// Boolean input that is set to false on the first run and true on
141148
/// subsequent runs.
142149
pub use_cache_flag: &'a str,
@@ -159,6 +166,7 @@ impl<'a> Default for ModelInputsConfig<'a> {
159166
input_ids: "input_ids",
160167
logits: "logits",
161168
attention_mask: "attention_mask",
169+
cache_position: "cache_position",
162170
position_ids: "position_ids",
163171
use_cache_flag: "use_cache_branch",
164172

@@ -317,6 +325,7 @@ impl<'a> Generator<'a> {
317325
/// The model may have the optional inputs:
318326
///
319327
/// - `attention_mask` - (batch, sequence) tensor of booleans
328+
/// - `cache_position` - (sequence) tensor of position indices
320329
/// - `position_ids` - (batch, sequence) tensor of position indices
321330
/// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache
322331
/// where `N` is the layer index
@@ -480,6 +489,15 @@ impl<'a> Generator<'a> {
480489
});
481490
}
482491

492+
let cache_position_input = model.find_node(model_inputs.cache_position);
493+
if let Some(cache_position_input) = cache_position_input {
494+
generator =
495+
generator.with_varying_input(cache_position_input, &|_batch_size, positions| {
496+
NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32)
497+
.into()
498+
});
499+
}
500+
483501
let use_cache_input = model.find_node(model_inputs.use_cache_flag);
484502
if let Some(use_cache_input) = use_cache_input {
485503
generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| {
@@ -981,6 +999,7 @@ mod tests {
981999
// Add inputs and outputs using the standard names.
9821000
let mut inputs = vec![
9831001
NodeInfo::from_name_shape("input_ids", &[]),
1002+
NodeInfo::from_name_shape("cache_position", &[]),
9841003
NodeInfo::from_name_shape("position_ids", &[]),
9851004
NodeInfo::from_name_shape("attention_mask", &[]),
9861005
];
@@ -1125,6 +1144,7 @@ mod tests {
11251144
let position_ids = model.find_node("position_ids").unwrap();
11261145
let attention_mask = model.find_node("attention_mask").unwrap();
11271146
let cache_branch = model.find_node("use_cache_branch");
1147+
let cache_position = model.find_node("cache_position").unwrap();
11281148

11291149
for step in 0..generation_len {
11301150
let step_inputs = model.get_inputs(step, input_id).unwrap();
@@ -1133,6 +1153,9 @@ mod tests {
11331153
let step_pos_ids = model.get_inputs(step, position_ids).unwrap();
11341154
let step_pos_ids: NdTensor<i32, 2> = step_pos_ids.try_into().unwrap();
11351155

1156+
let step_cache_pos = model.get_inputs(step, cache_position).unwrap();
1157+
let step_cache_pos: NdTensor<i32, 1> = step_cache_pos.try_into().unwrap();
1158+
11361159
let step_attn_mask = model.get_inputs(step, attention_mask).unwrap();
11371160
let step_attn_mask: NdTensor<i32, 2> = step_attn_mask.try_into().unwrap();
11381161

@@ -1155,6 +1178,12 @@ mod tests {
11551178
assert_eq!(step_pos_ids.size(1), prompt.len());
11561179
assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len()));
11571180

1181+
assert_eq!(step_cache_pos.size(0), prompt.len());
1182+
assert!(step_cache_pos
1183+
.iter()
1184+
.map(|x| *x as usize)
1185+
.eq(0..prompt.len()));
1186+
11581187
if let Some(cache_branch) = cache_branch {
11591188
assert_eq!(cache_branch.item(), Some(&0));
11601189
}
@@ -1168,6 +1197,9 @@ mod tests {
11681197
assert_eq!(step_pos_ids.size(1), 1);
11691198
assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32);
11701199

1200+
assert_eq!(step_cache_pos.size(0), 1);
1201+
assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32);
1202+
11711203
if let Some(cache_branch) = cache_branch {
11721204
assert_eq!(cache_branch.item(), Some(&1));
11731205
}
@@ -1194,7 +1226,11 @@ mod tests {
11941226
(0..prompt.len() + step).map(|x| x as i32).collect();
11951227
assert_eq!(
11961228
step_pos_ids,
1197-
NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids)
1229+
NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone())
1230+
);
1231+
assert_eq!(
1232+
step_cache_pos,
1233+
NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids)
11981234
);
11991235
}
12001236
}

0 commit comments

Comments
 (0)