11mod dtype;
2-
3- use std:: cmp:: { max, min} ;
2+ use std:: env;
43use std:: path:: PathBuf ;
54use std:: sync:: Arc ;
65use std:: thread:: JoinHandle ;
76use std:: time:: { Duration , Instant } ;
87use text_embeddings_backend_core:: { Backend as CoreBackend , Predictions } ;
98use tokio:: sync:: { mpsc, oneshot, watch} ;
109use tracing:: { instrument, Span } ;
11-
10+ use rand :: Rng ;
1211pub use crate :: dtype:: DType ;
1312pub use text_embeddings_backend_core:: {
1413 BackendError , Batch , Embedding , Embeddings , ModelType , Pool ,
@@ -68,62 +67,6 @@ impl Backend {
6867 } )
6968 }
7069
71- #[ instrument( skip( self ) ) ]
72- pub async fn warmup (
73- & self ,
74- max_input_length : usize ,
75- max_batch_tokens : usize ,
76- max_batch_requests : Option < usize > ,
77- ) -> Result < ( ) , BackendError > {
78- let mut input_ids = Vec :: with_capacity ( max_batch_tokens) ;
79- let mut token_type_ids = Vec :: with_capacity ( max_batch_tokens) ;
80- let mut position_ids = Vec :: with_capacity ( max_batch_tokens) ;
81-
82- let mut cumulative_seq_lengths = vec ! [ 0 ] ;
83- let mut pooled_indices = Vec :: new ( ) ;
84-
85- let mut i = 0_u32 ;
86- let mut remaining = max_batch_tokens;
87- let mut cumulative_length = 0 ;
88- let mut max_length = 0 ;
89-
90- while remaining > 0 {
91- let request_length = min ( remaining, max_input_length) ;
92- cumulative_length += request_length;
93- max_length = max ( max_length, request_length as u32 ) ;
94-
95- input_ids. extend ( vec ! [ 0 ; request_length] ) ;
96- token_type_ids. extend ( vec ! [ 0 ; request_length] ) ;
97- position_ids. extend ( ( 0 ..request_length as u32 ) . collect :: < Vec < u32 > > ( ) ) ;
98-
99- cumulative_seq_lengths. push ( cumulative_length as u32 ) ;
100- pooled_indices. push ( i) ;
101-
102- i += 1 ;
103- remaining = remaining. saturating_sub ( max_input_length) ;
104- if let Some ( max_batch_requests) = & max_batch_requests {
105- if i as usize == * max_batch_requests {
106- break ;
107- }
108- }
109- }
110-
111- let batch = Batch {
112- input_ids,
113- token_type_ids,
114- position_ids,
115- cumulative_seq_lengths,
116- max_length,
117- pooled_indices,
118- raw_indices : vec ! [ ] ,
119- } ;
120-
121- match & self . model_type {
122- ModelType :: Classifier => self . predict ( batch) . await . map ( |_| ( ) ) ,
123- ModelType :: Embedding ( _) => self . embed ( batch) . await . map ( |_| ( ) ) ,
124- }
125- }
126-
12770 #[ instrument( skip( self ) ) ]
12871 pub async fn health ( & self ) -> Result < ( ) , BackendError > {
12972 if * self . health_receiver . borrow ( ) {
@@ -158,6 +101,54 @@ impl Backend {
158101 }
159102 }
160103
104+ #[ instrument( skip( self ) ) ]
105+ pub async fn warmup (
106+ & self ,
107+ max_input_length : u32 ,
108+ max_token : u32 ,
109+ ) -> Result < ( ) , BackendError > {
110+ let read_env_var = |key : & str , default : u32 | -> u32 {
111+ env:: var ( key) . ok ( ) . map_or ( default, |value| value. parse :: < u32 > ( ) . unwrap ( ) )
112+ } ;
113+ // get all possible sequence lengths for prefill
114+ let bucket_size: u32 = read_env_var ( "PAD_SEQUENCE_TO_MULTIPLE_OF" , 128 ) ;
115+ let mut seq_lengths: Vec < u32 > = ( bucket_size..max_input_length+1 ) . step_by ( bucket_size as usize ) . collect ( ) ;
116+ if let Some ( & last) = seq_lengths. last ( ) {
117+ if last < max_input_length {
118+ seq_lengths. push ( max_input_length) ;
119+ }
120+ }
121+ for & length in seq_lengths. iter ( ) {
122+ tracing:: info!( "warmup for length: {}" , length) ;
123+ let batch = self . create_warmup_batch ( length, max_token) ;
124+ match & self . model_type {
125+ ModelType :: Classifier => self . predict ( batch) . await . map ( |_| ( ) ) ,
126+ ModelType :: Embedding ( _) => self . embed ( batch) . await . map ( |_| ( ) ) ,
127+ } ;
128+ }
129+ Ok ( ( ) )
130+ }
131+
132+ #[ instrument( skip_all) ]
133+ pub fn create_warmup_batch (
134+ & self ,
135+ length : u32 ,
136+ max_token : u32 ,
137+ ) -> Batch {
138+ let input_ids = ( 0 ..length) . map ( |_| rand:: thread_rng ( ) . gen_range ( 0 ..max_token) ) . collect ( ) ;
139+ let token_type_ids: Vec < u32 > = vec ! [ 0 ; length as usize ] ;
140+ let position_ids: Vec < u32 > = ( 0 ..length) . collect ( ) ;
141+ let cumulative_seq_lengths: Vec < u32 > = vec ! [ 0 , length - 1 ] ;
142+ Batch {
143+ input_ids : input_ids,
144+ token_type_ids : token_type_ids,
145+ position_ids : position_ids,
146+ cumulative_seq_lengths : cumulative_seq_lengths,
147+ max_length : length,
148+ pooled_indices : vec ! [ 0 ] ,
149+ raw_indices : vec ! [ ] ,
150+ }
151+ }
161152 #[ instrument( skip( self ) ) ]
162153 pub fn health_watcher ( & self ) -> watch:: Receiver < bool > {
163154 self . health_receiver . clone ( )
@@ -166,7 +157,6 @@ impl Backend {
166157 #[ instrument( skip_all) ]
167158 pub async fn embed ( & self , batch : Batch ) -> Result < ( Embeddings , Duration ) , BackendError > {
168159 let ( sender, receiver) = oneshot:: channel ( ) ;
169-
170160 self . backend_sender
171161 . try_send ( BackendCommand :: Embed ( batch, Span :: current ( ) , sender) )
172162 . expect ( "No backend receiver. This is a bug." ) ;
0 commit comments