Skip to content

Commit f135641

Browse files
committed
Add warmup function to speed up the first embedding (huggingface#8)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 2b4da4f commit f135641

File tree

3 files changed

+59
-68
lines changed

3 files changed

+59
-68
lines changed

backends/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ clap = { workspace = true, optional = true }
1010
text-embeddings-backend-core = { path = "core" }
1111
text-embeddings-backend-python = { path = "python", optional = true }
1212
text-embeddings-backend-candle = { path = "candle", optional = true }
13-
tokio = { workspace = true }
14-
tracing = { workspace = true }
13+
tokio = { version = "^1.25", features = ["sync"] }
14+
tracing = "^0.1"
15+
rand = "^0.8"
1516

1617
[features]
1718
clap = ["dep:clap", "text-embeddings-backend-core/clap"]

backends/src/lib.rs

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
mod dtype;
2-
3-
use std::cmp::{max, min};
2+
use std::env;
43
use std::path::PathBuf;
54
use std::sync::Arc;
65
use std::thread::JoinHandle;
76
use std::time::{Duration, Instant};
87
use text_embeddings_backend_core::{Backend as CoreBackend, Predictions};
98
use tokio::sync::{mpsc, oneshot, watch};
109
use tracing::{instrument, Span};
11-
10+
use rand::Rng;
1211
pub use crate::dtype::DType;
1312
pub use text_embeddings_backend_core::{
1413
BackendError, Batch, Embedding, Embeddings, ModelType, Pool,
@@ -68,62 +67,6 @@ impl Backend {
6867
})
6968
}
7069

71-
#[instrument(skip(self))]
72-
pub async fn warmup(
73-
&self,
74-
max_input_length: usize,
75-
max_batch_tokens: usize,
76-
max_batch_requests: Option<usize>,
77-
) -> Result<(), BackendError> {
78-
let mut input_ids = Vec::with_capacity(max_batch_tokens);
79-
let mut token_type_ids = Vec::with_capacity(max_batch_tokens);
80-
let mut position_ids = Vec::with_capacity(max_batch_tokens);
81-
82-
let mut cumulative_seq_lengths = vec![0];
83-
let mut pooled_indices = Vec::new();
84-
85-
let mut i = 0_u32;
86-
let mut remaining = max_batch_tokens;
87-
let mut cumulative_length = 0;
88-
let mut max_length = 0;
89-
90-
while remaining > 0 {
91-
let request_length = min(remaining, max_input_length);
92-
cumulative_length += request_length;
93-
max_length = max(max_length, request_length as u32);
94-
95-
input_ids.extend(vec![0; request_length]);
96-
token_type_ids.extend(vec![0; request_length]);
97-
position_ids.extend((0..request_length as u32).collect::<Vec<u32>>());
98-
99-
cumulative_seq_lengths.push(cumulative_length as u32);
100-
pooled_indices.push(i);
101-
102-
i += 1;
103-
remaining = remaining.saturating_sub(max_input_length);
104-
if let Some(max_batch_requests) = &max_batch_requests {
105-
if i as usize == *max_batch_requests {
106-
break;
107-
}
108-
}
109-
}
110-
111-
let batch = Batch {
112-
input_ids,
113-
token_type_ids,
114-
position_ids,
115-
cumulative_seq_lengths,
116-
max_length,
117-
pooled_indices,
118-
raw_indices: vec![],
119-
};
120-
121-
match &self.model_type {
122-
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
123-
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
124-
}
125-
}
126-
12770
#[instrument(skip(self))]
12871
pub async fn health(&self) -> Result<(), BackendError> {
12972
if *self.health_receiver.borrow() {
@@ -158,6 +101,54 @@ impl Backend {
158101
}
159102
}
160103

104+
#[instrument(skip(self))]
105+
pub async fn warmup(
106+
&self,
107+
max_input_length: u32,
108+
max_token: u32,
109+
) -> Result<(), BackendError> {
110+
let read_env_var = |key: &str, default: u32| -> u32 {
111+
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
112+
};
113+
// get all possible sequence lengths for prefill
114+
let bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
115+
let mut seq_lengths: Vec<u32> = (bucket_size..max_input_length+1).step_by(bucket_size as usize).collect();
116+
if let Some(&last) = seq_lengths.last() {
117+
if last < max_input_length {
118+
seq_lengths.push(max_input_length);
119+
}
120+
}
121+
for &length in seq_lengths.iter() {
122+
tracing::info!("warmup for length: {}", length);
123+
let batch = self.create_warmup_batch(length, max_token);
124+
match &self.model_type {
125+
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
126+
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
127+
};
128+
}
129+
Ok(())
130+
}
131+
132+
#[instrument(skip_all)]
133+
pub fn create_warmup_batch(
134+
&self,
135+
length: u32,
136+
max_token: u32,
137+
) -> Batch {
138+
let input_ids = (0..length).map(|_| rand::thread_rng().gen_range(0..max_token)).collect();
139+
let token_type_ids: Vec<u32> = vec![0; length as usize];
140+
let position_ids: Vec<u32> = (0..length).collect();
141+
let cumulative_seq_lengths: Vec<u32> = vec![0, length - 1];
142+
Batch {
143+
input_ids: input_ids,
144+
token_type_ids: token_type_ids,
145+
position_ids: position_ids,
146+
cumulative_seq_lengths: cumulative_seq_lengths,
147+
max_length: length,
148+
pooled_indices: vec![0],
149+
raw_indices: vec![],
150+
}
151+
}
161152
#[instrument(skip(self))]
162153
pub fn health_watcher(&self) -> watch::Receiver<bool> {
163154
self.health_receiver.clone()
@@ -166,7 +157,6 @@ impl Backend {
166157
#[instrument(skip_all)]
167158
pub async fn embed(&self, batch: Batch) -> Result<(Embeddings, Duration), BackendError> {
168159
let (sender, receiver) = oneshot::channel();
169-
170160
self.backend_sender
171161
.try_send(BackendCommand::Embed(batch, Span::current(), sender))
172162
.expect("No backend receiver. This is a bug.");

router/src/lib.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,12 @@ pub async fn run(
262262
.await
263263
.context("Model backend is not healthy")?;
264264

265-
if !backend.padded_model {
266-
tracing::info!("Warming up model");
267-
backend
268-
.warmup(max_input_length, max_batch_tokens, max_batch_requests)
269-
.await
270-
.context("Model backend is not healthy")?;
265+
266+
// Warmup
267+
if backend.warmup(
268+
max_input_length as u32,
269+
max_batch_tokens as u32).await.is_ok() {
270+
tracing::info!("Succeed doing warmup");
271271
}
272272

273273
let max_batch_requests = backend

0 commit comments

Comments
 (0)