diff --git a/Cargo.lock b/Cargo.lock index 36facc4e..1289b141 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4077,6 +4077,7 @@ dependencies = [ "anyhow", "ndarray", "nohash-hasher", + "num_cpus", "ort", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 82e79e57..64c4d721 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ clap = { version = "4.1", features = ["derive", "env"] } hf-hub = { version = "0.3.2", features = ["tokio", "online"], default-features = false } metrics = "0.23" nohash-hasher = "0.2" +num_cpus = "1.16.0" tokenizers = { version = "0.19.1", default-features = false, features = ["onig", "esaxx_fast"] } tokio = { version = "1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync", "signal"] } tracing = "0.1" diff --git a/backends/ort/Cargo.toml b/backends/ort/Cargo.toml index 8fdad006..0d40fddd 100644 --- a/backends/ort/Cargo.toml +++ b/backends/ort/Cargo.toml @@ -9,6 +9,7 @@ homepage.workspace = true anyhow = { workspace = true } nohash-hasher = { workspace = true } ndarray = "0.15.6" +num_cpus = { workspace = true } ort = { version = "2.0.0-rc.4", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 9573f6b0..08dbfa7e 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -52,6 +52,8 @@ impl OrtBackend { // Start onnx session let session = Session::builder() + .s()? + .with_intra_threads(num_cpus::get()) .s()? .with_optimization_level(GraphOptimizationLevel::Level3) .s()? diff --git a/router/Cargo.toml b/router/Cargo.toml index 3a3fba27..e648036b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -23,7 +23,7 @@ futures = "^0.3" init-tracing-opentelemetry = { version = "0.18.1", features = ["opentelemetry-otlp"] } hf-hub = { workspace = true } http = "1.0.0" -num_cpus = "1.16.0" +num_cpus = { workspace = true } metrics = { workspace = true } metrics-exporter-prometheus = { version = "0.15.1", features = [] } opentelemetry = "0.23.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index 86a0f884..540e61af 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -198,7 +198,7 @@ pub async fn run( }; tracing::info!("Maximum number of tokens per request: {max_input_length}"); - let tokenization_workers = tokenization_workers.unwrap_or_else(num_cpus::get_physical); + let tokenization_workers = tokenization_workers.unwrap_or_else(num_cpus::get); // Try to load new ST Config let mut new_st_config: Option = None;