Skip to content

feat: add default prompts #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions core/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,10 @@ pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {

Err(err)
}

#[instrument(skip_all)]
pub async fn download_new_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
tracing::info!("Downloading `config_sentence_transformers.json`");
let pool_config_path = api.get("config_sentence_transformers.json").await?;
Ok(pool_config_path)
}
17 changes: 13 additions & 4 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ impl Infer {
&self,
inputs: I,
add_special_tokens: bool,
) -> Result<RawEncoding, TextEmbeddingsError> {
prompt_name: Option<String>,
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
self.tokenization
.tokenize(inputs.into(), add_special_tokens)
.tokenize(inputs.into(), add_special_tokens, prompt_name)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down Expand Up @@ -119,6 +120,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
permit: OwnedSemaphorePermit,
) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -138,6 +140,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
false,
&start_time,
permit,
Expand Down Expand Up @@ -172,6 +175,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -191,6 +195,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
true,
&start_time,
permit,
Expand Down Expand Up @@ -225,6 +230,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
normalize: bool,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
Expand All @@ -245,6 +251,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
true,
&start_time,
permit,
Expand Down Expand Up @@ -290,11 +297,13 @@ impl Infer {
Ok(response)
}

#[allow(clippy::too_many_arguments)]
async fn embed<I: Into<EncodingInput> + std::fmt::Debug>(
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
pooling: bool,
start_time: &Instant,
_permit: OwnedSemaphorePermit,
Expand All @@ -315,7 +324,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate, truncation_direction)
.encode(inputs.into(), truncate, truncation_direction, prompt_name)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down Expand Up @@ -381,7 +390,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate, truncation_direction)
.encode(inputs.into(), truncate, truncation_direction, None)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down
140 changes: 125 additions & 15 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer;
pub use tokenizers::Encoding as RawEncoding;
use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
Expand All @@ -19,6 +20,8 @@ impl Tokenization {
tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt: Option<String>,
prompts: Option<HashMap<String, String>>,
) -> Self {
tracing::info!("Starting {workers} tokenization workers");

Expand All @@ -29,12 +32,16 @@ impl Tokenization {
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let receiver_clone = receiver.clone();
let default_prompt_clone = default_prompt.clone();
let prompts_clone = prompts.clone();
// Spawn worker
std::thread::spawn(move || {
tokenizer_worker(
tokenizer_clone,
max_input_length,
position_offset,
default_prompt_clone,
prompts_clone,
receiver_clone,
)
});
Expand All @@ -49,6 +56,7 @@ impl Tokenization {
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
Expand All @@ -66,6 +74,7 @@ impl Tokenization {
inputs,
truncate,
truncation_direction,
prompt_name,
response_sender,
Span::current(),
))
Expand All @@ -82,7 +91,8 @@ impl Tokenization {
&self,
inputs: EncodingInput,
add_special_tokens: bool,
) -> Result<RawEncoding, TextEmbeddingsError> {
prompt_name: Option<String>,
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
return Err(TextEmbeddingsError::Validation(
Expand All @@ -98,6 +108,7 @@ impl Tokenization {
.send(TokenizerRequest::Tokenize(
inputs,
add_special_tokens,
prompt_name,
response_sender,
Span::current(),
))
Expand Down Expand Up @@ -147,6 +158,8 @@ fn tokenizer_worker(
mut tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt: Option<String>,
prompts: Option<HashMap<String, String>>,
receiver: async_channel::Receiver<TokenizerRequest>,
) {
// Loop over requests
Expand All @@ -156,11 +169,17 @@ fn tokenizer_worker(
inputs,
truncate,
truncation_direction,
prompt_name,
response_tx,
parent_span,
) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
let default_prompt_clone = match prompt_name {
None => default_prompt.clone(),
Some(_) => None,
};

// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(encode_input(
Expand All @@ -169,20 +188,37 @@ fn tokenizer_worker(
truncation_direction,
max_input_length,
position_offset,
default_prompt_clone,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
));
}
})
}
TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => {
TokenizerRequest::Tokenize(
inputs,
add_special_tokens,
prompt_name,
response_tx,
parent_span,
) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
let default_prompt_clone = match prompt_name {
None => default_prompt.clone(),
Some(_) => None,
};

// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(tokenize_input(
inputs,
add_special_tokens,
None,
default_prompt_clone,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
));
}
Expand Down Expand Up @@ -212,40 +248,104 @@ fn decode_ids(
.decode(&ids, skip_special_tokens)?)
}

fn prepare_pre_prompt(
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
) -> Result<Option<String>, TextEmbeddingsError> {
let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() {
match prompts {
None => {
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration")));
}
Some(prompts) if !prompts.contains_key(prompt_name) => {
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys())));
}
Some(prompts) => prompts.get(prompt_name).cloned(),
}
} else {
default_prompt
};
Ok(pre_prompt)
}

fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
truncate_params: Option<TruncationParams>,
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<RawEncoding, TextEmbeddingsError> {
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?;

let encoding = match inputs {
// encode input
EncodingInput::Single(s) => tokenizer
.with_truncation(truncate_params)?
.encode::<String>(s, add_special_tokens)?,
EncodingInput::Dual(s1, s2) => {
tokenizer
EncodingInput::Single(s) => {
let s = if let Some(mut pre_prompt) = pre_prompt {
pre_prompt.push_str(&s);
pre_prompt
} else {
s
};

let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2), add_special_tokens)?
.encode::<&str>(&s, add_special_tokens)?;

(Some(s), encoding)
}
EncodingInput::Dual(s1, s2) => {
if pre_prompt.is_some() {
return Err(TextEmbeddingsError::Validation(
"`prompt_name` cannot be set with dual inputs".to_string(),
));
}

(
None,
tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2), add_special_tokens)?,
)
}
// input is encoded -> convert to tokenizers Encoding
EncodingInput::Ids(ids) => {
let text = tokenizer.decode(&ids, false)?;
tokenizer
.with_truncation(truncate_params)?
.encode::<String>(text, false)?
if let Some(mut pre_prompt) = pre_prompt {
let text = tokenizer.decode(&ids, true)?;
pre_prompt.push_str(&text);

let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<&str>(&pre_prompt, true)?;

(Some(pre_prompt), encoding)
} else {
let text = tokenizer.decode(&ids, false)?;

let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<&str>(&text, false)?;

(Some(text), encoding)
}
}
};
Ok(encoding)
}

/// Get input length and optionally truncate it
#[allow(clippy::too_many_arguments)]
fn encode_input(
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
max_input_length: usize,
position_offset: usize,
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Default truncation params
Expand All @@ -256,7 +356,15 @@ fn encode_input(
stride: 0,
});

let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
let (_, encoding) = tokenize_input(
inputs,
true,
truncate_params,
default_prompt,
prompt_name,
prompts,
tokenizer,
)?;
let seq_len = encoding.len();

if seq_len > max_input_length {
Expand Down Expand Up @@ -315,13 +423,15 @@ enum TokenizerRequest {
EncodingInput,
bool,
TruncationDirection,
Option<String>,
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
Span,
),
Tokenize(
EncodingInput,
bool,
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
Option<String>,
oneshot::Sender<Result<(Option<String>, RawEncoding), TextEmbeddingsError>>,
Span,
),
Decode(
Expand Down
Loading
Loading