-
Notifications
You must be signed in to change notification settings - Fork 325
Input Types Compatibility with OpenAI's API #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f5cad1f
32dfd0a
3f32209
2ee3044
7292326
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,12 +125,23 @@ fn encode_input( | |
| strategy: TruncationStrategy::LongestFirst, | ||
| stride: 0, | ||
| }); | ||
| if inputs.is_encoded() { | ||
| let seq_len = inputs.len(); | ||
| if seq_len > max_input_length { | ||
| return Err(TextEmbeddingsError::Validation(format!( | ||
| "`inputs` must have less than {max_input_length} tokens. Given: {seq_len}" | ||
| ))); | ||
| } | ||
| return inputs.try_into_encoding(position_offset); | ||
| } | ||
|
|
||
| let inputs: EncodeInput = match inputs { | ||
| EncodingInput::Single(s) => s.into(), | ||
| EncodingInput::Dual(s1, s2) => (s1, s2).into(), | ||
| _ => Err(TextEmbeddingsError::Validation( | ||
|
||
| "`inputs` must be a string or a tuple of strings".to_string(), | ||
| ))?, | ||
| }; | ||
|
|
||
| let encoding = tokenizer | ||
| .with_truncation(truncate_params)? | ||
| .encode(inputs, true)?; | ||
|
|
@@ -143,7 +154,6 @@ fn encode_input( | |
| } | ||
|
|
||
| metrics::histogram!("te_request_input_length", seq_len as f64); | ||
|
|
||
| Ok(Encoding { | ||
| input_ids: encoding.get_ids().to_vec(), | ||
| token_type_ids: encoding.get_type_ids().to_vec(), | ||
|
|
@@ -163,13 +173,45 @@ pub struct Encoding { | |
| pub enum EncodingInput { | ||
| Single(String), | ||
| Dual(String, String), | ||
| Vector(Vec<u32>), | ||
| } | ||
|
|
||
| impl EncodingInput { | ||
| fn is_empty(&self) -> bool { | ||
| match self { | ||
| EncodingInput::Single(s) => s.is_empty(), | ||
| EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(), | ||
| EncodingInput::Vector(v) => v.is_empty(), | ||
| } | ||
| } | ||
|
|
||
| fn is_encoded(&self) -> bool { | ||
| match self { | ||
| EncodingInput::Single(_) => false, | ||
| EncodingInput::Dual(_, _) => false, | ||
| EncodingInput::Vector(_) => true, | ||
| } | ||
| } | ||
|
|
||
| fn len(&self) -> usize { | ||
| match self { | ||
| EncodingInput::Single(s) => s.len(), | ||
| EncodingInput::Dual(s1, s2) => s1.len() + s2.len(), | ||
| EncodingInput::Vector(v) => v.len(), | ||
| } | ||
| } | ||
|
|
||
| fn try_into_encoding(&self, position_offset: usize) -> Result<Encoding, TextEmbeddingsError> { | ||
|
||
| match self { | ||
| EncodingInput::Vector(v) => Ok(Encoding { | ||
| input_ids: v.clone(), | ||
|
||
| token_type_ids: vec![0; v.len()], | ||
|
||
| position_ids: (position_offset as u32..(v.len() + position_offset) as u32) | ||
| .collect::<Vec<_>>(), | ||
| }), | ||
| _ => Err(TextEmbeddingsError::Validation( | ||
| "`inputs` must be a vector of input_ids".to_string(), | ||
|
||
| )), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -250,11 +250,36 @@ pub(crate) struct Rank { | |
| #[derive(Serialize, ToSchema)] | ||
| pub(crate) struct RerankResponse(pub Vec<Rank>); | ||
|
|
||
| #[derive(Deserialize, ToSchema, Debug)] | ||
| #[serde(untagged)] | ||
| pub(crate) enum InputType { | ||
| SingleString(String), | ||
| SingleInt(u32), | ||
| VectorInt(Vec<u32>), | ||
| } | ||
| impl InputType { | ||
| pub(crate) fn count_chars(&self) -> usize { | ||
| match self { | ||
| InputType::SingleString(s) => s.chars().count(), | ||
| InputType::SingleInt(_) => 1, | ||
| InputType::VectorInt(v) => v.len(), | ||
|
Comment on lines
+264
to
+265
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this how OpenAI counts when ids are given to the API? Or do they still count the chars by decoding the ids?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll look into it and modify this per my findings.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @OlivierDehaene I looked through the source of |
||
| } | ||
| } | ||
| } | ||
| impl From<InputType> for EncodingInput { | ||
| fn from(value: InputType) -> Self { | ||
| match value { | ||
| InputType::SingleString(s) => Self::Single(s), | ||
| InputType::SingleInt(i) => Self::Vector(vec![i]), | ||
| InputType::VectorInt(v) => Self::Vector(v), | ||
| } | ||
| } | ||
| } | ||
| #[derive(Deserialize, ToSchema)] | ||
| #[serde(untagged)] | ||
| pub(crate) enum Input { | ||
| Single(String), | ||
| Batch(Vec<String>), | ||
| Single(InputType), | ||
| Batch(Vec<InputType>), | ||
| } | ||
|
|
||
| #[derive(Deserialize, ToSchema)] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be merged with the
matchbellow? Sinceis_encodedis basically a match onEncodingInput::Vector.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done