@@ -7,6 +7,8 @@ use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
77use tokio:: sync:: oneshot;
88use tracing:: { instrument, Span } ;
99
10+ static MAX_CHAR_MULTIPLIER : usize = 250 ;
11+
1012/// Validation
1113#[ derive( Debug , Clone ) ]
1214pub struct Tokenization {
@@ -215,6 +217,7 @@ fn tokenizer_worker(
215217 let _ = response_tx. send ( tokenize_input (
216218 inputs,
217219 add_special_tokens,
220+ max_input_length,
218221 None ,
219222 default_prompt_clone,
220223 prompt_name,
@@ -269,9 +272,11 @@ fn prepare_pre_prompt(
269272 Ok ( pre_prompt)
270273}
271274
275+ #[ allow( clippy:: too_many_arguments) ]
272276fn tokenize_input (
273277 inputs : EncodingInput ,
274278 add_special_tokens : bool ,
279+ max_input_length : usize ,
275280 truncate_params : Option < TruncationParams > ,
276281 default_prompt : Option < String > ,
277282 prompt_name : Option < String > ,
@@ -280,6 +285,14 @@ fn tokenize_input(
280285) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
281286 let pre_prompt = prepare_pre_prompt ( default_prompt, prompt_name, prompts) ?;
282287
288+ let input_chars = inputs. count_chars ( ) ;
289+ let limit = max_input_length * MAX_CHAR_MULTIPLIER ;
290+ if input_chars > limit {
291+ return Err ( TextEmbeddingsError :: Validation ( format ! (
292+ "`inputs` must have less than {limit} characters. Given: {input_chars}"
293+ ) ) ) ;
294+ }
295+
283296 let encoding = match inputs {
284297 // encode input
285298 EncodingInput :: Single ( s) => {
@@ -359,6 +372,7 @@ fn encode_input(
359372 let ( _, encoding) = tokenize_input (
360373 inputs,
361374 true ,
375+ max_input_length,
362376 truncate_params,
363377 default_prompt,
364378 prompt_name,
@@ -404,6 +418,14 @@ impl EncodingInput {
404418 EncodingInput :: Ids ( v) => v. is_empty ( ) ,
405419 }
406420 }
421+
422+ fn count_chars ( & self ) -> usize {
423+ match self {
424+ EncodingInput :: Single ( s) => s. chars ( ) . count ( ) ,
425+ EncodingInput :: Dual ( s1, s2) => s1. chars ( ) . count ( ) + s2. chars ( ) . count ( ) ,
426+ EncodingInput :: Ids ( v) => v. len ( ) ,
427+ }
428+ }
407429}
408430
409431impl From < String > for EncodingInput {
0 commit comments