From ab7100a8b42d3b9e5d8f0e3e0ce38244389242dc Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:32:51 +0200 Subject: [PATCH 1/3] feat: add default prompts --- core/src/download.rs | 7 +++ core/src/infer.rs | 15 ++++-- core/src/tokenization.rs | 99 +++++++++++++++++++++++++++++++++++---- router/src/http/server.rs | 65 +++++++++++++++++++++---- router/src/http/types.rs | 8 ++++ router/src/lib.rs | 35 +++++++++++++- router/src/main.rs | 10 ++++ 7 files changed, 217 insertions(+), 22 deletions(-) diff --git a/core/src/download.rs b/core/src/download.rs index 24dc041f..08220752 100644 --- a/core/src/download.rs +++ b/core/src/download.rs @@ -102,3 +102,10 @@ pub async fn download_st_config(api: &ApiRepo) -> Result { Err(err) } + +#[instrument(skip_all)] +pub async fn download_new_st_config(api: &ApiRepo) -> Result { + tracing::info!("Downloading `config_sentence_transformers.json`"); + let pool_config_path = api.get("config_sentence_transformers.json").await?; + Ok(pool_config_path) +} diff --git a/core/src/infer.rs b/core/src/infer.rs index 7e6a4629..6c973249 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -60,9 +60,10 @@ impl Infer { &self, inputs: I, add_special_tokens: bool, + prompt_name: Option, ) -> Result { self.tokenization - .tokenize(inputs.into(), add_special_tokens) + .tokenize(inputs.into(), add_special_tokens, prompt_name) .await .map_err(|err| { let counter = metrics::counter!("te_request_failure", "err" => "tokenization"); @@ -119,6 +120,7 @@ impl Infer { inputs: I, truncate: bool, truncation_direction: TruncationDirection, + prompt_name: Option, permit: OwnedSemaphorePermit, ) -> Result { let start_time = Instant::now(); @@ -138,6 +140,7 @@ impl Infer { inputs, truncate, truncation_direction, + prompt_name, false, &start_time, permit, @@ -172,6 +175,7 @@ impl Infer { inputs: I, truncate: bool, truncation_direction: TruncationDirection, + prompt_name: Option, permit: OwnedSemaphorePermit, ) -> Result { let start_time = Instant::now(); @@ -191,6 +195,7 @@ impl Infer { inputs, truncate, truncation_direction, + prompt_name, true, &start_time, permit, @@ -225,6 +230,7 @@ impl Infer { inputs: I, truncate: bool, truncation_direction: TruncationDirection, + prompt_name: Option, normalize: bool, permit: OwnedSemaphorePermit, ) -> Result { @@ -245,6 +251,7 @@ impl Infer { inputs, truncate, truncation_direction, + prompt_name, true, &start_time, permit, @@ -290,11 +297,13 @@ impl Infer { Ok(response) } + #[allow(clippy::too_many_arguments)] async fn embed + std::fmt::Debug>( &self, inputs: I, truncate: bool, truncation_direction: TruncationDirection, + prompt_name: Option, pooling: bool, start_time: &Instant, _permit: OwnedSemaphorePermit, @@ -315,7 +324,7 @@ impl Infer { // Tokenization let encoding = self .tokenization - .encode(inputs.into(), truncate, truncation_direction) + .encode(inputs.into(), truncate, truncation_direction, prompt_name) .await .map_err(|err| { let counter = metrics::counter!("te_request_failure", "err" => "tokenization"); @@ -381,7 +390,7 @@ impl Infer { // Tokenization let encoding = self .tokenization - .encode(inputs.into(), truncate, truncation_direction) + .encode(inputs.into(), truncate, truncation_direction, None) .await .map_err(|err| { let counter = metrics::counter!("te_request_failure", "err" => "tokenization"); diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index ae281691..d2886343 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -1,5 +1,6 @@ /// Payload tokenization logic use crate::TextEmbeddingsError; +use std::collections::HashMap; use tokenizers::tokenizer::Tokenizer; pub use tokenizers::Encoding as RawEncoding; use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy}; @@ -19,6 +20,8 @@ impl Tokenization { tokenizer: Tokenizer, max_input_length: usize, position_offset: usize, + default_prompt_name: Option, + prompts: Option>, ) -> Self { tracing::info!("Starting {workers} tokenization workers"); @@ -29,12 +32,16 @@ impl Tokenization { for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let receiver_clone = receiver.clone(); + let default_prompt_name_clone = default_prompt_name.clone(); + let prompts_clone = prompts.clone(); // Spawn worker std::thread::spawn(move || { tokenizer_worker( tokenizer_clone, max_input_length, position_offset, + default_prompt_name_clone, + prompts_clone, receiver_clone, ) }); @@ -49,6 +56,7 @@ impl Tokenization { inputs: EncodingInput, truncate: bool, truncation_direction: TruncationDirection, + prompt_name: Option, ) -> Result { // Check if inputs is empty if inputs.is_empty() { @@ -66,6 +74,7 @@ impl Tokenization { inputs, truncate, truncation_direction, + prompt_name, response_sender, Span::current(), )) @@ -82,6 +91,7 @@ impl Tokenization { &self, inputs: EncodingInput, add_special_tokens: bool, + prompt_name: Option, ) -> Result { // Check if inputs is empty if inputs.is_empty() { @@ -98,6 +108,7 @@ impl Tokenization { .send(TokenizerRequest::Tokenize( inputs, add_special_tokens, + prompt_name, response_sender, Span::current(), )) @@ -147,6 +158,8 @@ fn tokenizer_worker( mut tokenizer: Tokenizer, max_input_length: usize, position_offset: usize, + default_prompt_name: Option, + prompts: Option>, receiver: async_channel::Receiver, ) { // Loop over requests @@ -156,10 +169,13 @@ fn tokenizer_worker( inputs, truncate, truncation_direction, + prompt_name, response_tx, parent_span, ) => { parent_span.in_scope(|| { + let prompt_name = prompt_name.or(default_prompt_name.clone()); + if !response_tx.is_closed() { // It's possible that the user dropped its request resulting in a send error. // We just discard the error @@ -169,12 +185,22 @@ fn tokenizer_worker( truncation_direction, max_input_length, position_offset, + prompt_name, + prompts.as_ref(), &mut tokenizer, )); } }) } - TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => { + TokenizerRequest::Tokenize( + inputs, + add_special_tokens, + prompt_name, + response_tx, + parent_span, + ) => { + let prompt_name = prompt_name.or(default_prompt_name.clone()); + parent_span.in_scope(|| { if !response_tx.is_closed() { // It's possible that the user dropped its request resulting in a send error. @@ -183,6 +209,8 @@ fn tokenizer_worker( inputs, add_special_tokens, None, + prompt_name, + prompts.as_ref(), &mut tokenizer, )); } @@ -216,36 +244,80 @@ fn tokenize_input( inputs: EncodingInput, add_special_tokens: bool, truncate_params: Option, + prompt_name: Option, + prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, ) -> Result { + let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() { + match prompts { + None => { + return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration"))); + } + Some(prompts) if !prompts.contains_key(prompt_name) => { + return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys()))); + } + Some(prompts) => prompts.get(prompt_name).cloned(), + } + } else { + None + }; + let encoding = match inputs { // encode input - EncodingInput::Single(s) => tokenizer - .with_truncation(truncate_params)? - .encode::(s, add_special_tokens)?, + EncodingInput::Single(s) => { + let s = if let Some(mut pre_prompt) = pre_prompt { + pre_prompt.push_str(&s); + pre_prompt + } else { + s + }; + + tokenizer + .with_truncation(truncate_params)? + .encode::(s, add_special_tokens)? + } EncodingInput::Dual(s1, s2) => { + if pre_prompt.is_some() { + return Err(TextEmbeddingsError::Validation( + "`prompt_name` cannot be set with dual inputs".to_string(), + )); + } + tokenizer .with_truncation(truncate_params)? .encode::<(String, String)>((s1, s2), add_special_tokens)? } // input is encoded -> convert to tokenizers Encoding EncodingInput::Ids(ids) => { - let text = tokenizer.decode(&ids, false)?; - tokenizer - .with_truncation(truncate_params)? - .encode::(text, false)? + if let Some(mut pre_prompt) = pre_prompt { + let text = tokenizer.decode(&ids, true)?; + pre_prompt.push_str(&text); + + tokenizer + .with_truncation(truncate_params)? + .encode::(pre_prompt, false)? + } else { + let text = tokenizer.decode(&ids, false)?; + + tokenizer + .with_truncation(truncate_params)? + .encode::(text, false)? + } } }; Ok(encoding) } /// Get input length and optionally truncate it +#[allow(clippy::too_many_arguments)] fn encode_input( inputs: EncodingInput, truncate: bool, truncation_direction: TruncationDirection, max_input_length: usize, position_offset: usize, + prompt_name: Option, + prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, ) -> Result { // Default truncation params @@ -256,7 +328,14 @@ fn encode_input( stride: 0, }); - let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?; + let encoding = tokenize_input( + inputs, + true, + truncate_params, + prompt_name, + prompts, + tokenizer, + )?; let seq_len = encoding.len(); if seq_len > max_input_length { @@ -315,12 +394,14 @@ enum TokenizerRequest { EncodingInput, bool, TruncationDirection, + Option, oneshot::Sender>, Span, ), Tokenize( EncodingInput, bool, + Option, oneshot::Sender>, Span, ), diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 3e09101d..6adca38f 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -500,6 +500,7 @@ async fn embed( input, truncate, req.truncation_direction, + req.prompt_name, req.normalize, permit, ) @@ -560,6 +561,7 @@ async fn embed( compute_chars += input.count_chars(); let local_infer = infer.clone(); + let prompt_name = req.prompt_name.clone(); futures.push(async move { let permit = local_infer.acquire_permit().await; local_infer @@ -567,6 +569,7 @@ async fn embed( input, truncate, req.truncation_direction, + prompt_name, req.normalize, permit, ) @@ -671,7 +674,13 @@ async fn embed_sparse( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_sparse(input, truncate, req.truncation_direction, permit) + .embed_sparse( + input, + truncate, + req.truncation_direction, + req.prompt_name, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -729,10 +738,17 @@ async fn embed_sparse( compute_chars += input.count_chars(); let local_infer = infer.clone(); + let prompt_name = req.prompt_name.clone(); futures.push(async move { let permit = local_infer.acquire_permit().await; let response = local_infer - .embed_sparse(input, truncate, req.truncation_direction, permit) + .embed_sparse( + input, + truncate, + req.truncation_direction, + prompt_name, + permit, + ) .await?; Ok((sparsify(response.results), response.metadata)) }) @@ -827,7 +843,13 @@ async fn embed_all( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_all(input, truncate, req.truncation_direction, permit) + .embed_all( + input, + truncate, + req.truncation_direction, + req.prompt_name, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -885,10 +907,17 @@ async fn embed_all( compute_chars += input.count_chars(); let local_infer = infer.clone(); + let prompt_name = req.prompt_name.clone(); futures.push(async move { let permit = local_infer.acquire_permit().await; local_infer - .embed_all(input, truncate, req.truncation_direction, permit) + .embed_all( + input, + truncate, + req.truncation_direction, + prompt_name, + permit, + ) .await }) } @@ -997,7 +1026,14 @@ async fn openai_embed( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_pooled(input, truncate, TruncationDirection::Right, true, permit) + .embed_pooled( + input, + truncate, + TruncationDirection::Right, + None, + true, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -1063,7 +1099,14 @@ async fn openai_embed( futures.push(async move { let permit = local_infer.acquire_permit().await; local_infer - .embed_pooled(input, truncate, TruncationDirection::Right, true, permit) + .embed_pooled( + input, + truncate, + TruncationDirection::Right, + None, + true, + permit, + ) .await }) } @@ -1148,9 +1191,12 @@ async fn tokenize( info: Extension, Json(req): Json, ) -> Result, (StatusCode, Json)> { - let tokenize_inner = move |input: String, add_special_tokens: bool, infer: Infer| async move { + let tokenize_inner = move |input: String, + add_special_tokens: bool, + prompt_name: Option, + infer: Infer| async move { let encoding = infer - .tokenize(input.clone(), add_special_tokens) + .tokenize(input.clone(), add_special_tokens, prompt_name) .await .map_err(ErrorResponse::from)?; let tokens: Vec = encoding @@ -1187,7 +1233,7 @@ async fn tokenize( let tokens = match req.inputs { TokenizeInput::Single(input) => { - vec![tokenize_inner(input, req.add_special_tokens, infer.0).await?] + vec![tokenize_inner(input, req.add_special_tokens, req.prompt_name, infer.0).await?] } TokenizeInput::Batch(inputs) => { if inputs.is_empty() { @@ -1223,6 +1269,7 @@ async fn tokenize( futures.push(tokenize_inner( input, req.add_special_tokens, + req.prompt_name.clone(), infer.0.clone(), )); } diff --git a/router/src/http/types.rs b/router/src/http/types.rs index a2a773e8..a47a995b 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -351,6 +351,8 @@ pub(crate) struct EmbedRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + #[schema(default = "null", example = "null", nullable = true)] + pub prompt_name: Option, #[serde(default = "default_normalize")] #[schema(default = "true", example = "true")] pub normalize: bool, @@ -373,6 +375,8 @@ pub(crate) struct EmbedSparseRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + #[schema(default = "null", example = "null", nullable = true)] + pub prompt_name: Option, } #[derive(Serialize, ToSchema)] @@ -393,6 +397,8 @@ pub(crate) struct EmbedAllRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + #[schema(default = "null", example = "null", nullable = true)] + pub prompt_name: Option, } #[derive(Serialize, ToSchema)] @@ -420,6 +426,8 @@ pub(crate) struct TokenizeRequest { #[serde(default = "default_add_special_tokens")] #[schema(default = "true", example = "true")] pub add_special_tokens: bool, + #[schema(default = "null", example = "null", nullable = true)] + pub prompt_name: Option, } fn default_add_special_tokens() -> bool { diff --git a/router/src/lib.rs b/router/src/lib.rs index 5c7899ec..fa47c7ea 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -28,7 +28,8 @@ use std::path::Path; use std::time::{Duration, Instant}; use text_embeddings_backend::{DType, Pool}; use text_embeddings_core::download::{ - download_artifacts, download_pool_config, download_st_config, ST_CONFIG_NAMES, + download_artifacts, download_new_st_config, download_pool_config, download_st_config, + ST_CONFIG_NAMES, }; use text_embeddings_core::infer::Infer; use text_embeddings_core::queue::Queue; @@ -52,6 +53,7 @@ pub async fn run( max_batch_requests: Option, max_client_batch_size: usize, auto_truncate: bool, + default_prompt_name: Option, hf_api_token: Option, hostname: Option, port: u16, @@ -91,6 +93,8 @@ pub async fn run( // Download sentence transformers config let _ = download_st_config(&api_repo).await; + // Download new sentence transformers config + let _ = download_new_st_config(&api_repo).await; // Download model from the Hub download_artifacts(&api_repo) @@ -171,12 +175,36 @@ pub async fn run( let tokenization_workers = tokenization_workers.unwrap_or_else(num_cpus::get_physical); + // Try to load new ST Config + let mut new_st_config: Option = None; + let config_path = model_root.join("config_sentence_transformers.json"); + if let Ok(config) = fs::read_to_string(config_path) { + new_st_config = Some( + serde_json::from_str(&config) + .context("Failed to parse `config_sentence_transformers.json`")?, + ); + } + let prompts = new_st_config.map(|c| c.prompts); + if let Some(default_prompt_name) = default_prompt_name.as_ref() { + match &prompts { + None => { + anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but no prompts were found in the Sentence Transformers configuration")); + } + Some(prompts) if !prompts.contains_key(default_prompt_name) => { + anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys())); + } + _ => (), + } + } + // Tokenization logic let tokenization = Tokenization::new( tokenization_workers, tokenizer, max_input_length, position_offset, + default_prompt_name, + prompts, ); // Get dtype @@ -390,6 +418,11 @@ pub struct STConfig { pub max_seq_length: usize, } +#[derive(Debug, Deserialize)] +pub struct NewSTConfig { + pub prompts: HashMap, +} + #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct EmbeddingModel { diff --git a/router/src/main.rs b/router/src/main.rs index 3c85f5f7..8036fec9 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -79,6 +79,15 @@ struct Args { #[clap(long, env)] auto_truncate: bool, + /// The name of the prompt to use for encoding. Must be a key in the `Sentence Transformers` + /// configuration `prompts` dictionary. + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the sentence + /// is appended to the prompt. + #[clap(long, env)] + default_prompt_name: Option, + /// Your HuggingFace hub token #[clap(long, env)] #[redact(partial)] @@ -172,6 +181,7 @@ async fn main() -> Result<()> { args.max_batch_requests, args.max_client_batch_size, args.auto_truncate, + args.default_prompt_name, args.hf_api_token, Some(args.hostname), args.port, From 0e2511b7a9f8d31b8348f0ad142a0b8aef59549c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 28 Jun 2024 11:45:37 +0200 Subject: [PATCH 2/3] add grpc implem --- core/src/infer.rs | 2 +- core/src/tokenization.rs | 83 ++++++++++++++++++++++++++------------- proto/tei.proto | 4 ++ router/src/grpc/server.rs | 13 +++++- router/src/http/server.rs | 4 +- router/src/lib.rs | 15 ++++--- router/src/main.rs | 31 ++++++++++++--- router/tests/common.rs | 2 + 8 files changed, 111 insertions(+), 43 deletions(-) diff --git a/core/src/infer.rs b/core/src/infer.rs index 6c973249..23b343bf 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -61,7 +61,7 @@ impl Infer { inputs: I, add_special_tokens: bool, prompt_name: Option, - ) -> Result { + ) -> Result<(Option, RawEncoding), TextEmbeddingsError> { self.tokenization .tokenize(inputs.into(), add_special_tokens, prompt_name) .await diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index d2886343..c33bfafc 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -20,7 +20,7 @@ impl Tokenization { tokenizer: Tokenizer, max_input_length: usize, position_offset: usize, - default_prompt_name: Option, + default_prompt: Option, prompts: Option>, ) -> Self { tracing::info!("Starting {workers} tokenization workers"); @@ -32,7 +32,7 @@ impl Tokenization { for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let receiver_clone = receiver.clone(); - let default_prompt_name_clone = default_prompt_name.clone(); + let default_prompt_clone = default_prompt.clone(); let prompts_clone = prompts.clone(); // Spawn worker std::thread::spawn(move || { @@ -40,7 +40,7 @@ impl Tokenization { tokenizer_clone, max_input_length, position_offset, - default_prompt_name_clone, + default_prompt_clone, prompts_clone, receiver_clone, ) @@ -92,7 +92,7 @@ impl Tokenization { inputs: EncodingInput, add_special_tokens: bool, prompt_name: Option, - ) -> Result { + ) -> Result<(Option, RawEncoding), TextEmbeddingsError> { // Check if inputs is empty if inputs.is_empty() { return Err(TextEmbeddingsError::Validation( @@ -158,7 +158,7 @@ fn tokenizer_worker( mut tokenizer: Tokenizer, max_input_length: usize, position_offset: usize, - default_prompt_name: Option, + default_prompt: Option, prompts: Option>, receiver: async_channel::Receiver, ) { @@ -174,9 +174,12 @@ fn tokenizer_worker( parent_span, ) => { parent_span.in_scope(|| { - let prompt_name = prompt_name.or(default_prompt_name.clone()); - if !response_tx.is_closed() { + let default_prompt_clone = match prompt_name { + None => default_prompt.clone(), + Some(_) => None, + }; + // It's possible that the user dropped its request resulting in a send error. // We just discard the error let _ = response_tx.send(encode_input( @@ -185,6 +188,7 @@ fn tokenizer_worker( truncation_direction, max_input_length, position_offset, + default_prompt_clone, prompt_name, prompts.as_ref(), &mut tokenizer, @@ -199,16 +203,20 @@ fn tokenizer_worker( response_tx, parent_span, ) => { - let prompt_name = prompt_name.or(default_prompt_name.clone()); - parent_span.in_scope(|| { if !response_tx.is_closed() { + let default_prompt_clone = match prompt_name { + None => default_prompt.clone(), + Some(_) => None, + }; + // It's possible that the user dropped its request resulting in a send error. // We just discard the error let _ = response_tx.send(tokenize_input( inputs, add_special_tokens, None, + default_prompt_clone, prompt_name, prompts.as_ref(), &mut tokenizer, @@ -240,14 +248,11 @@ fn decode_ids( .decode(&ids, skip_special_tokens)?) } -fn tokenize_input( - inputs: EncodingInput, - add_special_tokens: bool, - truncate_params: Option, +fn prepare_pre_prompt( + default_prompt: Option, prompt_name: Option, prompts: Option<&HashMap>, - tokenizer: &mut Tokenizer, -) -> Result { +) -> Result, TextEmbeddingsError> { let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() { match prompts { None => { @@ -259,8 +264,21 @@ fn tokenize_input( Some(prompts) => prompts.get(prompt_name).cloned(), } } else { - None + default_prompt }; + Ok(pre_prompt) +} + +fn tokenize_input( + inputs: EncodingInput, + add_special_tokens: bool, + truncate_params: Option, + default_prompt: Option, + prompt_name: Option, + prompts: Option<&HashMap>, + tokenizer: &mut Tokenizer, +) -> Result<(Option, RawEncoding), TextEmbeddingsError> { + let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?; let encoding = match inputs { // encode input @@ -272,9 +290,11 @@ fn tokenize_input( s }; - tokenizer + let encoding = tokenizer .with_truncation(truncate_params)? - .encode::(s, add_special_tokens)? + .encode::<&str>(&s, add_special_tokens)?; + + (Some(s), encoding) } EncodingInput::Dual(s1, s2) => { if pre_prompt.is_some() { @@ -283,9 +303,12 @@ fn tokenize_input( )); } - tokenizer - .with_truncation(truncate_params)? - .encode::<(String, String)>((s1, s2), add_special_tokens)? + ( + None, + tokenizer + .with_truncation(truncate_params)? + .encode::<(String, String)>((s1, s2), add_special_tokens)?, + ) } // input is encoded -> convert to tokenizers Encoding EncodingInput::Ids(ids) => { @@ -293,15 +316,19 @@ fn tokenize_input( let text = tokenizer.decode(&ids, true)?; pre_prompt.push_str(&text); - tokenizer + let encoding = tokenizer .with_truncation(truncate_params)? - .encode::(pre_prompt, false)? + .encode::<&str>(&pre_prompt, true)?; + + (Some(pre_prompt), encoding) } else { let text = tokenizer.decode(&ids, false)?; - tokenizer + let encoding = tokenizer .with_truncation(truncate_params)? - .encode::(text, false)? + .encode::<&str>(&text, false)?; + + (Some(text), encoding) } } }; @@ -316,6 +343,7 @@ fn encode_input( truncation_direction: TruncationDirection, max_input_length: usize, position_offset: usize, + default_prompt: Option, prompt_name: Option, prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, @@ -328,10 +356,11 @@ fn encode_input( stride: 0, }); - let encoding = tokenize_input( + let (_, encoding) = tokenize_input( inputs, true, truncate_params, + default_prompt, prompt_name, prompts, tokenizer, @@ -402,7 +431,7 @@ enum TokenizerRequest { EncodingInput, bool, Option, - oneshot::Sender>, + oneshot::Sender, RawEncoding), TextEmbeddingsError>>, Span, ), Decode( diff --git a/proto/tei.proto b/proto/tei.proto index 394c0262..aac6c2ba 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -79,6 +79,7 @@ message EmbedRequest { bool truncate = 2; bool normalize = 3; TruncationDirection truncation_direction = 4; + optional string prompt_name = 5; } message EmbedResponse { @@ -90,6 +91,7 @@ message EmbedSparseRequest { string inputs = 1; bool truncate = 2; TruncationDirection truncation_direction = 3; + optional string prompt_name = 4; } message SparseValue { @@ -106,6 +108,7 @@ message EmbedAllRequest { string inputs = 1; bool truncate = 2; TruncationDirection truncation_direction = 3; + optional string prompt_name = 4; } message TokenEmbedding { @@ -175,6 +178,7 @@ message RerankResponse { message EncodeRequest { string inputs = 1; bool add_special_tokens = 2; + optional string prompt_name = 3; } message SimpleToken { diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index c428e065..d3666214 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -87,6 +87,7 @@ impl TextEmbeddingsService { request.inputs, request.truncate, truncation_direction, + request.prompt_name, request.normalize, permit, ) @@ -142,6 +143,7 @@ impl TextEmbeddingsService { request.inputs, request.truncate, truncation_direction, + request.prompt_name, permit, ) .await @@ -207,6 +209,7 @@ impl TextEmbeddingsService { request.inputs, request.truncate, truncation_direction, + request.prompt_name, permit, ) .await @@ -326,11 +329,17 @@ impl TextEmbeddingsService { #[instrument(skip_all)] async fn tokenize_inner(&self, request: EncodeRequest) -> Result { let inputs = request.inputs; - let encoding = self + let (encoded_inputs, encoding) = self .infer - .tokenize(inputs.clone(), request.add_special_tokens) + .tokenize( + inputs.clone(), + request.add_special_tokens, + request.prompt_name, + ) .await .map_err(ErrorResponse::from)?; + let inputs = encoded_inputs.unwrap_or(inputs); + let tokens: Vec = encoding .get_ids() .iter() diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 6adca38f..49e6029a 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1195,10 +1195,12 @@ async fn tokenize( add_special_tokens: bool, prompt_name: Option, infer: Infer| async move { - let encoding = infer + let (encoded_input, encoding) = infer .tokenize(input.clone(), add_special_tokens, prompt_name) .await .map_err(ErrorResponse::from)?; + let input = encoded_input.unwrap_or(input); + let tokens: Vec = encoding .get_ids() .iter() diff --git a/router/src/lib.rs b/router/src/lib.rs index fa47c7ea..3be03190 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,6 +53,7 @@ pub async fn run( max_batch_requests: Option, max_client_batch_size: usize, auto_truncate: bool, + default_prompt: Option, default_prompt_name: Option, hf_api_token: Option, hostname: Option, @@ -184,8 +185,8 @@ pub async fn run( .context("Failed to parse `config_sentence_transformers.json`")?, ); } - let prompts = new_st_config.map(|c| c.prompts); - if let Some(default_prompt_name) = default_prompt_name.as_ref() { + let prompts = new_st_config.and_then(|c| c.prompts); + let default_prompt = if let Some(default_prompt_name) = default_prompt_name.as_ref() { match &prompts { None => { anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but no prompts were found in the Sentence Transformers configuration")); @@ -193,9 +194,11 @@ pub async fn run( Some(prompts) if !prompts.contains_key(default_prompt_name) => { anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys())); } - _ => (), + Some(prompts) => prompts.get(default_prompt_name).cloned(), } - } + } else { + default_prompt + }; // Tokenization logic let tokenization = Tokenization::new( @@ -203,7 +206,7 @@ pub async fn run( tokenizer, max_input_length, position_offset, - default_prompt_name, + default_prompt, prompts, ); @@ -420,7 +423,7 @@ pub struct STConfig { #[derive(Debug, Deserialize)] pub struct NewSTConfig { - pub prompts: HashMap, + pub prompts: Option>, } #[derive(Clone, Debug, Serialize)] diff --git a/router/src/main.rs b/router/src/main.rs index 8036fec9..a72caa8b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -79,15 +79,33 @@ struct Args { #[clap(long, env)] auto_truncate: bool, - /// The name of the prompt to use for encoding. Must be a key in the `Sentence Transformers` - /// configuration `prompts` dictionary. - /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// The name of the prompt that should be used by default for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``default_prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, /// then the sentence "What is the capital of France?" will be encoded as - /// "query: What is the capital of France?" because the sentence - /// is appended to the prompt. - #[clap(long, env)] + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. + /// + /// The argument '--default-prompt-name ' cannot be used with + /// '--default-prompt ` + #[clap(long, env, conflicts_with = "default_prompt")] default_prompt_name: Option, + /// The prompt that should be used by default for encoding. If not set, no prompt + /// will be applied. + /// + /// For example if ``default_prompt`` is "query: " then the sentence "What is the capital of + /// France?" will be encoded as "query: What is the capital of France?" because the prompt + /// text will be prepended before any text to encode. + /// + /// The argument '--default-prompt ' cannot be used with + /// '--default-prompt-name ` + #[clap(long, env, conflicts_with = "default_prompt_name")] + default_prompt: Option, + /// Your HuggingFace hub token #[clap(long, env)] #[redact(partial)] @@ -181,6 +199,7 @@ async fn main() -> Result<()> { args.max_batch_requests, args.max_client_batch_size, args.auto_truncate, + args.default_prompt, args.default_prompt_name, args.hf_api_token, Some(args.hostname), diff --git a/router/tests/common.rs b/router/tests/common.rs index c8669c12..55fdf5f5 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -58,6 +58,8 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy false, None, None, + None, + None, 8090, None, None, From 881e44f8b4f3c809a9a6310cb12fc29b9c8407a6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:07:15 +0200 Subject: [PATCH 3/3] update openapi.json --- Cargo.lock | 4 +- docs/openapi.json | 580 ++++++++++++-------------------------- router/src/http/server.rs | 27 +- router/src/http/types.rs | 56 +++- router/src/lib.rs | 12 +- 5 files changed, 256 insertions(+), 423 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 134036f6..7d449f0d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2071,9 +2071,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "macro_rules_attribute" diff --git a/docs/openapi.json b/docs/openapi.json index 7368145e..d2f36301 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "1.2.3" + "version": "1.3.0" }, "paths": { "/decode": { @@ -19,7 +19,6 @@ "Text Embeddings Inference" ], "summary": "Decode input ids", - "description": "Decode input ids", "operationId": "decode", "requestBody": { "content": { @@ -65,7 +64,6 @@ "Text Embeddings Inference" ], "summary": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.", - "description": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.", "operationId": "embed", "requestBody": { "content": { @@ -153,7 +151,7 @@ "Text Embeddings Inference" ], "summary": "Get all Embeddings without Pooling.", - "description": "Get all Embeddings without Pooling.\nReturns a 424 status code if the model is not an embedding model.", + "description": "Returns a 424 status code if the model is not an embedding model.", "operationId": "embed_all", "requestBody": { "content": { @@ -241,7 +239,6 @@ "Text Embeddings Inference" ], "summary": "Get Sparse Embeddings. Returns a 424 status code if the model is not an embedding model with SPLADE pooling.", - "description": "Get Sparse Embeddings. Returns a 424 status code if the model is not an embedding model with SPLADE pooling.", "operationId": "embed_sparse", "requestBody": { "content": { @@ -323,101 +320,12 @@ } } }, - "/embeddings": { - "post": { - "tags": [ - "Text Embeddings Inference" - ], - "summary": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.", - "description": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.", - "operationId": "openai_embed", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatRequest" - } - } - }, - "required": true - }, - "responses": { - "200": { - "description": "Embeddings", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatResponse" - } - } - } - }, - "413": { - "description": "Batch size error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" - }, - "example": { - "message": "Batch size error", - "type": "validation" - } - } - } - }, - "422": { - "description": "Tokenization error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" - }, - "example": { - "message": "Tokenization error", - "type": "tokenizer" - } - } - } - }, - "424": { - "description": "Embedding Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" - }, - "example": { - "message": "Inference failed", - "type": "backend" - } - } - } - }, - "429": { - "description": "Model is overloaded", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" - }, - "example": { - "message": "Model is overloaded", - "type": "overloaded" - } - } - } - } - } - } - }, "/health": { "get": { "tags": [ "Text Embeddings Inference" ], "summary": "Health check method", - "description": "Health check method", "operationId": "health", "responses": { "200": { @@ -446,7 +354,6 @@ "Text Embeddings Inference" ], "summary": "Text Embeddings Inference endpoint info", - "description": "Text Embeddings Inference endpoint info", "operationId": "get_model_info", "responses": { "200": { @@ -468,7 +375,6 @@ "Text Embeddings Inference" ], "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", "operationId": "metrics", "responses": { "200": { @@ -490,7 +396,6 @@ "Text Embeddings Inference" ], "summary": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model", - "description": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model", "operationId": "predict", "requestBody": { "content": { @@ -578,7 +483,7 @@ "Text Embeddings Inference" ], "summary": "Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with", - "description": "Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with\na single class.", + "description": "a single class.", "operationId": "rerank", "requestBody": { "content": { @@ -666,7 +571,6 @@ "Text Embeddings Inference" ], "summary": "Tokenize inputs", - "description": "Tokenize inputs", "operationId": "tokenize", "requestBody": { "content": { @@ -706,19 +610,18 @@ } } }, - "/vertex": { + "/v1/embeddings": { "post": { "tags": [ "Text Embeddings Inference" ], - "summary": "Generate embeddings from a Vertex request", - "description": "Generate embeddings from a Vertex request", - "operationId": "vertex_compatibility", + "summary": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.", + "operationId": "openai_embed", "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/VertexRequest" + "$ref": "#/components/schemas/OpenAICompatRequest" } } }, @@ -726,18 +629,25 @@ }, "responses": { "200": { - "description": "Results" + "description": "Embeddings", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatResponse" + } + } + } }, "413": { "description": "Batch size error", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ErrorResponse" + "$ref": "#/components/schemas/OpenAICompatErrorResponse" }, "example": { - "error": "Batch size error", - "error_type": "validation" + "message": "Batch size error", + "type": "validation" } } } @@ -747,25 +657,25 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ErrorResponse" + "$ref": "#/components/schemas/OpenAICompatErrorResponse" }, "example": { - "error": "Tokenization error", - "error_type": "tokenizer" + "message": "Tokenization error", + "type": "tokenizer" } } } }, "424": { - "description": "Error", + "description": "Embedding Error", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ErrorResponse" + "$ref": "#/components/schemas/OpenAICompatErrorResponse" }, "example": { - "error": "Inference failed", - "error_type": "backend" + "message": "Inference failed", + "type": "backend" } } } @@ -775,11 +685,11 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ErrorResponse" + "$ref": "#/components/schemas/OpenAICompatErrorResponse" }, "example": { - "error": "Model is overloaded", - "error_type": "overloaded" + "message": "Model is overloaded", + "type": "overloaded" } } } @@ -852,10 +762,26 @@ "inputs": { "$ref": "#/components/schemas/Input" }, + "prompt_name": { + "type": "string", + "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.", + "default": "null", + "example": "null", + "nullable": true + }, "truncate": { "type": "boolean", "default": "false", - "example": "false" + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" } } }, @@ -895,10 +821,26 @@ "default": "true", "example": "true" }, + "prompt_name": { + "type": "string", + "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.", + "default": "null", + "example": "null", + "nullable": true + }, "truncate": { "type": "boolean", "default": "false", - "example": "false" + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" } } }, @@ -928,10 +870,26 @@ "inputs": { "$ref": "#/components/schemas/Input" }, + "prompt_name": { + "type": "string", + "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.", + "default": "null", + "example": "null", + "nullable": true + }, "truncate": { "type": "boolean", "default": "false", - "example": "false" + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" } } }, @@ -944,6 +902,20 @@ } } }, + "Embedding": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "number", + "format": "float" + } + }, + { + "type": "string" + } + ] + }, "EmbeddingModel": { "type": "object", "required": [ @@ -956,6 +928,13 @@ } } }, + "EncodingFormat": { + "type": "string", + "enum": [ + "float", + "base64" + ] + }, "ErrorResponse": { "type": "object", "required": [ @@ -991,10 +970,14 @@ "max_input_length", "max_batch_tokens", "max_client_batch_size", + "auto_truncate", "tokenization_workers", "version" ], "properties": { + "auto_truncate": { + "type": "boolean" + }, "docker_label": { "type": "string", "example": "null", @@ -1065,12 +1048,12 @@ "Input": { "oneOf": [ { - "type": "string" + "$ref": "#/components/schemas/InputType" }, { "type": "array", "items": { - "type": "string" + "$ref": "#/components/schemas/InputType" } } ] @@ -1098,6 +1081,21 @@ } ] }, + "InputType": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + } + ] + }, "ModelType": { "oneOf": [ { @@ -1144,16 +1142,7 @@ ], "properties": { "embedding": { - "type": "array", - "items": { - "type": "number", - "format": "float" - }, - "example": [ - 0.0, - 1.0, - 2.0 - ] + "$ref": "#/components/schemas/Embedding" }, "index": { "type": "integer", @@ -1193,6 +1182,14 @@ "input" ], "properties": { + "encoding_format": { + "allOf": [ + { + "$ref": "#/components/schemas/EncodingFormat" + } + ], + "default": "float" + }, "input": { "$ref": "#/components/schemas/Input" }, @@ -1317,7 +1314,16 @@ "truncate": { "type": "boolean", "default": "false", - "example": "false" + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" } } }, @@ -1416,7 +1422,16 @@ "truncate": { "type": "boolean", "default": "false", - "example": "false" + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" } } }, @@ -1479,6 +1494,19 @@ } } }, + "TokenizeInput": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, "TokenizeRequest": { "type": "object", "required": [ @@ -1491,7 +1519,14 @@ "example": "true" }, "inputs": { - "$ref": "#/components/schemas/Input" + "$ref": "#/components/schemas/TokenizeInput" + }, + "prompt_name": { + "type": "string", + "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.", + "default": "null", + "example": "null", + "nullable": true } } }, @@ -1515,273 +1550,12 @@ ] ] }, - "VertexInstance": { - "oneOf": [ - { - "allOf": [ - { - "$ref": "#/components/schemas/EmbedRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "embed" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/EmbedAllRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "embed_all" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/EmbedSparseRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "embed_sparse" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/PredictRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "predict" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/RerankRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "rerank" - ] - } - } - } - ] - }, - { - "allOf": [ - { - "$ref": "#/components/schemas/TokenizeRequest" - }, - { - "type": "object", - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "tokenize" - ] - } - } - } - ] - } - ], - "discriminator": { - "propertyName": "type" - } - }, - "VertexRequest": { - "type": "object", - "required": [ - "instances" - ], - "properties": { - "instances": { - "type": "array", - "items": { - "$ref": "#/components/schemas/VertexInstance" - } - } - } - }, - "VertexResponse": { - "type": "array", - "items": { - "$ref": "#/components/schemas/VertexResponseInstance" - } - }, - "VertexResponseInstance": { - "oneOf": [ - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/EmbedResponse" - }, - "type": { - "type": "string", - "enum": [ - "embed" - ] - } - } - }, - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/EmbedAllResponse" - }, - "type": { - "type": "string", - "enum": [ - "embed_all" - ] - } - } - }, - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/EmbedSparseResponse" - }, - "type": { - "type": "string", - "enum": [ - "embed_sparse" - ] - } - } - }, - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/PredictResponse" - }, - "type": { - "type": "string", - "enum": [ - "predict" - ] - } - } - }, - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/RerankResponse" - }, - "type": { - "type": "string", - "enum": [ - "rerank" - ] - } - } - }, - { - "type": "object", - "required": [ - "type", - "result" - ], - "properties": { - "result": { - "$ref": "#/components/schemas/TokenizeResponse" - }, - "type": { - "type": "string", - "enum": [ - "tokenize" - ] - } - } - } - ], - "discriminator": { - "propertyName": "type" - } + "TruncationDirection": { + "type": "string", + "enum": [ + "Left", + "Right" + ] } } }, diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 49e6029a..17baada6 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -5,7 +5,8 @@ use crate::http::types::{ OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput, - TokenizeRequest, TokenizeResponse, VertexPrediction, VertexRequest, VertexResponse, + TokenizeRequest, TokenizeResponse, TruncationDirection, VertexPrediction, VertexRequest, + VertexResponse, }; use crate::{ shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, @@ -32,7 +33,6 @@ use text_embeddings_core::infer::{ AllEmbeddingsInferResponse, Infer, InferMetadata, PooledEmbeddingsInferResponse, }; use text_embeddings_core::TextEmbeddingsError; -use tokenizers::TruncationDirection; use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; @@ -118,7 +118,7 @@ async fn predict( .predict( inputs, truncate, - req.truncation_direction, + req.truncation_direction.into(), req.raw_scores, permit, ) @@ -335,7 +335,7 @@ async fn rerank( .predict( (query, text), truncate, - req.truncation_direction, + req.truncation_direction.into(), req.raw_scores, permit, ) @@ -499,7 +499,7 @@ async fn embed( .embed_pooled( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), req.prompt_name, req.normalize, permit, @@ -568,7 +568,7 @@ async fn embed( .embed_pooled( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), prompt_name, req.normalize, permit, @@ -677,7 +677,7 @@ async fn embed_sparse( .embed_sparse( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), req.prompt_name, permit, ) @@ -745,7 +745,7 @@ async fn embed_sparse( .embed_sparse( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), prompt_name, permit, ) @@ -846,7 +846,7 @@ async fn embed_all( .embed_all( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), req.prompt_name, permit, ) @@ -914,7 +914,7 @@ async fn embed_all( .embed_all( input, truncate, - req.truncation_direction, + req.truncation_direction.into(), prompt_name, permit, ) @@ -1029,7 +1029,7 @@ async fn openai_embed( .embed_pooled( input, truncate, - TruncationDirection::Right, + tokenizers::TruncationDirection::Right, None, true, permit, @@ -1102,7 +1102,7 @@ async fn openai_embed( .embed_pooled( input, truncate, - TruncationDirection::Right, + tokenizers::TruncationDirection::Right, None, true, permit, @@ -1483,6 +1483,8 @@ pub async fn run( Info, ModelType, ClassifierModel, + Embedding, + EncodingFormat, EmbeddingModel, PredictRequest, Prediction, @@ -1506,6 +1508,7 @@ pub async fn run( TokenizeInput, TokenizeRequest, TokenizeResponse, + TruncationDirection, SimpleToken, InputType, InputIds, diff --git a/router/src/http/types.rs b/router/src/http/types.rs index a47a995b..4414ecb4 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -4,7 +4,6 @@ use serde::{de, Deserialize, Deserializer, Serialize}; use serde_json::json; use std::fmt::Formatter; use text_embeddings_core::tokenization::EncodingInput; -use tokenizers::TruncationDirection; use utoipa::openapi::{RefOr, Schema}; use utoipa::ToSchema; @@ -194,6 +193,22 @@ impl<'__s> ToSchema<'__s> for PredictInput { } } +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, ToSchema, Eq, Default)] +pub(crate) enum TruncationDirection { + Left, + #[default] + Right, +} + +impl From for tokenizers::TruncationDirection { + fn from(value: TruncationDirection) -> Self { + match value { + TruncationDirection::Left => Self::Left, + TruncationDirection::Right => Self::Right, + } + } +} + #[derive(Deserialize, ToSchema)] pub(crate) struct PredictRequest { pub inputs: PredictInput, @@ -262,6 +277,7 @@ pub(crate) enum InputType { String(String), Ids(Vec), } + impl InputType { pub(crate) fn count_chars(&self) -> usize { match self { @@ -270,6 +286,7 @@ impl InputType { } } } + impl From for EncodingInput { fn from(value: InputType) -> Self { match value { @@ -278,6 +295,7 @@ impl From for EncodingInput { } } } + #[derive(Deserialize, ToSchema)] #[serde(untagged)] pub(crate) enum Input { @@ -351,6 +369,15 @@ pub(crate) struct EmbedRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + /// The name of the prompt that should be used by for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. #[schema(default = "null", example = "null", nullable = true)] pub prompt_name: Option, #[serde(default = "default_normalize")] @@ -375,6 +402,15 @@ pub(crate) struct EmbedSparseRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + /// The name of the prompt that should be used by for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. #[schema(default = "null", example = "null", nullable = true)] pub prompt_name: Option, } @@ -397,6 +433,15 @@ pub(crate) struct EmbedAllRequest { #[serde(default)] #[schema(default = "right", example = "right")] pub truncation_direction: TruncationDirection, + /// The name of the prompt that should be used by for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. #[schema(default = "null", example = "null", nullable = true)] pub prompt_name: Option, } @@ -426,6 +471,15 @@ pub(crate) struct TokenizeRequest { #[serde(default = "default_add_special_tokens")] #[schema(default = "true", example = "true")] pub add_special_tokens: bool, + /// The name of the prompt that should be used by for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. #[schema(default = "null", example = "null", nullable = true)] pub prompt_name: Option, } diff --git a/router/src/lib.rs b/router/src/lib.rs index 3be03190..f5fd102c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -238,11 +238,13 @@ pub async fn run( .await .context("Model backend is not healthy")?; - tracing::info!("Warming up model"); - backend - .warmup(max_input_length, max_batch_tokens, max_batch_requests) - .await - .context("Model backend is not healthy")?; + if !backend.padded_model { + tracing::info!("Warming up model"); + backend + .warmup(max_input_length, max_batch_tokens, max_batch_requests) + .await + .context("Model backend is not healthy")?; + } let max_batch_requests = backend .max_batch_size