Skip to content

Commit 5e584b8

Browse files
authored
Add warmup function to speed up the first embedding (huggingface#8)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 54beebd commit 5e584b8

File tree

3 files changed

+57
-4
lines changed

3 files changed

+57
-4
lines changed

backends/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ text-embeddings-backend-python = { path = "python", optional = true }
1212
text-embeddings-backend-candle = { path = "candle", optional = true }
1313
tokio = { version = "^1.25", features = ["sync"] }
1414
tracing = "^0.1"
15-
15+
rand = "^0.8"
1616
[features]
1717
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
1818
python = ["dep:text-embeddings-backend-python"]

backends/src/lib.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
mod dtype;
2-
2+
use std::env;
33
use std::path::PathBuf;
44
use std::sync::Arc;
55
use std::thread::JoinHandle;
66
use std::time::{Duration, Instant};
77
use text_embeddings_backend_core::{Backend as CoreBackend, Predictions};
88
use tokio::sync::{mpsc, oneshot, watch};
99
use tracing::{instrument, Span};
10-
10+
use rand::Rng;
1111
pub use crate::dtype::DType;
1212
pub use text_embeddings_backend_core::{
1313
BackendError, Batch, Embedding, Embeddings, ModelType, Pool,
@@ -98,6 +98,54 @@ impl Backend {
9898
}
9999
}
100100

101+
#[instrument(skip(self))]
102+
pub async fn warmup(
103+
&self,
104+
max_input_length: u32,
105+
max_token: u32,
106+
) -> Result<(), BackendError> {
107+
let read_env_var = |key: &str, default: u32| -> u32 {
108+
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
109+
};
110+
// get all possible sequence lengths for prefill
111+
let bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
112+
let mut seq_lengths: Vec<u32> = (bucket_size..max_input_length+1).step_by(bucket_size as usize).collect();
113+
if let Some(&last) = seq_lengths.last() {
114+
if last < max_input_length {
115+
seq_lengths.push(max_input_length);
116+
}
117+
}
118+
for &length in seq_lengths.iter() {
119+
tracing::info!("warmup for length: {}", length);
120+
let batch = self.create_warmup_batch(length, max_token);
121+
match &self.model_type {
122+
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
123+
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
124+
};
125+
}
126+
Ok(())
127+
}
128+
129+
#[instrument(skip_all)]
130+
pub fn create_warmup_batch(
131+
&self,
132+
length: u32,
133+
max_token: u32,
134+
) -> Batch {
135+
let input_ids = (0..length).map(|_| rand::thread_rng().gen_range(0..max_token)).collect();
136+
let token_type_ids: Vec<u32> = vec![0; length as usize];
137+
let position_ids: Vec<u32> = (0..length).collect();
138+
let cumulative_seq_lengths: Vec<u32> = vec![0, length - 1];
139+
Batch {
140+
input_ids: input_ids,
141+
token_type_ids: token_type_ids,
142+
position_ids: position_ids,
143+
cumulative_seq_lengths: cumulative_seq_lengths,
144+
max_length: length,
145+
pooled_indices: vec![0],
146+
raw_indices: vec![],
147+
}
148+
}
101149
#[instrument(skip(self))]
102150
pub fn health_watcher(&self) -> watch::Receiver<bool> {
103151
self.health_receiver.clone()
@@ -106,7 +154,6 @@ impl Backend {
106154
#[instrument(skip_all)]
107155
pub async fn embed(&self, batch: Batch) -> Result<(Embeddings, Duration), BackendError> {
108156
let (sender, receiver) = oneshot::channel();
109-
110157
self.backend_sender
111158
.send(BackendCommand::Embed(batch, Span::current(), sender))
112159
.expect("No backend receiver. This is a bug.");

router/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ pub async fn run(
205205
.await
206206
.context("Model backend is not healthy")?;
207207

208+
// Warmup
209+
if backend.warmup(
210+
max_input_length as u32,
211+
max_batch_tokens as u32).await.is_ok() {
212+
tracing::info!("Succeed doing warmup");
213+
}
208214
let max_batch_requests = backend
209215
.max_batch_size
210216
.map(|s| {

0 commit comments

Comments
 (0)