Skip to content

Fix to allow health check w/o auth #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1666,9 +1666,7 @@ pub async fn run(
ApiDoc::openapi()
};

// Create router
let mut app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
let mut routes = Router::new()
// Base routes
.route("/info", get(get_model_info))
.route("/embed", post(embed))
Expand All @@ -1683,74 +1681,72 @@ pub async fn run(
.route("/embeddings", post(openai_embed))
.route("/v1/embeddings", post(openai_embed))
// Vertex compat route
.route("/vertex", post(vertex_compatibility))
.route("/vertex", post(vertex_compatibility));

#[allow(unused_mut)]
let mut public_routes = Router::new()
// Base Health route
.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
// Update payload limit
.layer(DefaultBodyLimit::max(payload_limit));
.route("/metrics", get(metrics));

#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");

if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
tracing::info!("Serving Vertex compatible route on {env_predict_route}");
app = app.route(&env_predict_route, post(vertex_compatibility));
routes = routes.route(&env_predict_route, post(vertex_compatibility));
}

if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
tracing::info!("Serving Vertex compatible health route on {env_health_route}");
app = app.route(&env_health_route, get(health));
public_routes = public_routes.route(&env_health_route, get(health));
}
}
#[cfg(not(feature = "google"))]
{
// Set default routes
app = match &info.model_type {
routes = match &info.model_type {
ModelType::Classifier(_) => {
app.route("/", post(predict))
routes
.route("/", post(predict))
// AWS Sagemaker route
.route("/invocations", post(predict))
}
ModelType::Reranker(_) => {
app.route("/", post(rerank))
routes
.route("/", post(rerank))
// AWS Sagemaker route
.route("/invocations", post(rerank))
}
ModelType::Embedding(model) => {
if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) {
app.route("/", post(similarity))
routes
.route("/", post(similarity))
// AWS Sagemaker route
.route("/invocations", post(similarity))
} else if model.pooling == "splade" {
app.route("/", post(embed_sparse))
routes
.route("/", post(embed_sparse))
// AWS Sagemaker route
.route("/invocations", post(embed_sparse))
} else {
app.route("/", post(embed))
routes
.route("/", post(embed))
// AWS Sagemaker route
.route("/invocations", post(embed))
}
}
};
}

app = app
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);

if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);
let prefix = format!("Bearer {}", api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();
Expand All @@ -1767,9 +1763,20 @@ pub async fn run(
}
};

app = app.layer(axum::middleware::from_fn(auth));
routes = routes.layer(axum::middleware::from_fn(auth));
}

let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
.merge(routes)
.merge(public_routes)
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer);

// Run server
let listener = tokio::net::TcpListener::bind(&addr)
.await
Expand Down