diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 040ac070..c83e06d7 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1666,9 +1666,7 @@ pub async fn run( ApiDoc::openapi() }; - // Create router - let mut app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc)) + let mut routes = Router::new() // Base routes .route("/info", get(get_model_info)) .route("/embed", post(embed)) @@ -1683,7 +1681,10 @@ pub async fn run( .route("/embeddings", post(openai_embed)) .route("/v1/embeddings", post(openai_embed)) // Vertex compat route - .route("/vertex", post(vertex_compatibility)) + .route("/vertex", post(vertex_compatibility)); + + #[allow(unused_mut)] + let mut public_routes = Router::new() // Base Health route .route("/health", get(health)) // Inference API health route @@ -1691,9 +1692,7 @@ pub async fn run( // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route - .route("/metrics", get(metrics)) - // Update payload limit - .layer(DefaultBodyLimit::max(payload_limit)); + .route("/metrics", get(metrics)); #[cfg(feature = "google")] { @@ -1701,39 +1700,44 @@ pub async fn run( if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") { tracing::info!("Serving Vertex compatible route on {env_predict_route}"); - app = app.route(&env_predict_route, post(vertex_compatibility)); + routes = routes.route(&env_predict_route, post(vertex_compatibility)); } if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") { tracing::info!("Serving Vertex compatible health route on {env_health_route}"); - app = app.route(&env_health_route, get(health)); + public_routes = public_routes.route(&env_health_route, get(health)); } } #[cfg(not(feature = "google"))] { // Set default routes - app = match &info.model_type { + routes = match &info.model_type { ModelType::Classifier(_) => { - app.route("/", post(predict)) + routes + .route("/", post(predict)) // AWS Sagemaker route .route("/invocations", post(predict)) } ModelType::Reranker(_) => { - app.route("/", post(rerank)) + routes + .route("/", post(rerank)) // AWS Sagemaker route .route("/invocations", post(rerank)) } ModelType::Embedding(model) => { if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) { - app.route("/", post(similarity)) + routes + .route("/", post(similarity)) // AWS Sagemaker route .route("/invocations", post(similarity)) } else if model.pooling == "splade" { - app.route("/", post(embed_sparse)) + routes + .route("/", post(embed_sparse)) // AWS Sagemaker route .route("/invocations", post(embed_sparse)) } else { - app.route("/", post(embed)) + routes + .route("/", post(embed)) // AWS Sagemaker route .route("/invocations", post(embed)) } @@ -1741,16 +1745,8 @@ pub async fn run( }; } - app = app - .layer(Extension(infer)) - .layer(Extension(info)) - .layer(Extension(prom_handle.clone())) - .layer(OtelAxumLayer::default()) - .layer(cors_layer); - if let Some(api_key) = api_key { - let mut prefix = "Bearer ".to_string(); - prefix.push_str(&api_key); + let prefix = format!("Bearer {}", api_key); // Leak to allow FnMut let api_key: &'static str = prefix.leak(); @@ -1767,9 +1763,20 @@ pub async fn run( } }; - app = app.layer(axum::middleware::from_fn(auth)); + routes = routes.layer(axum::middleware::from_fn(auth)); } + let app = Router::new() + .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc)) + .merge(routes) + .merge(public_routes) + .layer(Extension(infer)) + .layer(Extension(info)) + .layer(Extension(prom_handle.clone())) + .layer(OtelAxumLayer::default()) + .layer(DefaultBodyLimit::max(payload_limit)) + .layer(cors_layer); + // Run server let listener = tokio::net::TcpListener::bind(&addr) .await