Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion core/src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::{
Encoding as TokenizersEncoding, Token, TruncationDirection, TruncationParams,
TruncationStrategy,
};
pub use tokenizers::Encoding as RawEncoding;
use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy};
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -248,6 +252,25 @@ fn encode_input(
stride: 0,
});

let encoding: TokenizersEncoding = match inputs {
// encode input
EncodingInput::Single(s) => tokenizer
.with_truncation(truncate_params)?
.encode::<String>(s.into(), true)?,
EncodingInput::Dual(s1, s2) => {
tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2).into(), true)?
}
// input is encoded -> convert to tokenizers Encoding
EncodingInput::Vector(v) => {
let (tokens, token_type_ids) = create_tokens_and_type_ids(tokenizer, v)?;
let mut encoding = TokenizersEncoding::from_tokens(tokens, 0u32);
encoding.set_type_ids(token_type_ids);
encoding
}
};

let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
let seq_len = encoding.len();

Expand All @@ -256,8 +279,8 @@ fn encode_input(
"`inputs` must have less than {max_input_length} tokens. Given: {seq_len}"
)));
}

metrics::histogram!("te_request_input_length", seq_len as f64);
Ok(Encoding {

Ok(ValidEncoding {
input_ids: encoding.get_ids().to_vec(),
Expand All @@ -267,6 +290,36 @@ fn encode_input(
})
}

// decode ids into tokens and assign token_type_ids
fn create_tokens_and_type_ids(
tokenizer: &mut Tokenizer,
ids: Vec<u32>,
) -> Result<(Vec<Token>, Vec<u32>), TextEmbeddingsError> {
let decoded = tokenizer.decode(&ids, false)?;
let splits = decoded.split(' ').map(|x| x.to_owned()).collect::<Vec<_>>();
let offsets: Vec<_> = splits
.iter()
.scan(0, |state, x| {
let res = (*state, *state + x.len());
*state += x.len() + 1;
Some(res)
})
.collect();
let mut sep_flag = false;
let (tokens, token_type_ids): (Vec<_>, Vec<_>) = splits
.iter()
.zip(ids.iter())
.zip(offsets.iter())
.map(|((value, id), offset)| {
let token = Token::new(*id, value.clone(), *offset);
let token_type_id = if sep_flag { 1 } else { 0 };
sep_flag = sep_flag || value == "[SEP]";
(token, token_type_id)
})
.unzip();
Ok((tokens, token_type_ids))
}

#[derive(Debug)]
pub struct ValidEncoding {
pub input_ids: Vec<u32>,
Expand All @@ -278,13 +331,15 @@ pub struct ValidEncoding {
pub enum EncodingInput {
Single(String),
Dual(String, String),
Vector(Vec<u32>),
}

impl EncodingInput {
fn is_empty(&self) -> bool {
match self {
EncodingInput::Single(s) => s.is_empty(),
EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(),
EncodingInput::Vector(v) => v.is_empty(),
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/// HTTP Server logic
use crate::http::types::{
EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding,
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
Sequence,
DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse,
EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, OpenAICompatEmbedding,
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
Expand Down Expand Up @@ -474,7 +478,7 @@ async fn embed(
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");

let compute_chars = input.chars().count();
let compute_chars = input.count_chars();

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
Expand Down Expand Up @@ -529,7 +533,7 @@ async fn embed(
let mut compute_chars = 0;

for input in inputs {
compute_chars += input.chars().count();
compute_chars += input.count_chars();

let local_infer = infer.clone();
futures.push(async move {
Expand Down Expand Up @@ -923,7 +927,7 @@ async fn openai_embed(
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");

let compute_chars = input.chars().count();
let compute_chars = input.count_chars();

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
Expand Down Expand Up @@ -982,7 +986,7 @@ async fn openai_embed(
let mut compute_chars = 0;

for input in inputs {
compute_chars += input.chars().count();
compute_chars += input.count_chars();

let local_infer = infer.clone();
futures.push(async move {
Expand Down
29 changes: 27 additions & 2 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,36 @@ pub(crate) struct Rank {
#[derive(Serialize, ToSchema)]
pub(crate) struct RerankResponse(pub Vec<Rank>);

#[derive(Deserialize, ToSchema, Debug)]
#[serde(untagged)]
pub(crate) enum InputType {
SingleString(String),
SingleInt(u32),
VectorInt(Vec<u32>),
}
impl InputType {
pub(crate) fn count_chars(&self) -> usize {
match self {
InputType::SingleString(s) => s.chars().count(),
InputType::SingleInt(_) => 1,
InputType::VectorInt(v) => v.len(),
Comment on lines +264 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this how OpenAI counts when ids are given to the API? Or do they still count the chars by decoding the ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it and modify this per my findings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OlivierDehaene I looked through the source of openai-python v1.7.1 and couldn't find a reference to characters counting in the embeddings API.
Should count_chars return 0 for InputType::SingleInt, InputType::VectorInt for correctness?
LMK how you want to proceed.

}
}
}
impl From<InputType> for EncodingInput {
fn from(value: InputType) -> Self {
match value {
InputType::SingleString(s) => Self::Single(s),
InputType::SingleInt(i) => Self::Vector(vec![i]),
InputType::VectorInt(v) => Self::Vector(v),
}
}
}
#[derive(Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum Input {
Single(String),
Batch(Vec<String>),
Single(InputType),
Batch(Vec<InputType>),
}

#[derive(Deserialize, ToSchema)]
Expand Down
27 changes: 25 additions & 2 deletions router/tests/test_http_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ async fn test_embeddings() -> Result<()> {
let request = json!({
"inputs": "test"
});

let client = reqwest::Client::new();
let res = client
.post("http://0.0.0.0:8090/embed")
Expand All @@ -31,6 +30,18 @@ async fn test_embeddings() -> Result<()> {
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);

let test_tokens = vec![[101, 3231, 102]]; // tokenized "test"
let request = json!({"inputs": &test_tokens});
let res = client
.post("http://0.0.0.0:8090/embed")
.json(&request)
.send()
.await?;

let embeddings_single = res.json::<Vec<Vec<Score>>>().await?;
let matcher = YamlMatcher::<Vec<Vec<Score>>>::new();
insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher);

let request = json!({
"inputs": vec!["test", "test", "test", "test", "test"],
});
Expand All @@ -41,10 +52,22 @@ async fn test_embeddings() -> Result<()> {
.json(&request)
.send()
.await?;

let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
for embeddings in &embeddings_batch {
assert_eq!(embeddings, &embeddings_single[0]);
}

let request =
json!({"inputs": &test_tokens.repeat(request["inputs"].as_array().unwrap().len())});
let res = client
.post("http://0.0.0.0:8090/embed")
.json(&request)
.send()
.await?;

let embeddings_batch = res.json::<Vec<Vec<Score>>>().await?;
insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher);
for embeddings in &embeddings_batch {
assert_eq!(embeddings, &embeddings_single[0]);
}
Expand Down