diff --git a/README.md b/README.md index b4b8de3..a06e85e 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,8 @@ Minimal `mizan-daemon.toml`: control_plane_url = "http://127.0.0.1:18180" daemon_token_path = "/run/secrets/mizan-daemon-token" local_provider_url = "http://127.0.0.1:11434/v1" +# Optional for local providers that require bearer auth. +# local_provider_api_key = "local-secret" provider_family = "openai-compatible" advertised_models = ["llama3.1"] max_concurrency = 2 diff --git a/crates/mizan-daemon/Cargo.toml b/crates/mizan-daemon/Cargo.toml index 09b20dd..6c6c652 100644 --- a/crates/mizan-daemon/Cargo.toml +++ b/crates/mizan-daemon/Cargo.toml @@ -9,9 +9,11 @@ rust-version.workspace = true [dependencies] clap = { version = "4", features = ["derive"] } mizan-core = { path = "../mizan-core" } +mizan-providers = { path = "../mizan-providers" } reqwest.workspace = true serde.workspace = true serde_json.workspace = true tokio.workspace = true toml = "0.8" tracing.workspace = true +uuid.workspace = true diff --git a/crates/mizan-daemon/src/main.rs b/crates/mizan-daemon/src/main.rs index 7b8908c..f736115 100644 --- a/crates/mizan-daemon/src/main.rs +++ b/crates/mizan-daemon/src/main.rs @@ -1,7 +1,8 @@ use std::{net::SocketAddr, path::PathBuf, process, time::Duration}; use clap::{Args, CommandFactory, Parser, Subcommand}; -use mizan_core::{AppError, AppResult, init_tracing, redact_for_logs}; +use mizan_core::{AppError, AppResult, RequestContextBuilder, init_tracing, redact_for_logs}; +use mizan_providers::{ChatRequest, ChatResponse, OpenAiCompatibleProvider, ProviderAdapter}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::{ @@ -9,6 +10,7 @@ use tokio::{ time::{sleep, timeout}, }; use tracing::{info, warn}; +use uuid::Uuid; #[tokio::main] async fn main() { @@ -136,7 +138,7 @@ async fn lease_and_run_one_job( return Ok(()); }; - let completion = match call_local_provider(client, config, &job.request).await { + let completion = match call_local_provider(config, &job).await { Ok(response) => DispatchJobCompleteRequest { status: "succeeded".to_owned(), response: Some(response), @@ -182,32 +184,34 @@ async fn lease_and_run_one_job( } async fn call_local_provider( - client: &reqwest::Client, config: &DaemonConfig, - request: &ChatRequest, + job: &DispatchJobLeaseResponse, ) -> AppResult { - let url = control_plane_endpoint(&config.local_provider_url, "/chat/completions"); - let response = client - .post(url) - .json(request) - .send() - .await - .map_err(|error| { - AppError::infrastructure(format!("local provider request failed: {error}")) - })?; - - let status = response.status(); - let body = response.text().await.map_err(|error| { - AppError::infrastructure(format!("local provider response read failed: {error}")) + let request_id = Uuid::parse_str(&job.request_id).map_err(|error| { + AppError::infrastructure(format!("daemon dispatch request_id is invalid: {error}")) })?; - if !status.is_success() { - return Err(AppError::provider(format!( - "local provider returned status={status} body={}", - redact_for_logs(body) - ))); - } + let context = RequestContextBuilder::default() + .request_id(request_id) + .trace_id(request_id) + .provider("mizan-daemon") + .route(job.model.clone()) + .model(job.request.model.clone()) + .method("POST") + .path("/v1/chat/completions") + .streaming(false) + .build(); + let provider = OpenAiCompatibleProvider::with_optional_api_key( + "mizan-daemon", + config.local_provider_url.clone(), + config.local_provider_api_key.clone(), + ); - parse_chat_completion_response(&body, request.model.clone()) + timeout( + Duration::from_secs(30), + provider.chat_completions(&context, job.request.clone()), + ) + .await + .map_err(|_| AppError::infrastructure("local provider request timed out after 30 seconds"))? } async fn register(args: ConfigArgs) -> AppResult<()> { @@ -356,6 +360,7 @@ struct DaemonConfig { control_plane_url: String, daemon_token_path: String, local_provider_url: String, + local_provider_api_key: Option, provider_family: String, advertised_models: Vec, max_concurrency: u32, @@ -379,6 +384,10 @@ impl DaemonConfig { let daemon_token_path = required_field(raw_config.daemon_token_path, "daemon_token_path")?; let local_provider_url = required_field(raw_config.local_provider_url, "local_provider_url")?; + let local_provider_api_key = raw_config + .local_provider_api_key + .map(|value| value.trim().to_owned()) + .filter(|value| !value.is_empty()); let provider_family = raw_config .provider_family .unwrap_or_else(|| "openai-compatible".to_owned()) @@ -422,6 +431,7 @@ impl DaemonConfig { control_plane_url, daemon_token_path, local_provider_url, + local_provider_api_key, provider_family, advertised_models, max_concurrency, @@ -458,6 +468,7 @@ struct RawDaemonConfig { control_plane_url: Option, daemon_token_path: Option, local_provider_url: Option, + local_provider_api_key: Option, provider_family: Option, advertised_models: Option>, max_concurrency: Option, @@ -529,80 +540,6 @@ struct DispatchJobCompleteRequest { error_message: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatRequest { - model: String, - messages: Vec, - #[serde(default)] - stream: bool, - #[serde(default)] - max_tokens: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatMessage { - role: String, - content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatResponse { - provider: String, - model: String, - content: String, - usage: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct TokenUsage { - prompt_tokens: u64, - completion_tokens: u64, - total_tokens: u64, - #[serde(default)] - estimated: bool, -} - -#[derive(Debug, Deserialize)] -struct OpenAiChatCompletionResponse { - model: Option, - choices: Vec, - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct OpenAiChoice { - message: Option, -} - -#[derive(Debug, Deserialize)] -struct OpenAiMessage { - content: Option, -} - -fn parse_chat_completion_response( - raw_body: &str, - requested_model: String, -) -> AppResult { - let response: OpenAiChatCompletionResponse = serde_json::from_str(raw_body) - .map_err(|error| AppError::provider(format!("invalid local provider response: {error}")))?; - let Some(first_choice) = response.choices.into_iter().next() else { - return Err(AppError::provider( - "local provider response returned no choices", - )); - }; - let content = first_choice - .message - .and_then(|message| message.content) - .unwrap_or_default(); - - Ok(ChatResponse { - provider: "mizan-daemon".to_owned(), - model: response.model.unwrap_or(requested_model), - content, - usage: response.usage, - }) -} - fn required_field(value: Option, key: &'static str) -> AppResult { value.ok_or_else(|| AppError::invalid_config(key, "is required")) } @@ -635,6 +572,7 @@ fn normalize_string_list(values: Vec) -> Vec { mod tests { use super::*; use mizan_core::redact_for_logs; + use mizan_providers::ChatMessage; const VALID_CONFIG: &str = r#" control_plane_url = "https://mizan.example.test" @@ -654,6 +592,7 @@ heartbeat_interval_seconds = 15 let config = DaemonConfig::parse(VALID_CONFIG).expect("config should parse"); assert_eq!(config.control_plane_url, "https://mizan.example.test"); + assert_eq!(config.local_provider_api_key, None); assert_eq!( config.advertised_models, vec!["llama3.1".to_owned(), "qwen2.5-coder".to_owned()] @@ -690,6 +629,21 @@ heartbeat_interval_seconds = 15 ); } + #[test] + fn parses_optional_local_provider_api_key() { + let raw = VALID_CONFIG.replace( + "local_provider_url = \"http://127.0.0.1:11434/v1\"", + "local_provider_url = \"http://127.0.0.1:11434/v1\"\n\ + local_provider_api_key = \" local-secret \"", + ); + let config = DaemonConfig::parse(&raw).expect("config should parse"); + + assert_eq!( + config.local_provider_api_key.as_deref(), + Some("local-secret") + ); + } + #[test] fn builds_capability_payload_from_config() { let config = DaemonConfig::parse(VALID_CONFIG).expect("config should parse"); @@ -750,4 +704,168 @@ heartbeat_interval_seconds = 15 "daemon_token=[REDACTED] bearer=[REDACTED] abc" ); } + + #[tokio::test] + async fn local_openai_compatible_provider_is_called_and_normalized() { + let request_id = Uuid::now_v7(); + let server = MockOpenAiServer::spawn(request_id); + let raw_config = VALID_CONFIG.replace( + "local_provider_url = \"http://127.0.0.1:11434/v1\"", + &format!("local_provider_url = \"{}\"", server.base_url), + ); + let config = DaemonConfig::parse(&raw_config).expect("config should parse"); + let job = DispatchJobLeaseResponse { + id: Uuid::now_v7().to_string(), + request_id: request_id.to_string(), + model: "public-llama".to_owned(), + request: ChatRequest { + model: "llama3.1".to_owned(), + messages: vec![ChatMessage { + role: "user".to_owned(), + content: "hello".to_owned(), + }], + stream: false, + max_tokens: Some(16), + }, + }; + + let response = call_local_provider(&config, &job) + .await + .expect("local provider should respond"); + + server.assert_ok(); + assert_eq!(response.provider, "mizan-daemon"); + assert_eq!(response.model, "llama3.1"); + assert_eq!(response.content, "hello from mock"); + let usage = response.usage.expect("usage should be propagated"); + assert_eq!(usage.prompt_tokens, 3); + assert_eq!(usage.completion_tokens, 4); + assert_eq!(usage.total_tokens, 7); + assert!(!usage.estimated); + } + + struct MockOpenAiServer { + base_url: String, + result_rx: std::sync::mpsc::Receiver>, + } + + impl MockOpenAiServer { + fn spawn(expected_request_id: Uuid) -> Self { + let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind mock server"); + let addr = listener.local_addr().expect("mock server addr"); + let (result_tx, result_rx) = std::sync::mpsc::channel(); + std::thread::spawn(move || { + let result = handle_mock_openai_request(listener, expected_request_id); + let _ = result_tx.send(result); + }); + + Self { + base_url: format!("http://{addr}"), + result_rx, + } + } + + fn assert_ok(self) { + let result = self + .result_rx + .recv_timeout(Duration::from_secs(2)) + .expect("mock server should finish"); + result.expect("mock server request should be valid"); + } + } + + fn handle_mock_openai_request( + listener: std::net::TcpListener, + expected_request_id: Uuid, + ) -> Result<(), String> { + use std::io::{Read, Write}; + + let (mut stream, _) = listener.accept().map_err(|error| error.to_string())?; + let mut request = Vec::new(); + let mut buffer = [0_u8; 1024]; + let header_end = loop { + let read = stream + .read(&mut buffer) + .map_err(|error| error.to_string())?; + if read == 0 { + return Err("connection closed before headers".to_owned()); + } + request.extend_from_slice(&buffer[..read]); + if let Some(position) = find_header_end(&request) { + break position; + } + }; + + let headers = String::from_utf8_lossy(&request[..header_end]).to_string(); + let content_length = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .ok_or_else(|| "missing content-length".to_owned())?; + while request.len() < header_end + 4 + content_length { + let read = stream + .read(&mut buffer) + .map_err(|error| error.to_string())?; + if read == 0 { + return Err("connection closed before body".to_owned()); + } + request.extend_from_slice(&buffer[..read]); + } + + assert!(headers.starts_with("POST /v1/chat/completions HTTP/1.1")); + assert!(headers.lines().any(|line| { + line.eq_ignore_ascii_case(&format!("x-request-id: {expected_request_id}")) + })); + assert!( + !headers + .lines() + .any(|line| line.to_ascii_lowercase().starts_with("authorization:")), + "local provider auth should be omitted when no key is configured" + ); + + let body_start = header_end + 4; + let payload: serde_json::Value = + serde_json::from_slice(&request[body_start..body_start + content_length]) + .map_err(|error| error.to_string())?; + assert_eq!(payload["model"], "llama3.1"); + assert_eq!(payload["stream"], false); + assert_eq!(payload["max_tokens"], 16); + + let body = serde_json::json!({ + "id": "chatcmpl-mock", + "object": "chat.completion", + "model": "llama3.1", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "hello from mock"}, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 3, + "completion_tokens": 4, + "total_tokens": 7 + } + }) + .to_string(); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + body.len(), + body + ); + stream + .write_all(response.as_bytes()) + .map_err(|error| error.to_string())?; + + Ok(()) + } + + fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") + } } diff --git a/crates/mizan-providers/src/lib.rs b/crates/mizan-providers/src/lib.rs index 55919f5..4eaf533 100644 --- a/crates/mizan-providers/src/lib.rs +++ b/crates/mizan-providers/src/lib.rs @@ -108,7 +108,7 @@ pub trait ProviderAdapter: Send + Sync { pub struct OpenAiCompatibleProvider { name: String, base_url: String, - api_key: String, + api_key: Option, } impl OpenAiCompatibleProvider { @@ -122,10 +122,22 @@ impl OpenAiCompatibleProvider { base_url: impl Into, api_key: impl Into, ) -> Self { + Self::with_optional_api_key(name, base_url, Some(api_key.into())) + } + + pub fn with_optional_api_key( + name: impl Into, + base_url: impl Into, + api_key: Option, + ) -> Self { + let api_key = api_key + .map(|value| value.trim().to_owned()) + .filter(|value| !value.is_empty()); + Self { name: name.into(), base_url: base_url.into().trim().trim_end_matches('/').to_string(), - api_key: api_key.into(), + api_key, } } @@ -142,10 +154,15 @@ impl OpenAiCompatibleProvider { request: reqwest::RequestBuilder, request_id: &str, ) -> reqwest::RequestBuilder { - request + let request = request .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) - .header(AUTHORIZATION, format!("Bearer {}", self.api_key)) - .header("x-request-id", request_id) + .header("x-request-id", request_id); + + if let Some(api_key) = &self.api_key { + request.header(AUTHORIZATION, format!("Bearer {api_key}")) + } else { + request + } } }