Skip to content

feat: add /similarity route #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

172 changes: 172 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,93 @@
}
}
},
"/similarity": {
"post": {
"tags": [
"Text Embeddings Inference"
],
"summary": "Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.",
"operationId": "similarity",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SimilarityRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Sentence Similarity",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SimilarityResponse"
}
}
}
},
"413": {
"description": "Batch size error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Batch size error",
"error_type": "validation"
}
}
}
},
"422": {
"description": "Tokenization error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Tokenization error",
"error_type": "tokenizer"
}
}
}
},
"424": {
"description": "Embedding Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Inference failed",
"error_type": "backend"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
}
}
}
},
"/tokenize": {
"post": {
"tags": [
Expand Down Expand Up @@ -1441,6 +1528,91 @@
"$ref": "#/components/schemas/Rank"
}
},
"SimilarityInput": {
"type": "object",
"required": [
"source_sentence",
"sentences"
],
"properties": {
"sentences": {
"type": "array",
"items": {
"type": "string"
},
"description": "A list of strings which will be compared against the source_sentence.",
"example": [
"What is Machine Learning?"
]
},
"source_sentence": {
"type": "string",
"description": "The string that you wish to compare the other strings with. This can be a phrase, sentence,\nor longer passage, depending on the model being used.",
"example": "What is Deep Learning?"
}
}
},
"SimilarityParameters": {
"type": "object",
"required": [
"truncation_direction"
],
"properties": {
"prompt_name": {
"type": "string",
"description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
"default": "null",
"example": "null",
"nullable": true
},
"truncate": {
"type": "boolean",
"default": "false",
"example": "false",
"nullable": true
},
"truncation_direction": {
"allOf": [
{
"$ref": "#/components/schemas/TruncationDirection"
}
],
"default": "right"
}
}
},
"SimilarityRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"inputs": {
"$ref": "#/components/schemas/SimilarityInput"
},
"parameters": {
"allOf": [
{
"$ref": "#/components/schemas/SimilarityParameters"
}
],
"default": "null",
"nullable": true
}
}
},
"SimilarityResponse": {
"type": "array",
"items": {
"type": "number",
"format": "float"
},
"example": [
0.0,
1.0,
0.5
]
},
"SimpleToken": {
"type": "object",
"required": [
Expand Down
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ opentelemetry = "0.23.0"
opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.16.0"
reqwest = { version = "0.12.5", features = [] }
simsimd = "4.4.0"
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
Expand Down
98 changes: 96 additions & 2 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::http::types::{
EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType,
OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse,
OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank,
RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput,
RerankRequest, RerankResponse, Sequence, SimilarityInput, SimilarityParameters,
SimilarityRequest, SimilarityResponse, SimpleToken, SparseValue, TokenizeInput,
TokenizeRequest, TokenizeResponse, TruncationDirection, VertexPrediction, VertexRequest,
VertexResponse,
};
Expand All @@ -26,6 +27,7 @@ use futures::future::join_all;
use futures::FutureExt;
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use simsimd::SpatialSimilarity;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use text_embeddings_backend::BackendError;
Expand Down Expand Up @@ -455,6 +457,88 @@ async fn rerank(
Ok((headers, Json(response)))
}

/// Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.
#[utoipa::path(
post,
tag = "Text Embeddings Inference",
path = "/similarity",
request_body = SimilarityRequest,
responses(
(status = 200, description = "Sentence Similarity", body = SimilarityResponse),
(status = 424, description = "Embedding Error", body = ErrorResponse,
example = json ! ({"error": "Inference failed", "error_type": "backend"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Tokenization error", body = ErrorResponse,
example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})),
(status = 413, description = "Batch size error", body = ErrorResponse,
example = json ! ({"error": "Batch size error", "error_type": "validation"})),
)
)]
#[instrument(
skip_all,
fields(total_time, tokenization_time, queue_time, inference_time,)
)]
async fn similarity(
infer: Extension<Infer>,
info: Extension<Info>,
Json(req): Json<SimilarityRequest>,
) -> Result<(HeaderMap, Json<SimilarityResponse>), (StatusCode, Json<ErrorResponse>)> {
if req.inputs.sentences.is_empty() {
let message = "`inputs.sentences` cannot be empty".to_string();
tracing::error!("{message}");
let err = ErrorResponse {
error: message,
error_type: ErrorType::Validation,
};
let counter = metrics::counter!("te_request_failure", "err" => "validation");
counter.increment(1);
Err(err)?;
}
// +1 because of the source sentence
let batch_size = req.inputs.sentences.len() + 1;
if batch_size > info.max_client_batch_size {
let message = format!(
"batch size {batch_size} > maximum allowed batch size {}",
info.max_client_batch_size
);
tracing::error!("{message}");
let err = ErrorResponse {
error: message,
error_type: ErrorType::Validation,
};
let counter = metrics::counter!("te_request_failure", "err" => "batch_size");
counter.increment(1);
Err(err)?;
}

// Convert request to embed request
let mut inputs = Vec::with_capacity(req.inputs.sentences.len() + 1);
inputs.push(InputType::String(req.inputs.source_sentence));
for s in req.inputs.sentences {
inputs.push(InputType::String(s));
}
let parameters = req.parameters.unwrap_or_default();
let embed_req = EmbedRequest {
inputs: Input::Batch(inputs),
truncate: parameters.truncate,
truncation_direction: parameters.truncation_direction,
prompt_name: parameters.prompt_name,
normalize: false,
};

// Get embeddings
let (header_map, embed_response) = embed(infer, info, Json(embed_req)).await?;
let embeddings = embed_response.0 .0;

// Compute cosine
let distances = (1..batch_size)
.map(|i| 1.0 - f32::cosine(&embeddings[0], &embeddings[i]).unwrap() as f32)
.collect();

Ok((header_map, Json(SimilarityResponse(distances))))
}

/// Get Embeddings. Returns a 424 status code if the model is not an embedding model.
#[utoipa::path(
post,
Expand Down Expand Up @@ -1472,6 +1556,7 @@ pub async fn run(
embed_all,
embed_sparse,
openai_embed,
similarity,
tokenize,
decode,
metrics,
Expand Down Expand Up @@ -1509,6 +1594,10 @@ pub async fn run(
TokenizeRequest,
TokenizeResponse,
TruncationDirection,
SimilarityInput,
SimilarityParameters,
SimilarityRequest,
SimilarityResponse,
SimpleToken,
InputType,
InputIds,
Expand Down Expand Up @@ -1587,6 +1676,7 @@ pub async fn run(
.route("/embed_sparse", post(embed_sparse))
.route("/predict", post(predict))
.route("/rerank", post(rerank))
.route("/similarity", post(similarity))
.route("/tokenize", post(tokenize))
.route("/decode", post(decode))
// OpenAI compat route
Expand Down Expand Up @@ -1634,7 +1724,11 @@ pub async fn run(
.route("/invocations", post(rerank))
}
ModelType::Embedding(model) => {
if model.pooling == "splade" {
if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) {
app.route("/", post(similarity))
// AWS Sagemaker route
.route("/invocations", post(similarity))
} else if model.pooling == "splade" {
app.route("/", post(embed_sparse))
// AWS Sagemaker route
.route("/invocations", post(embed_sparse))
Expand Down
Loading
Loading