11mod dtype;
2-
2+ use std :: env ;
33use std:: path:: PathBuf ;
44use std:: sync:: Arc ;
55use std:: thread:: JoinHandle ;
66use std:: time:: { Duration , Instant } ;
77use text_embeddings_backend_core:: { Backend as CoreBackend , Predictions } ;
88use tokio:: sync:: { mpsc, oneshot, watch} ;
99use tracing:: { instrument, Span } ;
10-
10+ use rand :: Rng ;
1111pub use crate :: dtype:: DType ;
1212pub use text_embeddings_backend_core:: {
1313 BackendError , Batch , Embedding , Embeddings , ModelType , Pool ,
@@ -98,6 +98,54 @@ impl Backend {
9898 }
9999 }
100100
101+ #[ instrument( skip( self ) ) ]
102+ pub async fn warmup (
103+ & self ,
104+ max_input_length : u32 ,
105+ max_token : u32 ,
106+ ) -> Result < ( ) , BackendError > {
107+ let read_env_var = |key : & str , default : u32 | -> u32 {
108+ env:: var ( key) . ok ( ) . map_or ( default, |value| value. parse :: < u32 > ( ) . unwrap ( ) )
109+ } ;
110+ // get all possible sequence lengths for prefill
111+ let bucket_size: u32 = read_env_var ( "PAD_SEQUENCE_TO_MULTIPLE_OF" , 128 ) ;
112+ let mut seq_lengths: Vec < u32 > = ( bucket_size..max_input_length+1 ) . step_by ( bucket_size as usize ) . collect ( ) ;
113+ if let Some ( & last) = seq_lengths. last ( ) {
114+ if last < max_input_length {
115+ seq_lengths. push ( max_input_length) ;
116+ }
117+ }
118+ for & length in seq_lengths. iter ( ) {
119+ tracing:: info!( "warmup for length: {}" , length) ;
120+ let batch = self . create_warmup_batch ( length, max_token) ;
121+ match & self . model_type {
122+ ModelType :: Classifier => self . predict ( batch) . await . map ( |_| ( ) ) ,
123+ ModelType :: Embedding ( _) => self . embed ( batch) . await . map ( |_| ( ) ) ,
124+ } ;
125+ }
126+ Ok ( ( ) )
127+ }
128+
129+ #[ instrument( skip_all) ]
130+ pub fn create_warmup_batch (
131+ & self ,
132+ length : u32 ,
133+ max_token : u32 ,
134+ ) -> Batch {
135+ let input_ids = ( 0 ..length) . map ( |_| rand:: thread_rng ( ) . gen_range ( 0 ..max_token) ) . collect ( ) ;
136+ let token_type_ids: Vec < u32 > = vec ! [ 0 ; length as usize ] ;
137+ let position_ids: Vec < u32 > = ( 0 ..length) . collect ( ) ;
138+ let cumulative_seq_lengths: Vec < u32 > = vec ! [ 0 , length - 1 ] ;
139+ Batch {
140+ input_ids : input_ids,
141+ token_type_ids : token_type_ids,
142+ position_ids : position_ids,
143+ cumulative_seq_lengths : cumulative_seq_lengths,
144+ max_length : length,
145+ pooled_indices : vec ! [ 0 ] ,
146+ raw_indices : vec ! [ ] ,
147+ }
148+ }
101149 #[ instrument( skip( self ) ) ]
102150 pub fn health_watcher ( & self ) -> watch:: Receiver < bool > {
103151 self . health_receiver . clone ( )
@@ -106,7 +154,6 @@ impl Backend {
106154 #[ instrument( skip_all) ]
107155 pub async fn embed ( & self , batch : Batch ) -> Result < ( Embeddings , Duration ) , BackendError > {
108156 let ( sender, receiver) = oneshot:: channel ( ) ;
109-
110157 self . backend_sender
111158 . send ( BackendCommand :: Embed ( batch, Span :: current ( ) , sender) )
112159 . expect ( "No backend receiver. This is a bug." ) ;
0 commit comments