@@ -132,11 +132,18 @@ pub struct ModelInputsConfig<'a> {
132132 pub attention_mask : & ' a str ,
133133
134134 /// Model input that contains position IDs for each position.
135- pub position_ids : & ' a str ,
135+ ///
136+ /// This input does not have a batch dimension.
137+ pub cache_position : & ' a str ,
136138
137139 /// Patterns for inputs and outputs used for key-value caches.
138140 pub kv_caches : Vec < KVCachePair < ' a > > ,
139141
142+ /// Model input that contains position IDs for each position.
143+ ///
144+ /// This input has a batch dimension.
145+ pub position_ids : & ' a str ,
146+
140147 /// Boolean input that is set to false on the first run and true on
141148 /// subsequent runs.
142149 pub use_cache_flag : & ' a str ,
@@ -159,6 +166,7 @@ impl<'a> Default for ModelInputsConfig<'a> {
159166 input_ids : "input_ids" ,
160167 logits : "logits" ,
161168 attention_mask : "attention_mask" ,
169+ cache_position : "cache_position" ,
162170 position_ids : "position_ids" ,
163171 use_cache_flag : "use_cache_branch" ,
164172
@@ -317,6 +325,7 @@ impl<'a> Generator<'a> {
317325 /// The model may have the optional inputs:
318326 ///
319327 /// - `attention_mask` - (batch, sequence) tensor of booleans
328+ /// - `cache_position` - (sequence) tensor of position indices
320329 /// - `position_ids` - (batch, sequence) tensor of position indices
321330 /// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache
322331 /// where `N` is the layer index
@@ -480,6 +489,15 @@ impl<'a> Generator<'a> {
480489 } ) ;
481490 }
482491
492+ let cache_position_input = model. find_node ( model_inputs. cache_position ) ;
493+ if let Some ( cache_position_input) = cache_position_input {
494+ generator =
495+ generator. with_varying_input ( cache_position_input, & |_batch_size, positions| {
496+ NdTensor :: from_fn ( [ positions. len ( ) ] , |[ pos] | ( positions. start + pos) as i32 )
497+ . into ( )
498+ } ) ;
499+ }
500+
483501 let use_cache_input = model. find_node ( model_inputs. use_cache_flag ) ;
484502 if let Some ( use_cache_input) = use_cache_input {
485503 generator = generator. with_varying_input ( use_cache_input, & |_batch_size, positions| {
@@ -981,6 +999,7 @@ mod tests {
981999 // Add inputs and outputs using the standard names.
9821000 let mut inputs = vec ! [
9831001 NodeInfo :: from_name_shape( "input_ids" , & [ ] ) ,
1002+ NodeInfo :: from_name_shape( "cache_position" , & [ ] ) ,
9841003 NodeInfo :: from_name_shape( "position_ids" , & [ ] ) ,
9851004 NodeInfo :: from_name_shape( "attention_mask" , & [ ] ) ,
9861005 ] ;
@@ -1125,6 +1144,7 @@ mod tests {
11251144 let position_ids = model. find_node ( "position_ids" ) . unwrap ( ) ;
11261145 let attention_mask = model. find_node ( "attention_mask" ) . unwrap ( ) ;
11271146 let cache_branch = model. find_node ( "use_cache_branch" ) ;
1147+ let cache_position = model. find_node ( "cache_position" ) . unwrap ( ) ;
11281148
11291149 for step in 0 ..generation_len {
11301150 let step_inputs = model. get_inputs ( step, input_id) . unwrap ( ) ;
@@ -1133,6 +1153,9 @@ mod tests {
11331153 let step_pos_ids = model. get_inputs ( step, position_ids) . unwrap ( ) ;
11341154 let step_pos_ids: NdTensor < i32 , 2 > = step_pos_ids. try_into ( ) . unwrap ( ) ;
11351155
1156+ let step_cache_pos = model. get_inputs ( step, cache_position) . unwrap ( ) ;
1157+ let step_cache_pos: NdTensor < i32 , 1 > = step_cache_pos. try_into ( ) . unwrap ( ) ;
1158+
11361159 let step_attn_mask = model. get_inputs ( step, attention_mask) . unwrap ( ) ;
11371160 let step_attn_mask: NdTensor < i32 , 2 > = step_attn_mask. try_into ( ) . unwrap ( ) ;
11381161
@@ -1155,6 +1178,12 @@ mod tests {
11551178 assert_eq ! ( step_pos_ids. size( 1 ) , prompt. len( ) ) ;
11561179 assert ! ( step_pos_ids. iter( ) . map( |x| * x as usize ) . eq( 0 ..prompt. len( ) ) ) ;
11571180
1181+ assert_eq ! ( step_cache_pos. size( 0 ) , prompt. len( ) ) ;
1182+ assert ! ( step_cache_pos
1183+ . iter( )
1184+ . map( |x| * x as usize )
1185+ . eq( 0 ..prompt. len( ) ) ) ;
1186+
11581187 if let Some ( cache_branch) = cache_branch {
11591188 assert_eq ! ( cache_branch. item( ) , Some ( & 0 ) ) ;
11601189 }
@@ -1168,6 +1197,9 @@ mod tests {
11681197 assert_eq ! ( step_pos_ids. size( 1 ) , 1 ) ;
11691198 assert_eq ! ( step_pos_ids[ [ 0 , 0 ] ] , ( prompt. len( ) + step - 1 ) as i32 ) ;
11701199
1200+ assert_eq ! ( step_cache_pos. size( 0 ) , 1 ) ;
1201+ assert_eq ! ( step_cache_pos[ [ 0 ] ] , ( prompt. len( ) + step - 1 ) as i32 ) ;
1202+
11711203 if let Some ( cache_branch) = cache_branch {
11721204 assert_eq ! ( cache_branch. item( ) , Some ( & 1 ) ) ;
11731205 }
@@ -1194,7 +1226,11 @@ mod tests {
11941226 ( 0 ..prompt. len ( ) + step) . map ( |x| x as i32 ) . collect ( ) ;
11951227 assert_eq ! (
11961228 step_pos_ids,
1197- NdTensor :: from_data( [ 1 , expected_pos_ids. len( ) ] , expected_pos_ids)
1229+ NdTensor :: from_data( [ 1 , expected_pos_ids. len( ) ] , expected_pos_ids. clone( ) )
1230+ ) ;
1231+ assert_eq ! (
1232+ step_cache_pos,
1233+ NdTensor :: from_data( [ expected_pos_ids. len( ) ] , expected_pos_ids)
11981234 ) ;
11991235 }
12001236 }
0 commit comments