Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,23 @@ fn encode_input(
strategy: TruncationStrategy::LongestFirst,
stride: 0,
});
if inputs.is_encoded() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be merged with the matchbellow? Since is_encoded is basically a match on EncodingInput::Vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

let seq_len = inputs.len();
if seq_len > max_input_length {
return Err(TextEmbeddingsError::Validation(format!(
"`inputs` must have less than {max_input_length} tokens. Given: {seq_len}"
)));
}
return inputs.try_into_encoding(position_offset);
}

let inputs: EncodeInput = match inputs {
EncodingInput::Single(s) => s.into(),
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
_ => Err(TextEmbeddingsError::Validation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, this branch cannot be reached. Can we merge the logic above here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give it a try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged the logic above per your recommendation which made this branch irrelevant.

"`inputs` must be a string or a tuple of strings".to_string(),
))?,
};

let encoding = tokenizer
.with_truncation(truncate_params)?
.encode(inputs, true)?;
Expand All @@ -143,7 +154,6 @@ fn encode_input(
}

metrics::histogram!("te_request_input_length", seq_len as f64);

Ok(Encoding {
input_ids: encoding.get_ids().to_vec(),
token_type_ids: encoding.get_type_ids().to_vec(),
Expand All @@ -163,13 +173,45 @@ pub struct Encoding {
pub enum EncodingInput {
Single(String),
Dual(String, String),
Vector(Vec<u32>),
}

impl EncodingInput {
fn is_empty(&self) -> bool {
match self {
EncodingInput::Single(s) => s.is_empty(),
EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(),
EncodingInput::Vector(v) => v.is_empty(),
}
}

fn is_encoded(&self) -> bool {
match self {
EncodingInput::Single(_) => false,
EncodingInput::Dual(_, _) => false,
EncodingInput::Vector(_) => true,
}
}

fn len(&self) -> usize {
match self {
EncodingInput::Single(s) => s.len(),
EncodingInput::Dual(s1, s2) => s1.len() + s2.len(),
EncodingInput::Vector(v) => v.len(),
}
}

fn try_into_encoding(&self, position_offset: usize) -> Result<Encoding, TextEmbeddingsError> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this needs to be a separate function. You can just take the logic here and add it to the match directly.

match self {
EncodingInput::Vector(v) => Ok(Encoding {
input_ids: v.clone(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be validation on wether v contains unvalid ids e.g. values that are outside of the vocab.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_type_ids: vec![0; v.len()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit brittle. In the future this could be false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids: (position_offset as u32..(v.len() + position_offset) as u32)
.collect::<Vec<_>>(),
}),
_ => Err(TextEmbeddingsError::Validation(
"`inputs` must be a vector of input_ids".to_string(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a logic error in our part and should not be a concern to the client.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this.

)),
}
}
}
Expand Down
15 changes: 8 additions & 7 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/// HTTP Server logic
use crate::http::types::{
EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse,
OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest,
PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence,
EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding,
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
Sequence,
};
use crate::{
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
Expand Down Expand Up @@ -455,7 +456,7 @@ async fn embed(
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");

let compute_chars = input.chars().count();
let compute_chars = input.count_chars();

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
Expand Down Expand Up @@ -499,7 +500,7 @@ async fn embed(
let mut compute_chars = 0;

for input in inputs {
compute_chars += input.chars().count();
compute_chars += input.count_chars();

let local_infer = infer.clone();
futures.push(async move {
Expand Down Expand Up @@ -591,7 +592,7 @@ async fn openai_embed(
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");

let compute_chars = input.chars().count();
let compute_chars = input.count_chars();

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
Expand Down Expand Up @@ -639,7 +640,7 @@ async fn openai_embed(
let mut compute_chars = 0;

for input in inputs {
compute_chars += input.chars().count();
compute_chars += input.count_chars();

let local_infer = infer.clone();
futures.push(async move {
Expand Down
29 changes: 27 additions & 2 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,36 @@ pub(crate) struct Rank {
#[derive(Serialize, ToSchema)]
pub(crate) struct RerankResponse(pub Vec<Rank>);

#[derive(Deserialize, ToSchema, Debug)]
#[serde(untagged)]
pub(crate) enum InputType {
SingleString(String),
SingleInt(u32),
VectorInt(Vec<u32>),
}
impl InputType {
pub(crate) fn count_chars(&self) -> usize {
match self {
InputType::SingleString(s) => s.chars().count(),
InputType::SingleInt(_) => 1,
InputType::VectorInt(v) => v.len(),
Comment on lines +264 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this how OpenAI counts when ids are given to the API? Or do they still count the chars by decoding the ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it and modify this per my findings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OlivierDehaene I looked through the source of openai-python v1.7.1 and couldn't find a reference to characters counting in the embeddings API.
Should count_chars return 0 for InputType::SingleInt, InputType::VectorInt for correctness?
LMK how you want to proceed.

}
}
}
impl From<InputType> for EncodingInput {
fn from(value: InputType) -> Self {
match value {
InputType::SingleString(s) => Self::Single(s),
InputType::SingleInt(i) => Self::Vector(vec![i]),
InputType::VectorInt(v) => Self::Vector(v),
}
}
}
#[derive(Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum Input {
Single(String),
Batch(Vec<String>),
Single(InputType),
Batch(Vec<InputType>),
}

#[derive(Deserialize, ToSchema)]
Expand Down
27 changes: 25 additions & 2 deletions router/tests/test_http_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ async fn test_embeddings() -> Result<()> {
let request = json!({
"inputs": "test"
});

let client = reqwest::Client::new();
let res = client
.post("http://0.0.0.0:8090/embed")
Expand All @@ -31,6 +30,18 @@ async fn test_embeddings() -> Result<()> {
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);

let test_tokens = vec![[101, 3231, 102]]; // tokenized "test"
let request = json!({"inputs": &test_tokens});
let res = client
.post("http://0.0.0.0:8090/embed")
.json(&request)
.send()
.await?;

let embeddings_single = res.json::<Vec<Vec<Score>>>().await?;
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);

let request = json!({
"inputs": vec!["test", "test", "test", "test", "test"],
});
Expand All @@ -41,10 +52,22 @@ async fn test_embeddings() -> Result<()> {
.json(&request)
.send()
.await?;

let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
for embeddings in &embeddings_batch {
assert_eq!(embeddings, &embeddings_single[0]);
}

let request =
json!({"inputs": &test_tokens.repeat(request["inputs"].as_array().unwrap().len())});
let res = client
.post("http://0.0.0.0:8090/embed")
.json(&request)
.send()
.await?;

let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
for embeddings in &embeddings_batch {
assert_eq!(embeddings, &embeddings_single[0]);
}
Expand Down