Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion architectures/inference-only/inference-node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ uuid = { version = "1", features = ["v4"] }
pyo3.workspace = true
postcard.workspace = true
axum = { version = "0.7", features = ["macros"] }
tower = { version = "0.4" }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.5", features = ["cors"] }
tikv-jemallocator.workspace = true
145 changes: 144 additions & 1 deletion architectures/inference-only/inference-node/src/bin/gateway-node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use anyhow::{Context, Result};
use axum::{
Json, Router,
extract::State,
http::StatusCode,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::post,
};
Expand Down Expand Up @@ -50,6 +50,10 @@ struct Args {

#[arg(long)]
write_endpoint_file: Option<PathBuf>,

/// Bearer token secret required for API authentication (optional)
#[arg(long, env = "GATEWAY_API_SECRET")]
api_secret: Option<String>,
}

#[derive(Clone, Debug)]
Expand All @@ -66,6 +70,7 @@ struct GatewayState {
available_nodes: RwLock<HashMap<EndpointId, InferenceNodeInfo>>,
pending_requests: RwLock<HashMap<String, mpsc::Sender<InferenceResponse>>>,
network_tx: mpsc::Sender<(EndpointId, InferenceMessage)>,
api_secret: Option<String>,
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
Expand Down Expand Up @@ -118,8 +123,20 @@ struct ChatCompletionResponse {
#[axum::debug_handler]
async fn handle_inference(
State(state): State<Arc<GatewayState>>,
headers: HeaderMap,
Json(req): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, AppError> {
if let Some(ref secret) = state.api_secret {
let authorized = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.is_some_and(|token| token == secret);
if !authorized {
return Err(AppError::Unauthorized);
}
}

let nodes = state.available_nodes.read().await;
let node = nodes.values().next().ok_or(AppError::NoNodesAvailable)?;

Expand Down Expand Up @@ -201,6 +218,7 @@ enum AppError {
NoNodesAvailable,
Timeout,
InternalError,
Unauthorized,
}

impl IntoResponse for AppError {
Expand All @@ -212,6 +230,7 @@ impl IntoResponse for AppError {
),
AppError::Timeout => (StatusCode::GATEWAY_TIMEOUT, "Inference request timed out"),
AppError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized"),
};
(status, message).into_response()
}
Expand Down Expand Up @@ -357,6 +376,7 @@ async fn run_gateway() -> Result<()> {
available_nodes: RwLock::new(HashMap::new()),
pending_requests: RwLock::new(HashMap::new()),
network_tx,
api_secret: args.api_secret.clone(),
});

info!("Gateway ready! Listening on http://{}", args.listen_addr);
Expand Down Expand Up @@ -494,3 +514,126 @@ async fn run_gateway() -> Result<()> {
info!("Shutdown complete");
Ok(())
}

#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Request, StatusCode},
};
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tower::ServiceExt;

use super::*;

fn make_app(secret: Option<&str>) -> axum::Router {
let (network_tx, _network_rx) = mpsc::channel(1);
let state = Arc::new(GatewayState {
available_nodes: RwLock::new(Default::default()),
pending_requests: RwLock::new(Default::default()),
network_tx,
api_secret: secret.map(|s| s.to_string()),
});
Router::new()
.route(
"/v1/chat/completions",
axum::routing::post(handle_inference),
)
.with_state(state)
}

fn chat_request_body() -> &'static str {
r#"{"messages":[{"role":"user","content":"hello"}]}"#
}

#[tokio::test]
async fn no_secret_configured_allows_unauthenticated_requests() {
let app = make_app(None);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(chat_request_body()))
.unwrap(),
)
.await
.unwrap();
// No nodes available, but auth passed — expect 503 not 401
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}

#[tokio::test]
async fn correct_bearer_token_is_accepted() {
let app = make_app(Some("supersecret"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer supersecret")
.body(Body::from(chat_request_body()))
.unwrap(),
)
.await
.unwrap();
// Auth passed, no nodes available — expect 503 not 401
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}

#[tokio::test]
async fn missing_auth_header_is_rejected() {
let app = make_app(Some("supersecret"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(chat_request_body()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn wrong_token_is_rejected() {
let app = make_app(Some("supersecret"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer wrongtoken")
.body(Body::from(chat_request_body()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn non_bearer_scheme_is_rejected() {
let app = make_app(Some("supersecret"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Basic supersecret")
.body(Body::from(chat_request_body()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
}
Loading