Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async fn predict(
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner(
inputs,
req.truncate,
req.truncate.unwrap_or(info.auto_truncate),
req.raw_scores,
infer.0,
info.0,
Expand Down Expand Up @@ -208,7 +208,7 @@ async fn predict(
let local_info = info.clone();
futures.push(predict_inner(
input,
req.truncate,
req.truncate.unwrap_or(info.auto_truncate),
req.raw_scores,
local_infer.0,
local_info.0,
Expand Down Expand Up @@ -370,7 +370,7 @@ async fn rerank(
futures.push(rerank_inner(
req.query.clone(),
text.clone(),
req.truncate,
req.truncate.unwrap_or(info.auto_truncate),
req.raw_scores,
local_infer.0,
))
Expand Down Expand Up @@ -478,7 +478,12 @@ async fn embed(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_pooled(input, req.truncate, req.normalize, permit)
.embed_pooled(
input,
req.truncate.unwrap_or(info.auto_truncate),
req.normalize,
permit,
)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -531,11 +536,12 @@ async fn embed(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer
.embed_pooled(input, req.truncate, req.normalize, permit)
.embed_pooled(input, truncate, req.normalize, permit)
.await
})
}
Expand Down Expand Up @@ -634,7 +640,7 @@ async fn embed_sparse(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_sparse(input, req.truncate, permit)
.embed_sparse(input, req.truncate.unwrap_or(info.auto_truncate), permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -687,12 +693,11 @@ async fn embed_sparse(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
let response = local_infer
.embed_sparse(input, req.truncate, permit)
.await?;
let response = local_infer.embed_sparse(input, truncate, permit).await?;
Ok((sparsify(response.results), response.metadata))
})
}
Expand Down Expand Up @@ -782,7 +787,7 @@ async fn embed_all(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_all(input, req.truncate, permit)
.embed_all(input, req.truncate.unwrap_or(info.auto_truncate), permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -835,10 +840,11 @@ async fn embed_all(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer.embed_all(input, req.truncate, permit).await
local_infer.embed_all(input, truncate, permit).await
})
}
let results = join_all(futures)
Expand Down
21 changes: 10 additions & 11 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,8 @@ impl<'__s> ToSchema<'__s> for PredictInput {
#[derive(Deserialize, ToSchema)]
pub(crate) struct PredictRequest {
pub inputs: PredictInput,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub raw_scores: bool,
Expand Down Expand Up @@ -226,8 +225,8 @@ pub(crate) struct RerankRequest {
#[schema(example = json!(["Deep Learning is ..."]))]
pub texts: Vec<String>,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub raw_scores: bool,
Expand Down Expand Up @@ -322,8 +321,8 @@ pub(crate) struct OpenAICompatResponse {
pub(crate) struct EmbedRequest {
pub inputs: Input,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
#[serde(default = "default_normalize")]
#[schema(default = "true", example = "true")]
pub normalize: bool,
Expand All @@ -341,8 +340,8 @@ pub(crate) struct EmbedResponse(pub Vec<Vec<f32>>);
pub(crate) struct EmbedSparseRequest {
pub inputs: Input,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
}

#[derive(Serialize, ToSchema)]
Expand All @@ -358,8 +357,8 @@ pub(crate) struct EmbedSparseResponse(pub Vec<Vec<SparseValue>>);
pub(crate) struct EmbedAllRequest {
pub inputs: Input,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
}

#[derive(Serialize, ToSchema)]
Expand Down
3 changes: 3 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub async fn run(
max_batch_tokens: usize,
max_batch_requests: Option<usize>,
max_client_batch_size: usize,
auto_truncate: bool,
hf_api_token: Option<String>,
hostname: Option<String>,
port: u16,
Expand Down Expand Up @@ -236,6 +237,7 @@ pub async fn run(
tokenization_workers,
max_batch_requests,
max_client_batch_size,
auto_truncate,
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
docker_label: option_env!("DOCKER_LABEL"),
Expand Down Expand Up @@ -428,6 +430,7 @@ pub struct Info {
pub max_batch_requests: Option<usize>,
#[cfg_attr(feature = "http", schema(example = "32"))]
pub max_client_batch_size: usize,
pub auto_truncate: bool,
#[cfg_attr(feature = "http", schema(example = "4"))]
pub tokenization_workers: usize,
/// Router Info
Expand Down
9 changes: 8 additions & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ struct Args {
#[clap(default_value = "32", long, env)]
max_client_batch_size: usize,

/// Automatically truncate inputs that are longer than the maximum supported size
///
/// Unused for gRPC servers
#[clap(long, env)]
auto_truncate: bool,

/// Your HuggingFace hub token
#[clap(long, env)]
#[redact(partial)]
Expand Down Expand Up @@ -117,7 +123,7 @@ struct Args {
#[clap(long, env)]
otlp_endpoint: Option<String>,

// Unused for gRPC servers
/// Unused for gRPC servers
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
}
Expand All @@ -143,6 +149,7 @@ async fn main() -> Result<()> {
args.max_batch_tokens,
args.max_batch_requests,
args.max_client_batch_size,
args.auto_truncate,
args.hf_api_token,
Some(args.hostname),
args.port,
Expand Down