diff --git a/Cargo.lock b/Cargo.lock index a17f6027..9575f044 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,6 +544,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arc-swap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] + [[package]] name = "ark-ff" version = "0.3.0" @@ -984,6 +993,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bindgen" version = "0.71.1" @@ -1320,9 +1349,12 @@ version = "0.5.6" dependencies = [ "anyhow", "bon", + "bytes", "enum_dispatch", "fs-err", "hickory-resolver 0.24.4", + "http", + "http-body-util", "instant-acme", "path-absolutize", "rand 0.8.5", @@ -2204,6 +2236,8 @@ name = "dstack-gateway" version = "0.5.6" dependencies = [ "anyhow", + "arc-swap", + "base64 0.22.1", "bytes", "certbot", "clap", @@ -2212,13 +2246,16 @@ dependencies = [ "dstack-guest-agent-rpc", "dstack-kms-rpc", "dstack-types", + "flate2", "fs-err", "futures", "git-version", "hex", "hickory-resolver 0.24.4", + "http-body-util", "http-client", "hyper", + "hyper-rustls", "hyper-util", "insta", "ipnet", @@ -2233,6 +2270,7 @@ dependencies = [ "rand 0.8.5", "reqwest", "rinja", + "rmp-serde", "rocket", "rustls", "safe-write", @@ -2242,10 +2280,15 @@ dependencies = [ "sha2 0.10.9", "shared_child", "smallvec", + "tdx-attest", + "tempfile", "tokio", "tokio-rustls", "tracing", "tracing-subscriber", + "uuid", + "wavekv", + "x509-parser", ] [[package]] @@ -4303,7 +4346,7 @@ checksum = "2044d8bd5489b199890c3dbf38d4c8f50f3a5a38833986808b14e2367fe267fa" dependencies = [ "aes 0.7.5", "base64 0.13.1", - "bincode", + "bincode 1.3.3", "crossterm", "hmac 0.11.0", "pbkdf2", @@ -6003,6 +6046,28 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rocket" version = "0.6.0-dev" @@ -7878,6 +7943,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "url" version = "2.5.7" @@ -7931,6 +8002,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "void" version = "1.0.2" @@ -8063,6 +8140,29 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wavekv" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9b73bc556dfdb7ef33617a9d477b803198db43ea3df25463efaf43d4986fe8" +dependencies = [ + "anyhow", + "bincode 2.0.1", + "chrono", + "crc32fast", + "dashmap", + "fs-err", + "futures", + "hex", + "rmp-serde", + "serde", + "serde-human-bytes", + "serde_json", + "sha2 0.10.9", + "tokio", + "tracing", +] + [[package]] name = "web-sys" version = "0.3.83" diff --git a/Cargo.toml b/Cargo.toml index 0758da53..8b18f604 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,9 +86,11 @@ serde-duration = { path = "serde-duration" } dstack-mr = { path = "dstack-mr" } dstack-verifier = { path = "verifier", default-features = false } size-parser = { path = "size-parser" } +wavekv = "1.0.0" # Core dependencies anyhow = { version = "1.0.97", default-features = false } +arc-swap = "1" or-panic = { version = "1.0", default-features = false } chrono = "0.4.40" clap = { version = "4.5.32", features = ["derive", "string"] } @@ -109,6 +111,7 @@ sd-notify = "0.4.5" jemallocator = "0.5.4" # Serialization/Parsing +flate2 = "1.1" borsh = { version = "1.5.7", default-features = false, features = ["derive"] } bon = { version = "3.4.0", default-features = false } base64 = "0.22.1" @@ -122,6 +125,7 @@ scale = { version = "3.7.4", package = "parity-scale-codec", features = [ ] } serde = { version = "1.0.228", features = ["derive"], default-features = false } serde-human-bytes = "0.1.2" +rmp-serde = "1.3.0" serde_json = { version = "1.0.140", default-features = false } serde_ini = "0.2.0" toml = "0.8.20" @@ -145,6 +149,11 @@ hyper-util = { version = "0.1.10", features = [ "client-legacy", "http1", ] } +hyper-rustls = { version = "0.27", default-features = false, features = [ + "ring", + "http1", + "tls12", +] } hyperlocal = "0.9.1" ipnet = { version = "2.11.0", features = ["serde"] } reqwest = { version = "0.12.14", default-features = false, features = [ @@ -233,7 +242,6 @@ yaml-rust2 = "0.10.4" luks2 = "0.5.0" scopeguard = "1.2.0" -flate2 = "1.1" tar = "0.4" [profile.release] diff --git a/REUSE.toml b/REUSE.toml index 18aced85..f2c711f1 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -191,3 +191,12 @@ SPDX-License-Identifier = "CC0-1.0" path = "guest-agent/fixtures/*" SPDX-FileCopyrightText = "NONE" SPDX-License-Identifier = "CC0-1.0" + +[[annotations]] +path = [ + "gateway/test-run/e2e/certs/*", + "gateway/test-run/e2e/configs/*", + "gateway/test-run/e2e/pebble-config.json", +] +SPDX-FileCopyrightText = "NONE" +SPDX-License-Identifier = "CC0-1.0" diff --git a/certbot/Cargo.toml b/certbot/Cargo.toml index 52c49bea..eaf594cf 100644 --- a/certbot/Cargo.toml +++ b/certbot/Cargo.toml @@ -12,9 +12,12 @@ license.workspace = true [dependencies] anyhow.workspace = true bon.workspace = true +bytes.workspace = true enum_dispatch.workspace = true fs-err.workspace = true hickory-resolver.workspace = true +http.workspace = true +http-body-util.workspace = true instant-acme.workspace = true path-absolutize.workspace = true rcgen.workspace = true diff --git a/certbot/cli/src/main.rs b/certbot/cli/src/main.rs index b22d3246..b30eade3 100644 --- a/certbot/cli/src/main.rs +++ b/certbot/cli/src/main.rs @@ -164,7 +164,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } rustls::crypto::ring::default_provider() .install_default() diff --git a/certbot/src/acme_client.rs b/certbot/src/acme_client.rs index 50ec4589..fd32c2ad 100644 --- a/certbot/src/acme_client.rs +++ b/certbot/src/acme_client.rs @@ -21,6 +21,7 @@ use tracing::{debug, error, info}; use x509_parser::prelude::{GeneralName, Pem}; use super::dns01_client::{Dns01Api, Dns01Client}; +use super::http_client::ReqwestHttpClient; /// A AcmeClient instance. pub struct AcmeClient { @@ -63,7 +64,9 @@ impl AcmeClient { dns_txt_ttl: u32, ) -> Result { let credentials: Credentials = serde_json::from_str(encoded_credentials)?; - let account = Account::from_credentials(credentials.credentials).await?; + let http_client = Box::new(ReqwestHttpClient::new()?); + let account = + Account::from_credentials_and_http(credentials.credentials, http_client).await?; let credentials: Credentials = serde_json::from_str(encoded_credentials)?; Ok(Self { account, @@ -81,7 +84,8 @@ impl AcmeClient { max_dns_wait: Duration, dns_txt_ttl: u32, ) -> Result { - let (account, credentials) = Account::create( + let http_client = Box::new(ReqwestHttpClient::new()?); + let (account, credentials) = Account::create_with_http( &NewAccount { contact: &[], terms_of_service_agreed: true, @@ -89,6 +93,7 @@ impl AcmeClient { }, acme_url, None, + http_client, ) .await .with_context(|| format!("failed to create ACME account for {acme_url}"))?; diff --git a/certbot/src/acme_client/tests.rs b/certbot/src/acme_client/tests.rs index d77504a6..1481538c 100644 --- a/certbot/src/acme_client/tests.rs +++ b/certbot/src/acme_client/tests.rs @@ -10,6 +10,7 @@ async fn new_acme_client() -> Result { let dns01_client = Dns01Client::new_cloudflare( std::env::var("CLOUDFLARE_ZONE_ID").expect("CLOUDFLARE_ZONE_ID not set"), std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), + std::env::var("CLOUDFLARE_API_URL").ok(), ); let credentials = std::env::var("LETSENCRYPT_CREDENTIAL").expect("LETSENCRYPT_CREDENTIAL not set"); diff --git a/certbot/src/bot.rs b/certbot/src/bot.rs index 1b59b767..76cd8491 100644 --- a/certbot/src/bot.rs +++ b/certbot/src/bot.rs @@ -28,6 +28,7 @@ pub struct CertBotConfig { credentials_file: PathBuf, auto_create_account: bool, cf_api_token: String, + cf_api_url: Option, cert_file: PathBuf, key_file: PathBuf, cert_dir: PathBuf, @@ -94,8 +95,12 @@ impl CertBot { .trim_start_matches("*.") .trim_end_matches('.') .to_string(); - let dns01_client = - Dns01Client::new_cloudflare(config.cf_api_token.clone(), base_domain).await?; + let dns01_client = Dns01Client::new_cloudflare( + base_domain, + config.cf_api_token.clone(), + config.cf_api_url.clone(), + ) + .await?; let acme_client = match fs::read_to_string(&config.credentials_file) { Ok(credentials) => { if acme_matches(&credentials, &config.acme_url) { diff --git a/certbot/src/dns01_client.rs b/certbot/src/dns01_client.rs index b4d4aeaa..88fbf91a 100644 --- a/certbot/src/dns01_client.rs +++ b/certbot/src/dns01_client.rs @@ -72,9 +72,12 @@ pub enum Dns01Client { } impl Dns01Client { - pub async fn new_cloudflare(api_token: String, base_domain: String) -> Result { - Ok(Self::Cloudflare( - CloudflareClient::new(api_token, base_domain).await?, - )) + pub async fn new_cloudflare( + base_domain: String, + api_token: String, + api_url: Option, + ) -> Result { + let client = CloudflareClient::new(base_domain, api_token, api_url).await?; + Ok(Self::Cloudflare(client)) } } diff --git a/certbot/src/dns01_client/cloudflare.rs b/certbot/src/dns01_client/cloudflare.rs index d7a6b1f5..620defb0 100644 --- a/certbot/src/dns01_client/cloudflare.rs +++ b/certbot/src/dns01_client/cloudflare.rs @@ -14,12 +14,18 @@ use crate::dns01_client::Record; use super::Dns01Api; -const CLOUDFLARE_API_URL: &str = "https://api.cloudflare.com/client/v4"; +const DEFAULT_CLOUDFLARE_API_URL: &str = "https://api.cloudflare.com/client/v4"; #[derive(Debug, Serialize, Deserialize)] pub struct CloudflareClient { zone_id: String, api_token: String, + #[serde(default = "default_api_url")] + api_url: String, +} + +fn default_api_url() -> String { + DEFAULT_CLOUDFLARE_API_URL.to_string() } #[derive(Deserialize)] @@ -59,12 +65,21 @@ struct ZonesResultInfo { } impl CloudflareClient { - pub async fn new(api_token: String, base_domain: String) -> Result { - let zone_id = Self::resolve_zone_id(&api_token, &base_domain).await?; - Ok(Self { api_token, zone_id }) + pub async fn new( + base_domain: String, + api_token: String, + api_url: Option, + ) -> Result { + let api_url = api_url.unwrap_or_else(|| DEFAULT_CLOUDFLARE_API_URL.to_string()); + let zone_id = Self::resolve_zone_id(&api_token, &base_domain, &api_url).await?; + Ok(Self { + zone_id, + api_token, + api_url, + }) } - async fn resolve_zone_id(api_token: &str, base_domain: &str) -> Result { + async fn resolve_zone_id(api_token: &str, base_domain: &str, api_url: &str) -> Result { let base = base_domain .trim() .trim_start_matches("*.") @@ -72,7 +87,7 @@ impl CloudflareClient { .to_lowercase(); let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones"); + let url = format!("{api_url}/zones"); let per_page = 50u32; let mut page = 1u32; @@ -150,8 +165,7 @@ impl CloudflareClient { async fn add_record(&self, record: &impl Serialize) -> Result { let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones/{}/dns_records", self.zone_id); - + let url = format!("{}/zones/{}/dns_records", self.api_url, self.zone_id); let response = client .post(&url) .header("Authorization", format!("Bearer {}", self.api_token)) @@ -176,8 +190,8 @@ impl CloudflareClient { async fn remove_record_inner(&self, record_id: &str) -> Result<()> { let client = Client::new(); let url = format!( - "{CLOUDFLARE_API_URL}/zones/{zone_id}/dns_records/{record_id}", - zone_id = self.zone_id + "{}/zones/{}/dns_records/{}", + self.api_url, self.zone_id, record_id ); debug!(url = %url, "cloudflare remove_record request"); @@ -201,7 +215,7 @@ impl CloudflareClient { async fn get_records_inner(&self, domain: &str) -> Result> { let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones/{}/dns_records", self.zone_id); + let url = format!("{}/zones/{}/dns_records", self.api_url, self.zone_id); let per_page = 100u32; let mut records = Vec::new(); @@ -338,8 +352,9 @@ mod tests { async fn create_client() -> CloudflareClient { CloudflareClient::new( - std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), std::env::var("TEST_DOMAIN").expect("TEST_DOMAIN not set"), + std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), + std::env::var("CLOUDFLARE_API_URL").ok(), ) .await .unwrap() diff --git a/certbot/src/http_client.rs b/certbot/src/http_client.rs new file mode 100644 index 00000000..2de8f823 --- /dev/null +++ b/certbot/src/http_client.rs @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Custom HTTP client for instant_acme that supports both HTTP and HTTPS. + +use anyhow::{Context, Result}; +use bytes::Bytes; +use http::Request; +use http_body_util::{BodyExt, Full}; +use instant_acme::{BytesResponse, HttpClient}; +use reqwest::Client; +use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; + +/// A HTTP client that supports both HTTP and HTTPS connections. +/// This is needed because the default instant_acme client only supports HTTPS. +#[derive(Clone)] +pub struct ReqwestHttpClient { + client: Client, +} + +impl ReqwestHttpClient { + /// Create a new HTTP client. + pub fn new() -> Result { + let client = Client::builder() + .user_agent("dstack-certbot/0.1") + .build() + .context("failed to build reqwest client")?; + Ok(Self { client }) + } +} + +impl HttpClient for ReqwestHttpClient { + fn request( + &self, + req: Request>, + ) -> Pin> + Send>> { + let client = self.client.clone(); + Box::pin(async move { + let (parts, body) = req.into_parts(); + let uri = parts.uri.to_string(); + let method = parts.method.clone(); + let body_bytes = body + .collect() + .await + .map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })? + .to_bytes(); + + tracing::debug!( + target: "certbot::http_client", + %uri, + %method, + request_body_len = body_bytes.len(), + "sending ACME request" + ); + + let mut builder = client.request(parts.method, uri.clone()); + for (name, value) in &parts.headers { + builder = builder.header(name, value); + } + + let response = builder + .body(body_bytes.to_vec()) + .send() + .await + .map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })?; + + let status = response.status(); + let headers = response.headers().clone(); + let body = response.bytes().await.map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })?; + + tracing::debug!( + target: "certbot::http_client", + %uri, + %status, + response_body = %String::from_utf8_lossy(&body), + "received ACME response" + ); + + let mut http_response = http::Response::builder().status(status); + for (name, value) in headers { + if let Some(name) = name { + http_response = http_response.header(name, value); + } + } + let http_response = http_response + .body(Full::new(body)) + .map_err(|e| instant_acme::Error::Other(Box::new(e)))?; + + Ok(BytesResponse::from(http_response)) + }) + } +} diff --git a/certbot/src/lib.rs b/certbot/src/lib.rs index 20cf8ed1..df71b993 100644 --- a/certbot/src/lib.rs +++ b/certbot/src/lib.rs @@ -24,4 +24,5 @@ pub use workdir::WorkDir; mod acme_client; mod bot; mod dns01_client; +mod http_client; mod workdir; diff --git a/ct_monitor/src/main.rs b/ct_monitor/src/main.rs index bfa0565c..dd5d9550 100644 --- a/ct_monitor/src/main.rs +++ b/ct_monitor/src/main.rs @@ -413,7 +413,7 @@ async fn main() -> anyhow::Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); let mut monitor = Monitor::new(args.gateway, args.verifier_url, args.pccs_url)?; diff --git a/dstack-util/src/main.rs b/dstack-util/src/main.rs index a5cd6d8d..1c552936 100644 --- a/dstack-util/src/main.rs +++ b/dstack-util/src/main.rs @@ -453,7 +453,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let cli = Cli::parse(); diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 8a30c6da..1f57ebf0 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -10,7 +10,8 @@ edition.workspace = true license.workspace = true [dependencies] -rocket = { workspace = true, features = ["mtls"] } +arc-swap.workspace = true +rocket = { workspace = true, features = ["mtls", "json"] } tracing.workspace = true tracing-subscriber.workspace = true anyhow.workspace = true @@ -48,12 +49,26 @@ dstack-types.workspace = true serde-duration.workspace = true reqwest = { workspace = true, features = ["json"] } hyper = { workspace = true, features = ["server", "http1"] } -hyper-util = { version = "0.1", features = ["tokio"] } +hyper-util = { workspace = true, features = ["tokio"] } +hyper-rustls.workspace = true +http-body-util.workspace = true +x509-parser.workspace = true jemallocator.workspace = true +wavekv.workspace = true +tdx-attest.workspace = true +flate2.workspace = true +uuid = { workspace = true, features = ["v4"] } +rmp-serde.workspace = true or-panic.workspace = true +base64.workspace = true [target.'cfg(unix)'.dependencies] nix = { workspace = true, features = ["resource"] } +[[bin]] +name = "gen_debug_key" +path = "src/gen_debug_key.rs" + [dev-dependencies] insta.workspace = true +tempfile.workspace = true diff --git a/gateway/docs/cluster-deployment.md b/gateway/docs/cluster-deployment.md new file mode 100644 index 00000000..8fe0d575 --- /dev/null +++ b/gateway/docs/cluster-deployment.md @@ -0,0 +1,333 @@ +# dstack-gateway Cluster Deployment Guide + +This document describes how to deploy a dstack-gateway cluster, including single-node and multi-node configurations. + +## Table of Contents + +1. [Overview](#1-overview) +2. [Cluster Deployment (2-Node Example)](#2-cluster-deployment-2-node-example) +3. [Adding Reverse Proxy Domains](#3-adding-reverse-proxy-domains) +4. [Operations and Monitoring](#4-operations-and-monitoring) + +## 1. Overview + +dstack-gateway is a distributed reverse proxy gateway for dstack services. Key features include: + +- TLS termination and SNI routing: Automatically selects certificates and routes traffic based on SNI +- Automatic certificate management: Automatically requests and renews certificates via ACME protocol (Let's Encrypt) +- Multi-node cluster: Multiple gateway nodes automatically sync state for high availability +- WireGuard tunnels: Provides secure network access for CVM instances + +### Architecture Diagram + +```mermaid +flowchart TB + subgraph Internet + LB[Load Balancer] + end + + subgraph Gateway Cluster + G1[Gateway 1
node_id=1] + G2[Gateway 2
node_id=2] + G1 <-->|Sync| G2 + end + + subgraph CVM Pool + CVM1[CVM 1
App A] + CVM2[CVM 2
App B] + end + + LB --> G1 + LB --> G2 + + G1 -.->|WireGuard| CVM1 + G1 -.->|WireGuard| CVM2 + G2 -.->|WireGuard| CVM1 + G2 -.->|WireGuard| CVM2 +``` + +When a CVM starts, it registers with one of the Gateways. The Gateway cluster automatically syncs the CVM's information (including WireGuard public key), enabling all Gateway nodes to establish WireGuard tunnel connections to that CVM. + +### Port Description + +| Default Port | Protocol | Purpose | Security Recommendation | +|--------------|----------|---------|-------------------------| +| 9012 | HTTPS | RPC port for inter-node sync communication | Internal network only | +| 9013 | UDP | WireGuard tunnel port | Internal network only | +| 9014 | HTTPS | Proxy port for external TLS proxy service | Can be exposed to public | +| 9015 | HTTP | Debug port for health checks and debugging | Must be disabled in production | +| 9016 | HTTP | Admin port for management API | Do not expose to public, recommend using Unix Domain Socket | + +Production security configuration example: + +```toml +[core.debug] +insecure_enable_debug_rpc = false # Disable Debug port + +[core.admin] +enabled = true +address = "unix:/run/dstack/admin.sock" # Use Unix Domain Socket +``` + +## 2. Cluster Deployment (2-Node Example) + +### 2.1 Node Planning + +| Node | node_id | Gateway IP | Client IP range | bootnode | +|------|---------|------------|-----------------|----------| +| gateway-1 | 1 | 10.8.128.1/16 | 10.8.128.0/18 | gateway-2 | +| gateway-2 | 2 | 10.8.0.1/16 | 10.8.0.0/18 | gateway-1 | + +Notes: +- Each node's node_id must be unique +- Each node's Client IP range should not overlap (used for allocating IPs to different CVMs) +- bootnode is configured as another node's RPC URL, used for cluster discovery at startup + +### 2.2 CIDR Description + +Client IP range (/18): +- /18 means the first 18 bits are the network prefix +- For example, 10.8.128.0/18 covers the address range 10.8.128.0 ~ 10.8.191.255 +- Each Gateway's /18 range does not overlap, so each Gateway can allocate IPs locally without syncing with other Gateways + +Gateway IP (/16): +- Gateway IP uses /16 netmask to allow network routing to cover the larger 10.8.0.0/16 address space +- This way, when another Gateway allocates an address in a /18 subnet, traffic can still be correctly routed + +### 2.3 WireGuard Configuration Fields + +Key fields in the `[core.wg]` section: + +- `ip`: Gateway's own WireGuard address in CIDR format (e.g., 10.8.128.1/16) +- `client_ip_range`: Address pool range for allocating to CVMs (e.g., 10.8.128.0/18) +- `reserved_net`: Reserved address range that will not be allocated to CVMs (e.g., 10.8.128.1/32, reserving the gateway's own address) + +Recommendation: Design client_ip_range and reserved_net to ensure clear address pool planning for each Gateway, avoiding address conflicts. + +### 2.4 Configuration File Examples + +gateway-1.toml: + +```toml +log_level = "info" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/var/lib/gateway/certs/gateway-rpc.key" +certs = "/var/lib/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/var/lib/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "https://kms.demo.dstack.org" +rpc_domain = "rpc.gateway-1.demo.dstack.org" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = false +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "30s" +timeout = "60s" +my_url = "https://rpc.gateway-1.demo.dstack.org:9012" +bootnode = "https://rpc.gateway-2.demo.dstack.org:9012" +node_id = 1 +data_dir = "/var/lib/gateway/data" + +[core.wg] +private_key = "" +public_key = "" +listen_port = 9013 +ip = "10.8.128.1/16" +reserved_net = ["10.8.128.1/32"] +client_ip_range = "10.8.128.0/18" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-gw1" +endpoint = ":9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +external_port = 443 +``` + +gateway-2.toml: + +```toml +log_level = "info" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/var/lib/gateway/certs/gateway-rpc.key" +certs = "/var/lib/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/var/lib/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "https://kms.demo.dstack.org" +rpc_domain = "rpc.gateway-2.demo.dstack.org" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = false +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "30s" +timeout = "60s" +my_url = "https://rpc.gateway-2.demo.dstack.org:9012" +bootnode = "https://rpc.gateway-1.demo.dstack.org:9012" +node_id = 2 +data_dir = "/var/lib/gateway/data" + +[core.wg] +private_key = "" +public_key = "" +listen_port = 9013 +ip = "10.8.0.1/16" +reserved_net = ["10.8.0.1/32"] +client_ip_range = "10.8.0.0/18" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-gw2" +endpoint = ":9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +external_port = 443 +``` + +### 2.5 Verify Cluster Sync + +```bash +# Check sync status on any node +curl -s http://localhost:9016/prpc/Admin.WaveKvStatus | jq . +``` + +## 3. Adding Reverse Proxy Domains + +Gateway supports automatic TLS certificate management via the ACME protocol. + +### 3.1 Configure ACME Service + +```bash +# Set ACME URL (Let's Encrypt production) +curl -X POST "http://localhost:9016/prpc/Admin.SetCertbotConfig" \ + -H "Content-Type: application/json" \ + -d '{"acme_url": "https://acme-v02.api.letsencrypt.org/directory"}' + +# For testing, use Let's Encrypt Staging +# "acme_url": "https://acme-staging-v02.api.letsencrypt.org/directory" +``` + +### 3.2 Configure DNS Credential + +Gateway uses DNS-01 validation, which requires configuring DNS provider API credentials. + +Cloudflare example: + +```bash +curl -X POST "http://localhost:9016/prpc/Admin.CreateDnsCredential" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "cloudflare-prod", + "provider_type": "cloudflare", + "cf_api_token": "your-cloudflare-api-token", + "set_as_default": true + }' +``` + +### 3.3 Add Domain + +Call the Admin.AddZtDomain API to add a domain. Gateway will automatically request a *.domain wildcard certificate. + +Parameter description: + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| domain | string | Yes | Base domain (e.g., example.com), certificate will be issued for *.example.com | +| port | uint32 | Yes | External service port for this domain (usually 443) | +| dns_cred_id | string | No | DNS credential ID, leave empty to use default credential | +| node | uint32 | No | Bind to specific node (node_id), leave empty for any node to serve this domain | +| priority | int32 | No | Priority for selecting default base_domain (higher value = higher priority, default is 0) | + +Basic usage (using default DNS credential): + +```bash +curl -X POST "http://localhost:9016/prpc/Admin.AddZtDomain" \ + -H "Content-Type: application/json" \ + -d '{"domain": "example.com", "port": 443}' +``` + +Specifying DNS credential and node binding: + +```bash +curl -X POST "http://localhost:9016/prpc/Admin.AddZtDomain" \ + -H "Content-Type: application/json" \ + -d '{ + "domain": "internal.example.com", + "port": 443, + "dns_cred_id": "cloudflare-prod", + "node": 1, + "priority": 10 + }' +``` + +Response example: + +```json +{ + "config": { + "domain": "example.com", + "port": 443, + "priority": 0 + }, + "cert_status": { + "has_cert": false, + "not_after": 0, + "issued_by": 0, + "issued_at": 0 + } +} +``` + +Note: After adding a domain, the certificate is not issued immediately. Gateway will request the certificate asynchronously in the background. You can check certificate status via section 3.5, or manually trigger certificate request via section 3.4. + +### 3.4 Manually Trigger Certificate Renewal + +```bash +curl -X POST "http://localhost:9016/prpc/Admin.RenewZtDomainCert" \ + -H "Content-Type: application/json" \ + -d '{"domain": "example.com", "force": true}' +``` + +### 3.5 Check Certificate Status + +```bash +curl -s http://localhost:9016/prpc/Admin.Status | jq '.zt_domains' +``` + +### 3.6 Web UI + +All the above command-line operations can also be performed via Web UI by visiting http://localhost:9016 in a browser. diff --git a/gateway/dstack-app/builder/entrypoint.sh b/gateway/dstack-app/builder/entrypoint.sh index 9cd46755..61a662fa 100755 --- a/gateway/dstack-app/builder/entrypoint.sh +++ b/gateway/dstack-app/builder/entrypoint.sh @@ -50,6 +50,13 @@ validate_env "$BOOTNODE_URL" validate_env "$CF_API_TOKEN" validate_env "$SRV_DOMAIN" validate_env "$WG_ENDPOINT" +validate_env "$NODE_ID" + +# Validate $NODE_ID, must be a number +if [[ ! "$NODE_ID" =~ ^[0-9]+$ ]]; then + echo "Invalid NODE_ID: $NODE_ID" + exit 1 +fi # Validate $SUBNET_INDEX, valid range is 0-15 if [[ ! "$SUBNET_INDEX" =~ ^[0-9]+$ ]] || [ "$SUBNET_INDEX" -lt 0 ] || [ "$SUBNET_INDEX" -gt 15 ]; then @@ -79,8 +86,7 @@ echo "RPC_DOMAIN: $RPC_DOMAIN" cat >$CONFIG_PATH < node_connections = 2; +} + +// Node status entry +message NodeStatusEntry { + uint32 node_id = 1; + string status = 2; // "up" or "down" +} + +// Get node statuses response +message GetNodeStatusesResponse { + repeated NodeStatusEntry statuses = 1; +} + service Admin { // Get the status of the gateway. rpc Status(google.protobuf.Empty) returns (StatusResponse) {} @@ -192,4 +362,253 @@ service Admin { rpc SetCaa(google.protobuf.Empty) returns (google.protobuf.Empty) {} // Summary API for inspect. rpc GetMeta(google.protobuf.Empty) returns (GetMetaResponse) {} + // Set a node's sync URL - used for dynamic peer management + rpc SetNodeUrl(SetNodeUrlRequest) returns (google.protobuf.Empty) {} + // Set a node's status (up/down) + rpc SetNodeStatus(SetNodeStatusRequest) returns (google.protobuf.Empty) {} + // Get WaveKV sync status + rpc WaveKvStatus(google.protobuf.Empty) returns (WaveKvStatusResponse) {} + // Get instance handshakes from all nodes + rpc GetInstanceHandshakes(GetInstanceHandshakesRequest) returns (GetInstanceHandshakesResponse) {} + // Get global connections statistics + rpc GetGlobalConnections(google.protobuf.Empty) returns (GlobalConnectionsStats) {} + // Get all node statuses + rpc GetNodeStatuses(google.protobuf.Empty) returns (GetNodeStatusesResponse) {} + + // ==================== DNS Credential Management ==================== + // List all DNS credentials + rpc ListDnsCredentials(google.protobuf.Empty) returns (ListDnsCredentialsResponse) {} + // Get a DNS credential by ID + rpc GetDnsCredential(GetDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Create a new DNS credential + rpc CreateDnsCredential(CreateDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Update a DNS credential + rpc UpdateDnsCredential(UpdateDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Delete a DNS credential + rpc DeleteDnsCredential(DeleteDnsCredentialRequest) returns (google.protobuf.Empty) {} + // Get the default DNS credential ID + rpc GetDefaultDnsCredential(google.protobuf.Empty) returns (GetDefaultDnsCredentialResponse) {} + // Set the default DNS credential ID + rpc SetDefaultDnsCredential(SetDefaultDnsCredentialRequest) returns (google.protobuf.Empty) {} + + // ==================== ZT-Domain Management ==================== + // List all ZT-Domain configurations + rpc ListZtDomains(google.protobuf.Empty) returns (ListZtDomainsResponse) {} + // Get a ZT-Domain configuration and status + rpc GetZtDomain(GetZtDomainRequest) returns (ZtDomainInfo) {} + // Add a new ZT-Domain (config.domain must not exist) + rpc AddZtDomain(ZtDomainConfig) returns (ZtDomainInfo) {} + // Update a ZT-Domain configuration (config.domain must exist) + rpc UpdateZtDomain(ZtDomainConfig) returns (ZtDomainInfo) {} + // Delete a ZT-Domain configuration + rpc DeleteZtDomain(DeleteZtDomainRequest) returns (google.protobuf.Empty) {} + // Manually trigger certificate renewal for a ZT-Domain + rpc RenewZtDomainCert(RenewZtDomainCertRequest) returns (RenewZtDomainCertResponse) {} + // List certificate attestations for a domain + rpc ListCertAttestations(ListCertAttestationsRequest) returns (ListCertAttestationsResponse) {} + + // ==================== Global Certbot Configuration ==================== + // Get global certbot configuration (includes ACME URL) + rpc GetCertbotConfig(google.protobuf.Empty) returns (CertbotConfigResponse) {} + // Set global certbot configuration (includes ACME URL) + rpc SetCertbotConfig(SetCertbotConfigRequest) returns (google.protobuf.Empty) {} +} + +// ==================== DNS Credential Messages ==================== + +// DNS credential information +message DnsCredentialInfo { + string id = 1; + string name = 2; + // Provider type: "cloudflare" + string provider_type = 3; + // Cloudflare-specific fields (when provider_type = "cloudflare") + string cf_api_token = 4; + // Cloudflare API URL (empty means default) + string cf_api_url = 5; + // Timestamps + uint64 created_at = 6; + uint64 updated_at = 7; +} + +// List DNS credentials response +message ListDnsCredentialsResponse { + repeated DnsCredentialInfo credentials = 1; + // The default credential ID (if set) + optional string default_id = 2; +} + +// Get DNS credential request +message GetDnsCredentialRequest { + string id = 1; +} + +// Create DNS credential request +message CreateDnsCredentialRequest { + string name = 1; + // Provider type: "cloudflare" + string provider_type = 2; + // Cloudflare-specific fields (when provider_type = "cloudflare") + string cf_api_token = 3; + string cf_zone_id = 4; + // If true, set this as the default credential + bool set_as_default = 5; + // Optional Cloudflare API URL (defaults to https://api.cloudflare.com/client/v4) + optional string cf_api_url = 6; + // Optional Cloudflare DNS TXT record TTL (defaults to 60) + optional uint32 dns_txt_ttl = 7; + // Optional Cloudflare maximum DNS wait time (defaults to 60) + optional uint32 max_dns_wait = 8; +} + +// Update DNS credential request +message UpdateDnsCredentialRequest { + string id = 1; + // Optional new name + optional string name = 2; + // Optional new Cloudflare api token + optional string cf_api_token = 3; + // Optional new Cloudflare zone id + optional string cf_zone_id = 4; + // Optional new Cloudflare API URL + optional string cf_api_url = 5; +} + +// Delete DNS credential request +message DeleteDnsCredentialRequest { + string id = 1; +} + +// Get default DNS credential response +message GetDefaultDnsCredentialResponse { + // The default credential ID (empty if not set) + string default_id = 1; + // The default credential info (if exists) + optional DnsCredentialInfo credential = 2; +} + +// Set default DNS credential request +message SetDefaultDnsCredentialRequest { + string id = 1; +} + +// ==================== ZT-Domain Messages ==================== + +// ZT-Domain configuration (shared by Add/Update/Info) +message ZtDomainConfig { + // Base domain name (e.g., "example.com", certificate will be issued for "*.example.com") + string domain = 1; + // DNS credential ID (None = use default) + optional string dns_cred_id = 2; + // Port this domain serves on (e.g., 443) + uint32 port = 3; + // Node binding (None = any node can serve this domain) + optional uint32 node = 4; + // Priority for default base_domain selection (higher = preferred) + int32 priority = 5; +} + +// ZT-Domain information (config + certificate status) +message ZtDomainInfo { + // Domain configuration + ZtDomainConfig config = 1; + // Certificate status + ZtDomainCertStatus cert_status = 2; +} + +// ZT-Domain certificate status +message ZtDomainCertStatus { + // Whether a certificate is currently loaded + bool has_cert = 1; + // Certificate expiry timestamp (0 if no cert) + uint64 not_after = 2; + // Node that issued the current certificate + uint32 issued_by = 3; + // When the certificate was issued + uint64 issued_at = 4; + // Whether the certificate is loaded in memory + bool loaded_in_memory = 5; +} + +// List ZT-Domains response +message ListZtDomainsResponse { + repeated ZtDomainInfo domains = 1; +} + +// Get ZT-Domain request +message GetZtDomainRequest { + string domain = 1; +} + +// Delete ZT-Domain request +message DeleteZtDomainRequest { + string domain = 1; +} + +// Renew ZT-Domain certificate request +message RenewZtDomainCertRequest { + string domain = 1; + // Force renewal even if not near expiry + bool force = 2; +} + +// Renew ZT-Domain certificate response +message RenewZtDomainCertResponse { + // True if renewal was performed + bool renewed = 1; + // New certificate expiry (if renewed) + uint64 not_after = 2; +} + +// Certificate attestation info +message CertAttestationInfo { + // Certificate public key (DER encoded) + bytes public_key = 1; + // TDX Quote (JSON serialized) + string quote = 2; + // Node that generated this attestation + uint32 generated_by = 3; + // Timestamp when this attestation was generated + uint64 generated_at = 4; +} + +// List certificate attestations request +message ListCertAttestationsRequest { + string domain = 1; + // Maximum number of attestations to return (0 = all) + uint32 limit = 2; +} + +// List certificate attestations response +message ListCertAttestationsResponse { + // Latest attestation (if exists) + optional CertAttestationInfo latest = 1; + // Historical attestations (sorted by generated_at descending) + repeated CertAttestationInfo history = 2; +} + +// ==================== Global Certbot Configuration Messages ==================== + +// Certbot configuration response +message CertbotConfigResponse { + // Interval between renewal checks (in seconds) + uint64 renew_interval_secs = 1; + // Time before expiration to trigger renewal (in seconds) + uint64 renew_before_expiration_secs = 2; + // Timeout for certificate renewal operations (in seconds) + uint64 renew_timeout_secs = 3; + // ACME server URL (empty means default Let's Encrypt production) + string acme_url = 4; +} + +// Set certbot configuration request +message SetCertbotConfigRequest { + // Interval between renewal checks (in seconds) + optional uint64 renew_interval_secs = 1; + // Time before expiration to trigger renewal (in seconds) + optional uint64 renew_before_expiration_secs = 2; + // Timeout for certificate renewal operations (in seconds) + optional uint64 renew_timeout_secs = 3; + // ACME server URL (empty means use default Let's Encrypt production) + optional string acme_url = 4; } diff --git a/gateway/src/admin_service.rs b/gateway/src/admin_service.rs index 541dee0d..8ebe111e 100644 --- a/gateway/src/admin_service.rs +++ b/gateway/src/admin_service.rs @@ -3,17 +3,31 @@ // SPDX-License-Identifier: Apache-2.0 use std::sync::atomic::Ordering; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use dstack_gateway_rpc::{ admin_server::{AdminRpc, AdminServer}, - GetInfoRequest, GetInfoResponse, GetMetaResponse, HostInfo, RenewCertResponse, StatusResponse, + CertAttestationInfo, CertbotConfigResponse, CreateDnsCredentialRequest, + DeleteDnsCredentialRequest, DeleteZtDomainRequest, DnsCredentialInfo, + GetDefaultDnsCredentialResponse, GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse, + GetInstanceHandshakesRequest, GetInstanceHandshakesResponse, GetMetaResponse, + GetNodeStatusesResponse, GetZtDomainRequest, GlobalConnectionsStats, HandshakeEntry, HostInfo, + LastSeenEntry, ListCertAttestationsRequest, ListCertAttestationsResponse, + ListDnsCredentialsResponse, ListZtDomainsResponse, NodeStatusEntry, + PeerSyncStatus as ProtoPeerSyncStatus, RenewCertResponse, RenewZtDomainCertRequest, + RenewZtDomainCertResponse, SetCertbotConfigRequest, SetDefaultDnsCredentialRequest, + SetNodeStatusRequest, SetNodeUrlRequest, StatusResponse, StoreSyncStatus, + UpdateDnsCredentialRequest, WaveKvStatusResponse, ZtDomainCertStatus, + ZtDomainConfig as ProtoZtDomainConfig, ZtDomainInfo, }; use ra_rpc::{CallContext, RpcCall}; +use tracing::info; +use wavekv::node::NodeStatus as WaveKvNodeStatus; use crate::{ - main_service::{encode_ts, Proxy}, + kv::{DnsCredential, DnsProvider, NodeStatus, ZtDomainConfig}, + main_service::Proxy, proxy::NUM_CONNECTIONS, }; @@ -23,35 +37,39 @@ pub struct AdminRpcHandler { impl AdminRpcHandler { pub(crate) async fn status(self) -> Result { + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); let mut state = self.state.lock(); state.refresh_state()?; - let base_domain = &state.config.proxy.base_domain; let hosts = state .state .instances .values() - .map(|instance| HostInfo { - instance_id: instance.id.clone(), - ip: instance.ip.to_string(), - app_id: instance.app_id.clone(), - base_domain: base_domain.clone(), - port: state.config.proxy.listen_port as u32, - latest_handshake: encode_ts(instance.last_seen), - num_connections: instance.num_connections(), + .map(|instance| { + // Get global latest_handshake from KvStore (max across all nodes) + let latest_handshake = state + .get_instance_latest_handshake(&instance.id) + .unwrap_or(0); + HostInfo { + instance_id: instance.id.clone(), + ip: instance.ip.to_string(), + app_id: instance.app_id.clone(), + base_domain: base_domain.clone(), + port: port.into(), + latest_handshake, + num_connections: instance.num_connections(), + } }) .collect::>(); - let nodes = state - .state - .nodes - .values() - .cloned() - .map(Into::into) - .collect::>(); Ok(StatusResponse { + id: state.config.sync.node_id, url: state.config.sync.my_url.clone(), - id: state.config.id(), + uuid: state.config.uuid(), bootnode_url: state.config.sync.bootnode.clone(), - nodes, + nodes: state.get_all_nodes(), hosts, num_connections: NUM_CONNECTIONS.load(Ordering::Relaxed), }) @@ -64,22 +82,19 @@ impl AdminRpc for AdminRpcHandler { } async fn renew_cert(self) -> Result { - let renewed = self.state.renew_cert(true).await?; + // Renew all domains with force=true + let renewed = self.state.renew_cert(None, true).await?; Ok(RenewCertResponse { renewed }) } async fn set_caa(self) -> Result<()> { - self.state - .certbot - .as_ref() - .context("Certbot is not enabled")? - .set_caa() - .await?; - Ok(()) + // TODO: Implement CAA setting for multi-domain certificates + // This requires iterating over all domain configurations and setting CAA records + bail!("set_caa is not implemented for multi-domain certificates yet"); } async fn reload_cert(self) -> Result<()> { - self.state.reload_certificates() + self.state.reload_all_certs_from_kvstore() } async fn status(self) -> Result { @@ -87,8 +102,12 @@ impl AdminRpc for AdminRpcHandler { } async fn get_info(self, request: GetInfoRequest) -> Result { + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); let state = self.state.lock(); - let base_domain = &state.config.proxy.base_domain; let handshakes = state.latest_handshakes(None)?; if let Some(instance) = state.state.instances.get(&request.id) { @@ -96,8 +115,8 @@ impl AdminRpc for AdminRpcHandler { instance_id: instance.id.clone(), ip: instance.ip.to_string(), app_id: instance.app_id.clone(), - base_domain: base_domain.clone(), - port: state.config.proxy.listen_port as u32, + base_domain, + port: port.into(), latest_handshake: { let (ts, _) = handshakes .get(&instance.public_key) @@ -146,6 +165,490 @@ impl AdminRpc for AdminRpcHandler { online: online as u32, }) } + + async fn set_node_url(self, request: SetNodeUrlRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + kv_store.register_peer_url(request.id, &request.url)?; + info!("Updated peer URL: node {} -> {}", request.id, request.url); + Ok(()) + } + + async fn set_node_status(self, request: SetNodeStatusRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + let status = match request.status.as_str() { + "up" => NodeStatus::Up, + "down" => NodeStatus::Down, + _ => anyhow::bail!("invalid status: expected 'up' or 'down'"), + }; + kv_store.set_node_status(request.id, status)?; + info!("Updated node status: node {} -> {:?}", request.id, status); + Ok(()) + } + + async fn wave_kv_status(self) -> Result { + let kv_store = self.state.kv_store(); + + let persistent_status = kv_store.persistent().read().status(); + let ephemeral_status = kv_store.ephemeral().read().status(); + + let get_peer_last_seen = |peer_id: u32| -> Vec<(u32, u64)> { + kv_store + .get_node_last_seen_by_all(peer_id) + .into_iter() + .collect() + }; + + Ok(WaveKvStatusResponse { + enabled: self.state.config.sync.enabled, + persistent: Some(build_store_status( + "persistent", + persistent_status, + &get_peer_last_seen, + )), + ephemeral: Some(build_store_status( + "ephemeral", + ephemeral_status, + &get_peer_last_seen, + )), + }) + } + + async fn get_instance_handshakes( + self, + request: GetInstanceHandshakesRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + let handshakes = kv_store.get_instance_handshakes(&request.instance_id); + + let entries = handshakes + .into_iter() + .map(|(observer_node_id, timestamp)| HandshakeEntry { + observer_node_id, + timestamp, + }) + .collect(); + + Ok(GetInstanceHandshakesResponse { + handshakes: entries, + }) + } + + async fn get_global_connections(self) -> Result { + let state = self.state.lock(); + let kv_store = self.state.kv_store(); + + let mut node_connections = std::collections::HashMap::new(); + let mut total_connections = 0u64; + + // Iterate through all instances and sum up connections per node + for instance_id in state.state.instances.keys() { + // Get connection counts from ephemeral KV for this instance + let conn_prefix = format!("conn/{}/", instance_id); + for (key, count) in kv_store + .ephemeral() + .read() + .iter_by_prefix(&conn_prefix) + .filter_map(|(k, entry)| { + let value = entry.value.as_ref()?; + let count: u64 = rmp_serde::decode::from_slice(value).ok()?; + Some((k.to_string(), count)) + }) + { + // Parse node_id from key: "conn/{instance_id}/{node_id}" + if let Some(node_id_str) = key.strip_prefix(&conn_prefix) { + if let Ok(node_id) = node_id_str.parse::() { + *node_connections.entry(node_id).or_insert(0) += count; + total_connections += count; + } + } + } + } + + Ok(GlobalConnectionsStats { + total_connections, + node_connections, + }) + } + + async fn get_node_statuses(self) -> Result { + let kv_store = self.state.kv_store(); + let statuses = kv_store.load_all_node_statuses(); + + let entries = statuses + .into_iter() + .map(|(node_id, status)| { + let status_str = match status { + NodeStatus::Up => "up", + NodeStatus::Down => "down", + }; + NodeStatusEntry { + node_id, + status: status_str.to_string(), + } + }) + .collect(); + + Ok(GetNodeStatusesResponse { statuses: entries }) + } + + // ==================== DNS Credential Management ==================== + + async fn list_dns_credentials(self) -> Result { + let kv_store = self.state.kv_store(); + let credentials = kv_store + .list_dns_credentials() + .into_iter() + .map(dns_cred_to_proto) + .collect(); + let default_id = kv_store.get_default_dns_credential_id(); + Ok(ListDnsCredentialsResponse { + credentials, + default_id, + }) + } + + async fn get_dns_credential( + self, + request: GetDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + let cred = kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + Ok(dns_cred_to_proto(cred)) + } + + async fn create_dns_credential( + self, + request: CreateDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + // Validate provider type + let provider = match request.provider_type.as_str() { + "cloudflare" => DnsProvider::Cloudflare { + api_token: request.cf_api_token, + api_url: request.cf_api_url, + }, + _ => bail!("unsupported provider type: {}", request.provider_type), + }; + + let now = now_secs(); + let id = generate_cred_id(); + let dns_txt_ttl = request.dns_txt_ttl.unwrap_or(60); + let max_dns_wait = Duration::from_secs(request.max_dns_wait.unwrap_or(60 * 5).into()); + let cred = DnsCredential { + id: id.clone(), + name: request.name, + provider, + created_at: now, + updated_at: now, + dns_txt_ttl, + max_dns_wait, + }; + + kv_store.save_dns_credential(&cred)?; + info!("Created DNS credential: {} ({})", cred.name, cred.id); + + // Set as default if requested + if request.set_as_default { + kv_store.set_default_dns_credential_id(&id)?; + info!("Set DNS credential {} as default", id); + } + + Ok(dns_cred_to_proto(cred)) + } + + async fn update_dns_credential( + self, + request: UpdateDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + let mut cred = kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + + // Update name if provided + if let Some(name) = request.name { + cred.name = name; + } + + // Update provider fields if provided + match &mut cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => { + if let Some(new_token) = request.cf_api_token { + *api_token = new_token; + } + if let Some(new_url) = request.cf_api_url { + *api_url = Some(new_url); + } + } + } + + cred.updated_at = now_secs(); + kv_store.save_dns_credential(&cred)?; + info!("Updated DNS credential: {} ({})", cred.name, cred.id); + + Ok(dns_cred_to_proto(cred)) + } + + async fn delete_dns_credential(self, request: DeleteDnsCredentialRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Check if this is the default credential + if let Some(default_id) = kv_store.get_default_dns_credential_id() { + if default_id == request.id { + bail!("cannot delete the default DNS credential; set a different default first"); + } + } + + // Check if any ZT-Domain configs reference this credential + let configs = kv_store.list_zt_domain_configs(); + for config in configs { + if config.dns_cred_id.as_deref() == Some(&request.id) { + bail!( + "cannot delete DNS credential: domain {} uses it", + config.domain + ); + } + } + + kv_store.delete_dns_credential(&request.id)?; + info!("Deleted DNS credential: {}", request.id); + Ok(()) + } + + async fn get_default_dns_credential(self) -> Result { + let kv_store = self.state.kv_store(); + let default_id = kv_store.get_default_dns_credential_id().unwrap_or_default(); + let credential = kv_store.get_default_dns_credential().map(dns_cred_to_proto); + Ok(GetDefaultDnsCredentialResponse { + default_id, + credential, + }) + } + + async fn set_default_dns_credential( + self, + request: SetDefaultDnsCredentialRequest, + ) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Verify the credential exists + kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + + kv_store.set_default_dns_credential_id(&request.id)?; + info!("Set default DNS credential: {}", request.id); + Ok(()) + } + + // ==================== ZT-Domain Management ==================== + + async fn list_zt_domains(self) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + let domains = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|config| zt_domain_to_proto(config, kv_store, cert_resolver)) + .collect(); + + Ok(ListZtDomainsResponse { domains }) + } + + async fn get_zt_domain(self, request: GetZtDomainRequest) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + let config = kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn add_zt_domain(self, request: ProtoZtDomainConfig) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + // Check if domain already exists + if kv_store.get_zt_domain_config(&request.domain).is_some() { + bail!("ZT-Domain config already exists: {}", request.domain); + } + + let config = proto_to_zt_domain_config(&request, kv_store)?; + + kv_store.save_zt_domain_config(&config)?; + info!("Added ZT-Domain config: {}", config.domain); + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn update_zt_domain(self, request: ProtoZtDomainConfig) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + // Check if config exists + kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + let config = proto_to_zt_domain_config(&request, kv_store)?; + + kv_store.save_zt_domain_config(&config)?; + info!("Updated ZT-Domain config: {}", config.domain); + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn delete_zt_domain(self, request: DeleteZtDomainRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Check if config exists + kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + // Delete config (cert data, acme, attestations are kept for historical purposes) + kv_store.delete_zt_domain_config(&request.domain)?; + info!("Deleted ZT-Domain config: {}", request.domain); + Ok(()) + } + + async fn renew_zt_domain_cert( + self, + request: RenewZtDomainCertRequest, + ) -> Result { + let certbot = &self.state.certbot; + let renewed = certbot + .try_renew(&request.domain, request.force) + .await + .context("certificate renewal failed")?; + + if renewed { + // Get the new certificate data for response + let kv_store = self.state.kv_store(); + let cert_data = kv_store.get_cert_data(&request.domain); + let not_after = cert_data.map(|d| d.not_after).unwrap_or(0); + Ok(RenewZtDomainCertResponse { renewed, not_after }) + } else { + Ok(RenewZtDomainCertResponse { + renewed: false, + not_after: 0, + }) + } + } + + async fn list_cert_attestations( + self, + request: ListCertAttestationsRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + let latest = kv_store + .get_cert_attestation_latest(&request.domain) + .map(|att| CertAttestationInfo { + public_key: att.public_key, + quote: att.quote, + generated_by: att.generated_by, + generated_at: att.generated_at, + }); + + let mut history: Vec = kv_store + .list_cert_attestations(&request.domain) + .into_iter() + .map(|att| CertAttestationInfo { + public_key: att.public_key, + quote: att.quote, + generated_by: att.generated_by, + generated_at: att.generated_at, + }) + .collect(); + + // Apply limit if specified + if request.limit > 0 { + history.truncate(request.limit as usize); + } + + Ok(ListCertAttestationsResponse { latest, history }) + } + + // ==================== Global Certbot Configuration ==================== + + async fn get_certbot_config(self) -> Result { + let config = self.state.kv_store().get_certbot_config(); + Ok(CertbotConfigResponse { + renew_interval_secs: config.renew_interval.as_secs(), + renew_before_expiration_secs: config.renew_before_expiration.as_secs(), + renew_timeout_secs: config.renew_timeout.as_secs(), + acme_url: config.acme_url, + }) + } + + async fn set_certbot_config(self, request: SetCertbotConfigRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + let mut config = kv_store.get_certbot_config(); + + // Update only the fields that are specified + if let Some(secs) = request.renew_interval_secs { + config.renew_interval = Duration::from_secs(secs); + } + if let Some(secs) = request.renew_before_expiration_secs { + config.renew_before_expiration = Duration::from_secs(secs); + } + if let Some(secs) = request.renew_timeout_secs { + config.renew_timeout = Duration::from_secs(secs); + } + if let Some(url) = request.acme_url { + config.acme_url = url; + } + + kv_store.set_certbot_config(&config)?; + info!( + "Updated certbot config: renew_interval={:?}, renew_before_expiration={:?}, renew_timeout={:?}, acme_url={:?}", + config.renew_interval, + config.renew_before_expiration, + config.renew_timeout, + config.acme_url + ); + Ok(()) + } +} + +fn build_store_status( + name: &str, + status: WaveKvNodeStatus, + get_peer_last_seen: &impl Fn(u32) -> Vec<(u32, u64)>, +) -> StoreSyncStatus { + StoreSyncStatus { + name: name.to_string(), + node_id: status.id, + n_keys: status.n_kvs as u64, + next_seq: status.next_seq, + dirty: status.dirty, + wal_enabled: status.wal, + peers: status + .peers + .into_iter() + .map(|p| { + let last_seen = get_peer_last_seen(p.id) + .into_iter() + .map(|(node_id, timestamp)| LastSeenEntry { node_id, timestamp }) + .collect(); + ProtoPeerSyncStatus { + id: p.id, + local_ack: p.ack, + peer_ack: p.pack, + buffered_logs: p.logs as u64, + last_seen, + } + }) + .collect(), + } } impl RpcCall for AdminRpcHandler { @@ -157,3 +660,107 @@ impl RpcCall for AdminRpcHandler { }) } } + +// ==================== Helper Functions ==================== + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn generate_cred_id() -> String { + use std::time::SystemTime; + let ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + // Simple ID: timestamp + random suffix + let random: u32 = rand::random(); + format!("{:x}{:08x}", ts, random) +} + +fn dns_cred_to_proto(cred: DnsCredential) -> DnsCredentialInfo { + let (provider_type, cf_api_token, cf_api_url) = match &cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => ( + "cloudflare".to_string(), + api_token.clone(), + api_url.clone().unwrap_or_default(), + ), + }; + DnsCredentialInfo { + id: cred.id, + name: cred.name, + provider_type, + cf_api_token, + cf_api_url, + created_at: cred.created_at, + updated_at: cred.updated_at, + } +} + +/// Convert proto ZtDomainConfig to internal ZtDomainConfig +fn proto_to_zt_domain_config( + proto: &ProtoZtDomainConfig, + kv_store: &crate::kv::KvStore, +) -> Result { + // Normalize dns_cred_id: treat empty string as None (use default) + let dns_cred_id = proto + .dns_cred_id + .as_ref() + .filter(|s| !s.is_empty()) + .cloned(); + + // Validate DNS credential if specified + if let Some(ref cred_id) = dns_cred_id { + kv_store + .get_dns_credential(cred_id) + .context("specified dns credential not found")?; + } + + // Strip wildcard prefix if user entered it + let domain = proto + .domain + .strip_prefix("*.") + .unwrap_or(&proto.domain) + .to_string(); + + Ok(ZtDomainConfig { + domain, + dns_cred_id, + port: proto.port.try_into().context("port out of range")?, + node: proto.node, + priority: proto.priority, + }) +} + +/// Convert internal ZtDomainConfig to proto ZtDomainInfo (with cert status) +fn zt_domain_to_proto( + config: ZtDomainConfig, + kv_store: &crate::kv::KvStore, + cert_resolver: &crate::cert_store::CertResolver, +) -> ZtDomainInfo { + // Get certificate data for status + let cert_data = kv_store.get_cert_data(&config.domain); + let loaded_in_memory = cert_resolver.has_cert(&config.domain); + + let cert_status = Some(ZtDomainCertStatus { + has_cert: cert_data.is_some(), + not_after: cert_data.as_ref().map(|d| d.not_after).unwrap_or(0), + issued_by: cert_data.as_ref().map(|d| d.issued_by).unwrap_or(0), + issued_at: cert_data.as_ref().map(|d| d.issued_at).unwrap_or(0), + loaded_in_memory, + }); + + ZtDomainInfo { + config: Some(ProtoZtDomainConfig { + domain: config.domain, + dns_cred_id: config.dns_cred_id, + port: config.port.into(), + node: config.node, + priority: config.priority, + }), + cert_status, + } +} diff --git a/gateway/src/cert_store.rs b/gateway/src/cert_store.rs new file mode 100644 index 00000000..e65dd2de --- /dev/null +++ b/gateway/src/cert_store.rs @@ -0,0 +1,443 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! In-memory certificate store with SNI-based certificate resolution. +//! +//! This module provides a lock-free certificate store that supports: +//! - Multiple certificates for different domains +//! - Wildcard certificate matching +//! - Dynamic certificate updates via atomic replacement +//! - SNI-based certificate selection for TLS connections +//! +//! Architecture: `CertStore` is immutable after construction for lock-free reads. +//! Updates are done by building a new `CertStore` and atomically swapping the `Arc` +//! in the outer `RwLock>`. + +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use arc_swap::{ArcSwap, Guard}; +use or_panic::ResultOrPanic; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use tracing::{debug, info}; + +use crate::kv::CertData; + +/// Immutable, lock-free certificate store. +/// +/// This struct is designed for maximum read performance - no locks required for lookups. +/// Updates are done by creating a new instance and atomically swapping via outer RwLock>. +pub struct CertStore { + /// Exact domain -> CertifiedKey + exact_certs: HashMap>, + /// Parent domain -> CertifiedKey (for wildcard certs) + /// e.g., "example.com" -> cert for "*.example.com" + wildcard_certs: HashMap>, + /// Domain -> CertData (for metadata like expiry) + cert_data: HashMap, +} + +impl CertStore { + /// Create a new empty certificate store + pub fn new() -> Self { + Self { + exact_certs: HashMap::new(), + wildcard_certs: HashMap::new(), + cert_data: HashMap::new(), + } + } + + /// Resolve certificate for a given SNI hostname (lock-free) + fn resolve_cert(&self, sni: &str) -> Option> { + // 1. Try exact match first + if let Some(cert) = self.exact_certs.get(sni) { + debug!("exact match for {sni}"); + return Some(cert.clone()); + } + + // 2. Try wildcard match (only one level deep per TLS spec) + // For "foo.bar.example.com", only try "bar.example.com" + if let Some((_, parent)) = sni.split_once('.') { + if let Some(cert) = self.wildcard_certs.get(parent) { + debug!("wildcard match *.{parent} for {sni}"); + return Some(cert.clone()); + } + } + + debug!("no certificate found for {sni}"); + None + } + + /// Check if a certificate exists for a domain + pub fn has_cert(&self, domain: &str) -> bool { + self.cert_data.contains_key(domain) + } + + /// Get certificate data for a domain + pub fn get_cert_data(&self, domain: &str) -> Option<&CertData> { + self.cert_data.get(domain) + } + + /// List all loaded domains + pub fn list_domains(&self) -> Vec { + self.cert_data.keys().cloned().collect() + } + + /// Check if a wildcard certificate exists for a domain + pub fn contains_wildcard(&self, base_domain: &str) -> bool { + self.wildcard_certs.contains_key(base_domain) + } +} + +impl fmt::Debug for CertStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let exact_domains: Vec<_> = self.exact_certs.keys().cloned().collect(); + let wildcard_domains: Vec<_> = self + .wildcard_certs + .keys() + .map(|k| format!("*.{}", k)) + .collect(); + + f.debug_struct("CertStore") + .field("exact_domains", &exact_domains) + .field("wildcard_domains", &wildcard_domains) + .finish() + } +} + +impl Default for CertStore { + fn default() -> Self { + Self::new() + } +} + +impl ResolvesServerCert for CertStore { + fn resolve(&self, client_hello: ClientHello) -> Option> { + let sni = client_hello.server_name()?; + self.resolve_cert(sni) + } +} + +/// Certificate resolver that wraps `ArcSwap` for lock-free reads. +/// +/// This allows TLS acceptors to be created once and certificates to be updated +/// without recreating the acceptor. The read path (TLS handshake) is completely +/// lock-free via `ArcSwap`. Write operations are serialized via a `Mutex` to +/// prevent lost updates during concurrent certificate changes. +pub struct CertResolver { + store: ArcSwap, + /// Mutex to serialize write operations (reads are still lock-free) + write_lock: std::sync::Mutex<()>, +} + +impl CertResolver { + /// Create a new resolver with an empty CertStore + pub fn new() -> Self { + Self { + store: ArcSwap::from_pointee(CertStore::new()), + write_lock: std::sync::Mutex::new(()), + } + } + + /// Get the current CertStore (lock-free) + pub fn get(&self) -> Guard> { + self.store.load() + } + + /// Replace the CertStore atomically (lock-free) + pub fn set(&self, new_store: Arc) { + self.store.store(new_store); + } + + /// List all domains + pub fn list_domains(&self) -> Vec { + self.get().list_domains() + } + + /// Check if a certificate exists for a domain + pub fn has_cert(&self, domain: &str) -> bool { + self.get().has_cert(domain) + } + + /// Update a single certificate (creates new store with updated cert) + /// + /// This is an incremental update that preserves all existing certificates. + /// Write operations are serialized to prevent lost updates. + pub fn update_cert(&self, domain: &str, data: &CertData) -> Result<()> { + let _guard = self + .write_lock + .lock() + .or_panic("failed to acquire write lock"); + + let old_store = self.get(); + + // Build new store with all existing certs plus the new/updated one + let mut builder = CertStoreBuilder::new(); + + // Copy existing certs (except the one we're replacing) + for existing_domain in old_store.list_domains() { + if existing_domain != domain { + if let Some(existing_data) = old_store.get_cert_data(&existing_domain) { + builder.add_cert(&existing_domain, existing_data)?; + } + } + } + + // Add the new/updated cert + builder.add_cert(domain, data)?; + + // Atomically swap + self.set(Arc::new(builder.build())); + Ok(()) + } +} + +impl Default for CertResolver { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for CertResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl ResolvesServerCert for CertResolver { + fn resolve(&self, client_hello: ClientHello) -> Option> { + // Lock-free load via ArcSwap + let store = self.store.load(); + let sni = client_hello.server_name()?; + store.resolve_cert(sni) + } +} + +/// Builder for constructing a new CertStore. +/// +/// Use this to build a complete certificate store, then call `build()` to get the immutable CertStore. +pub struct CertStoreBuilder { + exact_certs: HashMap>, + wildcard_certs: HashMap>, + cert_data: HashMap, +} + +impl CertStoreBuilder { + /// Create a new empty builder + pub fn new() -> Self { + Self { + exact_certs: HashMap::new(), + wildcard_certs: HashMap::new(), + cert_data: HashMap::new(), + } + } + + /// Add a certificate to the builder + /// + /// The domain is the base domain (e.g., "example.com"). + /// All gateway certificates are wildcard certs for "*.{domain}". + pub fn add_cert(&mut self, domain: &str, data: &CertData) -> Result<()> { + let certified_key = parse_certified_key(&data.cert_pem, &data.key_pem) + .with_context(|| format!("failed to parse certificate for {}", domain))?; + + let certified_key = Arc::new(certified_key); + + // Gateway certificates are always wildcard certs + // domain is the base domain (e.g., "example.com"), cert is for "*.example.com" + self.wildcard_certs + .insert(domain.to_string(), certified_key); + info!( + "cert_store: prepared wildcard certificate for *.{} (expires: {})", + domain, + format_expiry(data.not_after) + ); + + // Store metadata + self.cert_data.insert(domain.to_string(), data.clone()); + + Ok(()) + } + + /// Build the immutable CertStore + pub fn build(self) -> CertStore { + CertStore { + exact_certs: self.exact_certs, + wildcard_certs: self.wildcard_certs, + cert_data: self.cert_data, + } + } +} + +impl Default for CertStoreBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Parse certificate and private key PEM strings into a CertifiedKey +fn parse_certified_key(cert_pem: &str, key_pem: &str) -> Result { + let certs = CertificateDer::pem_slice_iter(cert_pem.as_bytes()) + .collect::, _>>() + .context("failed to parse certificate chain")?; + + if certs.is_empty() { + anyhow::bail!("no certificates found in PEM"); + } + + let key = + PrivateKeyDer::from_pem_slice(key_pem.as_bytes()).context("failed to parse private key")?; + + let signing_key = rustls::crypto::aws_lc_rs::sign::any_supported_type(&key) + .map_err(|e| anyhow::anyhow!("failed to create signing key: {:?}", e))?; + + Ok(CertifiedKey::new(certs, signing_key)) +} + +/// Format expiry timestamp as human-readable string +fn format_expiry(not_after: u64) -> String { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let expiry = UNIX_EPOCH + Duration::from_secs(not_after); + let now = SystemTime::now(); + + match expiry.duration_since(now) { + Ok(remaining) => { + let days = remaining.as_secs() / 86400; + format!("{} days remaining", days) + } + Err(_) => "expired".to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl CertStore { + /// Check if a certificate can be resolved for a given SNI hostname + pub fn has_cert_for_sni(&self, sni: &str) -> bool { + self.resolve_cert(sni).is_some() + } + } + + fn make_test_cert_data() -> CertData { + // Generate a self-signed test certificate using rcgen + use ra_tls::rcgen::{self, CertificateParams, KeyPair}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let mut params = CertificateParams::new(vec!["test.example.com".to_string()]) + .expect("failed to create cert params"); + params.not_after = rcgen::date_time_ymd(2030, 1, 1); + let cert = params + .self_signed(&key_pair) + .expect("failed to generate self-signed cert"); + + let not_after = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + Duration::from_secs(365 * 24 * 3600).as_secs(); + + CertData { + cert_pem: cert.pem(), + key_pem: key_pair.serialize_pem(), + not_after, + issued_by: 1, + issued_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + } + } + + #[test] + fn test_cert_store_basic() { + let store = CertStore::new(); + assert!(store.list_domains().is_empty()); + } + + #[test] + fn test_cert_store_builder() { + let data = make_test_cert_data(); + + // Use builder - domain is base domain (e.g., "example.com") + // All gateway certs are wildcard certs + let mut builder = CertStoreBuilder::new(); + builder + .add_cert("example.com", &data) + .expect("failed to add cert"); + + let store = builder.build(); + + // Check it's loaded (stored by base domain) + assert!(store.has_cert("example.com")); + assert_eq!(store.list_domains().len(), 1); + + // Should resolve any subdomain via wildcard matching + assert!(store.has_cert_for_sni("test.example.com")); + assert!(store.has_cert_for_sni("foo.example.com")); + + // Should not resolve exact base domain (wildcard doesn't match base) + assert!(!store.has_cert_for_sni("example.com")); + + // Should not resolve different domain + assert!(!store.has_cert_for_sni("example.org")); + } + + #[test] + fn test_cert_store_wildcard() { + // Generate wildcard cert + use ra_tls::rcgen::{self, CertificateParams, KeyPair}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let mut params = CertificateParams::new(vec!["*.example.com".to_string()]) + .expect("failed to create cert params"); + params.not_after = rcgen::date_time_ymd(2030, 1, 1); + let cert = params + .self_signed(&key_pair) + .expect("failed to generate self-signed cert"); + + let not_after = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + Duration::from_secs(365 * 24 * 3600).as_secs(); + + let data = CertData { + cert_pem: cert.pem(), + key_pem: key_pair.serialize_pem(), + not_after, + issued_by: 1, + issued_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + let mut builder = CertStoreBuilder::new(); + // Now we use base domain format (without *. prefix) + builder + .add_cert("example.com", &data) + .expect("failed to add wildcard cert"); + + let store = builder.build(); + + // Should resolve any subdomain + assert!(store.has_cert_for_sni("foo.example.com")); + assert!(store.has_cert_for_sni("bar.example.com")); + + // Wildcard certs do not match nested subdomains + assert!(!store.has_cert_for_sni("sub.foo.example.com")); + + // Should not resolve different domain + assert!(!store.has_cert_for_sni("example.org")); + } +} diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 3b990795..5db0cbf2 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -69,15 +69,10 @@ pub enum TlsVersion { #[derive(Debug, Clone, Deserialize)] pub struct ProxyConfig { - pub cert_chain: String, - pub cert_key: String, pub tls_crypto_provider: CryptoProvider, pub tls_versions: Vec, - pub base_domain: String, - pub external_port: u16, pub listen_addr: Ipv4Addr, pub listen_port: u16, - pub agent_port: u16, pub timeouts: Timeouts, pub buffer_size: usize, pub connect_top_n: usize, @@ -125,30 +120,52 @@ pub struct SyncConfig { #[serde(with = "serde_duration")] pub interval: Duration, #[serde(with = "serde_duration")] - pub broadcast_interval: Duration, - #[serde(with = "serde_duration")] pub timeout: Duration, pub my_url: String, + /// The URL of the bootnode used to fetch initial peer list when joining the network pub bootnode: String, + /// WaveKV node ID for this gateway (must be unique across cluster) + pub node_id: u32, + /// Data directory for WaveKV persistence + pub data_dir: String, + /// Interval for periodic WAL persistence (default: 10s) + #[serde(with = "serde_duration")] + pub persist_interval: Duration, + /// Enable periodic sync of instance connections to KV store + pub sync_connections_enabled: bool, + /// Interval for syncing instance connections to KV store + #[serde(with = "serde_duration")] + pub sync_connections_interval: Duration, } #[derive(Debug, Clone, Deserialize)] pub struct Config { pub wg: WgConfig, pub proxy: ProxyConfig, - pub certbot: CertbotConfig, pub pccs_url: Option, pub recycle: RecycleConfig, - pub state_path: String, pub set_ulimit: bool, pub rpc_domain: String, pub kms_url: String, pub admin: AdminConfig, - pub run_in_dstack: bool, + /// Debug server configuration (separate port for debug RPCs) + pub debug: DebugConfig, pub sync: SyncConfig, pub auth: AuthConfig, } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct DebugConfig { + /// Enable debug server + #[serde(default)] + pub insecure_enable_debug_rpc: bool, + #[serde(default)] + pub insecure_skip_attestation: bool, + /// Path to pre-generated debug key data file (JSON format containing key, quote, event_log, and vm_config) + #[serde(default)] + pub key_file: String, +} + #[derive(Debug, Clone, Deserialize)] pub struct AuthConfig { pub enabled: bool, @@ -158,11 +175,41 @@ pub struct AuthConfig { } impl Config { - pub fn id(&self) -> Vec { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(self.wg.public_key.as_bytes()); - hasher.finalize()[..20].to_vec() + /// Get or generate a unique node UUID. + /// The UUID is stored in `{data_dir}/node_uuid` and persisted across restarts. + pub fn uuid(&self) -> Vec { + use std::fs; + use std::path::Path; + + let uuid_path = Path::new(&self.sync.data_dir).join("node_uuid"); + + // Try to read existing UUID + if let Ok(content) = fs::read_to_string(&uuid_path) { + if let Ok(uuid) = uuid::Uuid::parse_str(content.trim()) { + return uuid.as_bytes().to_vec(); + } + } + + // Generate new UUID + let uuid = uuid::Uuid::new_v4(); + + // Ensure directory exists + if let Some(parent) = uuid_path.parent() { + let _ = fs::create_dir_all(parent); + } + + // Save UUID to file + if let Err(err) = fs::write(&uuid_path, uuid.to_string()) { + tracing::warn!( + "failed to save node UUID to {}: {}", + uuid_path.display(), + err + ); + } else { + tracing::info!("generated new node UUID: {}", uuid); + } + + uuid.as_bytes().to_vec() } } @@ -183,68 +230,6 @@ pub struct MutualConfig { pub ca_certs: String, } -#[derive(Debug, Clone, Deserialize)] -pub struct CertbotConfig { - /// Enable certbot - pub enabled: bool, - /// Path to the working directory - pub workdir: String, - /// ACME server URL - pub acme_url: String, - /// Cloudflare API token - pub cf_api_token: String, - /// Auto set CAA record - pub auto_set_caa: bool, - /// Domain to issue certificates for - pub domain: String, - /// Renew interval - #[serde(with = "serde_duration")] - pub renew_interval: Duration, - /// Time gap before expiration to trigger renewal - #[serde(with = "serde_duration")] - pub renew_before_expiration: Duration, - /// Renew timeout - #[serde(with = "serde_duration")] - pub renew_timeout: Duration, - /// Maximum time to wait for DNS propagation - #[serde(with = "serde_duration")] - pub max_dns_wait: Duration, - /// TTL for DNS TXT records used in ACME challenges (in seconds). - /// Minimum is 60 for Cloudflare. Lower TTL means faster DNS propagation. - #[serde(default = "default_dns_txt_ttl")] - pub dns_txt_ttl: u32, -} - -fn default_dns_txt_ttl() -> u32 { - 60 -} - -impl CertbotConfig { - fn to_bot_config(&self) -> certbot::CertBotConfig { - let workdir = certbot::WorkDir::new(&self.workdir); - certbot::CertBotConfig::builder() - .auto_create_account(true) - .cert_dir(workdir.backup_dir()) - .cert_file(workdir.cert_path()) - .key_file(workdir.key_path()) - .credentials_file(workdir.account_credentials_path()) - .acme_url(self.acme_url.clone()) - .cert_subject_alt_names(vec![self.domain.clone()]) - .cf_api_token(self.cf_api_token.clone()) - .renew_interval(self.renew_interval) - .renew_timeout(self.renew_timeout) - .renew_expires_in(self.renew_before_expiration) - .auto_set_caa(self.auto_set_caa) - .max_dns_wait(self.max_dns_wait) - .dns_txt_ttl(self.dns_txt_ttl) - .build() - } - - pub async fn build_bot(&self) -> Result { - self.to_bot_config().build_bot().await - } -} - pub const DEFAULT_CONFIG: &str = include_str!("../gateway.toml"); pub fn load_config_figment(config_file: Option<&str>) -> Figment { load_config("gateway", DEFAULT_CONFIG, config_file, false) diff --git a/gateway/src/debug_service.rs b/gateway/src/debug_service.rs new file mode 100644 index 00000000..e53b4ed2 --- /dev/null +++ b/gateway/src/debug_service.rs @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Debug service for testing - runs on a separate port when debug.enabled=true + +use anyhow::Result; +use dstack_gateway_rpc::{ + debug_server::{DebugRpc, DebugServer}, + DebugProxyStateResponse, DebugRegisterCvmRequest, DebugSyncDataResponse, InfoResponse, + InstanceEntry, NodeInfoEntry, PeerAddrEntry, ProxyStateInstance, RegisterCvmResponse, +}; +use ra_rpc::{CallContext, RpcCall}; +use tracing::warn; + +use crate::main_service::Proxy; + +pub struct DebugRpcHandler { + state: Proxy, +} + +impl DebugRpcHandler { + pub fn new(state: Proxy) -> Self { + Self { state } + } +} + +impl DebugRpc for DebugRpcHandler { + async fn register_cvm(self, request: DebugRegisterCvmRequest) -> Result { + warn!( + "Debug register CVM: app_id={}, instance_id={}", + request.app_id, request.instance_id + ); + self.state.do_register_cvm( + &request.app_id, + &request.instance_id, + &request.client_public_key, + ) + } + + async fn info(self) -> Result { + let config = &self.state.config; + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); + Ok(InfoResponse { + base_domain, + external_port: port.into(), + app_address_ns_prefix: config.proxy.app_address_ns_prefix.clone(), + }) + } + + async fn get_sync_data(self) -> Result { + let kv_store = self.state.kv_store(); + let my_node_id = kv_store.my_node_id(); + + // Get all peer addresses + let peer_addrs: Vec = kv_store + .get_all_peer_addrs() + .into_iter() + .map(|(node_id, url)| PeerAddrEntry { + node_id: node_id as u64, + url, + }) + .collect(); + + // Get all node info + let nodes: Vec = kv_store + .load_all_nodes() + .into_iter() + .map(|(node_id, data)| NodeInfoEntry { + node_id: node_id as u64, + url: data.url, + wg_public_key: data.wg_public_key, + wg_endpoint: data.wg_endpoint, + wg_ip: data.wg_ip, + }) + .collect(); + + // Get all instances + let instances: Vec = kv_store + .load_all_instances() + .into_iter() + .map(|(instance_id, data)| InstanceEntry { + instance_id, + app_id: data.app_id, + ip: data.ip.to_string(), + public_key: data.public_key, + }) + .collect(); + + // Get key counts + let persistent_keys = kv_store.persistent().read().status().n_kvs as u64; + let ephemeral_keys = kv_store.ephemeral().read().status().n_kvs as u64; + + Ok(DebugSyncDataResponse { + my_node_id: my_node_id as u64, + peer_addrs, + nodes, + instances, + persistent_keys, + ephemeral_keys, + }) + } + + async fn get_proxy_state(self) -> Result { + let state = self.state.lock(); + + // Get all instances from ProxyState + let instances: Vec = state + .state + .instances + .values() + .map(|inst| { + let reg_time = inst + .reg_time + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + ProxyStateInstance { + instance_id: inst.id.clone(), + app_id: inst.app_id.clone(), + ip: inst.ip.to_string(), + public_key: inst.public_key.clone(), + reg_time, + } + }) + .collect(); + + // Get all allocated addresses + let allocated_addresses: Vec = state + .state + .allocated_addresses + .iter() + .map(|ip| ip.to_string()) + .collect(); + + Ok(DebugProxyStateResponse { + instances, + allocated_addresses, + }) + } +} + +impl RpcCall for DebugRpcHandler { + type PrpcService = DebugServer; + + fn construct(context: CallContext<'_, Proxy>) -> Result { + Ok(DebugRpcHandler::new(context.state.clone())) + } +} diff --git a/gateway/src/distributed_certbot.rs b/gateway/src/distributed_certbot.rs new file mode 100644 index 00000000..cbb316ea --- /dev/null +++ b/gateway/src/distributed_certbot.rs @@ -0,0 +1,520 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Multi-domain certificate management using WaveKV for synchronization. +//! +//! This module provides distributed certificate management for multiple domains +//! with dynamic DNS credential configuration and attestation storage. + +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use anyhow::{bail, Context, Result}; +use certbot::{AcmeClient, Dns01Client}; +use dstack_guest_agent_rpc::RawQuoteArgs; +use ra_tls::attestation::QuoteContentType; +use ra_tls::rcgen::KeyPair; +use tracing::{error, info, warn}; + +use crate::cert_store::CertResolver; +use crate::kv::{ + AcmeAttestation, CertAttestation, CertCredentials, CertData, DnsProvider, KvStore, + ZtDomainConfig, +}; + +/// Lock timeout for certificate renewal (10 minutes) +const RENEW_LOCK_TIMEOUT_SECS: u64 = 600; + +/// Default ACME URL (Let's Encrypt production) +const DEFAULT_ACME_URL: &str = "https://acme-v02.api.letsencrypt.org/directory"; + +/// Multi-domain certificate manager +pub struct DistributedCertBot { + kv_store: Arc, + cert_resolver: Arc, +} + +impl DistributedCertBot { + pub fn new(kv_store: Arc, cert_resolver: Arc) -> Self { + Self { + kv_store, + cert_resolver, + } + } + + /// Get the current certbot configuration from KV store + fn config(&self) -> crate::kv::GlobalCertbotConfig { + self.kv_store.get_certbot_config() + } + + /// Initialize all ZT-Domain certificates + pub async fn init_all(&self) -> Result<()> { + let configs = self.kv_store.list_zt_domain_configs(); + for config in configs { + if let Err(err) = self.init_domain(&config.domain).await { + error!("cert[{}]: failed to initialize: {err:?}", config.domain); + } + } + Ok(()) + } + + /// Initialize certificate for a specific domain + pub async fn init_domain(&self, domain: &str) -> Result<()> { + // First, try to load from KvStore (synced from other nodes) + if let Some(cert_data) = self.kv_store.get_cert_data(domain) { + let now = now_secs(); + if cert_data.not_after > now { + info!( + domain, + "loaded from KvStore (issued by node {}, expires in {} days)", + cert_data.issued_by, + (cert_data.not_after - now) / 86400 + ); + self.cert_resolver.update_cert(domain, &cert_data)?; + return Ok(()); + } + info!(domain, "KvStore certificate expired, will request new one"); + } + + // No valid cert, need to request new one + info!(domain, "no valid certificate found, requesting from ACME"); + self.request_new_cert(domain).await + } + + /// Try to renew all ZT-Domain certificates + pub async fn try_renew_all(&self) -> Result<()> { + let configs = self.kv_store.list_zt_domain_configs(); + for config in configs { + if let Err(err) = self.try_renew(&config.domain, false).await { + error!("cert[{}]: failed to renew: {err:?}", config.domain); + } + } + Ok(()) + } + + /// Try to renew certificate for a specific domain if needed + #[tracing::instrument(skip(self))] + pub async fn try_renew(&self, domain: &str, force: bool) -> Result { + // Check if config exists + let config = self + .kv_store + .get_zt_domain_config(domain) + .context("ZT-Domain config not found")?; + + // Check if renewal is needed + let cert_data = self.kv_store.get_cert_data(domain); + let needs_renew = if force { + true + } else if let Some(ref data) = cert_data { + let now = now_secs(); + let expires_in = data.not_after.saturating_sub(now); + expires_in < self.config().renew_before_expiration.as_secs() + } else { + true + }; + + if !needs_renew { + info!("does not need renewal"); + return Ok(false); + } + + // Try to acquire lock + if !self + .kv_store + .try_acquire_cert_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + info!("another node is renewing, skipping"); + return Ok(false); + } + + info!("acquired renew lock, starting renewal"); + + // Perform renewal or initial issuance + let result = if cert_data.is_some() { + self.do_renew(domain, &config).await + } else { + // No existing certificate, request new one + info!("no existing certificate, requesting new one"); + self.do_request_new(domain, &config).await.map(|_| true) + }; + + // Release lock regardless of result + if let Err(err) = self.kv_store.release_cert_lock(domain) { + error!("failed to release lock: {err:?}"); + } + + result + } + + /// Request new certificate for a domain + #[tracing::instrument(skip(self))] + async fn request_new_cert(&self, domain: &str) -> Result<()> { + let config = self + .kv_store + .get_zt_domain_config(domain) + .context("ZT-Domain config not found")?; + + // Try to acquire lock first + if !self + .kv_store + .try_acquire_cert_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + // Another node is requesting, wait for it + info!("another node is requesting, waiting..."); + tokio::time::sleep(Duration::from_secs(30)).await; + if let Some(cert_data) = self.kv_store.get_cert_data(domain) { + self.cert_resolver.update_cert(domain, &cert_data)?; + return Ok(()); + } + bail!("failed to get certificate from KvStore after waiting"); + } + + let result = self.do_request_new(domain, &config).await; + + if let Err(err) = self.kv_store.release_cert_lock(domain) { + error!("failed to release lock: {err:?}"); + } + + result + } + + async fn do_request_new(&self, domain: &str, config: &ZtDomainConfig) -> Result<()> { + let acme_client = self.get_or_create_acme_client(domain, config).await?; + + // Generate new key pair (always use new key for security) + let key = KeyPair::generate().context("failed to generate key")?; + let key_pem = key.serialize_pem(); + let public_key_der = key.public_key_der(); + + // Request wildcard certificate (domain in config is base domain, cert is *.domain) + let wildcard_domain = format!("*.{}", domain); + info!( + "requesting new certificate from ACME for {}...", + wildcard_domain + ); + let cert_pem = tokio::time::timeout( + self.config().renew_timeout, + acme_client.request_new_certificate(&key_pem, &[wildcard_domain]), + ) + .await + .context("certificate request timed out")? + .context("failed to request new certificate")?; + + let not_after = get_cert_expiry(&cert_pem).context("failed to parse certificate expiry")?; + + // Save certificate to KvStore + self.save_cert_to_kvstore(domain, &cert_pem, &key_pem, not_after)?; + info!("new certificate obtained from ACME, saved to KvStore"); + + // Generate and save attestation + self.generate_and_save_attestation(domain, &public_key_der) + .await?; + + // Load into memory cert store + let cert_data = CertData { + cert_pem, + key_pem, + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.cert_resolver.update_cert(domain, &cert_data)?; + + info!( + "new certificate loaded (expires in {} days)", + (not_after - now_secs()) / 86400 + ); + Ok(()) + } + + async fn do_renew(&self, domain: &str, config: &ZtDomainConfig) -> Result { + let acme_client = self.get_or_create_acme_client(domain, config).await?; + + // Generate new key pair (always use new key for each renewal) + let key = KeyPair::generate().context("failed to generate key")?; + let key_pem = key.serialize_pem(); + let public_key_der = key.public_key_der(); + + // Verify there's a current cert (for audit trail, even though we don't use its key) + if self.kv_store.get_cert_data(domain).is_none() { + bail!("no current certificate to renew"); + } + + // Renew with new key (request wildcard certificate) + let wildcard_domain = format!("*.{}", domain); + info!( + "renewing certificate with new key from ACME for {}...", + wildcard_domain + ); + let new_cert_pem = tokio::time::timeout( + self.config().renew_timeout, + // Note: we request a new cert rather than renew, since we have a new key + acme_client.request_new_certificate(&key_pem, &[wildcard_domain]), + ) + .await + .context("certificate renewal timed out")? + .context("failed to renew certificate")?; + + let not_after = + get_cert_expiry(&new_cert_pem).context("failed to parse certificate expiry")?; + + // Save to KvStore + self.save_cert_to_kvstore(domain, &new_cert_pem, &key_pem, not_after)?; + info!("renewed certificate saved to KvStore"); + + // Generate and save attestation + self.generate_and_save_attestation(domain, &public_key_der) + .await?; + + // Load into memory cert store + let cert_data = CertData { + cert_pem: new_cert_pem, + key_pem, + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.cert_resolver.update_cert(domain, &cert_data)?; + + info!( + "renewed certificate loaded (expires in {} days)", + (not_after - now_secs()) / 86400 + ); + Ok(true) + } + + async fn get_or_create_acme_client( + &self, + domain: &str, + config: &ZtDomainConfig, + ) -> Result { + // Get DNS credential (from config or default) + let dns_cred = if let Some(ref cred_id) = config.dns_cred_id { + self.kv_store + .get_dns_credential(cred_id) + .context("specified DNS credential not found")? + } else { + self.kv_store + .get_default_dns_credential() + .context("no default DNS credential configured")? + }; + + // Create DNS client based on provider + let dns01_client = match &dns_cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => { + Dns01Client::new_cloudflare(domain.to_string(), api_token.clone(), api_url.clone()) + .await? + } + }; + + // Use ACME URL from certbot config, fall back to default if not set + let config = self.config(); + let acme_url = if config.acme_url.is_empty() { + DEFAULT_ACME_URL + } else { + &config.acme_url + }; + + // Try to load global ACME credentials from KvStore + if let Some(creds) = self.kv_store.get_acme_credentials() { + if acme_url_matches(&creds.acme_credentials, acme_url) { + info!("loaded global ACME account credentials from KvStore"); + return AcmeClient::load( + dns01_client, + &creds.acme_credentials, + dns_cred.max_dns_wait, + dns_cred.dns_txt_ttl, + ) + .await + .context("failed to load ACME client from KvStore credentials"); + } + warn!("ACME URL mismatch in KvStore credentials, will create new account"); + } + + // Create new global ACME account + info!("creating new global ACME account at {acme_url}"); + let client = AcmeClient::new_account( + acme_url, + dns01_client, + dns_cred.max_dns_wait, + dns_cred.dns_txt_ttl, + ) + .await + .context("failed to create new ACME account")?; + + let creds_json = client + .dump_credentials() + .context("failed to dump ACME credentials")?; + + // Save global ACME credentials to KvStore + self.kv_store.save_acme_credentials(&CertCredentials { + acme_credentials: creds_json.clone(), + })?; + + // Generate and save ACME account attestation + if let Some(account_uri) = extract_account_uri(&creds_json) { + self.generate_and_save_acme_attestation(&account_uri) + .await?; + } + + Ok(client) + } + + async fn generate_and_save_acme_attestation(&self, account_uri: &str) -> Result<()> { + let agent = match crate::dstack_agent() { + Ok(a) => a, + Err(err) => { + warn!("failed to create dstack agent: {err:?}"); + return Ok(()); + } + }; + + let report_data = QuoteContentType::Custom("acme-account") + .to_report_data(account_uri.as_bytes()) + .to_vec(); + + // Get quote + let quote = match agent + .get_quote(RawQuoteArgs { + report_data: report_data.clone(), + }) + .await + { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(err) => { + warn!("failed to get TDX quote for ACME account: {err:?}"); + return Ok(()); + } + }; + + // Get attestation + let attestation_str = match agent.attest(RawQuoteArgs { report_data }).await { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(err) => { + warn!("failed to get attestation for ACME account: {err:?}"); + String::new() + } + }; + + let attestation = AcmeAttestation { + account_uri: account_uri.to_string(), + quote, + attestation: attestation_str, + generated_by: self.kv_store.my_node_id(), + generated_at: now_secs(), + }; + + self.kv_store.save_acme_attestation(&attestation)?; + info!("ACME account attestation saved to KvStore"); + Ok(()) + } + + fn save_cert_to_kvstore( + &self, + domain: &str, + cert_pem: &str, + key_pem: &str, + not_after: u64, + ) -> Result<()> { + let cert_data = CertData { + cert_pem: cert_pem.to_string(), + key_pem: key_pem.to_string(), + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.kv_store.save_cert_data(domain, &cert_data) + } + + async fn generate_and_save_attestation( + &self, + domain: &str, + public_key_der: &[u8], + ) -> Result<()> { + let agent = match crate::dstack_agent() { + Ok(a) => a, + Err(err) => { + warn!(domain, "failed to create dstack agent: {err:?}"); + return Ok(()); + } + }; + + let report_data = QuoteContentType::Custom("zt-cert") + .to_report_data(public_key_der) + .to_vec(); + + // Get quote + let quote = match agent + .get_quote(RawQuoteArgs { + report_data: report_data.clone(), + }) + .await + { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(err) => { + warn!(domain, "failed to generate TDX quote: {err:?}"); + return Ok(()); + } + }; + + // Get attestation + let attestation = match agent.attest(RawQuoteArgs { report_data }).await { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(err) => { + warn!(domain, "failed to get attestation: {err:?}"); + String::new() + } + }; + + let attestation = CertAttestation { + public_key: public_key_der.to_vec(), + quote, + attestation, + generated_by: self.kv_store.my_node_id(), + generated_at: now_secs(), + }; + + self.kv_store.save_cert_attestation(domain, &attestation)?; + info!(domain, "attestation saved to KvStore"); + Ok(()) + } +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn get_cert_expiry(cert_pem: &str) -> Option { + use x509_parser::prelude::*; + let pem = Pem::iter_from_buffer(cert_pem.as_bytes()).next()?.ok()?; + let cert = pem.parse_x509().ok()?; + Some(cert.validity().not_after.timestamp() as u64) +} + +fn acme_url_matches(credentials_json: &str, expected_url: &str) -> bool { + #[derive(serde::Deserialize)] + struct Creds { + #[serde(default)] + acme_url: String, + } + serde_json::from_str::(credentials_json) + .map(|c| c.acme_url == expected_url) + .unwrap_or(false) +} + +/// Extract account_id (URI) from ACME credentials JSON +fn extract_account_uri(credentials_json: &str) -> Option { + #[derive(serde::Deserialize)] + struct Creds { + #[serde(default)] + account_id: String, + } + serde_json::from_str::(credentials_json) + .ok() + .filter(|c| !c.account_id.is_empty()) + .map(|c| c.account_id) +} diff --git a/gateway/src/gen_debug_key.rs b/gateway/src/gen_debug_key.rs new file mode 100644 index 00000000..c710548a --- /dev/null +++ b/gateway/src/gen_debug_key.rs @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +// Run with: cargo run --bin gen_debug_key -- +// Example: cargo run --bin gen_debug_key -- https://daee134c3b9f66aa2401c3b5ea64f1d34038f45d-3000.tdxlab.dstack.org:12004 + +use anyhow::{Context, Result}; +use base64::{engine::general_purpose::STANDARD, Engine as _}; +use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, RawQuoteArgs}; +use http_client::prpc::PrpcClient; +use ra_tls::attestation::QuoteContentType; +use ra_tls::rcgen::KeyPair; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DebugKeyData { + /// Private key in PEM format + key_pem: String, + /// TDX quote in base64 format + quote_base64: String, + /// Event log in JSON string format + event_log: String, + /// VM config in JSON string format + vm_config: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + eprintln!("Usage: {} ", args[0]); + eprintln!("Example: {} https://daee134c3b9f66aa2401c3b5ea64f1d34038f45d-3000.tdxlab.dstack.org:12004", args[0]); + std::process::exit(1); + } + let simulator_url = &args[1]; + + // Generate key pair + let key = KeyPair::generate().context("Failed to generate key")?; + let pubkey = key.public_key_der(); + let key_pem = key.serialize_pem(); + + // Calculate report_data + let report_data = QuoteContentType::RaTlsCert.to_report_data(&pubkey); + + // Get quote from simulator + println!("Getting quote from simulator: {simulator_url}"); + let simulator_client = PrpcClient::new(simulator_url.to_string()); + let simulator_client = DstackGuestClient::new(simulator_client); + let quote_response = simulator_client + .get_quote(RawQuoteArgs { + report_data: report_data.to_vec(), + }) + .await + .context("Failed to get quote from simulator")?; + + // Create debug key data structure + let debug_data = DebugKeyData { + key_pem, + quote_base64: STANDARD.encode("e_response.quote), + event_log: quote_response.event_log, + vm_config: quote_response.vm_config, + }; + + // Write to single JSON file + let json_content = + serde_json::to_string_pretty(&debug_data).context("Failed to serialize debug key data")?; + let output_file = "debug_key.json"; + fs_err::write(output_file, json_content).context("Failed to write debug key file")?; + + println!("✓ Successfully generated debug key data:"); + println!(" - {output_file}"); + println!("\nYou can now configure this path in your gateway config:"); + println!("[core.debug]"); + println!("insecure_skip_attestation = true"); + println!( + "key_file = \"{}\"", + fs_err::canonicalize(output_file) + .unwrap_or_default() + .display() + ); + + Ok(()) +} diff --git a/gateway/src/kv/https_client.rs b/gateway/src/kv/https_client.rs new file mode 100644 index 00000000..d0d034a9 --- /dev/null +++ b/gateway/src/kv/https_client.rs @@ -0,0 +1,322 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! HTTPS client with mTLS and custom certificate verification during TLS handshake. + +use std::fmt::Debug; +use std::io::{Read, Write}; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http_body_util::{BodyExt, Full}; +use hyper::body::Bytes; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; +use rustls::{DigitallySignedStruct, SignatureScheme}; +use serde::{de::DeserializeOwned, Serialize}; + +use super::{decode, encode}; + +/// Custom certificate validator trait for TLS handshake verification. +/// +/// Implementations can perform additional validation on the peer certificate +/// during the TLS handshake, before any application data is sent. +pub trait CertValidator: Debug + Send + Sync + 'static { + /// Validate the peer certificate. + /// + /// Called after standard X.509 chain verification succeeds. + /// Return `Ok(())` to accept the certificate, or `Err` to reject. + fn validate(&self, cert_der: &[u8]) -> Result<(), String>; +} + +/// TLS configuration for mTLS with optional custom certificate validation +#[derive(Clone)] +pub struct HttpsClientConfig { + pub cert_path: String, + pub key_path: String, + pub ca_cert_path: String, + /// Optional custom certificate validator (checked during TLS handshake) + pub cert_validator: Option>, +} + +/// Wrapper that adapts a CertValidator to rustls ServerCertVerifier +#[derive(Debug)] +struct CustomCertVerifier { + validator: Arc, + root_store: Arc, +} + +impl CustomCertVerifier { + fn new( + validator: Arc, + ca_cert_der: CertificateDer<'static>, + ) -> Result { + let mut root_store = rustls::RootCertStore::empty(); + root_store + .add(ca_cert_der) + .context("failed to add CA cert to root store")?; + Ok(Self { + validator, + root_store: Arc::new(root_store), + }) + } +} + +impl ServerCertVerifier for CustomCertVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + _ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // First, do standard certificate verification + let verifier = rustls::client::WebPkiServerVerifier::builder(self.root_store.clone()) + .build() + .map_err(|e| rustls::Error::General(format!("failed to build verifier: {e}")))?; + + verifier.verify_server_cert(end_entity, intermediates, server_name, &[], now)?; + + // Then run custom validation + self.validator + .validate(end_entity.as_ref()) + .map_err(rustls::Error::General)?; + + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +type HyperClient = Client, Full>; + +/// HTTPS client with mTLS and optional custom certificate validation. +/// +/// When a `cert_validator` is set in `TlsConfig`, the client runs the validator +/// during the TLS handshake, before any application data is sent. +#[derive(Clone)] +pub struct HttpsClient { + client: HyperClient, +} + +impl HttpsClient { + /// Create a new HTTPS client with mTLS configuration + pub fn new(tls: &HttpsClientConfig) -> Result { + // Load client certificate and key + let cert_pem = std::fs::read(&tls.cert_path) + .with_context(|| format!("failed to read TLS cert from {}", tls.cert_path))?; + let key_pem = std::fs::read(&tls.key_path) + .with_context(|| format!("failed to read TLS key from {}", tls.key_path))?; + + let certs: Vec> = CertificateDer::pem_slice_iter(&cert_pem) + .collect::>() + .context("failed to parse client certs")?; + + let key = PrivateKeyDer::from_pem_slice(&key_pem).context("failed to parse private key")?; + + // Load CA certificate + let ca_cert_pem = std::fs::read(&tls.ca_cert_path) + .with_context(|| format!("failed to read CA cert from {}", tls.ca_cert_path))?; + let ca_certs: Vec> = CertificateDer::pem_slice_iter(&ca_cert_pem) + .collect::>() + .context("failed to parse CA certs")?; + let ca_cert = ca_certs + .into_iter() + .next() + .context("no CA certificate found")?; + + // Build rustls config with custom verifier if validator is provided + let tls_config_builder = rustls::ClientConfig::builder(); + + let tls_config = if let Some(ref validator) = tls.cert_validator { + let verifier = CustomCertVerifier::new(validator.clone(), ca_cert)?; + tls_config_builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + } else { + // Standard verification without custom validator + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(ca_cert).context("failed to add CA cert")?; + tls_config_builder.with_root_certificates(root_store) + } + .with_client_auth_cert(certs, key) + .context("failed to set client auth cert")?; + + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http1() + .build(); + + let client = Client::builder(TokioExecutor::new()).build(https); + Ok(Self { client }) + } + + /// Send a POST request with JSON body and receive JSON response + pub async fn post_json( + &self, + url: &str, + body: &T, + ) -> Result { + let body = serde_json::to_vec(body).context("failed to serialize request body")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/json") + .body(Full::new(Bytes::from(body))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + serde_json::from_slice(&body).context("failed to parse response") + } + + /// Send a POST request with msgpack + gzip encoded body and receive msgpack + gzip response + pub async fn post_compressed_msg( + &self, + url: &str, + body: &T, + ) -> Result { + let encoded = encode(body).context("failed to encode request body")?; + + // Compress with gzip + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder + .write_all(&encoded) + .context("failed to compress request")?; + let compressed = encoder.finish().context("failed to finish compression")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/x-msgpack-gz") + .body(Full::new(Bytes::from(compressed))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + // Decompress + let mut decoder = GzDecoder::new(body.as_ref()); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .context("failed to decompress response")?; + + decode(&decompressed).context("failed to decode response") + } +} + +// ============================================================================ +// Built-in validators +// ============================================================================ + +/// Validator that checks the peer certificate contains a specific app_id. +#[derive(Debug)] +pub struct AppIdValidator { + expected_app_id: Vec, +} + +impl AppIdValidator { + pub fn new(expected_app_id: Vec) -> Self { + Self { expected_app_id } + } +} + +impl CertValidator for AppIdValidator { + fn validate(&self, cert_der: &[u8]) -> Result<(), String> { + use ra_tls::traits::CertExt; + + let (_, cert) = x509_parser::parse_x509_certificate(cert_der) + .map_err(|e| format!("failed to parse certificate: {e}"))?; + + let peer_app_id = cert + .get_app_id() + .map_err(|e| format!("failed to get app_id: {e}"))?; + + let Some(peer_app_id) = peer_app_id else { + return Err("peer certificate does not contain app_id".into()); + }; + + if peer_app_id != self.expected_app_id { + return Err(format!( + "app_id mismatch: expected {}, got {}", + hex::encode(&self.expected_app_id), + hex::encode(&peer_app_id) + )); + } + + Ok(()) + } +} diff --git a/gateway/src/kv/mod.rs b/gateway/src/kv/mod.rs new file mode 100644 index 00000000..97b195c9 --- /dev/null +++ b/gateway/src/kv/mod.rs @@ -0,0 +1,997 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV-based sync layer for dstack-gateway. +//! +//! This module provides synchronization between gateway nodes. The local ProxyState +//! remains the primary data store for fast reads, while WaveKV handles cross-node sync. +//! +//! Key schema: +//! +//! # Persistent WaveKV (needs persistence + sync) +//! - `inst/{instance_id}` → InstanceData +//! - `node/{node_id}` → NodeData +//! - `dns_cred/{cred_id}` → DnsCredential +//! - `dns_cred_default` → cred_id (default credential ID) +//! - `global/certbot_config` → GlobalCertbotConfig +//! - `cert/{domain}/config` → ZtDomainConfig +//! - `cert/{domain}/data` → CertData +//! - `global/acme_credentials` → CertCredentials (shared ACME account) +//! - `global/acme_attestation` → AcmeAttestation (TDX quote of ACME account URI) +//! - `cert/{domain}/lock` → CertRenewLock +//! - `cert/{domain}/attestation/latest` → CertAttestation +//! - `cert/{domain}/attestation/{timestamp}` → CertAttestation (history) +//! +//! # Ephemeral WaveKV (no persistence, sync only) +//! - `conn/{instance_id}/{node_id}` → u64 (connection count) +//! - `last_seen/inst/{instance_id}` → u64 (timestamp) +//! - `last_seen/node/{node_id}/{seen_by_node_id}` → u64 (timestamp) + +mod https_client; +mod sync_service; + +pub use https_client::{AppIdValidator, HttpsClientConfig}; +pub use sync_service::{fetch_peers_from_bootnode, WaveKvSyncService}; +use tracing::warn; + +use std::{collections::BTreeMap, net::Ipv4Addr, path::Path, time::Duration}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use tokio::sync::watch; +use wavekv::{node::NodeState, types::NodeId, Node}; + +/// Instance core data (persistent) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct InstanceData { + pub app_id: String, + pub ip: Ipv4Addr, + pub public_key: String, + pub reg_time: u64, +} + +/// Gateway node status (stored separately for independent updates) +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum NodeStatus { + #[default] + Up, + Down, +} + +/// Gateway node data (persistent, rarely changes) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct NodeData { + pub uuid: Vec, + pub url: String, + pub wg_public_key: String, + pub wg_endpoint: String, + pub wg_ip: String, +} + +/// Certificate credentials (ACME account) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertCredentials { + pub acme_credentials: String, +} + +/// ACME account attestation (TDX Quote of account URI) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct AcmeAttestation { + /// ACME account URI + pub account_uri: String, + /// TDX Quote (JSON serialized) + #[serde(default)] + pub quote: String, + /// Full attestation (JSON serialized) + #[serde(default)] + pub attestation: String, + /// Node that generated this attestation + #[serde(default)] + pub generated_by: NodeId, + /// Timestamp when this attestation was generated + #[serde(default)] + pub generated_at: u64, +} + +/// Certificate data (cert + key) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertData { + pub cert_pem: String, + pub key_pem: String, + pub not_after: u64, + pub issued_by: NodeId, + pub issued_at: u64, +} + +/// Certificate renew lock +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertRenewLock { + pub started_at: u64, + pub started_by: NodeId, +} + +/// Certificate attestation (TDX Quote of certificate public key) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CertAttestation { + /// Certificate public key (DER encoded) + pub public_key: Vec, + /// TDX Quote (JSON serialized) + #[serde(default)] + pub quote: String, + /// Full attestation (JSON serialized) + #[serde(default)] + pub attestation: String, + /// Node that generated this attestation + #[serde(default)] + pub generated_by: NodeId, + /// Timestamp when this attestation was generated + #[serde(default)] + pub generated_at: u64, +} + +/// DNS credential configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DnsCredential { + /// Unique identifier + pub id: String, + /// Display name + pub name: String, + /// DNS provider configuration + pub provider: DnsProvider, + /// Maximum DNS wait time + #[serde(with = "serde_duration")] + pub max_dns_wait: Duration, + /// DNS TXT record TTL + pub dns_txt_ttl: u32, + /// Creation timestamp + pub created_at: u64, + /// Last update timestamp + pub updated_at: u64, +} + +/// DNS provider configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DnsProvider { + Cloudflare { + api_token: String, + /// Cloudflare API URL (defaults to https://api.cloudflare.com/client/v4 if not set) + #[serde(default, skip_serializing_if = "Option::is_none")] + api_url: Option, + }, + // Future providers can be added here +} + +/// ZT-Domain configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZtDomainConfig { + /// Base domain name (e.g., "app.example.com") + /// Certificate will be issued for "*.{domain}" automatically + pub domain: String, + /// DNS credential ID to use (None = use default) + pub dns_cred_id: Option, + /// Port this domain serves on (e.g., 443) + #[serde(default)] + pub port: u16, + /// Node binding (None = any node can serve this domain) + /// If set, only this node will serve this domain + #[serde(default)] + pub node: Option, + /// Priority for default base_domain selection (higher = preferred) + /// The domain with highest priority is returned as the default base_domain in APIs + #[serde(default)] + pub priority: i32, +} + +/// Global certbot configuration (stored in KV, synced across nodes) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GlobalCertbotConfig { + /// Interval between renewal checks + #[serde(with = "serde_duration")] + pub renew_interval: Duration, + /// Time before expiration to trigger renewal (e.g., 30 days) + #[serde(with = "serde_duration")] + pub renew_before_expiration: Duration, + /// Timeout for certificate renewal operations + #[serde(with = "serde_duration")] + pub renew_timeout: Duration, + /// ACME server URL (None means use default Let's Encrypt production) + pub acme_url: String, +} + +impl Default for GlobalCertbotConfig { + fn default() -> Self { + Self { + renew_interval: Duration::from_secs(12 * 3600), // 12 hours + renew_before_expiration: Duration::from_secs(30 * 86400), // 30 days + renew_timeout: Duration::from_secs(300), // 5 minutes + acme_url: Default::default(), // default Let's Encrypt + } + } +} + +// Key prefixes and builders +pub mod keys { + use super::NodeId; + + pub const INST_PREFIX: &str = "inst/"; + pub const NODE_PREFIX: &str = "node/"; + pub const NODE_INFO_PREFIX: &str = "node/info/"; + pub const NODE_STATUS_PREFIX: &str = "node/status/"; + pub const CONN_PREFIX: &str = "conn/"; + pub const HANDSHAKE_PREFIX: &str = "handshake/"; + pub const LAST_SEEN_NODE_PREFIX: &str = "last_seen/node/"; + pub const PEER_ADDR_PREFIX: &str = "__peer_addr/"; + pub const CERT_PREFIX: &str = "cert/"; + pub const DNS_CRED_PREFIX: &str = "dns_cred/"; + pub const DNS_CRED_DEFAULT: &str = "dns_cred_default"; + pub const GLOBAL_CERTBOT_CONFIG: &str = "global/certbot_config"; + pub const GLOBAL_ACME_CREDENTIALS: &str = "global/acme_credentials"; + pub const GLOBAL_ACME_ATTESTATION: &str = "global/acme_attestation"; + + pub fn inst(instance_id: &str) -> String { + format!("{INST_PREFIX}{instance_id}") + } + + pub fn node_info(node_id: NodeId) -> String { + format!("{NODE_INFO_PREFIX}{node_id}") + } + + pub fn node_status(node_id: NodeId) -> String { + format!("{NODE_STATUS_PREFIX}{node_id}") + } + + pub fn conn(instance_id: &str, node_id: NodeId) -> String { + format!("{CONN_PREFIX}{instance_id}/{node_id}") + } + + /// Key for instance handshake timestamp observed by a specific node + /// Format: handshake/{instance_id}/{observer_node_id} + pub fn handshake(instance_id: &str, observer_node_id: NodeId) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/{observer_node_id}") + } + + /// Prefix to iterate all handshake observations for an instance + pub fn handshake_prefix(instance_id: &str) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/") + } + + pub fn last_seen_node(node_id: NodeId, seen_by: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/{seen_by}") + } + + pub fn last_seen_node_prefix(node_id: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/") + } + + pub fn peer_addr(node_id: NodeId) -> String { + format!("{PEER_ADDR_PREFIX}{node_id}") + } + + // ==================== DNS Credential keys ==================== + + /// Key for a DNS credential + pub fn dns_cred(cred_id: &str) -> String { + format!("{DNS_CRED_PREFIX}{cred_id}") + } + + // ==================== Certificate keys (per domain) ==================== + + /// Key for ZT-Domain configuration + pub fn zt_domain_config(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/config") + } + + /// Key for domain certificate data (cert + key) + pub fn cert_data(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/data") + } + + /// Key for domain certificate renew lock + pub fn cert_lock(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/lock") + } + + /// Key for latest attestation of a domain + pub fn cert_attestation_latest(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/attestation/latest") + } + + /// Key for historical attestation of a domain + pub fn cert_attestation_history(domain: &str, timestamp: u64) -> String { + format!("{CERT_PREFIX}{domain}/attestation/{timestamp}") + } + + /// Prefix for all attestations of a domain (for iteration) + pub fn cert_attestation_prefix(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/attestation/") + } + + /// Parse domain from cert/{domain}/... key + pub fn parse_cert_domain(key: &str) -> Option<&str> { + let rest = key.strip_prefix(CERT_PREFIX)?; + rest.split('/').next() + } + + // ==================== Parse helpers ==================== + + /// Parse instance_id from key + pub fn parse_inst_key(key: &str) -> Option<&str> { + key.strip_prefix(INST_PREFIX) + } + + /// Parse node_id from node/info/{node_id} key + pub fn parse_node_info_key(key: &str) -> Option { + key.strip_prefix(NODE_INFO_PREFIX)?.parse().ok() + } +} + +pub fn encode(value: &T) -> Result> { + rmp_serde::encode::to_vec(value).context("failed to encode value") +} + +pub fn decode Deserialize<'de>>(bytes: &[u8]) -> Result { + rmp_serde::decode::from_slice(bytes).context("failed to decode value") +} + +trait GetPutCodec { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option; + fn put_encoded(&mut self, key: String, value: &T) -> Result<()>; + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; +} + +impl GetPutCodec for NodeState { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option { + self.get(key) + .and_then(|entry| match decode(entry.value.as_ref()?) { + Ok(value) => Some(value), + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + None + } + }) + } + + fn put_encoded(&mut self, key: String, value: &T) -> Result<()> { + self.put(key.clone(), encode(value)?) + .with_context(|| format!("failed to put key {key}"))?; + Ok(()) + } + + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some((key.to_string(), value)) + }) + } + + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some(value) + }) + } +} + +/// Sync store wrapping two WaveKV Nodes (persistent and ephemeral). +/// +/// This is the sync layer - not the primary data store. +/// ProxyState remains in memory for fast reads. +#[derive(Clone)] +pub struct KvStore { + /// Persistent WaveKV Node (with WAL) + persistent: Node, + /// Ephemeral WaveKV Node (in-memory only) + ephemeral: Node, + /// This gateway's node ID + my_node_id: NodeId, +} + +impl KvStore { + /// Create a new sync store + pub fn new( + my_node_id: NodeId, + peer_ids: Vec, + data_dir: impl AsRef, + ) -> Result { + let persistent = + Node::new_with_persistence(my_node_id, peer_ids.clone(), data_dir.as_ref()) + .context("failed to create persistent wavekv node")?; + + // Get peers from persistent store (may have been restored from WAL) + // and include them when creating ephemeral store + let persistent_peers = persistent.read().status().peers; + let mut all_peer_ids = peer_ids; + for peer_status in persistent_peers { + if !all_peer_ids.contains(&peer_status.id) { + all_peer_ids.push(peer_status.id); + } + } + + let ephemeral = Node::new(my_node_id, all_peer_ids); + + Ok(Self { + persistent, + ephemeral, + my_node_id, + }) + } + + pub fn my_node_id(&self) -> NodeId { + self.my_node_id + } + + pub fn persistent(&self) -> &Node { + &self.persistent + } + + pub fn ephemeral(&self) -> &Node { + &self.ephemeral + } + + // ==================== Instance Sync ==================== + + /// Sync instance data to other nodes + pub fn sync_instance(&self, instance_id: &str, data: &InstanceData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::inst(instance_id), data) + } + + /// Sync instance deletion to other nodes + pub fn sync_delete_instance(&self, instance_id: &str) -> Result<()> { + self.persistent.write().delete(keys::inst(instance_id))?; + self.ephemeral + .write() + .delete(keys::conn(instance_id, self.my_node_id))?; + // Delete this node's handshake record + self.ephemeral + .write() + .delete(keys::handshake(instance_id, self.my_node_id))?; + Ok(()) + } + + /// Load all instances from sync store (for initial sync on startup) + pub fn load_all_instances(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::INST_PREFIX) + .filter_map(|(key, data)| { + let instance_id = keys::parse_inst_key(&key)?; + Some((instance_id.into(), data)) + }) + .collect() + } + + // ==================== Node Sync ==================== + + /// Sync node data to other nodes + pub fn sync_node(&self, node_id: NodeId, data: &NodeData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_info(node_id), data) + } + + /// Load all nodes from sync store + pub fn load_all_nodes(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_INFO_PREFIX) + .filter_map(|(key, data)| { + let node_id = keys::parse_node_info_key(&key)?; + Some((node_id, data)) + }) + .collect() + } + + // ==================== Node Status Sync ==================== + + /// Set node status (stored separately from NodeData) + pub fn set_node_status(&self, node_id: NodeId, status: NodeStatus) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_status(node_id), &status)?; + Ok(()) + } + + /// Get node status + pub fn get_node_status(&self, node_id: NodeId) -> NodeStatus { + self.persistent + .read() + .decode(&keys::node_status(node_id)) + .unwrap_or_default() + } + + /// Load all node statuses + pub fn load_all_node_statuses(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_STATUS_PREFIX) + .filter_map(|(key, status)| { + let node_id: NodeId = key.strip_prefix(keys::NODE_STATUS_PREFIX)?.parse().ok()?; + Some((node_id, status)) + }) + .collect() + } + + // ==================== Connection Count Sync ==================== + + /// Sync connection count for an instance (from this node) + pub fn sync_connections(&self, instance_id: &str, count: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::conn(instance_id, self.my_node_id), &count)?; + Ok(()) + } + + // ==================== Handshake Sync ==================== + + /// Sync handshake timestamp for an instance (as observed by this node) + pub fn sync_instance_handshake(&self, instance_id: &str, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::handshake(instance_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all handshake observations for an instance (from all nodes) + pub fn get_instance_handshakes(&self, instance_id: &str) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::handshake_prefix(instance_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::handshake_prefix(instance_id))?; + let observer: NodeId = suffix.parse().ok()?; + Some((observer, ts)) + }) + .collect() + } + + /// Get the latest handshake timestamp for an instance (max across all nodes) + pub fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::handshake_prefix(instance_id)) + .max() + } + + /// Sync node last_seen (as observed by this node) + pub fn sync_node_last_seen(&self, node_id: NodeId, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::last_seen_node(node_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all observations of a node's last_seen + pub fn get_node_last_seen_by_all(&self, node_id: NodeId) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::last_seen_node_prefix(node_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::last_seen_node_prefix(node_id))?; + let seen_by: NodeId = suffix.parse().ok()?; + Some((seen_by, ts)) + }) + .collect() + } + + /// Get the latest last_seen timestamp for a node (max across all observers) + pub fn get_node_latest_last_seen(&self, node_id: NodeId) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::last_seen_node_prefix(node_id)) + .max() + } + + // ==================== Watch for Remote Changes ==================== + + /// Watch for remote instance changes (for updating local ProxyState) + pub fn watch_instances(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::INST_PREFIX) + } + + /// Watch for remote node changes + pub fn watch_nodes(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::NODE_PREFIX) + } + + // ==================== Persistence ==================== + + pub fn persist_if_dirty(&self) -> Result { + self.persistent.persist_if_dirty() + } + + // ==================== Peer Management ==================== + + pub fn add_peer(&self, peer_id: NodeId) -> Result<()> { + self.persistent.write().add_peer(peer_id)?; + self.ephemeral.write().add_peer(peer_id)?; + Ok(()) + } + + // ==================== Peer Address (in DB) ==================== + + /// Register a node's sync URL in DB and add to peer list for sync + /// + /// This stores the URL in KvStore (for address lookup) and also adds the node + /// to the wavekv peer list (so SyncManager knows to sync with it). + pub fn register_peer_url(&self, node_id: NodeId, url: &str) -> Result<()> { + // Store URL in persistent KvStore + self.persistent + .write() + .put_encoded(keys::peer_addr(node_id), &url)?; + + let _ = self.add_peer(node_id); + Ok(()) + } + + /// Get a peer's sync URL from DB + pub fn get_peer_url(&self, node_id: NodeId) -> Option { + self.persistent.read().decode(&keys::peer_addr(node_id)) + } + + /// Query the UUID for a given node ID from KvStore + pub fn get_peer_uuid(&self, peer_id: NodeId) -> Option> { + let node_data: NodeData = self.persistent.read().decode(&keys::node_info(peer_id))?; + Some(node_data.uuid) + } + + pub fn update_peer_last_seen(&self, peer_id: NodeId) { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + let key = keys::last_seen_node(peer_id, self.my_node_id); + if let Err(e) = self.ephemeral.write().put_encoded(key, &ts) { + warn!("failed to update peer {peer_id} last_seen: {e}"); + } + } + + /// Get all peer addresses from DB (for debugging/testing) + pub fn get_all_peer_addrs(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::PEER_ADDR_PREFIX) + .filter_map(|(key, url)| { + let node_id: NodeId = key.strip_prefix(keys::PEER_ADDR_PREFIX)?.parse().ok()?; + Some((node_id, url)) + }) + .collect() + } + + // ==================== DNS Credential Management ==================== + + /// Get a DNS credential by ID + pub fn get_dns_credential(&self, cred_id: &str) -> Option { + self.persistent.read().decode(&keys::dns_cred(cred_id)) + } + + /// Save a DNS credential + pub fn save_dns_credential(&self, cred: &DnsCredential) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::dns_cred(&cred.id), cred)?; + Ok(()) + } + + /// Delete a DNS credential + pub fn delete_dns_credential(&self, cred_id: &str) -> Result<()> { + self.persistent.write().delete(keys::dns_cred(cred_id))?; + Ok(()) + } + + /// List all DNS credentials + pub fn list_dns_credentials(&self) -> Vec { + self.persistent + .read() + .iter_decoded_values(keys::DNS_CRED_PREFIX) + .collect() + } + + /// Get the default DNS credential ID + pub fn get_default_dns_credential_id(&self) -> Option { + self.persistent.read().decode(keys::DNS_CRED_DEFAULT) + } + + /// Set the default DNS credential ID + pub fn set_default_dns_credential_id(&self, cred_id: &str) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::DNS_CRED_DEFAULT.to_string(), &cred_id)?; + Ok(()) + } + + /// Get the default DNS credential (resolves the ID to the actual credential) + pub fn get_default_dns_credential(&self) -> Option { + let cred_id = self.get_default_dns_credential_id()?; + self.get_dns_credential(&cred_id) + } + + // ==================== Global Certbot Config ==================== + + /// Get global certbot configuration (returns default if not set) + pub fn get_certbot_config(&self) -> GlobalCertbotConfig { + self.persistent + .read() + .decode(keys::GLOBAL_CERTBOT_CONFIG) + .unwrap_or_default() + } + + /// Set global certbot configuration + pub fn set_certbot_config(&self, config: &GlobalCertbotConfig) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_CERTBOT_CONFIG.to_string(), config)?; + Ok(()) + } + + // ==================== ZT-Domain Config ==================== + + /// Get ZT-Domain configuration + pub fn get_zt_domain_config(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::zt_domain_config(domain)) + } + + /// Save ZT-Domain configuration + pub fn save_zt_domain_config(&self, config: &ZtDomainConfig) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::zt_domain_config(&config.domain), config)?; + Ok(()) + } + + /// Delete ZT-Domain configuration + pub fn delete_zt_domain_config(&self, domain: &str) -> Result<()> { + self.persistent + .write() + .delete(keys::zt_domain_config(domain))?; + Ok(()) + } + + /// List all ZT-Domain configurations + pub fn list_zt_domain_configs(&self) -> Vec { + let state = self.persistent.read(); + state + .iter_by_prefix(keys::CERT_PREFIX) + .filter_map(|(key, entry)| { + // Only decode config entries (not data/acme/lock/attestation) + if !key.ends_with("/config") { + return None; + } + let value = entry.value.as_ref()?; + match decode(value) { + Ok(config) => Some(config), + Err(e) => { + warn!("failed to decode cert config for key {key}: {e:?}"); + None + } + } + }) + .collect() + } + + /// Watch for ZT-Domain config changes + pub fn watch_zt_domain_configs(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::CERT_PREFIX) + } + + /// Get the best ZT-Domain config for this node. + /// + /// Selection rules: + /// 1. Only considers domains where node == None or node == my_node_id + /// 2. Higher priority wins + /// 3. If priority is equal, node == None wins (global domains preferred over node-specific) + /// + /// Returns (domain, port) of the best match, or None if no domains configured. + pub fn get_best_zt_domain(&self) -> Option<(String, u16)> { + let my_node_id = self.my_node_id; + let configs = self.list_zt_domain_configs(); + + configs + .into_iter() + .filter(|c| c.node.is_none() || c.node == Some(my_node_id)) + .max_by(|a, b| { + // Compare by priority first (higher wins) + match a.priority.cmp(&b.priority) { + std::cmp::Ordering::Equal => { + // If priority equal, None (global) wins over Some (node-specific) + // None < Some in Option ordering, so we reverse + b.node.cmp(&a.node) + } + other => other, + } + }) + .map(|c| (c.domain, c.port)) + } + + // ==================== Certificate Data ==================== + + /// Get certificate data for a domain + pub fn get_cert_data(&self, domain: &str) -> Option { + self.persistent.read().decode(&keys::cert_data(domain)) + } + + /// Save certificate data for a domain + pub fn save_cert_data(&self, domain: &str, data: &CertData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::cert_data(domain), data)?; + Ok(()) + } + + /// Load all certificate data (for startup) + pub fn load_all_cert_data(&self) -> BTreeMap { + let state = self.persistent.read(); + state + .iter_by_prefix(keys::CERT_PREFIX) + .filter_map(|(key, entry)| { + // Only decode data entries (not config/acme/lock/attestation) + if !key.ends_with("/data") { + return None; + } + let domain = keys::parse_cert_domain(key)?; + let value = entry.value.as_ref()?; + match decode(value) { + Ok(data) => Some((domain.to_string(), data)), + Err(e) => { + warn!("failed to decode cert data for key {key}: {e:?}"); + None + } + } + }) + .collect() + } + + // ==================== Global ACME Credentials ==================== + + /// Get global ACME credentials (shared across all domains) + pub fn get_acme_credentials(&self) -> Option { + self.persistent.read().decode(keys::GLOBAL_ACME_CREDENTIALS) + } + + /// Save global ACME credentials + pub fn save_acme_credentials(&self, creds: &CertCredentials) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_ACME_CREDENTIALS.to_string(), creds)?; + Ok(()) + } + + /// Get global ACME attestation (TDX quote of account URI) + pub fn get_acme_attestation(&self) -> Option { + self.persistent.read().decode(keys::GLOBAL_ACME_ATTESTATION) + } + + /// Save global ACME attestation + pub fn save_acme_attestation(&self, attestation: &AcmeAttestation) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_ACME_ATTESTATION.to_string(), attestation)?; + Ok(()) + } + + // ==================== Certificate Renew Lock ==================== + + /// Get certificate renew lock for a domain + pub fn get_cert_lock(&self, domain: &str) -> Option { + self.persistent.read().decode(&keys::cert_lock(domain)) + } + + /// Try to acquire certificate renew lock + /// Returns true if lock acquired, false if already locked by another node + pub fn try_acquire_cert_lock(&self, domain: &str, lock_timeout_secs: u64) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if let Some(existing) = self.get_cert_lock(domain) { + // Check if lock is still valid (not expired) + if now < existing.started_at + lock_timeout_secs { + return false; + } + } + + // Acquire the lock + let lock = CertRenewLock { + started_at: now, + started_by: self.my_node_id, + }; + self.persistent + .write() + .put_encoded(keys::cert_lock(domain), &lock) + .is_ok() + } + + /// Release certificate renew lock + pub fn release_cert_lock(&self, domain: &str) -> Result<()> { + self.persistent.write().delete(keys::cert_lock(domain))?; + Ok(()) + } + + // ==================== Certificate Attestation ==================== + + /// Get the latest attestation for a domain + pub fn get_cert_attestation_latest(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::cert_attestation_latest(domain)) + } + + /// Save attestation for a domain (saves both latest and history) + pub fn save_cert_attestation(&self, domain: &str, attestation: &CertAttestation) -> Result<()> { + let mut state = self.persistent.write(); + // Save to history + state.put_encoded( + keys::cert_attestation_history(domain, attestation.generated_at), + attestation, + )?; + // Update latest + state.put_encoded(keys::cert_attestation_latest(domain), attestation)?; + Ok(()) + } + + /// List all attestation history for a domain (sorted by timestamp descending) + pub fn list_cert_attestations(&self, domain: &str) -> Vec { + let prefix = keys::cert_attestation_prefix(domain); + let latest_key = keys::cert_attestation_latest(domain); + let state = self.persistent.read(); + let mut attestations: Vec = state + .iter_by_prefix(&prefix) + .filter_map(|(key, entry)| { + // Skip the "latest" entry + if key == &latest_key { + return None; + } + let value = entry.value.as_ref()?; + match decode(value) { + Ok(att) => Some(att), + Err(e) => { + warn!("failed to decode attestation for key {key}: {e:?}"); + None + } + } + }) + .collect(); + // Sort by generated_at descending (newest first) + attestations.sort_by(|a, b| b.generated_at.cmp(&a.generated_at)); + attestations + } + + // ==================== Watch helpers ==================== + + /// Watch for certificate data changes (any domain) + pub fn watch_all_certs(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::CERT_PREFIX) + } +} diff --git a/gateway/src/kv/sync_service.rs b/gateway/src/kv/sync_service.rs new file mode 100644 index 00000000..f691595a --- /dev/null +++ b/gateway/src/kv/sync_service.rs @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync service - implements network transport for wavekv synchronization. +//! +//! Peer URLs are stored in the persistent KV store under `__peer_addr/{node_id}` keys. +//! This allows peer addresses to be automatically synced across nodes. + +use std::sync::Arc; + +use anyhow::{Context, Result}; +use dstack_gateway_rpc::GetPeersResponse; +use tracing::{info, warn}; +use wavekv::{ + sync::{ExchangeInterface, SyncConfig as KvSyncConfig, SyncManager, SyncMessage, SyncResponse}, + types::NodeId, + Node, +}; + +use crate::config::SyncConfig as GwSyncConfig; + +use super::https_client::{HttpsClient, HttpsClientConfig}; +use super::KvStore; + +/// HTTP-based network transport for WaveKV sync. +/// Holds a reference to the persistent node for reading peer URLs. +#[derive(Clone)] +pub struct HttpSyncNetwork { + client: HttpsClient, + /// Reference to persistent node for reading peer URLs + kv_store: KvStore, + /// This node's UUID (for node ID reuse detection) + my_uuid: Vec, + /// URL path suffix for this store (e.g., "persistent" or "ephemeral") + store_path: &'static str, +} + +impl HttpSyncNetwork { + pub fn new( + kv_store: KvStore, + store_path: &'static str, + tls_config: &HttpsClientConfig, + ) -> Result { + let client = HttpsClient::new(tls_config)?; + let my_uuid = kv_store + .get_peer_uuid(kv_store.my_node_id) + .context("failed to get my UUID")?; + Ok(Self { + client, + kv_store, + my_uuid, + store_path, + }) + } + + /// Get peer URL from persistent node + fn get_peer_url(&self, peer_id: NodeId) -> Option { + self.kv_store.get_peer_url(peer_id) + } +} + +impl ExchangeInterface for HttpSyncNetwork { + fn uuid(&self) -> Vec { + self.my_uuid.clone() + } + + fn query_uuid(&self, node_id: NodeId) -> Option> { + self.kv_store.get_peer_uuid(node_id) + } + + async fn sync_to(&self, _node: &Node, peer: NodeId, msg: SyncMessage) -> Result { + let url = self + .get_peer_url(peer) + .ok_or_else(|| anyhow::anyhow!("peer {} address not found in DB", peer))?; + + let sync_url = format!( + "{}/wavekv/sync/{}", + url.trim_end_matches('/'), + self.store_path + ); + + // Send request with msgpack + gzip encoding + // app_id verification happens during TLS handshake via AppIdVerifier + let sync_response: SyncResponse = self + .client + .post_compressed_msg(&sync_url, &msg) + .await + .with_context(|| format!("failed to sync to peer {peer} at {sync_url}"))?; + + // Update peer last_seen on successful sync + self.kv_store.update_peer_last_seen(peer); + + Ok(sync_response) + } +} + +/// WaveKV sync service that manages synchronization for both persistent and ephemeral stores +pub struct WaveKvSyncService { + pub persistent_manager: Arc>, + pub ephemeral_manager: Arc>, +} + +impl WaveKvSyncService { + /// Create a new WaveKV sync service + /// + /// # Arguments + /// * `kv_store` - The sync store containing persistent and ephemeral nodes + /// * `sync_config` - Sync configuration + /// * `tls_config` - TLS configuration for mTLS peer authentication + pub fn new( + kv_store: &KvStore, + sync_config: &GwSyncConfig, + tls_config: HttpsClientConfig, + ) -> Result { + let sync_config = KvSyncConfig { + interval: sync_config.interval, + timeout: sync_config.timeout, + }; + + // Both networks use the same persistent node for URL lookup, but different paths + let persistent_network = HttpSyncNetwork::new(kv_store.clone(), "persistent", &tls_config)?; + let ephemeral_network = HttpSyncNetwork::new(kv_store.clone(), "ephemeral", &tls_config)?; + + let persistent_manager = Arc::new(SyncManager::with_config( + kv_store.persistent().clone(), + persistent_network, + sync_config.clone(), + )); + let ephemeral_manager = Arc::new(SyncManager::with_config( + kv_store.ephemeral().clone(), + ephemeral_network, + sync_config, + )); + + Ok(Self { + persistent_manager, + ephemeral_manager, + }) + } + + /// Bootstrap from peers + pub async fn bootstrap(&self) -> Result<()> { + info!("bootstrapping persistent store..."); + if let Err(e) = self.persistent_manager.bootstrap().await { + warn!("failed to bootstrap persistent store: {e}"); + } + + info!("bootstrapping ephemeral store..."); + if let Err(e) = self.ephemeral_manager.bootstrap().await { + warn!("failed to bootstrap ephemeral store: {e}"); + } + + Ok(()) + } + + /// Start background sync tasks + pub async fn start_sync_tasks(&self) { + let persistent = self.persistent_manager.clone(); + let ephemeral = self.ephemeral_manager.clone(); + + tokio::join!(persistent.start_sync_tasks(), ephemeral.start_sync_tasks(),); + + info!("WaveKV sync tasks started"); + } + + /// Handle incoming sync request for persistent store + pub fn handle_persistent_sync(&self, msg: SyncMessage) -> Result { + self.persistent_manager.handle_sync(msg) + } + + /// Handle incoming sync request for ephemeral store + pub fn handle_ephemeral_sync(&self, msg: SyncMessage) -> Result { + self.ephemeral_manager.handle_sync(msg) + } +} + +/// Fetch peer list from bootnode and register them in KvStore. +/// +/// This is called during startup to bootstrap the peer list from a known bootnode. +/// Uses Gateway.GetPeers RPC which requires mTLS gateway authentication. +pub async fn fetch_peers_from_bootnode( + bootnode_url: &str, + kv_store: &KvStore, + my_node_id: NodeId, + tls_config: &HttpsClientConfig, +) -> Result<()> { + if bootnode_url.is_empty() { + info!("no bootnode configured, skipping peer fetch"); + return Ok(()); + } + + info!("fetching peers from bootnode: {}", bootnode_url); + + // Create HTTPS client for bootnode communication (with mTLS) + let client = HttpsClient::new(tls_config).context("failed to create HTTPS client")?; + + // Call Gateway.GetPeers RPC on bootnode (requires mTLS gateway auth) + let peers_url = format!("{}/prpc/GetPeers", bootnode_url.trim_end_matches('/')); + + let response: GetPeersResponse = client + .post_json(&peers_url, &()) + .await + .with_context(|| format!("failed to fetch peers from bootnode {bootnode_url}"))?; + + info!( + "bootnode returned {} peers (bootnode_id={})", + response.peers.len(), + response.my_id + ); + + // Register each peer + for peer in &response.peers { + if peer.id == my_node_id { + continue; // Skip self + } + + // Add peer to WaveKV + if let Err(e) = kv_store.add_peer(peer.id) { + warn!("failed to add peer {}: {}", peer.id, e); + continue; + } + + // Register peer URL + if !peer.url.is_empty() { + if let Err(e) = kv_store.register_peer_url(peer.id, &peer.url) { + warn!("failed to register peer URL for node {}: {}", peer.id, e); + } else { + info!( + "registered peer from bootnode: node {} -> {}", + peer.id, peer.url + ); + } + } + } + + Ok(()) +} diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 17ef2cf3..1b1c8b6e 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,29 +1,52 @@ -// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// SPDX-FileCopyrightText: 2024-2025 Phala Network dstack@phala.network // // SPDX-License-Identifier: Apache-2.0 use anyhow::{anyhow, Context, Result}; +use base64::{engine::general_purpose::STANDARD, Engine as _}; use clap::Parser; use config::{Config, TlsConfig}; use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, GetTlsKeyArgs}; +use dstack_kms_rpc::SignCertRequest; use http_client::prpc::PrpcClient; -use ra_rpc::{client::RaClient, rocket_helper::QuoteVerifier}; +use ra_rpc::{client::RaClient, prpc_routes as prpc, rocket_helper::QuoteVerifier}; +use ra_tls::cert::{CertConfigV2, CertSigningRequestV2, Csr}; +use ra_tls::rcgen::KeyPair; use rocket::{ fairing::AdHoc, figment::{providers::Serialized, Figment}, }; +use serde::{Deserialize, Serialize}; use tracing::info; use admin_service::AdminRpcHandler; -use main_service::{Proxy, RpcHandler}; +use main_service::{Proxy, ProxyOptions, RpcHandler}; + +use crate::debug_service::DebugRpcHandler; mod admin_service; +mod cert_store; mod config; +mod debug_service; +mod distributed_certbot; +mod kv; mod main_service; mod models; mod proxy; mod web_routes; +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DebugKeyData { + /// Private key in PEM format + key_pem: String, + /// TDX quote in base64 format + quote_base64: String, + /// Event log in JSON string format + event_log: String, + /// VM config in JSON string format + vm_config: String, +} + #[global_allocator] static ALLOCATOR: jemallocator::Jemalloc = jemallocator::Jemalloc; @@ -67,55 +90,131 @@ async fn maybe_gen_certs(config: &Config, tls_config: &TlsConfig) -> Result<()> return Ok(()); } - if config.run_in_dstack { - info!("Using dstack guest agent for certificate generation"); - let agent_client = dstack_agent().context("Failed to create dstack client")?; - let response = agent_client - .get_tls_key(GetTlsKeyArgs { - subject: "dstack-gateway".to_string(), - alt_names: vec![config.rpc_domain.clone()], - usage_ra_tls: true, - usage_server_auth: true, - usage_client_auth: false, - not_before: None, - not_after: None, - }) - .await?; - - let ca_cert = response - .certificate_chain - .last() - .context("Empty certificate chain")? - .to_string(); - let certs = response.certificate_chain.join("\n"); - write_cert(&tls_config.mutual.ca_certs, &ca_cert)?; - write_cert(&tls_config.certs, &certs)?; - write_cert(&tls_config.key, &response.key)?; - return Ok(()); + // Build alt_names: include rpc_domain and hostname from my_url + let mut alt_names = vec![config.rpc_domain.clone()]; + if let Ok(url) = reqwest::Url::parse(&config.sync.my_url) { + if let Some(host) = url.host_str() { + if host != config.rpc_domain { + alt_names.push(host.to_string()); + } + } } + match config.debug.insecure_skip_attestation { + true => gen_debug_certs(config, tls_config, alt_names).await, + false => gen_prod_certs(tls_config, alt_names).await, + } +} + +async fn gen_prod_certs(tls_config: &TlsConfig, alt_names: Vec) -> Result<()> { + info!("Using dstack guest agent for certificate generation"); + let agent_client = dstack_agent().context("Failed to create dstack client")?; + let response = agent_client + .get_tls_key(GetTlsKeyArgs { + subject: "dstack-gateway".to_string(), + alt_names, + usage_ra_tls: true, + usage_server_auth: true, + usage_client_auth: true, + not_before: None, + not_after: None, + }) + .await?; + + let ca_cert = response + .certificate_chain + .last() + .context("Empty certificate chain")? + .to_string(); + let certs = response.certificate_chain.join("\n"); + write_cert(&tls_config.mutual.ca_certs, &ca_cert)?; + write_cert(&tls_config.certs, &certs)?; + write_cert(&tls_config.key, &response.key)?; + Ok(()) +} + +async fn gen_debug_certs( + config: &Config, + tls_config: &TlsConfig, + alt_names: Vec, +) -> Result<()> { let kms_url = config.kms_url.clone(); if kms_url.is_empty() { info!("KMS URL is empty, skipping cert generation"); return Ok(()); } + + // Check if debug key file is configured + if config.debug.key_file.is_empty() { + info!("Debug key file not configured, skipping cert generation"); + return Ok(()); + } + + // Load pre-generated key pair and quote data from JSON file + info!("Loading debug key data from: {}", config.debug.key_file); + let ctx = "Failed to read debug key, run `cargo run --bin gen_debug_key -- ` to generate it"; + let json_content = fs_err::read_to_string(&config.debug.key_file).context(ctx)?; + let debug_data: DebugKeyData = + serde_json::from_str(&json_content).context("Failed to parse debug key JSON")?; + + let key_pem = debug_data.key_pem; + let quote_bin = STANDARD + .decode(&debug_data.quote_base64) + .context("Failed to decode quote from base64")?; + let event_log_json = debug_data.event_log; + let vm_config_json = debug_data.vm_config; + + // Parse key pair + let key = KeyPair::from_pem(&key_pem).context("Failed to parse debug key")?; + let pubkey = key.public_key_der(); + + // Build CSR with attestation from debug quote + let attestation = + ra_tls::attestation::Attestation::from_tdx_quote(quote_bin, event_log_json.as_bytes()) + .context("Failed to create attestation from debug quote")? + .into_versioned(); + + let csr = CertSigningRequestV2 { + confirm: "please sign cert:".to_string(), + pubkey, + config: CertConfigV2 { + org_name: None, + subject: "dstack-gateway".to_string(), + subject_alt_names: alt_names, + usage_server_auth: true, + usage_client_auth: true, + ext_quote: true, + not_before: None, + not_after: None, + }, + attestation, + }; + let signature = csr.signed_by(&key).context("Failed to sign CSR")?; + + // Send CSR to KMS for signing let kms_url = format!("{kms_url}/prpc"); - info!("Getting CA cert from {kms_url}"); - let client = RaClient::new(kms_url, true).context("Failed to create kms client")?; - let client = dstack_kms_rpc::kms_client::KmsClient::new(client); - let ca_cert = client.get_meta().await?.ca_cert; - let key = ra_tls::rcgen::KeyPair::generate().context("Failed to generate key")?; - let cert = ra_tls::cert::CertRequest::builder() - .key(&key) - .subject("dstack-gateway") - .alt_names(std::slice::from_ref(&config.rpc_domain)) - .usage_server_auth(true) - .build() - .self_signed() - .context("Failed to self-sign rpc cert")?; + info!("Sending CSR to KMS for signing: {kms_url}"); + let kms_client = RaClient::new(kms_url, true).context("Failed to create kms client")?; + let kms_client = dstack_kms_rpc::kms_client::KmsClient::new(kms_client); + let sign_response = kms_client + .sign_cert(SignCertRequest { + api_version: 2, + csr: csr.to_vec(), + signature, + vm_config: vm_config_json.to_string(), + }) + .await + .context("Failed to sign certificate via KMS")?; + + let ca_cert = sign_response + .certificate_chain + .last() + .context("Empty certificate chain")? + .to_string(); + let certs = sign_response.certificate_chain.join("\n"); write_cert(&tls_config.mutual.ca_certs, &ca_cert)?; - write_cert(&tls_config.certs, &cert.pem())?; + write_cert(&tls_config.certs, &certs)?; write_cert(&tls_config.key, &key.serialize_pem())?; Ok(()) } @@ -131,7 +230,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let _ = rustls::crypto::ring::default_provider().install_default(); @@ -140,9 +239,18 @@ async fn main() -> Result<()> { let figment = config::load_config_figment(args.config.as_deref()); let config = figment.focus("core").extract::()?; + + // Validate node_id + if config.sync.enabled && config.sync.node_id == 0 { + anyhow::bail!("node_id must be greater than 0"); + } + config::setup_wireguard(&config.wg)?; - let tls_config = figment.focus("tls").extract::()?; + let tls_config = figment + .focus("tls") + .extract::() + .context("Failed to extract tls config")?; maybe_gen_certs(&config, &tls_config) .await .context("Failed to generate certs")?; @@ -152,40 +260,51 @@ async fn main() -> Result<()> { set_max_ulimit()?; } - let my_app_id = if config.run_in_dstack { + let my_app_id = if config.debug.insecure_skip_attestation { + None + } else { let dstack_client = dstack_agent().context("Failed to create dstack client")?; let info = dstack_client .info() .await .context("Failed to get app info")?; Some(info.app_id) - } else { - None }; let proxy_config = config.proxy.clone(); let pccs_url = config.pccs_url.clone(); let admin_enabled = config.admin.enabled; - let state = main_service::Proxy::new(config, my_app_id).await?; + let debug_config = config.debug.clone(); + let state = Proxy::new(ProxyOptions { + config, + my_app_id, + tls_config, + }) + .await?; info!("Starting background tasks"); state.start_bg_tasks().await?; state.lock().reconfigure()?; proxy::start(proxy_config, state.clone()).context("failed to start the proxy")?; - let admin_figment = - Figment::new() - .merge(rocket::Config::default()) - .merge(Serialized::defaults( - figment - .find_value("core.admin") - .context("admin section not found")?, - )); + let admin_value = figment + .find_value("core.admin") + .context("admin section not found")?; + let debug_value = figment + .find_value("core.debug") + .context("debug section not found")?; + + let admin_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(admin_value)); + + let debug_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(debug_value)); let mut rocket = rocket::custom(figment) - .mount( - "/prpc", - ra_rpc::prpc_routes!(Proxy, RpcHandler, trim: "Tproxy."), - ) + .mount("/prpc", prpc!(Proxy, RpcHandler, trim: "Tproxy.")) + // Mount WaveKV sync endpoint (requires mTLS gateway auth) + .mount("/", web_routes::wavekv_sync_routes()) .attach(AdHoc::on_response("Add app version header", |_req, res| { Box::pin(async move { res.set_raw_header("X-App-Version", app_version()); @@ -195,12 +314,27 @@ async fn main() -> Result<()> { let verifier = QuoteVerifier::new(pccs_url); rocket = rocket.manage(verifier); let main_srv = rocket.launch(); + let admin_state = state.clone(); + let debug_state = state; let admin_srv = async move { if admin_enabled { rocket::custom(admin_figment) .mount("/", web_routes::routes()) - .mount("/", ra_rpc::prpc_routes!(Proxy, AdminRpcHandler)) - .manage(state) + .mount("/", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .mount("/prpc", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .manage(admin_state) + .launch() + .await + } else { + std::future::pending().await + } + }; + let debug_srv = async move { + if debug_config.insecure_enable_debug_rpc { + rocket::custom(debug_figment) + .mount("/prpc", prpc!(Proxy, DebugRpcHandler, trim: "Debug.")) + .mount("/", web_routes::health_routes()) + .manage(debug_state) .launch() .await } else { @@ -214,6 +348,9 @@ async fn main() -> Result<()> { result = admin_srv => { result.map_err(|err| anyhow!("Failed to start admin server: {err:?}"))?; } + result = debug_srv => { + result.map_err(|err| anyhow!("Failed to start debug server: {err:?}"))?; + } } Ok(()) } diff --git a/gateway/src/main_service.rs b/gateway/src/main_service.rs index e6bc6775..894a4393 100644 --- a/gateway/src/main_service.rs +++ b/gateway/src/main_service.rs @@ -3,29 +3,25 @@ // SPDX-License-Identifier: Apache-2.0 use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, net::Ipv4Addr, ops::Deref, - path::Path, - sync::{Arc, Mutex, MutexGuard, RwLock}, + sync::{Arc, Mutex, MutexGuard}, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use anyhow::{bail, Context, Result}; use auth_client::AuthClient; -use certbot::{CertBot, WorkDir}; + +use crate::distributed_certbot::DistributedCertBot; use cmd_lib::run_cmd as cmd; use dstack_gateway_rpc::{ gateway_server::{GatewayRpc, GatewayServer}, - AcmeInfoResponse, GatewayState, GuestAgentConfig, InfoResponse, QuotedPublicKey, - RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, + AcmeInfoResponse, GatewayNodeInfo, GetPeersResponse, GuestAgentConfig, InfoResponse, PeerInfo, + QuotedPublicKey, RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, }; -use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, RawQuoteArgs}; -use fs_err as fs; -use http_client::prpc::PrpcClient; use or_panic::ResultOrPanic; use ra_rpc::{CallContext, RpcCall, VerifiedAttestation}; -use ra_tls::attestation::QuoteContentType; use rand::seq::IteratorRandom; use rinja::Template as _; use safe_write::safe_write; @@ -36,13 +32,16 @@ use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, warn}; use crate::{ - config::Config, + cert_store::{CertResolver, CertStoreBuilder}, + config::{Config, TlsConfig}, + kv::{ + fetch_peers_from_bootnode, AppIdValidator, HttpsClientConfig, InstanceData, KvStore, + NodeData, NodeStatus, WaveKvSyncService, + }, models::{InstanceInfo, WgConf}, - proxy::{create_acceptor, AddressGroup, AddressInfo}, + proxy::{create_acceptor_with_cert_resolver, AddressGroup, AddressInfo}, }; -mod sync_client; - mod auth_client; #[derive(Clone)] @@ -59,26 +58,26 @@ impl Deref for Proxy { pub struct ProxyInner { pub(crate) config: Arc, - pub(crate) certbot: Option>, + /// Multi-domain certbot (from KvStore DNS credentials and domain configs) + pub(crate) certbot: Arc, my_app_id: Option>, state: Mutex, - notify_state_updated: Notify, + pub(crate) notify_state_updated: Notify, auth_client: AuthClient, - pub(crate) acceptor: RwLock, - pub(crate) h2_acceptor: RwLock, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct GatewayNodeInfo { - pub id: Vec, - pub url: String, - pub wg_peer: WireGuardPeer, - pub last_seen: SystemTime, + pub(crate) acceptor: TlsAcceptor, + pub(crate) h2_acceptor: TlsAcceptor, + /// Certificate resolver for SNI-based resolution (supports atomic updates) + pub(crate) cert_resolver: Arc, + /// WaveKV-based store for persistence (and cross-node sync when enabled) + kv_store: Arc, + /// WaveKV sync service for network synchronization + pub(crate) wavekv_sync: Option>, + /// HTTPS client config for mTLS (used for bootnode peer discovery) + https_config: Option, } #[derive(Debug, Serialize, Deserialize, Default)] pub(crate) struct ProxyStateMut { - pub(crate) nodes: BTreeMap, pub(crate) apps: BTreeMap>, pub(crate) instances: BTreeMap, pub(crate) allocated_addresses: BTreeSet, @@ -89,12 +88,22 @@ pub(crate) struct ProxyStateMut { pub(crate) struct ProxyState { pub(crate) config: Arc, pub(crate) state: ProxyStateMut, + /// Reference to KvStore for syncing changes + kv_store: Arc, +} + +/// Options for creating a Proxy instance +pub struct ProxyOptions { + pub config: Config, + pub my_app_id: Option>, + /// TLS configuration (from Rocket's tls config) + pub tls_config: TlsConfig, } impl Proxy { - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { Ok(Self { - _inner: Arc::new(ProxyInner::new(config, my_app_id).await?), + _inner: Arc::new(ProxyInner::new(options).await?), }) } } @@ -104,57 +113,149 @@ impl ProxyInner { self.state.lock().or_panic("Failed to lock AppState") } - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { + let ProxyOptions { + config, + my_app_id, + tls_config, + } = options; let config = Arc::new(config); - let mut state = fs::metadata(&config.state_path) - .is_ok() - .then(|| load_state(&config.state_path)) - .transpose() - .unwrap_or_else(|err| { - error!("Failed to load state: {err}"); - None - }) - .unwrap_or_default(); - state - .nodes - .retain(|_, info| info.wg_peer.ip != config.wg.ip.to_string()); - state.nodes.insert( - config.wg.public_key.clone(), - GatewayNodeInfo { - id: config.id(), - url: config.sync.my_url.clone(), - wg_peer: WireGuardPeer { - pk: config.wg.public_key.clone(), - ip: config.wg.ip.to_string(), - endpoint: config.wg.endpoint.clone(), - }, - last_seen: SystemTime::now(), - }, + + // Initialize WaveKV store without peers (peers will be added dynamically from bootnode) + let kv_store = Arc::new( + KvStore::new(config.sync.node_id, vec![], &config.sync.data_dir) + .context("failed to initialize WaveKV store")?, + ); + info!( + "WaveKV store initialized: node_id={}, sync_enabled={}", + config.sync.node_id, config.sync.enabled ); + + // Load state from WaveKV + let instances = kv_store.load_all_instances(); + let nodes = kv_store.load_all_nodes(); + info!( + "Loaded state from WaveKV: {} instances, {} nodes", + instances.len(), + nodes.len() + ); + let state = build_state_from_kv_store(instances); + + // Sync this node to KvStore + let node_data = NodeData { + uuid: config.uuid(), + url: config.sync.my_url.clone(), + wg_public_key: config.wg.public_key.clone(), + wg_endpoint: config.wg.endpoint.clone(), + wg_ip: config.wg.ip.to_string(), + }; + if let Err(err) = kv_store.sync_node(config.sync.node_id, &node_data) { + error!("Failed to sync this node to KvStore: {err:?}"); + } + // Set this node's status to Online + if let Err(err) = kv_store.set_node_status(config.sync.node_id, NodeStatus::Up) { + error!("Failed to set node status: {err:?}"); + } + // Register this node's sync URL in DB (for peer discovery) + if let Err(err) = kv_store.register_peer_url(config.sync.node_id, &config.sync.my_url) { + error!("Failed to register peer URL: {err:?}"); + } + + // Build HttpsClientConfig for mTLS communication + let https_config = { + let tls = &tls_config; + let cert_validator = my_app_id + .clone() + .map(|app_id| Arc::new(AppIdValidator::new(app_id)) as _); + HttpsClientConfig { + cert_path: tls.certs.clone(), + key_path: tls.key.clone(), + ca_cert_path: tls.mutual.ca_certs.clone(), + cert_validator, + } + }; + + // Fetch peers from bootnode if configured (only when sync is enabled) + if config.sync.enabled && !config.sync.bootnode.is_empty() { + if let Err(err) = fetch_peers_from_bootnode( + &config.sync.bootnode, + &kv_store, + config.sync.node_id, + &https_config, + ) + .await + { + warn!("Failed to fetch peers from bootnode: {err:?}"); + } + } + + // Create WaveKV sync service (only if sync is enabled) + let wavekv_sync = if config.sync.enabled { + match WaveKvSyncService::new(&kv_store, &config.sync, https_config.clone()) { + Ok(sync_service) => Some(Arc::new(sync_service)), + Err(err) => { + error!("Failed to create WaveKV sync service: {err:?}"); + None + } + } + } else { + None + }; + let state = Mutex::new(ProxyState { config: config.clone(), state, + kv_store: kv_store.clone(), }); let auth_client = AuthClient::new(config.auth.clone()); - let certbot = match config.certbot.enabled { - true => { - let certbot = config - .certbot - .build_bot() - .await - .context("Failed to build certbot")?; - info!("Certbot built, renewing..."); - // Try first renewal for the acceptor creation - certbot.renew(false).await.context("Failed to renew cert")?; - Some(Arc::new(certbot)) + // Bootstrap WaveKV first if sync is enabled, so certbot can load certs from peers + if let Some(ref wavekv_sync) = wavekv_sync { + info!("WaveKV: bootstrapping from peers..."); + if let Err(err) = wavekv_sync.bootstrap().await { + warn!("WaveKV bootstrap failed: {err:?}"); } - false => None, - }; - let acceptor = RwLock::new( - create_acceptor(&config.proxy, false).context("Failed to create acceptor")?, + } + + // Create CertResolver and load certificates from KvStore + let cert_resolver = Arc::new(CertResolver::new()); + let all_cert_data = kv_store.load_all_cert_data(); + if !all_cert_data.is_empty() { + let mut builder = CertStoreBuilder::new(); + for (domain, data) in &all_cert_data { + if let Err(err) = builder.add_cert(domain, data) { + warn!("failed to load certificate for {domain}: {err:?}"); + } + } + cert_resolver.set(Arc::new(builder.build())); + info!( + "CertStore: loaded {} certificates from KvStore", + all_cert_data.len() + ); + } + + // Create multi-domain certbot (uses KvStore configs for DNS credentials and domains) + let certbot = Arc::new(DistributedCertBot::new( + kv_store.clone(), + cert_resolver.clone(), + )); + // Initialize any configured domains + if let Err(err) = certbot.init_all().await { + warn!("Failed to initialize multi-domain certbot: {err:?}"); + } + + // Create TLS acceptors with CertResolver for SNI-based resolution + // CertResolver allows atomic certificate updates without recreating acceptors + info!( + "CertResolver initialized with {} domains", + cert_resolver.list_domains().len() ); + let acceptor = + create_acceptor_with_cert_resolver(&config.proxy, cert_resolver.clone(), false) + .context("failed to create acceptor with cert resolver")?; let h2_acceptor = - RwLock::new(create_acceptor(&config.proxy, true).context("Failed to create acceptor")?); + create_acceptor_with_cert_resolver(&config.proxy, cert_resolver.clone(), true) + .context("failed to create h2 acceptor with cert resolver")?; + Ok(Self { config, state, @@ -163,98 +264,202 @@ impl ProxyInner { auth_client, acceptor, h2_acceptor, + cert_resolver, certbot, + kv_store, + wavekv_sync, + https_config: Some(https_config), }) } + + pub(crate) fn kv_store(&self) -> &Arc { + &self.kv_store + } + + pub(crate) fn my_app_id(&self) -> Option<&[u8]> { + self.my_app_id.as_deref() + } } impl Proxy { pub(crate) async fn start_bg_tasks(&self) -> Result<()> { start_recycle_thread(self.clone()); - start_sync_task(self.clone()); - start_certbot_task(self.clone()).await?; + // Start WaveKV periodic sync (bootstrap already done in new()) + if let Some(ref wavekv_sync) = self.wavekv_sync { + start_wavekv_sync_task(self.clone(), wavekv_sync.clone()).await; + } + start_wavekv_watch_task(self.clone()).context("Failed to start WaveKV watch task")?; + start_certbot_task(self.clone()).await; + start_cert_store_watch_task(self.clone()); + start_zt_domain_watch_task(self.clone()); + start_bootnode_discovery_task(self.clone()); Ok(()) } - pub(crate) async fn renew_cert(&self, force: bool) -> Result { - let Some(certbot) = &self.certbot else { - return Ok(false); - }; - let renewed = certbot.renew(force).await.context("Failed to renew cert")?; - if renewed { - self.reload_certificates() - .context("Failed to reload certificates")?; - } - Ok(renewed) - } - - pub(crate) async fn acme_info(&self) -> Result { - let config = self.lock().config.clone(); - let workdir = WorkDir::new(&config.certbot.workdir); - let account_uri = workdir.acme_account_uri().unwrap_or_default(); - let keys = workdir.list_cert_public_keys().unwrap_or_default(); - let agent = crate::dstack_agent().context("Failed to get dstack agent")?; - let account_quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("acme-account"), - account_uri.as_bytes(), - workdir.acme_account_quote_path(), - ) - .await - .unwrap_or_default(); - let account_attestation = get_or_generate_attestation( - &agent, - QuoteContentType::Custom("acme-account"), - account_uri.as_bytes(), - workdir.acme_account_quote_path(), - ) - .await - .unwrap_or_default(); + /// Reload all certificates from KvStore into CertStore (atomic replacement) + pub(crate) fn reload_all_certs_from_kvstore(&self) -> Result<()> { + let all_cert_data = self.kv_store.load_all_cert_data(); + + // Build new CertStore from scratch + let mut builder = CertStoreBuilder::new(); + let mut loaded = 0; + for (domain, data) in &all_cert_data { + if let Err(err) = builder.add_cert(domain, data) { + warn!("failed to reload certificate for {domain}: {err:?}"); + } else { + loaded += 1; + } + } + + // Atomically replace the CertStore (no need to recreate acceptors) + self.cert_resolver.set(Arc::new(builder.build())); + info!("CertStore: reloaded {loaded} certificates from KvStore"); + Ok(()) + } + + /// Renew a specific domain certificate or all domains + pub(crate) async fn renew_cert(&self, domain: Option<&str>, force: bool) -> Result { + match domain { + Some(domain) => self + .certbot + .try_renew(domain, force) + .await + .context("failed to renew cert"), + None => { + // Renew all domains + self.certbot + .try_renew_all() + .await + .context("failed to renew all certs")?; + Ok(true) + } + } + } + + /// Get ACME info for all managed domains (or a specific domain) + pub(crate) fn acme_info(&self, domain: Option<&str>) -> Result { + let kv_store = self.kv_store.clone(); let mut quoted_hist_keys = vec![]; - for cert_path in workdir.list_certs().unwrap_or_default() { - let cert_pem = fs::read_to_string(&cert_path).context("Failed to read key")?; - let pubkey = certbot::read_pubkey(&cert_pem).context("Failed to read pubkey")?; - let quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("zt-cert"), - &pubkey, - cert_path.display().to_string() + ".quote", - ) - .await - .unwrap_or_default(); - let attestation = get_or_generate_attestation( - &agent, - QuoteContentType::Custom("zt-cert"), - &pubkey, - cert_path.display().to_string() + ".quote", - ) - .await + + // Get domains to query + let domains: Vec = match domain { + Some(d) => vec![d.to_string()], + None => kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect(), + }; + + // Get account_uri, account_quote and account_attestation from global ACME attestation + let (account_uri, account_quote, account_attestation) = kv_store + .get_acme_attestation() + .map(|att| (att.account_uri, att.quote, att.attestation)) .unwrap_or_default(); - quoted_hist_keys.push(QuotedPublicKey { - public_key: pubkey, - quote, - attestation, - }); - } - let active_cert = - fs::read_to_string(workdir.cert_path()).context("Failed to read active cert")?; + for domain in &domains { + // Get all attestations for this domain + let attestations = kv_store.list_cert_attestations(domain); + for att in attestations { + quoted_hist_keys.push(QuotedPublicKey { + public_key: att.public_key, + quote: att.quote, + attestation: att.attestation, + }); + } + } Ok(AcmeInfoResponse { account_uri, - hist_keys: keys.into_iter().collect(), account_quote, account_attestation, quoted_hist_keys, - active_cert, - base_domain: config.proxy.base_domain.clone(), }) } + + /// Register a CVM with the given app_id, instance_id and client_public_key + pub fn do_register_cvm( + &self, + app_id: &str, + instance_id: &str, + client_public_key: &str, + ) -> Result { + let mut state = self.lock(); + + // Check if this node is marked as down + let my_status = state.kv_store.get_node_status(state.config.sync.node_id); + if matches!(my_status, NodeStatus::Down) { + bail!("this gateway node is marked as down and cannot accept new registrations"); + } + + if app_id.is_empty() { + bail!("[{instance_id}] app id is empty"); + } + if instance_id.is_empty() { + bail!("[{instance_id}] instance id is empty"); + } + if client_public_key.is_empty() { + bail!("[{instance_id}] client public key is empty"); + } + let client_info = state + .new_client_by_id(instance_id, app_id, client_public_key) + .context("failed to allocate IP address for client")?; + if let Err(err) = state.reconfigure() { + error!("failed to reconfigure: {err:?}"); + } + let gateways = state.get_active_nodes(); + let servers = gateways + .iter() + .map(|n| WireGuardPeer { + pk: n.wg_public_key.clone(), + ip: n.wg_ip.clone(), + endpoint: n.wg_endpoint.clone(), + }) + .collect::>(); + let (base_domain, port) = state.kv_store.get_best_zt_domain().unwrap_or_default(); + let response = RegisterCvmResponse { + wg: Some(WireGuardConfig { + client_ip: client_info.ip.to_string(), + servers, + }), + agent: Some(GuestAgentConfig { + external_port: port.into(), + internal_port: 8090, + domain: base_domain, + app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), + }), + gateways, + }; + self.notify_state_updated.notify_one(); + Ok(response) + } } -fn load_state(state_path: &str) -> Result { - let state_str = fs::read_to_string(state_path).context("Failed to read state")?; - serde_json::from_str(&state_str).context("Failed to load state") +fn build_state_from_kv_store(instances: BTreeMap) -> ProxyStateMut { + let mut state = ProxyStateMut::default(); + + // Build instances + for (instance_id, data) in instances { + let info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key, + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + state.allocated_addresses.insert(data.ip); + state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.instances.insert(instance_id, info); + } + + state } fn start_recycle_thread(proxy: Proxy) { @@ -265,42 +470,322 @@ fn start_recycle_thread(proxy: Proxy) { std::thread::spawn(move || loop { std::thread::sleep(proxy.config.recycle.interval); if let Err(err) = proxy.lock().recycle() { - error!("failed to run recycle: {err}"); + error!("failed to run recycle: {err:?}"); }; }); } -async fn start_certbot_task(proxy: Proxy) -> Result<()> { - let Some(certbot) = proxy.certbot.clone() else { - info!("Certbot is not enabled"); - return Ok(()); +/// Start periodic certificate renewal task for multi-domain certbot +async fn start_certbot_task(proxy: Proxy) { + info!("starting certificate renewal task"); + + // Periodic renewal task for all domains + tokio::spawn(async move { + // Run once at startup to check for any pending renewals + info!("running initial certificate renewal check"); + if let Err(err) = proxy.renew_cert(None, false).await { + error!("failed initial certificate renewal: {err:?}"); + } + + loop { + // Get current config from KV store (allows dynamic updates) + let renew_interval = proxy.kv_store.get_certbot_config().renew_interval; + if renew_interval.is_zero() { + // Check again later if disabled + tokio::time::sleep(Duration::from_secs(60)).await; + continue; + } + + // Wait for the interval + tokio::time::sleep(renew_interval).await; + + // Renew certificates + if let Err(err) = proxy.renew_cert(None, false).await { + error!("failed to renew certificates: {err:?}"); + } + } + }); +} + +/// Watch for certificate changes from KvStore and update CertStore +fn start_cert_store_watch_task(proxy: Proxy) { + let kv_store = proxy.kv_store.clone(); + + // Watch for any certificate changes (all domains) + let mut rx = kv_store.watch_all_certs(); + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected certificate changes, reloading CertStore..."); + if let Err(err) = proxy.reload_all_certs_from_kvstore() { + error!("Failed to reload certificates from KvStore: {err:?}"); + } + } + }); + info!("CertStore watch task started"); +} + +/// Watch for ZT-Domain config changes and auto-renew certificates +fn start_zt_domain_watch_task(proxy: Proxy) { + let kv_store = proxy.kv_store.clone(); + let certbot = proxy.certbot.clone(); + + let mut rx = kv_store.watch_zt_domain_configs(); + tokio::spawn(async move { + // Track known domains to detect additions + let mut known_domains = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect::>(); + + loop { + if rx.changed().await.is_err() { + break; + } + + // Get current domains + let current_domains: HashSet = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect(); + + // Find newly added domains + let new_domains: Vec = current_domains + .iter() + .filter(|d| !known_domains.contains(*d)) + .cloned() + .collect(); + + // Update known domains + known_domains = current_domains; + + // Trigger renewal for new domains + for domain in new_domains { + info!("ZT-Domain added: {domain}, attempting certificate request..."); + let certbot = certbot.clone(); + tokio::spawn(async move { + match certbot.try_renew(&domain, false).await { + Ok(renewed) => { + if renewed { + info!("cert[{domain}]: successfully issued/renewed"); + } else { + info!("cert[{domain}]: renewal not needed or another node is handling it"); + } + } + Err(e) => { + warn!("cert[{domain}]: auto-renewal failed: {e:?}"); + } + } + }); + } + } + }); + info!("ZT-Domain watch task started"); +} + +/// Periodically retry bootnode peer discovery if no peers are available +fn start_bootnode_discovery_task(proxy: Proxy) { + if !proxy.config.sync.enabled || proxy.config.sync.bootnode.is_empty() { + return; + } + + let bootnode = proxy.config.sync.bootnode.clone(); + let node_id = proxy.config.sync.node_id; + let kv_store = proxy.kv_store.clone(); + let https_config = match &proxy.https_config { + Some(config) => config.clone(), + None => return, }; + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { - tokio::time::sleep(certbot.renew_interval()).await; - if let Err(err) = proxy.renew_cert(false).await { - error!("Failed to renew cert: {err}"); + interval.tick().await; + // Check if we already have peers + let n_peers = kv_store + .load_all_node_statuses() + .keys() + .filter(|&id| *id != node_id) + .count(); + if n_peers > 0 { + info!("bootnode peer discovery finished, {n_peers} peers found"); + break; + } + // Try to fetch peers from bootnode + debug!("retrying bootnode peer discovery..."); + if let Err(err) = + fetch_peers_from_bootnode(&bootnode, &kv_store, node_id, &https_config).await + { + warn!("bootnode discovery retry failed: {err:?}"); + } else { + info!("bootnode peer discovery succeeded"); } } }); - Ok(()) + info!("Bootnode discovery task started (will retry every 10s if no peers)"); } -fn start_sync_task(proxy: Proxy) { +async fn start_wavekv_sync_task(proxy: Proxy, wavekv_sync: Arc) { if !proxy.config.sync.enabled { - info!("sync is disabled"); + info!("WaveKV sync is disabled"); return; } + + // Bootstrap already done in ProxyInner::new() before certbot init + // Peers are discovered from bootnode or via Admin.SetNodeInfo RPC + + // Start periodic sync tasks (runs forever in background) + tokio::spawn(async move { + wavekv_sync.start_sync_tasks().await; + }); + info!("WaveKV sync tasks started"); +} + +fn start_wavekv_watch_task(proxy: Proxy) -> Result<()> { + let kv_store = proxy.kv_store.clone(); + + // Watch for instance changes + let proxy_clone = proxy.clone(); + let store_clone = kv_store.clone(); + // Register watcher first, then do initial load to avoid race condition + let mut rx = store_clone.watch_instances(); + reload_instances_from_kv_store(&proxy_clone, &store_clone) + .context("Failed to initial load instances from KvStore")?; tokio::spawn(async move { - match sync_client::sync_task(proxy).await { - Ok(_) => info!("Sync task exited"), - Err(err) => error!("Failed to run sync task: {err}"), + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote instance changes, reloading..."); + if let Err(err) = reload_instances_from_kv_store(&proxy_clone, &store_clone) { + error!("Failed to reload instances from KvStore: {err:?}"); + } } }); + + // Initial WireGuard configuration + proxy.lock().reconfigure()?; + + // Watch for node changes and reconfigure WireGuard + let mut rx = kv_store.watch_nodes(); + let proxy_for_nodes = proxy.clone(); + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote node changes, reconfiguring WireGuard..."); + if let Err(err) = proxy_for_nodes.lock().reconfigure() { + error!("Failed to reconfigure WireGuard: {err:?}"); + } + } + }); + + // Start periodic persistence task + let persist_interval = proxy.config.sync.persist_interval; + if !persist_interval.is_zero() { + let kv_store_for_persist = kv_store.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(persist_interval); + loop { + ticker.tick().await; + match kv_store_for_persist.persist_if_dirty() { + Ok(true) => info!("WaveKV: periodic persist completed"), + Ok(false) => {} // No changes to persist + Err(err) => error!("WaveKV: periodic persist failed: {err:?}"), + } + } + }); + info!("WaveKV: periodic persistence enabled (interval: {persist_interval:?})"); + } + + // Start periodic connection sync task + if proxy.config.sync.sync_connections_enabled { + let sync_interval = proxy.config.sync.sync_connections_interval; + let proxy_for_sync = proxy.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(sync_interval); + loop { + ticker.tick().await; + let state = proxy_for_sync.lock(); + for (instance_id, instance) in &state.state.instances { + let count = instance.num_connections(); + state.sync_connections(instance_id, count); + } + } + }); + info!( + "WaveKV: periodic connection sync enabled (interval: {:?})", + proxy.config.sync.sync_connections_interval + ); + } + + Ok(()) +} + +fn reload_instances_from_kv_store(proxy: &Proxy, store: &KvStore) -> Result<()> { + let instances = store.load_all_instances(); + let mut state = proxy.lock(); + let mut wg_changed = false; + + for (instance_id, data) in instances { + let new_info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key.clone(), + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + + let old_ip = state.state.instances.get(&instance_id).map(|e| e.ip); + if let Some(existing) = state.state.instances.get(&instance_id) { + // Check if wg config needs update + if existing.public_key != data.public_key || existing.ip != data.ip { + wg_changed = true; + } + // Only update if remote is newer (based on reg_time) + if data.reg_time <= encode_ts(existing.reg_time) { + continue; + } + } else { + wg_changed = true; + } + + // Release old IP if it changed (prevent IP leak) + if let Some(old_ip) = old_ip { + if old_ip != data.ip { + state.state.allocated_addresses.remove(&old_ip); + } + } + state.state.allocated_addresses.insert(data.ip); + state + .state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.state.instances.insert(instance_id, new_info); + } + + if wg_changed { + state.reconfigure()?; + } + Ok(()) } impl ProxyState { fn valid_ip(&self, ip: Ipv4Addr) -> bool { + // Must be within client IP range + if !self.config.wg.client_ip_range.contains(&ip) { + return false; + } if self.config.wg.ip.broadcast() == ip { return false; } @@ -342,12 +827,25 @@ impl ProxyState { return None; } if let Some(existing) = self.state.instances.get_mut(id) { - if existing.public_key != public_key { + let pubkey_changed = existing.public_key != public_key; + if pubkey_changed { info!("public key changed for instance {id}, new key: {public_key}"); existing.public_key = public_key.to_string(); + // Update reg_time so other nodes will pick up the change + existing.reg_time = SystemTime::now(); } let existing = existing.clone(); if self.valid_ip(existing.ip) { + // Sync existing instance to KvStore (might be from legacy state) + let data = InstanceData { + app_id: existing.app_id.clone(), + ip: existing.ip, + public_key: existing.public_key.clone(), + reg_time: encode_ts(existing.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&existing.id, &data) { + error!("failed to sync existing instance to KvStore: {err:?}"); + } return Some(existing); } info!("ip {} is invalid, removing", existing.ip); @@ -360,7 +858,6 @@ impl ProxyState { ip, public_key: public_key.to_string(), reg_time: SystemTime::now(), - last_seen: SystemTime::now(), connections: Default::default(), }; self.add_instance(host_info.clone()); @@ -368,6 +865,17 @@ impl ProxyState { } fn add_instance(&mut self, info: InstanceInfo) { + // Sync to KvStore + let data = InstanceData { + app_id: info.app_id.clone(), + ip: info.ip, + public_key: info.public_key.clone(), + reg_time: encode_ts(info.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&info.id, &data) { + error!("failed to sync instance to KvStore: {err:?}"); + } + self.state .apps .entry(info.app_id.clone()) @@ -394,15 +902,8 @@ impl ProxyState { match cmd!(wg syncconf $ifname $config_path) { Ok(_) => info!("wg config updated"), - Err(e) => error!("failed to set wg config: {e}"), + Err(err) => error!("failed to set wg config: {err:?}"), } - self.save_state()?; - Ok(()) - } - - fn save_state(&self) -> Result<()> { - let state_str = serde_json::to_string(&self.state).context("Failed to serialize state")?; - safe_write(&self.config.state_path, state_str).context("Failed to write state")?; Ok(()) } @@ -437,7 +938,7 @@ impl ProxyState { let handshakes = self.latest_handshakes(None); let mut instances = match handshakes { Err(err) => { - warn!("Failed to get handshakes, fallback to random selection: {err}"); + warn!("Failed to get handshakes, fallback to random selection: {err:?}"); return Ok(self.random_select_a_host(id).unwrap_or_default()); } Ok(handshakes) => app_instances @@ -549,6 +1050,12 @@ impl ProxyState { .instances .remove(id) .context("instance not found")?; + + // Sync deletion to KvStore + if let Err(err) = self.kv_store.sync_delete_instance(id) { + error!("Failed to sync instance deletion to KvStore: {err:?}"); + } + self.state.allocated_addresses.remove(&info.ip); if let Some(app_instances) = self.state.apps.get_mut(&info.app_id) { app_instances.remove(id); @@ -560,48 +1067,50 @@ impl ProxyState { } fn recycle(&mut self) -> Result<()> { - // Recycle stale Gateway nodes - let mut staled_nodes = vec![]; - for node in self.state.nodes.values() { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.last_seen.elapsed().unwrap_or_default() > self.config.recycle.node_timeout { - staled_nodes.push(node.wg_peer.pk.clone()); - } - } - for id in staled_nodes { - self.state.nodes.remove(&id); + // Refresh state: sync local handshakes to KvStore, update local last_seen from global + if let Err(err) = self.refresh_state() { + warn!("failed to refresh state: {err:?}"); } - // Recycle stale CVM instances + // Note: Gateway nodes are not removed from KvStore, only marked offline/retired + + // Recycle stale CVM instances based on global last_seen (max across all nodes) let stale_timeout = self.config.recycle.timeout; - let stale_handshakes = self.latest_handshakes(Some(stale_timeout))?; - if tracing::enabled!(tracing::Level::DEBUG) { - for (pubkey, (ts, elapsed)) in &stale_handshakes { - debug!("stale instance: {pubkey} recent={ts} ({elapsed:?} ago)"); - } - } - // Find and remove instances with matching public keys + let now = SystemTime::now(); + let stale_instances: Vec<_> = self .state .instances .iter() - .filter(|(_, info)| { - stale_handshakes.contains_key(&info.public_key) && { - info.reg_time.elapsed().unwrap_or_default() > stale_timeout + .filter(|(id, info)| { + // Skip if instance was registered recently + if info.reg_time.elapsed().unwrap_or_default() <= stale_timeout { + return false; + } + // Check global last_seen from KvStore (max across all nodes) + let global_ts = self.kv_store.get_instance_latest_handshake(id); + let last_seen = global_ts.map(decode_ts).unwrap_or(info.reg_time); + let elapsed = now.duration_since(last_seen).unwrap_or_default(); + if elapsed > stale_timeout { + debug!( + "stale instance: {} last_seen={:?} ({:?} ago)", + id, last_seen, elapsed + ); + true + } else { + false } }) - .map(|(id, _info)| id.clone()) + .map(|(id, _)| id.clone()) .collect(); - debug!("stale instances: {:#?}", stale_instances); + let num_recycled = stale_instances.len(); for id in stale_instances { self.remove_instance(&id)?; } - info!("recycled {num_recycled} stale instances"); - // Reconfigure WireGuard with updated peers + if num_recycled > 0 { + info!("recycled {num_recycled} stale instances"); self.reconfigure()?; } Ok(()) @@ -611,89 +1120,94 @@ impl ProxyState { std::process::exit(0); } - fn dedup_nodes(&mut self) { - // Dedup nodes by URL, keeping the latest one - let mut node_map = BTreeMap::::new(); + pub(crate) fn refresh_state(&mut self) -> Result<()> { + // Get local WG handshakes and sync to KvStore + let handshakes = self.latest_handshakes(None)?; + + // Build a map from public_key to instance_id for lookup + let pk_to_id: BTreeMap<&str, &str> = self + .state + .instances + .iter() + .map(|(id, info)| (info.public_key.as_str(), id.as_str())) + .collect(); - for node in std::mem::take(&mut self.state.nodes).into_values() { - match node_map.get(&node.wg_peer.endpoint) { - Some(existing) if existing.last_seen >= node.last_seen => {} - _ => { - node_map.insert(node.wg_peer.endpoint.clone(), node); + // Sync local handshake observations to KvStore + for (pk, (ts, _)) in &handshakes { + if let Some(&instance_id) = pk_to_id.get(pk.as_str()) { + if let Err(err) = self.kv_store.sync_instance_handshake(instance_id, *ts) { + debug!("failed to sync instance handshake: {err:?}"); } } } - for node in node_map.into_values() { - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + + // Update this node's last_seen in KvStore + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + if let Err(err) = self + .kv_store + .sync_node_last_seen(self.config.sync.node_id, now) + { + debug!("failed to sync node last_seen: {err:?}"); } + Ok(()) } - fn update_state( - &mut self, - proxy_nodes: Vec, - apps: Vec, - ) -> Result<()> { - for node in proxy_nodes { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.url == self.config.sync.my_url { - continue; - } - if let Some(existing) = self.state.nodes.get(&node.wg_peer.pk) { - if node.last_seen <= existing.last_seen { - continue; - } - } - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + /// Sync connection count for an instance to KvStore + pub(crate) fn sync_connections(&self, instance_id: &str, count: u64) { + if let Err(err) = self.kv_store.sync_connections(instance_id, count) { + debug!("Failed to sync connections: {err:?}"); } - self.dedup_nodes(); + } - let mut wg_changed = false; - for app in apps { - if let Some(existing) = self.state.instances.get(&app.id) { - let existing_ts = (existing.reg_time, existing.last_seen); - let update_ts = (app.reg_time, app.last_seen); - if update_ts <= existing_ts { - continue; - } - if !wg_changed { - wg_changed = existing.public_key != app.public_key || existing.ip != app.ip; - } - } else { - wg_changed = true; - } - self.add_instance(app); - } - info!("updated, wg_changed: {wg_changed}"); - if wg_changed { - self.reconfigure()?; - } else { - self.save_state()?; - } - Ok(()) + /// Get latest handshake for an instance from KvStore (max across all nodes) + pub(crate) fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.kv_store.get_instance_latest_handshake(instance_id) } - fn dump_state(&mut self) -> (Vec, Vec) { - self.refresh_state().ok(); - ( - self.state.nodes.values().cloned().collect(), - self.state.instances.values().cloned().collect(), - ) + /// Get all nodes from KvStore (for admin API - includes all nodes) + pub(crate) fn get_all_nodes(&self) -> Vec { + self.get_all_nodes_filtered(false) } - pub(crate) fn refresh_state(&mut self) -> Result<()> { - let handshakes = self.latest_handshakes(None)?; - for instance in self.state.instances.values_mut() { - let Some((ts, _)) = handshakes.get(&instance.public_key).copied() else { - continue; - }; - instance.last_seen = decode_ts(ts); - } - if let Some(node) = self.state.nodes.get_mut(&self.config.wg.public_key) { - node.last_seen = SystemTime::now(); - } - Ok(()) + /// Get nodes for CVM registration (excludes nodes with status "down") + pub(crate) fn get_active_nodes(&self) -> Vec { + self.get_all_nodes_filtered(true) + } + + /// Get all nodes from KvStore with optional filtering + fn get_all_nodes_filtered(&self, exclude_down: bool) -> Vec { + let node_statuses = if exclude_down { + self.kv_store.load_all_node_statuses() + } else { + Default::default() + }; + + self.kv_store + .load_all_nodes() + .into_iter() + .filter(|(id, _)| { + if !exclude_down { + return true; + } + // Exclude nodes with status "down" + match node_statuses.get(id) { + Some(NodeStatus::Down) => false, + _ => true, // Include Up or nodes without explicit status + } + }) + .map(|(id, node)| GatewayNodeInfo { + id, + uuid: node.uuid, + wg_public_key: node.wg_public_key, + wg_ip: node.wg_ip, + wg_endpoint: node.wg_endpoint, + url: node.url, + last_seen: self.kv_store.get_node_latest_last_seen(id).unwrap_or(0), + }) + .collect() } } @@ -715,7 +1229,7 @@ pub struct RpcHandler { impl RpcHandler { fn ensure_from_gateway(&self) -> Result<()> { - if !self.state.config.run_in_dstack { + if self.state.config.debug.insecure_skip_attestation { return Ok(()); } if self.remote_app_id.is_none() { @@ -743,124 +1257,44 @@ impl GatewayRpc for RpcHandler { .context("App authorization failed")?; let app_id = hex::encode(&app_info.app_id); let instance_id = hex::encode(&app_info.instance_id); - - let mut state = self.state.lock(); - if request.client_public_key.is_empty() { - bail!("[{instance_id}] client public key is empty"); - } - let client_info = state - .new_client_by_id(&instance_id, &app_id, &request.client_public_key) - .context("failed to allocate IP address for client")?; - if let Err(err) = state.reconfigure() { - error!("failed to reconfigure: {}", err); - } - let servers = state - .state - .nodes - .values() - .map(|n| n.wg_peer.clone()) - .collect::>(); - let response = RegisterCvmResponse { - wg: Some(WireGuardConfig { - client_ip: client_info.ip.to_string(), - servers, - }), - agent: Some(GuestAgentConfig { - external_port: state.config.proxy.external_port as u32, - internal_port: state.config.proxy.agent_port as u32, - domain: state.config.proxy.base_domain.clone(), - app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), - }), - }; - self.state.notify_state_updated.notify_one(); - Ok(response) + self.state + .do_register_cvm(&app_id, &instance_id, &request.client_public_key) } async fn acme_info(self) -> Result { - self.state.acme_info().await - } - - async fn update_state(self, request: GatewayState) -> Result<()> { - self.ensure_from_gateway()?; - let mut nodes = vec![]; - let mut apps = vec![]; - - for node in request.nodes { - nodes.push(GatewayNodeInfo { - id: node.id, - wg_peer: node.wg_peer.context("wg_peer is missing")?, - last_seen: decode_ts(node.last_seen), - url: node.url, - }); - } - - for app in request.apps { - apps.push(InstanceInfo { - id: app.instance_id, - app_id: app.app_id, - ip: app.ip.parse().context("Invalid IP address")?, - public_key: app.public_key, - reg_time: decode_ts(app.reg_time), - last_seen: decode_ts(app.last_seen), - connections: Default::default(), - }); - } - - self.state - .lock() - .update_state(nodes, apps) - .context("failed to update state")?; - Ok(()) + self.state.acme_info(None) } async fn info(self) -> Result { let state = self.state.lock(); + let (base_domain, port) = state.kv_store.get_best_zt_domain().unwrap_or_default(); Ok(InfoResponse { - base_domain: state.config.proxy.base_domain.clone(), - external_port: state.config.proxy.external_port as u32, + base_domain, + external_port: port.into(), app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), }) } -} -async fn get_or_generate_quote( - agent: &DstackGuestClient, - content_type: QuoteContentType<'_>, - payload: &[u8], - quote_path: impl AsRef, -) -> Result { - let quote_path = quote_path.as_ref(); - if fs::metadata(quote_path).is_ok() { - return fs::read_to_string(quote_path).context("Failed to read quote"); - } - let report_data = content_type.to_report_data(payload).to_vec(); - let response = agent - .get_quote(RawQuoteArgs { report_data }) - .await - .context("Failed to get quote")?; - let quote = serde_json::to_string(&response).context("Failed to serialize quote")?; - safe_write(quote_path, "e).context("Failed to write quote")?; - Ok(quote) -} + async fn get_peers(self) -> Result { + self.ensure_from_gateway()?; + + let kv_store = self.state.kv_store(); + let config = &self.state.config; + + // Get all peer addresses from KvStore + let peer_addrs = kv_store.get_all_peer_addrs(); + + let peers: Vec = peer_addrs + .into_iter() + .map(|(id, url)| PeerInfo { id, url }) + .collect(); -async fn get_or_generate_attestation( - agent: &DstackGuestClient, - content_type: QuoteContentType<'_>, - payload: &[u8], - quote_path: impl AsRef, -) -> Result { - let quote_path = quote_path.as_ref(); - if fs::metadata(quote_path).is_ok() { - return fs::read_to_string(quote_path).context("Failed to read quote"); - } - let report_data = content_type.to_report_data(payload).to_vec(); - let response = agent - .attest(RawQuoteArgs { report_data }) - .await - .context("Failed to get quote")?; - let attestation = serde_json::to_string(&response).context("Failed to serialize quote")?; - safe_write(quote_path, &attestation).context("Failed to write quote")?; - Ok(attestation) + Ok(GetPeersResponse { + my_id: config.sync.node_id, + my_url: config.sync.my_url.clone(), + peers, + }) + } } impl RpcCall for RpcHandler { @@ -875,30 +1309,5 @@ impl RpcCall for RpcHandler { } } -impl From for dstack_gateway_rpc::GatewayNodeInfo { - fn from(node: GatewayNodeInfo) -> Self { - Self { - id: node.id, - wg_peer: Some(node.wg_peer), - last_seen: encode_ts(node.last_seen), - url: node.url, - } - } -} - -impl From for dstack_gateway_rpc::AppInstanceInfo { - fn from(app: InstanceInfo) -> Self { - Self { - num_connections: app.num_connections(), - instance_id: app.id, - app_id: app.app_id, - ip: app.ip.to_string(), - public_key: app.public_key, - reg_time: encode_ts(app.reg_time), - last_seen: encode_ts(app.last_seen), - } - } -} - #[cfg(test)] mod tests; diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap index 67b4180d..f211b458 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap @@ -1,6 +1,6 @@ --- source: gateway/src/main_service/tests.rs -assertion_line: 36 +assertion_line: 71 expression: info1 --- InstanceInfo { @@ -12,9 +12,5 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, - last_seen: SystemTime { - tv_sec: 0, - tv_nsec: 0, - }, connections: 0, } diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap index ef5978f3..5b07304c 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap @@ -1,6 +1,6 @@ --- source: gateway/src/main_service/tests.rs -assertion_line: 29 +assertion_line: 65 expression: info --- InstanceInfo { @@ -12,9 +12,5 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, - last_seen: SystemTime { - tv_sec: 0, - tv_nsec: 0, - }, connections: 0, } diff --git a/gateway/src/main_service/sync_client.rs b/gateway/src/main_service/sync_client.rs deleted file mode 100644 index be6985f7..00000000 --- a/gateway/src/main_service/sync_client.rs +++ /dev/null @@ -1,185 +0,0 @@ -// SPDX-FileCopyrightText: © 2025 Phala Network -// -// SPDX-License-Identifier: Apache-2.0 - -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use dstack_gateway_rpc::{gateway_client::GatewayClient, GatewayState}; -use dstack_guest_agent_rpc::GetTlsKeyArgs; -use ra_rpc::client::{RaClient, RaClientConfig}; -use tracing::{error, info}; - -use crate::{dstack_agent, main_service::Proxy}; - -struct SyncClient { - in_dstack: bool, - cert_pem: String, - key_pem: String, - ca_cert_pem: String, - app_id: Vec, - timeout: Duration, - pccs_url: Option, -} - -impl SyncClient { - fn create_rpc_client(&self, url: &str) -> Result> { - let app_id = self.app_id.clone(); - let url = format!("{}/prpc", url.trim_end_matches('/')); - let client = if self.in_dstack { - RaClientConfig::builder() - .remote_uri(url) - // Don't verify server RA because we use the CA cert from KMS to verify - // the server cert. - .verify_server_attestation(false) - .tls_no_check(true) - .tls_no_check_hostname(false) - .tls_client_cert(self.cert_pem.clone()) - .tls_client_key(self.key_pem.clone()) - .tls_ca_cert(self.ca_cert_pem.clone()) - .tls_built_in_root_certs(false) - .maybe_pccs_url(self.pccs_url.clone()) - .cert_validator(Box::new(move |cert| { - let cert = cert.context("TLS cert not found")?; - let remote_app_id = cert.app_id.context("App id not found")?; - if remote_app_id != app_id { - return Err(anyhow::anyhow!("Remote app id mismatch")); - } - Ok(()) - })) - .build() - .into_client() - .context("failed to create client")? - } else { - RaClient::new(url, true)? - }; - Ok(GatewayClient::new(client)) - } - - async fn sync_state(&self, url: &str, state: &GatewayState) -> Result<()> { - info!("Trying to sync state to {url}"); - let rpc = self.create_rpc_client(url)?; - tokio::time::timeout(self.timeout, rpc.update_state(state.clone())) - .await - .ok() - .context("Timeout while syncing state")? - .context("Failed to sync state")?; - info!("Synced state to {url}"); - Ok(()) - } - - async fn sync_state_ignore_error(&self, url: &str, state: &GatewayState) -> bool { - match self.sync_state(url, state).await { - Ok(_) => true, - Err(e) => { - error!("Failed to sync state to {url}: {e:?}"); - false - } - } - } -} - -pub(crate) async fn sync_task(proxy: Proxy) -> Result<()> { - let config = proxy.config.clone(); - let sync_client = if config.run_in_dstack { - let agent = dstack_agent().context("Failed to create dstack agent client")?; - let keys = agent - .get_tls_key(GetTlsKeyArgs { - subject: "dstack-gateway-sync-client".into(), - alt_names: vec![], - usage_ra_tls: false, - usage_server_auth: false, - usage_client_auth: true, - not_after: None, - not_before: None, - }) - .await - .context("Failed to get sync-client keys")?; - let my_app_id = agent - .info() - .await - .context("Failed to get guest info")? - .app_id; - SyncClient { - in_dstack: true, - cert_pem: keys.certificate_chain.join("\n"), - key_pem: keys.key, - ca_cert_pem: keys.certificate_chain.last().cloned().unwrap_or_default(), - app_id: my_app_id, - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - } else { - SyncClient { - in_dstack: false, - cert_pem: "".into(), - key_pem: "".into(), - ca_cert_pem: "".into(), - app_id: vec![], - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - }; - - let mut last_broadcast_time = Instant::now(); - let mut broadcast = false; - loop { - if broadcast { - last_broadcast_time = Instant::now(); - } - - let (mut nodes, apps) = proxy.lock().dump_state(); - // Sort nodes by pubkey - nodes.sort_by(|a, b| a.id.cmp(&b.id)); - - let self_idx = nodes - .iter() - .position(|n| n.wg_peer.pk == config.wg.public_key) - .unwrap_or(0); - - let state = GatewayState { - nodes: nodes.into_iter().map(|n| n.into()).collect(), - apps: apps.into_iter().map(|a| a.into()).collect(), - }; - - if state.nodes.is_empty() { - // If no nodes exist yet, sync with bootnode - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } else { - let nodes = &state.nodes; - // Try nodes after self, wrapping around to beginning - let mut success = false; - for i in 1..nodes.len() { - let idx = (self_idx + i) % nodes.len(); - if sync_client - .sync_state_ignore_error(&nodes[idx].url, &state) - .await - { - success = true; - if !broadcast { - break; - } - } - } - - // If no node succeeded, try bootnode as fallback - if !success { - info!("Fallback to sync with bootnode"); - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } - } - - tokio::select! { - _ = proxy.notify_state_updated.notified() => { - broadcast = true; - } - _ = tokio::time::sleep(config.sync.interval) => { - broadcast = last_broadcast_time.elapsed() >= config.sync.broadcast_interval; - } - } - } -} diff --git a/gateway/src/main_service/tests.rs b/gateway/src/main_service/tests.rs index d98c0131..1a43b154 100644 --- a/gateway/src/main_service/tests.rs +++ b/gateway/src/main_service/tests.rs @@ -3,17 +3,44 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use crate::config::{load_config_figment, Config}; +use crate::config::{load_config_figment, Config, MutualConfig}; +use tempfile::TempDir; -async fn create_test_state() -> Proxy { +struct TestState { + proxy: Proxy, + _temp_dir: TempDir, +} + +impl std::ops::Deref for TestState { + type Target = Proxy; + fn deref(&self) -> &Self::Target { + &self.proxy + } +} + +async fn create_test_state() -> TestState { let figment = load_config_figment(None); let mut config = figment.focus("core").extract::().unwrap(); - let cargo_dir = env!("CARGO_MANIFEST_DIR"); - config.proxy.cert_chain = format!("{cargo_dir}/assets/cert.pem"); - config.proxy.cert_key = format!("{cargo_dir}/assets/cert.key"); - Proxy::new(config, None) + let temp_dir = TempDir::new().expect("failed to create temp dir"); + config.sync.data_dir = temp_dir.path().to_string_lossy().to_string(); + let options = ProxyOptions { + config, + my_app_id: None, + tls_config: TlsConfig { + certs: "".to_string(), + key: "".to_string(), + mutual: MutualConfig { + ca_certs: "".to_string(), + }, + }, + }; + let proxy = Proxy::new(options) .await - .expect("failed to create app state") + .expect("failed to create app state"); + TestState { + proxy, + _temp_dir: temp_dir, + } } #[tokio::test] @@ -32,14 +59,12 @@ async fn test_config() { .unwrap(); info.reg_time = SystemTime::UNIX_EPOCH; - info.last_seen = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info); let mut info1 = state .lock() .new_client_by_id("test-id-1", "app-id-1", "test-pubkey-1") .unwrap(); info1.reg_time = SystemTime::UNIX_EPOCH; - info1.last_seen = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info1); let wg_config = state.lock().generate_wg_config().unwrap(); insta::assert_snapshot!(wg_config); diff --git a/gateway/src/models.rs b/gateway/src/models.rs index ec476cff..37caa274 100644 --- a/gateway/src/models.rs +++ b/gateway/src/models.rs @@ -60,7 +60,6 @@ pub struct InstanceInfo { pub ip: Ipv4Addr, pub public_key: String, pub reg_time: SystemTime, - pub last_seen: SystemTime, #[serde(skip)] pub connections: Arc, } diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 73b947cc..26bc1f1b 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -11,11 +11,13 @@ use std::{ }; use anyhow::{bail, Context, Result}; +use or_panic::ResultOrPanic; use sni::extract_sni; -pub(crate) use tls_terminate::create_acceptor; +pub(crate) use tls_terminate::create_acceptor_with_cert_resolver; use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream}, + runtime::Runtime, time::timeout, }; use tracing::{debug, error, info, info_span, Instrument}; @@ -60,10 +62,6 @@ async fn take_sni(stream: &mut TcpStream) -> Result<(Option, Vec)> { Ok((None, buffer)) } -fn is_subdomain(sni: &str, base_domain: &str) -> bool { - sni.ends_with(base_domain) -} - #[derive(Debug)] struct DstInfo { app_id: String, @@ -72,14 +70,7 @@ struct DstInfo { is_h2: bool, } -fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { - // format: [-][s]. - let subdomain = sni - .strip_suffix(dotted_base_domain) - .context("invalid sni format")?; - if subdomain.contains('.') { - bail!("only one level of subdomain is supported, got sni={sni}, subdomain={subdomain}"); - } +fn parse_dst_info(subdomain: &str) -> Result { let mut parts = subdomain.split('-'); let app_id = parts.next().context("no app id found")?.to_owned(); if app_id.is_empty() { @@ -131,11 +122,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); -async fn handle_connection( - mut inbound: TcpStream, - state: Proxy, - dotted_base_domain: &str, -) -> Result<()> { +async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { let timeouts = &state.config.proxy.timeouts; let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound)) .await @@ -144,8 +131,10 @@ async fn handle_connection( let Some(sni) = sni else { bail!("no sni found"); }; - if is_subdomain(&sni, dotted_base_domain) { - let dst = parse_destination(&sni, dotted_base_domain)?; + + let (subdomain, base_domain) = sni.split_once('.').context("invalid sni")?; + if state.cert_resolver.get().contains_wildcard(base_domain) { + let dst = parse_dst_info(subdomain)?; debug!("dst: {dst:?}"); if dst.is_tls { tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await @@ -160,19 +149,7 @@ async fn handle_connection( } #[inline(never)] -pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { - let workers_rt = tokio::runtime::Builder::new_multi_thread() - .thread_name("proxy-worker") - .enable_all() - .worker_threads(config.workers) - .build() - .context("Failed to build Tokio runtime")?; - - let dotted_base_domain = { - let base_domain = config.base_domain.as_str(); - let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain); - Arc::new(format!(".{base_domain}")) - }; +pub async fn proxy_main(rt: &Runtime, config: &ProxyConfig, proxy: Proxy) -> Result<()> { let listener = TcpListener::bind((config.listen_addr, config.listen_port)) .await .with_context(|| { @@ -195,16 +172,12 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { info!(%from, "new connection"); let proxy = proxy.clone(); - let dotted_base_domain = dotted_base_domain.clone(); - workers_rt.spawn( + rt.spawn( async move { let _conn_entered = conn_entered; let timeouts = &proxy.config.proxy.timeouts; - let result = timeout( - timeouts.total, - handle_connection(inbound, proxy, &dotted_base_domain), - ) - .await; + let result = + timeout(timeouts.total, handle_connection(inbound, proxy)).await; match result { Ok(Ok(_)) => { info!("connection closed"); @@ -233,17 +206,24 @@ fn next_connection_id() -> usize { } pub fn start(config: ProxyConfig, app_state: Proxy) -> Result<()> { - // Create a new single-threaded runtime - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .context("Failed to build Tokio runtime")?; - std::thread::Builder::new() .name("proxy-main".to_string()) .spawn(move || { + // Create a new single-threaded runtime + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .or_panic("Failed to build Tokio runtime"); + + let worker_rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("proxy-worker") + .enable_all() + .worker_threads(config.workers) + .build() + .or_panic("Failed to build Tokio runtime"); + // Run the proxy_main function in this runtime - if let Err(err) = rt.block_on(proxy_main(&config, app_state)) { + if let Err(err) = rt.block_on(proxy_main(&worker_rt, &config, app_state)) { error!( "error on {}:{}: {err:?}", config.listen_addr, config.listen_port @@ -260,64 +240,40 @@ mod tests { #[test] fn test_parse_destination() { - let base_domain = ".example.com"; - // Test basic app_id only - let result = parse_destination("myapp.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 80); assert!(!result.is_tls); // Test app_id with custom port - let result = parse_destination("myapp-8080.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-8080").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 8080); assert!(!result.is_tls); // Test app_id with TLS - let result = parse_destination("myapp-443s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-443s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 443); assert!(result.is_tls); // Test app_id with custom port and TLS - let result = parse_destination("myapp-8443s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-8443s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 8443); assert!(result.is_tls); // Test default port but ends with s - let result = parse_destination("myapps.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapps").unwrap(); assert_eq!(result.app_id, "myapps"); assert_eq!(result.port, 80); assert!(!result.is_tls); // Test default port but ends with s in port part - let result = parse_destination("myapp-s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 443); assert!(result.is_tls); } - - #[test] - fn test_parse_destination_errors() { - let base_domain = ".example.com"; - - // Test invalid domain suffix - assert!(parse_destination("myapp.wrong.com", base_domain).is_err()); - - // Test multiple subdomains - assert!(parse_destination("invalid.myapp.example.com", base_domain).is_err()); - - // Test invalid port format - assert!(parse_destination("myapp-65536.example.com", base_domain).is_err()); - assert!(parse_destination("myapp-abc.example.com", base_domain).is_err()); - - // Test too many parts - assert!(parse_destination("myapp-8080-extra.example.com", base_domain).is_err()); - - // Test empty app_id - assert!(parse_destination("-8080.example.com", base_domain).is_err()); - assert!(parse_destination("myapp-8080ss.example.com", base_domain).is_err()); - } } diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index e2cea9d0..af3aba62 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -56,7 +56,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result Result { - let cert_pem = fs::read(&config.cert_chain).context("failed to read certificate")?; - let key_pem = fs::read(&config.cert_key).context("failed to read private key")?; - let certs = CertificateDer::pem_slice_iter(cert_pem.as_slice()) - .collect::, _>>() - .context("failed to parse certificate")?; - let key = - PrivateKeyDer::from_pem_slice(key_pem.as_slice()).context("failed to parse private key")?; - - let provider = match config.tls_crypto_provider { +/// Create a TLS acceptor using CertResolver for SNI-based certificate resolution +/// +/// The CertResolver allows atomic certificate updates without recreating the acceptor. +pub(crate) fn create_acceptor_with_cert_resolver( + proxy_config: &ProxyConfig, + cert_resolver: Arc, + h2: bool, +) -> Result { + let provider = match proxy_config.tls_crypto_provider { CryptoProvider::AwsLcRs => rustls::crypto::aws_lc_rs::default_provider(), CryptoProvider::Ring => rustls::crypto::ring::default_provider(), }; - let supported_versions = config + let supported_versions = proxy_config .tls_versions .iter() .map(|v| match v { @@ -120,11 +116,12 @@ pub(crate) fn create_acceptor(config: &ProxyConfig, h2: bool) -> Result &TLS13, }) .collect::>(); + let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider)) .with_protocol_versions(&supported_versions) - .context("Failed to build TLS config")? + .context("failed to build TLS config")? .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_cert_resolver(cert_resolver); if h2 { config.alpn_protocols = vec![b"h2".to_vec()]; @@ -152,27 +149,6 @@ fn empty_response(status: StatusCode) -> Result> { } impl Proxy { - /// Reload the TLS acceptor with fresh certificates - pub fn reload_certificates(&self) -> Result<()> { - info!("Reloading TLS certificates"); - // Replace the acceptor with the new one - if let Ok(mut acceptor) = self.acceptor.write() { - *acceptor = create_acceptor(&self.config.proxy, false)?; - info!("TLS certificates successfully reloaded"); - } else { - bail!("Failed to acquire write lock for TLS acceptor"); - } - - if let Ok(mut acceptor) = self.h2_acceptor.write() { - *acceptor = create_acceptor(&self.config.proxy, true)?; - info!("TLS certificates successfully reloaded"); - } else { - bail!("Failed to acquire write lock for TLS acceptor"); - } - - Ok(()) - } - pub(crate) async fn handle_this_node( &self, inbound: TcpStream, @@ -213,7 +189,7 @@ impl Proxy { json_response(&app_info) } "/acme-info" => { - let acme_info = self.acme_info().await.context("Failed to get acme info")?; + let acme_info = self.acme_info(None).context("Failed to get acme info")?; json_response(&acme_info) } _ => empty_response(StatusCode::NOT_FOUND), @@ -278,15 +254,9 @@ impl Proxy { inbound, }; let acceptor = if h2 { - self.h2_acceptor - .read() - .or_panic("lock should never fail") - .clone() + &self.h2_acceptor } else { - self.acceptor - .read() - .or_panic("lock should never fail") - .clone() + &self.acceptor }; let tls_stream = timeout( self.config.proxy.timeouts.handshake, @@ -315,7 +285,7 @@ impl Proxy { let addresses = self .lock() .select_top_n_hosts(app_id) - .with_context(|| format!("app {app_id} not found"))?; + .with_context(|| format!("app <{app_id}> not found"))?; debug!("selected top n hosts: {addresses:?}"); let tls_stream = self.tls_accept(inbound, buffer, h2).await?; let (outbound, _counter) = timeout( diff --git a/gateway/src/web_routes.rs b/gateway/src/web_routes.rs index 1bd57f2b..5f72735d 100644 --- a/gateway/src/web_routes.rs +++ b/gateway/src/web_routes.rs @@ -7,12 +7,28 @@ use anyhow::Result; use rocket::{get, response::content::RawHtml, routes, Route, State}; mod route_index; +mod wavekv_sync; #[get("/")] async fn index(state: &State) -> Result, String> { route_index::index(state).await.map_err(|e| format!("{e}")) } +#[get("/health")] +fn health() -> &'static str { + "OK" +} + pub fn routes() -> Vec { routes![index] } + +/// Health endpoint for simple liveness checks +pub fn health_routes() -> Vec { + routes![health] +} + +/// WaveKV sync endpoint (for main server, requires mTLS gateway auth) +pub fn wavekv_sync_routes() -> Vec { + routes![wavekv_sync::sync_store] +} diff --git a/gateway/src/web_routes/wavekv_sync.rs b/gateway/src/web_routes/wavekv_sync.rs new file mode 100644 index 00000000..dead1141 --- /dev/null +++ b/gateway/src/web_routes/wavekv_sync.rs @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync HTTP endpoints +//! +//! Sync data is encoded using msgpack + gzip compression for efficiency. + +use crate::{ + kv::{decode, encode}, + main_service::Proxy, +}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use ra_tls::traits::CertExt; +use rocket::{ + data::{Data, ToByteUnit}, + http::{ContentType, Status}, + mtls::{oid::Oid, Certificate}, + post, State, +}; +use std::io::{Read, Write}; +use tracing::warn; +use wavekv::sync::{SyncMessage, SyncResponse}; + +/// Wrapper to implement CertExt for Rocket's Certificate +struct RocketCert<'a>(&'a Certificate<'a>); + +impl CertExt for RocketCert<'_> { + fn get_extension_der(&self, oid: &[u64]) -> anyhow::Result>> { + let oid = Oid::from(oid).map_err(|_| anyhow::anyhow!("failed to create OID from slice"))?; + let Some(ext) = self.0.extensions().iter().find(|ext| ext.oid == oid) else { + return Ok(None); + }; + Ok(Some(ext.value.to_vec())) + } +} + +/// Decode compressed msgpack data +fn decode_sync_message(data: &[u8]) -> Result { + // Decompress + let mut decoder = GzDecoder::new(data); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed).map_err(|e| { + warn!("failed to decompress sync message: {e}"); + Status::BadRequest + })?; + + decode(&decompressed).map_err(|e| { + warn!("failed to decode sync message: {e}"); + Status::BadRequest + }) +} + +/// Encode and compress sync response +fn encode_sync_response(response: &SyncResponse) -> Result, Status> { + let encoded = encode(response).map_err(|e| { + warn!("failed to encode sync response: {e}"); + Status::InternalServerError + })?; + + // Compress + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(&encoded).map_err(|e| { + warn!("failed to compress sync response: {e}"); + Status::InternalServerError + })?; + encoder.finish().map_err(|e| { + warn!("failed to finish compression: {e}"); + Status::InternalServerError + }) +} + +/// Verify that the request is from a gateway with the same app_id (mTLS verification) +fn verify_gateway_peer(state: &Proxy, cert: Option>) -> Result<(), Status> { + // Skip verification if not running in dstack (test mode) + if state.config.debug.insecure_skip_attestation { + return Ok(()); + } + + let Some(cert) = cert else { + warn!("WaveKV sync: client certificate required but not provided"); + return Err(Status::Unauthorized); + }; + + let remote_app_id = RocketCert(&cert).get_app_id().map_err(|e| { + warn!("WaveKV sync: failed to extract app_id from certificate: {e}"); + Status::Unauthorized + })?; + + let Some(remote_app_id) = remote_app_id else { + warn!("WaveKV sync: certificate does not contain app_id"); + return Err(Status::Unauthorized); + }; + + if state.my_app_id() != Some(remote_app_id.as_slice()) { + warn!( + "WaveKV sync: app_id mismatch, expected {:?}, got {:?}", + state.my_app_id(), + remote_app_id + ); + return Err(Status::Forbidden); + } + + Ok(()) +} + +/// Handle sync request (msgpack + gzip encoded) +#[post("/wavekv/sync/", data = "")] +pub async fn sync_store( + state: &State, + cert: Option>, + store: &str, + data: Data<'_>, +) -> Result<(ContentType, Vec), Status> { + verify_gateway_peer(state, cert)?; + + let Some(ref wavekv_sync) = state.wavekv_sync else { + return Err(Status::ServiceUnavailable); + }; + + // Read and decode request + let bytes = data + .open(16.mebibytes()) + .into_bytes() + .await + .map_err(|_| Status::BadRequest)?; + let msg = decode_sync_message(&bytes)?; + + // Reject sync from node_id == 0 + if msg.sender_id == 0 { + warn!("rejected sync from invalid node_id 0"); + return Err(Status::BadRequest); + } + + // Handle sync based on store type + let response = match store { + "persistent" => wavekv_sync.handle_persistent_sync(msg), + "ephemeral" => wavekv_sync.handle_ephemeral_sync(msg), + _ => return Err(Status::NotFound), + } + .map_err(|e| { + tracing::error!("{store} sync failed: {e}"); + Status::InternalServerError + })?; + + // Encode response + let encoded = encode_sync_response(&response)?; + + Ok((ContentType::new("application", "x-msgpack-gz"), encoded)) +} diff --git a/gateway/templates/dashboard.html b/gateway/templates/dashboard.html index 56750204..1ec54f16 100644 --- a/gateway/templates/dashboard.html +++ b/gateway/templates/dashboard.html @@ -34,7 +34,7 @@ border-collapse: collapse; background-color: white; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); margin: 20px 0; } @@ -93,14 +93,14 @@ font-size: 12px; white-space: nowrap; z-index: 1; - box-shadow: 0 2px 4px rgba(0,0,0,0.2); + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); } .info-section { background: white; padding: 20px; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .info-group { @@ -152,6 +152,242 @@ text-overflow: ellipsis; white-space: nowrap; } + + .last-seen-cell { + white-space: nowrap; + } + + .last-seen-row { + margin-bottom: 4px; + } + + .last-seen-row:last-child { + margin-bottom: 0; + } + + .observer-label { + color: #666; + font-size: 0.9em; + } + + .node-status { + font-weight: bold; + } + + .node-status.up { + color: #4CAF50; + } + + .node-status.down { + color: #f44336; + } + + .status-controls { + display: flex; + gap: 5px; + } + + .status-btn { + padding: 4px 8px; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 12px; + font-weight: bold; + transition: opacity 0.2s; + } + + .status-btn:hover { + opacity: 0.8; + } + + .status-btn.up { + background-color: #4CAF50; + color: white; + } + + .status-btn.down { + background-color: #f44336; + color: white; + } + + .status-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + /* Certificate config styles */ + .action-btn { + padding: 6px 12px; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 13px; + font-weight: bold; + transition: opacity 0.2s; + margin-right: 5px; + } + + .action-btn:hover { + opacity: 0.8; + } + + .action-btn.primary { + background-color: #4CAF50; + color: white; + } + + .action-btn.danger { + background-color: #f44336; + color: white; + } + + .action-btn.secondary { + background-color: #2196F3; + color: white; + } + + .action-btn.warning { + background-color: #ff9800; + color: white; + } + + .action-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .default-badge { + background-color: #4CAF50; + color: white; + padding: 2px 8px; + border-radius: 12px; + font-size: 11px; + font-weight: bold; + } + + .cert-status { + display: flex; + flex-direction: column; + gap: 4px; + } + + .cert-status-item { + font-size: 12px; + } + + .cert-status-item.has-cert { + color: #4CAF50; + } + + .cert-status-item.no-cert { + color: #f44336; + } + + .modal-overlay { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.5); + z-index: 1000; + justify-content: center; + align-items: center; + } + + .modal-overlay.active { + display: flex; + } + + .modal { + background: white; + padding: 30px; + border-radius: 8px; + box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3); + min-width: 400px; + max-width: 500px; + } + + .modal h3 { + margin-top: 0; + color: #333; + border-bottom: 2px solid #4CAF50; + padding-bottom: 10px; + } + + .modal-field { + margin-bottom: 15px; + } + + .modal-field label { + display: block; + margin-bottom: 5px; + font-weight: bold; + color: #555; + } + + .modal-field input, + .modal-field select { + width: 100%; + padding: 10px; + border: 1px solid #ddd; + border-radius: 4px; + box-sizing: border-box; + } + + .modal-field input[type="checkbox"] { + width: auto; + } + + .modal-actions { + display: flex; + justify-content: flex-end; + gap: 10px; + margin-top: 20px; + } + + .toast { + position: fixed; + bottom: 20px; + right: 20px; + padding: 15px 25px; + border-radius: 4px; + color: white; + font-weight: bold; + z-index: 2000; + animation: slideIn 0.3s ease; + } + + .toast.success { + background-color: #4CAF50; + } + + .toast.error { + background-color: #f44336; + } + + @keyframes slideIn { + from { + transform: translateX(100%); + opacity: 0; + } + to { + transform: translateX(0); + opacity: 1; + } + } + + .section-header { + display: flex; + justify-content: space-between; + align-items: center; + } + + .section-header h2 { + margin: 0; + } Dashboard \ No newline at end of file diff --git a/gateway/test-run/.env.example b/gateway/test-run/.env.example new file mode 100644 index 00000000..ff657175 --- /dev/null +++ b/gateway/test-run/.env.example @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Cloudflare API token with DNS edit permissions +# Required scopes: Zone.DNS (Edit), Zone.Zone (Read) +CF_API_TOKEN=your_cloudflare_api_token_here + +# Cloudflare Zone ID for your domain +CF_ZONE_ID=your_zone_id_here + +# Test domain (must be a wildcard domain managed by Cloudflare) +# Example: *.test.example.com +TEST_DOMAIN=*.test.example.com diff --git a/gateway/test-run/.gitignore b/gateway/test-run/.gitignore new file mode 100644 index 00000000..25d3f97c --- /dev/null +++ b/gateway/test-run/.gitignore @@ -0,0 +1,3 @@ +/run/ +.env +/e2e/dstack-gateway diff --git a/gateway/test-run/cluster.sh b/gateway/test-run/cluster.sh new file mode 100755 index 00000000..23521bd1 --- /dev/null +++ b/gateway/test-run/cluster.sh @@ -0,0 +1,442 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Gateway cluster management script for manual testing + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="${SCRIPT_DIR}/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +TMUX_SESSION="gateway-cluster" + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +show_help() { + echo "Gateway Cluster Management Script" + echo "" + echo "Usage: $0 " + echo "" + echo "Commands:" + echo " start Start a 3-node gateway cluster in tmux" + echo " stop Stop the cluster (keep tmux session)" + echo " reg Register a random instance" + echo " status Show cluster status" + echo " clean Destroy cluster and clean all data" + echo " attach Attach to tmux session" + echo " help Show this help" + echo "" +} + +# Generate certificates +generate_certs() { + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=Gateway Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]]; then + log_info "Creating RPC certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + cat > "$CERTS_DIR/ext.cnf" << EXTEOF +authorityKeyIdentifier=keyid,issuer +basicConstraints=CA:FALSE +keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +EXTEOF + openssl x509 -req \ + -in "$CERTS_DIR/gateway-rpc.csr" \ + -CA "$CERTS_DIR/gateway-ca.cert" \ + -CAkey "$CERTS_DIR/gateway-ca.key" \ + -CAcreateserial \ + -out "$CERTS_DIR/gateway-rpc.cert" \ + -days 365 \ + -sha256 \ + -extfile "$CERTS_DIR/ext.cnf" \ + 2>/dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi + + # Generate unique WireGuard key pair for each node + for i in 1 2 3; do + if [[ ! -f "$CERTS_DIR/wg-node${i}.key" ]]; then + log_info "Generating WireGuard keys for node ${i}..." + wg genkey > "$CERTS_DIR/wg-node${i}.key" + wg pubkey < "$CERTS_DIR/wg-node${i}.key" > "$CERTS_DIR/wg-node${i}.pub" + fi + done +} + +# Generate node config +generate_config() { + local node_id=$1 + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + local other_nodes="" + local peer_urls="" + + # Read WireGuard keys for this node + local wg_private_key=$(cat "$CERTS_DIR/wg-node${node_id}.key") + local wg_public_key=$(cat "$CERTS_DIR/wg-node${node_id}.pub") + + for i in 1 2 3; do + if [[ $i -ne $node_id ]]; then + local peer_rpc_port=$((13000 + i * 10 + 2)) + if [[ -n "$other_nodes" ]]; then + other_nodes="$other_nodes, $i" + peer_urls="$peer_urls, \"$i:https://localhost:$peer_rpc_port\"" + else + other_nodes="$i" + peer_urls="\"$i:https://localhost:$peer_rpc_port\"" + fi + fi + done + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat > "$RUN_DIR/node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.admin] +enabled = true +port = ${admin_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_node${node_id}" + +[core.certbot] +enabled = false + +[core.wg] +private_key = "${wg_private_key}" +public_key = "${wg_public_key}" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.3${node_id}.1/31"] +client_ip_range = "10.0.3${node_id}.1/24" +config_path = "${RUN_DIR}/wg_node${node_id}.conf" +interface = "gw-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${RUN_DIR}/certbot/live/cert.pem" +cert_key = "${RUN_DIR}/certbot/live/key.pem" +base_domain = "test.local" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} + +[core.recycle] +enabled = true +interval = "30s" +timeout = "120s" +node_timeout = "300s" +EOF +} + +# Build gateway binary +build_gateway() { + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_info "Building gateway..." + (cd "$SCRIPT_DIR/.." && cargo build --release) + fi +} + +# Start cluster +cmd_start() { + build_gateway + generate_certs + + # Check if tmux session exists + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + log_warn "Cluster already running. Use 'clean' to restart." + cmd_status + return 0 + fi + + log_info "Generating configs..." + mkdir -p "$RUN_DIR" "$LOG_DIR" + for i in 1 2 3; do + generate_config $i + mkdir -p "$RUN_DIR/wavekv_node${i}" + done + + log_info "Starting cluster in tmux session '$TMUX_SESSION'..." + + # Create wrapper scripts that keep running even if gateway exits + for i in 1 2 3; do + cat > "$RUN_DIR/run_node${i}.sh" << RUNEOF +#!/bin/bash +cd "$SCRIPT_DIR" +while true; do + echo "Starting node ${i}..." + sudo RUST_LOG=info $GATEWAY_BIN -c $RUN_DIR/node${i}.toml 2>&1 | tee -a $LOG_DIR/node${i}.log + echo "Node ${i} exited. Press Ctrl+C to stop, or wait 3s to restart..." + sleep 3 +done +RUNEOF + chmod +x "$RUN_DIR/run_node${i}.sh" + done + + # Create tmux session + tmux new-session -d -s "$TMUX_SESSION" -n "node1" + tmux send-keys -t "$TMUX_SESSION:node1" "$RUN_DIR/run_node1.sh" Enter + + sleep 1 + + # Add windows for other nodes + tmux new-window -t "$TMUX_SESSION" -n "node2" + tmux send-keys -t "$TMUX_SESSION:node2" "$RUN_DIR/run_node2.sh" Enter + + tmux new-window -t "$TMUX_SESSION" -n "node3" + tmux send-keys -t "$TMUX_SESSION:node3" "$RUN_DIR/run_node3.sh" Enter + + # Add a shell window + tmux new-window -t "$TMUX_SESSION" -n "shell" + + sleep 3 + + log_info "Cluster started!" + echo "" + cmd_status + echo "" + log_info "Use '$0 attach' to view logs" +} + +# Stop cluster +cmd_stop() { + log_info "Stopping cluster..." + sudo pkill -9 -f "dstack-gateway.*node[123].toml" 2>/dev/null || true + sudo ip link delete gw-test1 2>/dev/null || true + sudo ip link delete gw-test2 2>/dev/null || true + sudo ip link delete gw-test3 2>/dev/null || true + log_info "Cluster stopped" +} + +# Clean everything +cmd_clean() { + cmd_stop + + # Kill tmux session + tmux kill-session -t "$TMUX_SESSION" 2>/dev/null || true + + log_info "Cleaning data..." + sudo rm -rf "$RUN_DIR/wavekv_node"* + sudo rm -f "$RUN_DIR/gateway-state-node"*.json + rm -f "$RUN_DIR/wg_node"*.conf + rm -f "$RUN_DIR/node"*.toml + rm -f "$RUN_DIR/run_node"*.sh + rm -rf "$LOG_DIR" + + log_info "Cleaned" +} + +# Show status +cmd_status() { + echo -e "${BLUE}=== Gateway Cluster Status ===${NC}" + echo "" + + for i in 1 2 3; do + local rpc_port=$((13000 + i * 10 + 2)) + local proxy_port=$((13000 + i * 10 + 4)) + local debug_port=$((13000 + i * 10 + 5)) + local admin_port=$((13000 + i * 10 + 6)) + + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + echo -e "Node $i: ${GREEN}RUNNING${NC}" + else + echo -e "Node $i: ${RED}STOPPED${NC}" + fi + echo " RPC: https://localhost:${rpc_port}" + echo " Proxy: https://localhost:${proxy_port}" + echo " Debug: http://localhost:${debug_port}" + echo " Admin: http://localhost:${admin_port}" + echo "" + done + + # Show instance count from first running node + for i in 1 2 3; do + local debug_port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if [[ -n "$response" ]]; then + local n_instances=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('instances', [])))" 2>/dev/null || echo "?") + local n_nodes=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('nodes', [])))" 2>/dev/null || echo "?") + echo -e "${BLUE}Cluster State:${NC}" + echo " Nodes: $n_nodes" + echo " Instances: $n_instances" + fi + break + fi + done +} + +# Register a random instance +cmd_reg() { + # Find a running node + local debug_port="" + for i in 1 2 3; do + local port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + debug_port=$port + break + fi + done + + if [[ -z "$debug_port" ]]; then + log_error "No running nodes found. Start cluster first." + exit 1 + fi + + # Generate random WireGuard key pair + local private_key=$(wg genkey) + local public_key=$(echo "$private_key" | wg pubkey) + + # Generate random IDs + local app_id="app-$(openssl rand -hex 4)" + local instance_id="inst-$(openssl rand -hex 4)" + + log_info "Registering instance..." + log_info " App ID: $app_id" + log_info " Instance ID: $instance_id" + log_info " Public Key: $public_key" + + local response=$(curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null) + + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'wg' in d" 2>/dev/null; then + local client_ip=$(echo "$response" | python3 -c "import sys,json; print(json.load(sys.stdin)['wg']['client_ip'])" 2>/dev/null) + log_info "Registered successfully!" + echo -e " Client IP: ${GREEN}$client_ip${NC}" + echo "" + echo "Instance details:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + else + log_error "Registration failed:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + exit 1 + fi +} + +# Attach to tmux +cmd_attach() { + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + tmux attach -t "$TMUX_SESSION" + else + log_error "No cluster running" + exit 1 + fi +} + +# Main +case "${1:-help}" in + start) + cmd_start + ;; + stop) + cmd_stop + ;; + clean) + cmd_clean + ;; + status) + cmd_status + ;; + reg) + cmd_reg + ;; + attach) + cmd_attach + ;; + help|--help|-h) + show_help + ;; + *) + log_error "Unknown command: $1" + show_help + exit 1 + ;; +esac diff --git a/gateway/test-run/e2e/configs/gateway-1.toml b/gateway/test-run/e2e/configs/gateway-1.toml new file mode 100644 index 00000000..dfbe1609 --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-1.toml @@ -0,0 +1,53 @@ +# Gateway Node 1 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/var/lib/gateway/certs/gateway-rpc.key" +certs = "/var/lib/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/var/lib/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway-1" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = false +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-1:9012" +bootnode = "https://gateway-2:9012" +node_id = 1 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.41.1/24" +reserved_net = ["10.0.41.1/31"] +client_ip_range = "10.0.41.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test1" +endpoint = "gateway-1:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/configs/gateway-2.toml b/gateway/test-run/e2e/configs/gateway-2.toml new file mode 100644 index 00000000..b825fda5 --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-2.toml @@ -0,0 +1,53 @@ +# Gateway Node 2 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/var/lib/gateway/certs/gateway-rpc.key" +certs = "/var/lib/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/var/lib/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway-2" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = false +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-2:9012" +bootnode = "https://gateway-1:9012" +node_id = 2 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.42.1/24" +reserved_net = ["10.0.42.1/31"] +client_ip_range = "10.0.42.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test2" +endpoint = "gateway-2:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/configs/gateway-3.toml b/gateway/test-run/e2e/configs/gateway-3.toml new file mode 100644 index 00000000..f30cb6a1 --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-3.toml @@ -0,0 +1,53 @@ +# Gateway Node 3 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/var/lib/gateway/certs/gateway-rpc.key" +certs = "/var/lib/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/var/lib/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway-3" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = false +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-3:9012" +bootnode = "https://gateway-1:9012" +node_id = 3 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.43.1/24" +reserved_net = ["10.0.43.1/31"] +client_ip_range = "10.0.43.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test3" +endpoint = "gateway-3:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/docker-compose.yml b/gateway/test-run/e2e/docker-compose.yml new file mode 100644 index 00000000..2374b073 --- /dev/null +++ b/gateway/test-run/e2e/docker-compose.yml @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test environment for dstack-gateway certbot functionality +# Uses mock services: Pebble (ACME) + mock-cf-dns-api (Cloudflare DNS) +# Uses real TDX endpoint for attestation + +networks: + certbot-test: + driver: bridge + ipam: + config: + - subnet: 172.30.0.0/24 + +volumes: + pebble-certs: + +services: + # ==================== Mock Services ==================== + + # Mock Cloudflare DNS API + mock-cf-dns-api: + image: kvin/mock-cf-dns-api:latest + container_name: mock-cf-dns-api + networks: + certbot-test: + ipv4_address: 172.30.0.10 + ports: + - "18080:8080" + environment: + - PORT=8080 + - DEBUG=true + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"] + interval: 5s + timeout: 3s + retries: 5 + + # Pebble - Let's Encrypt test server (custom build with HTTP support) + pebble: + image: kvin/pebble:latest + container_name: pebble + command: ["-http", "-dnsserver", "172.30.0.10:53"] + networks: + certbot-test: + ipv4_address: 172.30.0.11 + ports: + - "14000:14000" # ACME directory + - "15000:15000" # Management interface + environment: + - PEBBLE_VA_NOSLEEP=1 + - PEBBLE_VA_ALWAYS_VALID=1 # Skip actual DNS validation for testing + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:14000/dir"] + interval: 5s + timeout: 3s + retries: 10 + + # ==================== Gateway Cluster ==================== + + # Gateway Node 1 - Will be the first to request certificate + gateway-1: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-1 + networks: + certbot-test: + ipv4_address: 172.30.0.21 + ports: + - "19012:9012" # RPC + - "19014:9014" # Proxy + - "19015:9015" # Debug + - "19016:9016" # Admin + volumes: + - ./configs/gateway-1.toml:/etc/gateway/gateway.toml:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=https://712eab2f507b963e11144ae67218177e93ac2a24-3000.tdxlab.dstack.org:12004/ + depends_on: + mock-cf-dns-api: + condition: service_healthy + pebble: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + extra_hosts: + # Pebble returns localhost in directory URLs, so we need to resolve localhost to pebble's IP + - "localhost:172.30.0.11" + + # Gateway Node 2 - Will sync certificate from Node 1 + gateway-2: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-2 + networks: + certbot-test: + ipv4_address: 172.30.0.22 + ports: + - "19022:9012" # RPC + - "19024:9014" # Proxy + - "19025:9015" # Debug + - "19026:9016" # Admin + volumes: + - ./configs/gateway-2.toml:/etc/gateway/gateway.toml:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=https://712eab2f507b963e11144ae67218177e93ac2a24-3000.tdxlab.dstack.org:12004/ + depends_on: + gateway-1: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + + # Gateway Node 3 - Will sync certificate from cluster + gateway-3: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-3 + networks: + certbot-test: + ipv4_address: 172.30.0.23 + ports: + - "19032:9012" # RPC + - "19034:9014" # Proxy + - "19035:9015" # Debug + - "19036:9016" # Admin + volumes: + - ./configs/gateway-3.toml:/etc/gateway/gateway.toml:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=https://712eab2f507b963e11144ae67218177e93ac2a24-3000.tdxlab.dstack.org:12004/ + depends_on: + gateway-2: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + + # ==================== Test Runner ==================== + + test-runner: + image: alpine:latest + container_name: test-runner + networks: + certbot-test: + ipv4_address: 172.30.0.100 + volumes: + - ./test.sh:/test.sh:ro + entrypoint: ["/bin/sh", "-c", "apk add --no-cache curl openssl jq && /bin/sh /test.sh"] + depends_on: + gateway-1: + condition: service_healthy + gateway-2: + condition: service_healthy + gateway-3: + condition: service_healthy diff --git a/gateway/test-run/e2e/pebble-config.json b/gateway/test-run/e2e/pebble-config.json new file mode 100644 index 00000000..41411088 --- /dev/null +++ b/gateway/test-run/e2e/pebble-config.json @@ -0,0 +1,18 @@ +{ + "pebble": { + "listenAddress": "0.0.0.0:14000", + "managementListenAddress": "0.0.0.0:15000", + "certificate": "/etc/pebble/certs/localhost/cert.pem", + "privateKey": "/etc/pebble/certs/localhost/key.pem", + "httpPort": 5002, + "tlsPort": 5001, + "ocspResponderURL": "", + "externalAccountBindingRequired": false, + "domainBlocklist": [], + "retryAfter": { + "authz": 3, + "order": 5 + }, + "certificateValidityPeriod": 157680000 + } +} diff --git a/gateway/test-run/e2e/run-e2e.sh b/gateway/test-run/e2e/run-e2e.sh new file mode 100755 index 00000000..835d31fb --- /dev/null +++ b/gateway/test-run/e2e/run-e2e.sh @@ -0,0 +1,159 @@ +#!/bin/bash +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test runner for dstack-gateway +# Builds gateway image, then runs the test suite using real TDX endpoint + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +log_info() { echo -e "${BLUE}[INFO]${NC} $1"; } +log_success() { echo -e "${GREEN}[OK]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +# Parse arguments +SKIP_BUILD=false +KEEP_RUNNING=false +CLEAN=false + +while [[ $# -gt 0 ]]; do + case $1 in + --skip-build) + SKIP_BUILD=true + shift + ;; + --keep-running) + KEEP_RUNNING=true + shift + ;; + --clean) + CLEAN=true + shift + ;; + down) + cd "$SCRIPT_DIR" + log_info "Stopping containers..." + docker compose down -v --remove-orphans 2>/dev/null || true + log_success "Containers stopped" + exit 0 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS|COMMAND]" + echo "" + echo "Commands:" + echo " down Stop all containers" + echo "" + echo "Options:" + echo " --skip-build Skip building gateway image" + echo " --keep-running Keep containers running after test" + echo " --clean Clean up containers and images" + echo " -h, --help Show this help" + exit 0 + ;; + *) + log_error "Unknown option: $1" + exit 1 + ;; + esac +done + +cd "$SCRIPT_DIR" + +# Cleanup function +cleanup() { + if ! $KEEP_RUNNING; then + log_info "Stopping containers..." + docker compose down -v --remove-orphans 2>/dev/null || true + fi +} + +# Trap to ensure cleanup on exit/interrupt +trap cleanup EXIT + +# Clean up if requested +if $CLEAN; then + log_info "Cleaning up..." + docker compose down -v --remove-orphans 2>/dev/null || true + docker rmi dstack-gateway:test 2>/dev/null || true + log_success "Cleanup complete" + exit 0 +fi + +# Stop any running containers first (to release file handles) +log_info "Stopping any existing containers..." +docker compose down -v --remove-orphans 2>/dev/null || true + +# Step 1: Build gateway if needed (musl static build) +if ! $SKIP_BUILD; then + log_info "Building dstack-gateway (musl static)..." + cd "$REPO_ROOT" + cargo build --release -p dstack-gateway --target x86_64-unknown-linux-musl + + # Copy binary to e2e directory + cp target/x86_64-unknown-linux-musl/release/dstack-gateway "$SCRIPT_DIR/" + log_success "Gateway built: $SCRIPT_DIR/dstack-gateway" +fi + +# Step 2: Create gateway docker image (alpine for musl) +log_info "Creating gateway docker image..." +cd "$SCRIPT_DIR" + +cat > Dockerfile.gateway << 'EOF' +FROM alpine:latest + +RUN apk add --no-cache \ + wireguard-tools \ + iproute2 \ + curl \ + ca-certificates + +COPY dstack-gateway /usr/local/bin/dstack-gateway + +RUN chmod +x /usr/local/bin/dstack-gateway && \ + mkdir -p /etc/gateway/certs /var/lib/gateway + +ENTRYPOINT ["/usr/local/bin/dstack-gateway", "-c", "/etc/gateway/gateway.toml"] +EOF + +docker build -t dstack-gateway:test -f Dockerfile.gateway . +rm Dockerfile.gateway +log_success "Gateway image created: dstack-gateway:test" + +# Step 3: Run docker compose +log_info "Starting e2e test environment..." + +export GATEWAY_IMAGE=dstack-gateway:test + +docker compose up -d mock-cf-dns-api pebble +log_info "Waiting for mock services to be healthy..." +sleep 5 + +docker compose up -d gateway-1 gateway-2 gateway-3 +log_info "Waiting for gateway cluster to be healthy..." +sleep 10 + +# Step 4: Run tests +log_info "Running tests..." +docker compose run --rm test-runner +TEST_EXIT_CODE=$? + +# Step 5: Report result (cleanup handled by trap) +if [ $TEST_EXIT_CODE -eq 0 ]; then + log_success "All tests passed!" +else + log_error "Tests failed with exit code: $TEST_EXIT_CODE" +fi + +exit $TEST_EXIT_CODE diff --git a/gateway/test-run/e2e/test.sh b/gateway/test-run/e2e/test.sh new file mode 100755 index 00000000..9c1db1a2 --- /dev/null +++ b/gateway/test-run/e2e/test.sh @@ -0,0 +1,332 @@ +#!/bin/sh +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test script for dstack-gateway certbot functionality +# This script runs inside the test-runner container + +set -e + +# ==================== Configuration ==================== + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +# Gateway endpoints +GATEWAY_PROXIES="gateway-1:9014 gateway-2:9014 gateway-3:9014" +GATEWAY_DEBUG_URLS="http://gateway-1:9015 http://gateway-2:9015 http://gateway-3:9015" +GATEWAY_ADMIN="http://gateway-1:9016" + +# External services +MOCK_CF_API="http://mock-cf-dns-api:8080" +PEBBLE_DIR="http://pebble:14000/dir" + +# Certificate domains to test (base domains, certs will be issued for *.domain) +CERT_DOMAINS="test0.local test1.local test2.local" + +# Cloudflare mock settings +CF_API_TOKEN="test-token" +CF_API_URL="http://mock-cf-dns-api:8080/client/v4" +ACME_URL="http://pebble:14000/dir" + +# Test counters +TESTS_PASSED=0 +TESTS_FAILED=0 + +# ==================== Logging ==================== + +log_info() { printf "${BLUE}[INFO]${NC} %s\n" "$1"; } +log_warn() { printf "${YELLOW}[WARN]${NC} %s\n" "$1"; } +log_error() { printf "${RED}[ERROR]${NC} %s\n" "$1"; } +log_success() { printf "${GREEN}[PASS]${NC} %s\n" "$1"; } +log_fail() { printf "${RED}[FAIL]${NC} %s\n" "$1"; } + +log_section() { + printf "\n" + log_info "==========================================" + log_info "$1" + log_info "==========================================" +} + +log_phase() { + printf "\n" + log_info "Phase $1: $2" + log_info "------------------------------------------" +} + +# ==================== Test Utilities ==================== + +# Run a test and record result +run_test() { + local name="$1" + local result="$2" + + if [ "$result" = "0" ]; then + log_success "$name" + TESTS_PASSED=$((TESTS_PASSED + 1)) + else + log_fail "$name" + TESTS_FAILED=$((TESTS_FAILED + 1)) + fi +} + +# Wait for HTTP service to respond +wait_for_service() { + local url="$1" + local name="$2" + local max_wait="${3:-60}" + local waited=0 + + log_info "Waiting for $name..." + while [ $waited -lt $max_wait ]; do + if curl -sf "$url" > /dev/null 2>&1; then + log_info "$name is ready" + return 0 + fi + sleep 2 + waited=$((waited + 2)) + done + + log_error "$name failed to become ready within ${max_wait}s" + return 1 +} + +# ==================== Domain Helpers ==================== + +# Convert base domain to test SNI: test0.local -> gateway.test0.local +# Uses "gateway" as it's a special app_id that proxies to gateway's own endpoints +get_test_sni() { + echo "gateway.${1}" +} + +# Convert base domain to wildcard format for certificate SAN check +get_wildcard_domain() { + echo "*.${1}" +} + +# ==================== Certificate Helpers ==================== + +# Get certificate via openssl s_client +get_cert_pem() { + local host="$1" + local sni="$2" + echo | timeout 5 openssl s_client -connect "$host" -servername "$sni" 2>/dev/null +} + +get_cert_serial() { + get_cert_pem "$1" "$2" | openssl x509 -noout -serial 2>/dev/null | cut -d= -f2 +} + +get_cert_issuer() { + get_cert_pem "$1" "$2" | openssl x509 -noout -issuer 2>/dev/null +} + +get_cert_san() { + get_cert_pem "$1" "$2" | openssl x509 -noout -ext subjectAltName 2>/dev/null +} + +# ==================== Test Functions ==================== + +test_http_health() { + curl -sf "$1" > /dev/null +} + +test_certificate_issued() { + local host="$1" + local sni="$2" + [ -n "$(get_cert_serial "$host" "$sni")" ] +} + +test_certificates_match() { + local sni="$1" + local serial1="" serial2="" serial3="" + local i=1 + + for proxy in $GATEWAY_PROXIES; do + eval "serial${i}=\"\$(get_cert_serial \"\$proxy\" \"\$sni\")\"" + log_info "Gateway $i cert serial ($sni): $(eval echo \$serial$i)" >&2 + i=$((i + 1)) + done + + [ "$serial1" = "$serial2" ] && [ "$serial2" = "$serial3" ] && [ -n "$serial1" ] +} + +test_certificate_from_pebble() { + local sni="$1" + local proxy=$(echo "$GATEWAY_PROXIES" | cut -d' ' -f1) + get_cert_issuer "$proxy" "$sni" | grep -qi "pebble" +} + +test_sni_cert_selection() { + local host="$1" + local sni="$2" + local expected_wildcard="$3" + get_cert_san "$host" "$sni" | grep -q "$expected_wildcard" +} + +test_proxy_tls_health() { + local host="$1" + local gateway_sni="$2" + curl -sf --connect-to "${gateway_sni}:9014:${host}" -k "https://${gateway_sni}:9014/health" > /dev/null 2>&1 +} + +# ==================== Setup ==================== + +setup_certbot_config() { + log_info "Configuring certbot via Admin API..." + + # Set ACME URL + log_info "Setting ACME URL: ${ACME_URL}" + if ! curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.SetCertbotConfig" \ + -H "Content-Type: application/json" \ + -d '{"acme_url": "'"${ACME_URL}"'"}' > /dev/null; then + log_error "Failed to set certbot config" + return 1 + fi + + # Create DNS credential + log_info "Creating DNS credential..." + if ! curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.CreateDnsCredential" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "test-cloudflare", + "provider_type": "cloudflare", + "cf_api_token": "'"${CF_API_TOKEN}"'", + "cf_api_url": "'"${CF_API_URL}"'", + "set_as_default": true, + "dns_txt_ttl": 1, + "max_dns_wait": 0 + }' > /dev/null; then + log_error "Failed to create DNS credential" + return 1 + fi + + # Add domains and trigger renewal + for domain in $CERT_DOMAINS; do + log_info "Adding domain: $domain" + curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.AddZtDomain" \ + -H "Content-Type: application/json" \ + -d '{"domain": "'"${domain}"'"}' > /dev/null || true + + log_info "Triggering renewal for: $domain" + curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.RenewZtDomainCert" \ + -H "Content-Type: application/json" \ + -d '{"domain": "'"${domain}"'", "force": true}' > /dev/null || \ + log_warn "Renewal request failed for $domain (may retry)" + done + + return 0 +} + +# ==================== Main ==================== + +main() { + log_section "dstack-gateway Certbot E2E Test" + + # Phase 1: Mock services + log_phase 1 "Verify mock services" + run_test "Mock CF DNS API health" "$(test_http_health "${MOCK_CF_API}/health"; echo $?)" + run_test "Pebble ACME directory" "$(test_http_health "${PEBBLE_DIR}"; echo $?)" + + # Phase 2: Gateway cluster + log_phase 2 "Verify gateway cluster" + local i=1 + for url in $GATEWAY_DEBUG_URLS; do + run_test "Gateway $i health" "$(test_http_health "${url}/health"; echo $?)" + i=$((i + 1)) + done + + # Phase 3: Configure certbot + log_phase 3 "Configure certbot" + if ! setup_certbot_config; then + log_error "Failed to setup certbot configuration" + fi + + # Phase 4: Certificate issuance + log_phase 4 "Certificate issuance" + local first_domain=$(echo "$CERT_DOMAINS" | cut -d' ' -f1) + local first_sni=$(get_test_sni "$first_domain") + local first_proxy=$(echo "$GATEWAY_PROXIES" | cut -d' ' -f1) + + log_info "Waiting for certificates (up to 120s)..." + local waited=0 + while [ $waited -lt 120 ]; do + if test_certificate_issued "$first_proxy" "$first_sni"; then + log_info "Certificate detected for $first_sni" + break + fi + sleep 5 + waited=$((waited + 5)) + log_info "Waiting... (${waited}s)" + done + + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + run_test "Certificate issued for $domain" \ + "$(test_certificate_issued "$first_proxy" "$sni"; echo $?)" + done + + log_info "Waiting 20s for cluster sync..." + sleep 20 + + # Phase 5: Certificate consistency + log_phase 5 "Certificate consistency" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + run_test "All gateways have same cert for $domain" \ + "$(test_certificates_match "$sni"; echo $?)" + run_test "Cert for $domain issued by Pebble" \ + "$(test_certificate_from_pebble "$sni"; echo $?)" + done + + # Phase 6: SNI-based selection + log_phase 6 "SNI-based certificate selection" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + local wildcard=$(get_wildcard_domain "$domain") + run_test "SNI $sni returns $wildcard cert" \ + "$(test_sni_cert_selection "$first_proxy" "$sni" "$wildcard"; echo $?)" + done + + # Phase 7: Proxy TLS health + log_phase 7 "Proxy TLS health endpoint" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + local i=1 + for proxy in $GATEWAY_PROXIES; do + run_test "Gateway $i TLS health ($sni)" \ + "$(test_proxy_tls_health "$proxy" "$sni"; echo $?)" + i=$((i + 1)) + done + done + + # Phase 8: DNS records (informational) + log_phase 8 "DNS-01 challenge records" + local records=$(curl -sf "${MOCK_CF_API}/api/records" 2>/dev/null || echo "") + if echo "$records" | grep -q "TXT"; then + log_success "DNS TXT records found" + else + log_info "No DNS TXT records (expected if certs cached)" + fi + + # Summary + log_section "Test Summary" + log_info "Passed: $TESTS_PASSED" + log_info "Failed: $TESTS_FAILED" + log_info "Domains: $(echo "$CERT_DOMAINS" | wc -w)" + + if [ $TESTS_FAILED -eq 0 ]; then + log_success "All tests passed!" + exit 0 + else + log_fail "Some tests failed!" + exit 1 + fi +} + +main diff --git a/gateway/test-run/test_certbot.sh b/gateway/test-run/test_certbot.sh new file mode 100755 index 00000000..26a6ba97 --- /dev/null +++ b/gateway/test-run/test_certbot.sh @@ -0,0 +1,563 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Distributed Certbot E2E test script +# Tests certificate issuance and synchronization across gateway nodes + +set -m + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Show help +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Distributed Certbot E2E Test" + echo "" + echo "Options:" + echo " --fresh Clean everything and request new certificate from ACME" + echo " --sync-only Keep existing cert, only test sync between nodes" + echo " --clean Clean all test data and exit" + echo " -h, --help Show this help message" + echo "" + echo "Default (no options): Keep ACME account, request new certificate" + echo "" + echo "Examples:" + echo " $0 # Keep account, new cert" + echo " $0 --fresh # Fresh start, new account and cert" + echo " $0 --sync-only # Test sync with existing cert" + echo " $0 --clean # Clean up all test data" +} + +# Parse arguments +MODE="default" +while [[ $# -gt 0 ]]; do + case $1 in + --fresh) + MODE="fresh" + shift + ;; + --sync-only) + MODE="sync-only" + shift + ;; + --clean) + MODE="clean" + shift + ;; + -h|--help) + show_help + exit 0 + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Load environment variables from .env +if [[ -f ".env" ]]; then + source ".env" +else + echo "ERROR: .env file not found!" + echo "" + echo "Please create a .env file with the following variables:" + echo " CF_API_TOKEN=" + echo " CF_ZONE_ID=" + echo " TEST_DOMAIN=" + echo "" + echo "The domain must be managed by Cloudflare and the API token must have" + echo "permissions to manage DNS records and CAA records." + exit 1 +fi + +# Validate required environment variables +if [[ -z "$CF_API_TOKEN" ]]; then + echo "ERROR: CF_API_TOKEN is not set in .env" + exit 1 +fi + +if [[ -z "$CF_ZONE_ID" ]]; then + echo "ERROR: CF_ZONE_ID is not set in .env" + exit 1 +fi + +if [[ -z "$TEST_DOMAIN" ]]; then + echo "ERROR: TEST_DOMAIN is not set in .env" + exit 1 +fi + +GATEWAY_BIN="$SCRIPT_DIR/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="test_certbot" + +# Let's Encrypt staging URL (for testing without rate limits) +ACME_STAGING_URL="https://acme-staging-v02.api.letsencrypt.org/directory" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + sudo pkill -9 -f "dstack-gateway.*certbot_node[12].toml" >/dev/null 2>&1 || true + sudo ip link delete certbot-test1 2>/dev/null || true + sudo ip link delete certbot-test2 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node config with certbot enabled +generate_certbot_config() { + local node_id=$1 + local rpc_port=$((14000 + node_id * 10 + 2)) + local wg_port=$((14000 + node_id * 10 + 3)) + local proxy_port=$((14000 + node_id * 10 + 4)) + local debug_port=$((14000 + node_id * 10 + 5)) + local wg_ip="10.0.4${node_id}.1/24" + + # Build peer config + local other_node=$((3 - node_id)) # If node_id=1, other=2; if node_id=2, other=1 + local other_rpc_port=$((14000 + other_node * 10 + 2)) + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + local certbot_dir="$abs_run_dir/certbot_node${node_id}" + + mkdir -p "$certbot_dir" + + cat > "$RUN_DIR/certbot_node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.tdxlab.dstack.org" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "https://localhost:${other_rpc_port}" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_certbot_node${node_id}" + +[core.certbot] +enabled = true +workdir = "${certbot_dir}" +acme_url = "${ACME_STAGING_URL}" +cf_api_token = "${CF_API_TOKEN}" +cf_zone_id = "${CF_ZONE_ID}" +auto_set_caa = true +domain = "${TEST_DOMAIN}" +renew_interval = "1h" +renew_before_expiration = "720h" +renew_timeout = "5m" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.4${node_id}.1/31"] +client_ip_range = "10.0.4${node_id}.1/24" +config_path = "${RUN_DIR}/wg_certbot_node${node_id}.conf" +interface = "certbot-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${certbot_dir}/live/cert.pem" +cert_key = "${certbot_dir}/live/key.pem" +base_domain = "tdxlab.dstack.org" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} +EOF + log_info "Generated certbot_node${node_id}.toml (rpc=${rpc_port}, debug=${debug_port}, proxy=${proxy_port})" +} + +start_certbot_node() { + local node_id=$1 + local config="$RUN_DIR/certbot_node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + log_info "Starting certbot node ${node_id}..." + mkdir -p "$RUN_DIR/wavekv_certbot_node${node_id}" + mkdir -p "$LOG_DIR" + ( sudo RUST_LOG=info "$GATEWAY_BIN" -c "$config" > "$log_file" 2>&1 & ) + + # Wait for process to either stabilize or fail + local max_wait=30 + local waited=0 + while [[ $waited -lt $max_wait ]]; do + sleep 2 + waited=$((waited + 2)) + + if ! pgrep -f "dstack-gateway.*${config}" > /dev/null; then + # Process exited, check why + log_error "Certbot node ${node_id} exited after ${waited}s" + echo "--- Log output ---" + cat "$log_file" + echo "--- End log ---" + + # Check for rate limit error + if grep -q "rateLimited" "$log_file"; then + log_error "Let's Encrypt rate limit hit. Wait a few minutes and retry." + fi + return 1 + fi + + # Check if cert files exist (indicates successful init) + local certbot_dir="$RUN_DIR/certbot_node${node_id}" + if [[ -f "$certbot_dir/live/cert.pem" ]] && [[ -f "$certbot_dir/live/key.pem" ]]; then + log_info "Certbot node ${node_id} started and certificate obtained" + return 0 + fi + + log_info "Waiting for node ${node_id} to initialize... (${waited}s)" + done + + # Process still running but no cert yet - might still be requesting + if pgrep -f "dstack-gateway.*${config}" > /dev/null; then + log_info "Certbot node ${node_id} still running, certificate request in progress" + return 0 + fi + + log_error "Certbot node ${node_id} failed to start within ${max_wait}s" + cat "$log_file" + return 1 +} + +stop_certbot_node() { + local node_id=$1 + log_info "Stopping certbot node ${node_id}..." + sudo pkill -9 -f "dstack-gateway.*certbot_node${node_id}.toml" >/dev/null 2>&1 || true + sleep 1 +} + +# Get debug sync data from a node +debug_get_sync_data() { + local debug_port=$1 + curl -s "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Check if KvStore has cert data for the domain +check_kvstore_cert() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + + # The cert data would be in the persistent store + # For now, check if we can get any data + if [[ -z "$response" ]]; then + return 1 + fi + + # Check for cert-related keys in the response + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + # Check if there are any keys that start with 'cert/' + # This is a simplified check + print('ok') + sys.exit(0) +except Exception as e: + print(f'error: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Check if proxy is using a valid certificate by connecting via TLS +check_proxy_cert() { + local proxy_port=$1 + + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + + # Use openssl to check the certificate + local cert_info=$(echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null) + + if [[ -z "$cert_info" ]]; then + log_error "Failed to connect to proxy on port ${proxy_port}" + return 1 + fi + + # Check if the certificate is valid (not self-signed test cert) + # For staging certs, the issuer should contain "Staging" or "(STAGING)" + local issuer=$(echo "$cert_info" | openssl x509 -noout -issuer 2>/dev/null) + + if echo "$issuer" | grep -qi "staging\|fake\|test"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt staging certificate" + log_info "Issuer: $issuer" + return 0 + elif echo "$issuer" | grep -qi "let's encrypt\|letsencrypt"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt certificate" + log_info "Issuer: $issuer" + return 0 + else + log_warn "Proxy on port ${proxy_port} certificate issuer: $issuer" + # Still return success if we got a certificate + return 0 + fi +} + +# Get certificate expiry from proxy health endpoint +get_proxy_cert_expiry() { + local proxy_port=$1 + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -enddate 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate serial from proxy health endpoint +get_proxy_cert_serial() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -serial 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate issuer from proxy +get_proxy_cert_issuer() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -issuer 2>/dev/null +} + +# Wait for certificate to be issued (with timeout) +wait_for_cert() { + local proxy_port=$1 + local timeout_secs=${2:-300} # Default 5 minutes + local start_time=$(date +%s) + + log_info "Waiting for certificate to be issued (timeout: ${timeout_secs}s)..." + + while true; do + local current_time=$(date +%s) + local elapsed=$((current_time - start_time)) + + if [[ $elapsed -ge $timeout_secs ]]; then + log_error "Timeout waiting for certificate" + return 1 + fi + + # Try to get certificate info + local expiry=$(get_proxy_cert_expiry "$proxy_port") + if [[ -n "$expiry" ]]; then + log_info "Certificate detected! Expiry: $expiry" + return 0 + fi + + log_info "Waiting... (${elapsed}s elapsed)" + sleep 10 + done +} + +# ============================================================ +# Main Test +# ============================================================ + +do_clean() { + log_info "Cleaning all certbot test data..." + cleanup + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + log_info "Done." +} + +main() { + log_info "==========================================" + log_info "Distributed Certbot E2E Test" + log_info "==========================================" + log_info "Test domain: $TEST_DOMAIN" + log_info "ACME URL: $ACME_STAGING_URL" + log_info "Mode: $MODE" + log_info "" + + # Handle --clean mode + if [[ "$MODE" == "clean" ]]; then + do_clean + return 0 + fi + + # Handle --sync-only mode: check if cert exists + if [[ "$MODE" == "sync-only" ]]; then + if [[ ! -f "$RUN_DIR/certbot_node1/live/cert.pem" ]]; then + log_error "No existing certificate found. Run without --sync-only first." + return 1 + fi + log_info "Using existing certificate for sync test" + fi + + # Clean up processes and state + cleanup + + # Decide what to clean based on mode + case "$MODE" in + fresh) + # Clean everything including ACME account + log_info "Fresh mode: cleaning all data including ACME account" + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + ;; + sync-only) + # Keep node1 cert, only clean node2 and wavekv + log_info "Sync-only mode: keeping node1 certificate" + sudo rm -rf "$RUN_DIR/certbot_node2" + ;; + *) + # Default: keep ACME account (credentials.json), clean certs + log_info "Default mode: keeping ACME account, requesting new certificate" + # Backup credentials if exists + if [[ -f "$RUN_DIR/certbot_node1/credentials.json" ]]; then + sudo cp "$RUN_DIR/certbot_node1/credentials.json" /tmp/certbot_credentials_backup.json + fi + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + # Restore credentials + if [[ -f /tmp/certbot_credentials_backup.json ]]; then + mkdir -p "$RUN_DIR/certbot_node1" + sudo mv /tmp/certbot_credentials_backup.json "$RUN_DIR/certbot_node1/credentials.json" + fi + ;; + esac + + # Always clean wavekv and gateway state + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + + # Generate configs + log_info "Generating node configurations..." + generate_certbot_config 1 + generate_certbot_config 2 + + # Start Node 1 first - it will request the certificate + log_info "" + log_info "==========================================" + log_info "Phase 1: Start Node 1 and request certificate" + log_info "==========================================" + + if ! start_certbot_node 1; then + log_error "Failed to start node 1" + return 1 + fi + + # Wait for certificate to be issued + local proxy_port_1=14014 + if ! wait_for_cert "$proxy_port_1" 300; then + log_error "Node 1 failed to obtain certificate" + cat "$LOG_DIR/${CURRENT_TEST}_node1.log" | tail -50 + return 1 + fi + + # Get Node 1's certificate info + local node1_serial=$(get_proxy_cert_serial "$proxy_port_1") + local node1_expiry=$(get_proxy_cert_expiry "$proxy_port_1") + log_info "Node 1 certificate serial: $node1_serial" + log_info "Node 1 certificate expiry: $node1_expiry" + + # Show certificate source logs for Node 1 + log_info "" + log_info "Node 1 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node1.log" 2>/dev/null | sed 's/^/ /' + + # Start Node 2 - it should sync the certificate from Node 1 + log_info "" + log_info "==========================================" + log_info "Phase 2: Start Node 2 and verify sync" + log_info "==========================================" + + if ! start_certbot_node 2; then + log_error "Failed to start node 2" + return 1 + fi + + # Wait for Node 2 to sync and load the certificate + local proxy_port_2=14024 + sleep 10 # Give time for sync + + if ! wait_for_cert "$proxy_port_2" 60; then + log_error "Node 2 failed to obtain certificate via sync" + cat "$LOG_DIR/${CURRENT_TEST}_node2.log" | tail -50 + return 1 + fi + + # Get Node 2's certificate info + local node2_serial=$(get_proxy_cert_serial "$proxy_port_2") + local node2_expiry=$(get_proxy_cert_expiry "$proxy_port_2") + log_info "Node 2 certificate serial: $node2_serial" + log_info "Node 2 certificate expiry: $node2_expiry" + + # Show certificate source logs for Node 2 + log_info "" + log_info "Node 2 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node2.log" 2>/dev/null | sed 's/^/ /' + + # Verify both nodes have the same certificate + log_info "" + log_info "==========================================" + log_info "Verification" + log_info "==========================================" + + if [[ "$node1_serial" == "$node2_serial" ]]; then + log_info "SUCCESS: Both nodes have the same certificate (serial: $node1_serial)" + else + log_error "FAILURE: Certificate mismatch!" + log_error " Node 1 serial: $node1_serial" + log_error " Node 2 serial: $node2_serial" + return 1 + fi + + # Check that proxy is actually using the certificate + check_proxy_cert "$proxy_port_1" + check_proxy_cert "$proxy_port_2" + + log_info "" + log_info "==========================================" + log_info "All tests passed!" + log_info "==========================================" + + return 0 +} + +# Run main +main +exit $? diff --git a/gateway/test-run/test_suite.sh b/gateway/test-run/test_suite.sh new file mode 100755 index 00000000..ddb00814 --- /dev/null +++ b/gateway/test-run/test_suite.sh @@ -0,0 +1,2130 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# WaveKV integration test script + +# Don't use set -e as it causes issues with cleanup and test flow +# set -e + +# Disable job control messages (prevents "Killed" messages from messing up output) +set +m + +# Fix terminal output - ensure proper line endings +stty -echoctl 2>/dev/null || true + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="/home/kvin/sdc/home/wavekv/dstack/target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + # Kill only dstack-gateway processes started by this test (matching our specific config path) + # Use absolute path to avoid killing system dstack-gateway processes + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + # Only delete WireGuard interfaces with sudo (these are our test interfaces) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + # Clean up all wavekv data directories to prevent peer list contamination + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" 2>/dev/null || true + rm -f "$RUN_DIR/gateway-state-node"*.json 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node configs +# Usage: generate_config [bootnode_url] +generate_config() { + local node_id=$1 + local bootnode_url=${2:-""} + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + + # Use absolute paths to avoid Rocket's relative path resolution issues + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat >"$RUN_DIR/node${node_id}.toml" </dev/null | grep -q ":${port} "; then + return 0 + fi + sleep 1 + ((waited++)) + done + return 1 +} + +ensure_wg_interface() { + local node_id=$1 + local iface="wavekv-test${node_id}" + + # Check if interface exists, create if not + if ! ip link show "$iface" >/dev/null 2>&1; then + log_info "Creating WireGuard interface ${iface}..." + sudo ip link add "$iface" type wireguard || { + log_error "Failed to create WireGuard interface ${iface}" + return 1 + } + fi + return 0 +} + +start_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + # Calculate ports for this node + local admin_port=$((13000 + node_id * 10 + 6)) + local rpc_port=$((13000 + node_id * 10 + 2)) + + log_info "Starting node ${node_id}..." + + # Kill any existing test process for this node first (use absolute path to be precise) + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Wait for ports to be free + if ! wait_for_port_free $admin_port; then + log_error "Port $admin_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${admin_port} " || true + return 1 + fi + if ! wait_for_port_free $rpc_port; then + log_error "Port $rpc_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${rpc_port} " || true + return 1 + fi + + # Ensure WireGuard interface exists before starting + if ! ensure_wg_interface "$node_id"; then + return 1 + fi + + mkdir -p "$RUN_DIR/wavekv_node${node_id}" + mkdir -p "$LOG_DIR" + (RUST_LOG=info "$GATEWAY_BIN" -c "$config" >"$log_file" 2>&1 &) + sleep 2 + + if pgrep -f "dstack-gateway.*${config}" >/dev/null; then + log_info "Node ${node_id} started successfully" + return 0 + else + log_error "Node ${node_id} failed to start" + cat "$log_file" + return 1 + fi +} + +stop_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local admin_port=$((13000 + node_id * 10 + 6)) + + log_info "Stopping node ${node_id}..." + # Kill only the specific test process using absolute config path + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Verify the port is free, otherwise force kill by PID + if ! wait_for_port_free $admin_port; then + log_warn "Node ${node_id} port still in use, forcing cleanup..." + # Find and kill the process holding the port + local pid=$(netstat -tlnp 2>/dev/null | grep ":${admin_port} " | awk '{print $7}' | cut -d'/' -f1) + if [[ -n "$pid" ]]; then + kill -9 "$pid" 2>/dev/null || true + sleep 1 + fi + fi + + # Reset terminal to fix any broken line endings + stty sane 2>/dev/null || true +} + +# Get WaveKV status via Admin.WaveKvStatus RPC +# Usage: get_status +get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.WaveKvStatus" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +get_n_keys() { + local admin_port=$1 + get_status "$admin_port" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d['persistent']['n_keys'])" 2>/dev/null || echo "0" +} + +# Register CVM via debug port (no attestation required) +# Usage: debug_register_cvm +# Returns: JSON response +debug_register_cvm() { + local debug_port=$1 + local public_key=$2 + local app_id=${3:-"testapp"} + local instance_id=${4:-"testinstance"} + curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null +} + +# Check if debug service is available +# Usage: check_debug_service +check_debug_service() { + local debug_port=$1 + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.Info" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'base_domain' in d" 2>/dev/null; then + return 0 + else + return 1 + fi +} + +# Verify register response is successful (has wg config, no error) +# Usage: verify_register_response +verify_register_response() { + local response="$1" + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) + assert 'wg' in d, 'missing wg config' + assert 'client_ip' in d['wg'], 'missing client_ip' + print(d['wg']['client_ip']) +except Exception as e: + print(f'ERROR: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Get sync data from debug port (peer_addrs, nodes, instances) +# Usage: debug_get_sync_data +# Returns: JSON response with my_node_id, peer_addrs, nodes, instances +debug_get_sync_data() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Check if node has synced peer address from another node +# Usage: has_peer_addr +# Returns: 0 if peer address exists, 1 otherwise +has_peer_addr() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + peer_addrs = d.get('peer_addrs', []) + for pa in peer_addrs: + if pa.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Check if node has synced node info from another node +# Usage: has_node_info +# Returns: 0 if node info exists, 1 otherwise +has_node_info() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + nodes = d.get('nodes', []) + for n in nodes: + if n.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Get number of peer addresses from sync data +# Usage: get_n_peer_addrs +get_n_peer_addrs() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('peer_addrs', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of node infos from sync data +# Usage: get_n_nodes +get_n_nodes() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('nodes', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of instances from KvStore sync data +# Usage: get_n_instances +get_n_instances() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Get Proxy State from debug port (in-memory state) +# Usage: debug_get_proxy_state +# Returns: JSON response with instances and allocated_addresses +debug_get_proxy_state() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/GetProxyState" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Get number of instances from ProxyState (in-memory) +# Usage: get_n_proxy_state_instances +get_n_proxy_state_instances() { + local debug_port=$1 + local response=$(debug_get_proxy_state "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Check KvStore and ProxyState instance consistency +# Usage: check_instance_consistency +# Returns: 0 if consistent, 1 otherwise +check_instance_consistency() { + local debug_port=$1 + local kvstore_instances=$(get_n_instances "$debug_port") + local proxystate_instances=$(get_n_proxy_state_instances "$debug_port") + + if [[ "$kvstore_instances" -eq "$proxystate_instances" ]]; then + return 0 + else + log_error "Instance count mismatch: KvStore=$kvstore_instances, ProxyState=$proxystate_instances" + return 1 + fi +} + +# ============================================================================= +# Test 1: Single node persistence +# ============================================================================= +test_persistence() { + log_info "========== Test 1: Persistence ==========" + cleanup + + generate_config 1 + + # Start node and let it write some data + start_node 1 + + local admin_port=13016 + local initial_keys=$(get_n_keys $admin_port) + log_info "Initial keys: $initial_keys" + + # The gateway auto-writes some data (peer_addr, etc) + sleep 2 + local keys_after_write=$(get_n_keys $admin_port) + log_info "Keys after startup: $keys_after_write" + + # Stop and restart + stop_node 1 + log_info "Restarting node 1..." + start_node 1 + + local keys_after_restart=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after_restart" + + if [[ "$keys_after_restart" -ge "$keys_after_write" ]]; then + log_info "Persistence test PASSED" + return 0 + else + log_error "Persistence test FAILED: expected >= $keys_after_write keys, got $keys_after_restart" + return 1 + fi +} + +# ============================================================================= +# Test 2: Multi-node sync +# ============================================================================= +test_multi_node_sync() { + log_info "========== Test 2: Multi-node Sync ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for sync + log_info "Waiting for nodes to sync..." + sleep 10 + + # Use debug RPC to check actual synced data + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local nodes1=$(get_n_nodes $debug_port1) + local nodes2=$(get_n_nodes $debug_port2) + + log_info "Node 1: peer_addrs=$peer_addrs1, nodes=$nodes1" + log_info "Node 2: peer_addrs=$peer_addrs2, nodes=$nodes2" + + # For true sync, each node should have: + # - At least 2 peer addresses (both nodes' addresses) + # - At least 2 node infos (both nodes' info) + local sync_ok=true + + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + sync_ok=false + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + sync_ok=false + fi + if ! has_node_info $debug_port1 2; then + log_error "Node 1 missing node_info for node 2" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Multi-node sync test PASSED" + return 0 + else + log_error "Multi-node sync test FAILED: nodes did not sync peer data" + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 3: Node recovery after disconnect +# ============================================================================= +test_node_recovery() { + log_info "========== Test 3: Node Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial sync + sleep 5 + + # Stop node 2 + log_info "Stopping node 2 to simulate disconnect..." + stop_node 2 + + # Wait and let node 1 continue + sleep 3 + + # Check node 1 has its own data + local peer_addrs1_before=$(get_n_peer_addrs $debug_port1) + log_info "Node 1 peer_addrs before node 2 restart: $peer_addrs1_before" + + # Restart node 2 + log_info "Restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 10 + + # After recovery, node 2 should have synced node 1's data + local sync_ok=true + + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1 after recovery" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1 after recovery" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Node recovery test PASSED" + return 0 + else + log_error "Node recovery test FAILED: node 2 did not sync data from node 1" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 4: Status endpoint structure (Admin.WaveKvStatus RPC) +# ============================================================================= +test_status_endpoint() { + log_info "========== Test 4: Status Endpoint ==========" + cleanup + + generate_config 1 + start_node 1 + + local admin_port=13016 + local status=$(get_status $admin_port) + + # Verify all expected fields exist + local checks_passed=0 + local total_checks=6 + + echo "$status" | python3 -c " +import sys, json +d = json.load(sys.stdin) +assert d['enabled'] == True, 'enabled should be True' +assert 'persistent' in d, 'missing persistent' +assert 'ephemeral' in d, 'missing ephemeral' +assert d['persistent']['wal_enabled'] == True, 'persistent wal should be enabled' +assert d['ephemeral']['wal_enabled'] == False, 'ephemeral wal should be disabled' +assert 'peers' in d['persistent'], 'missing peers in persistent' +print('All status checks passed') +" && checks_passed=1 + + if [[ $checks_passed -eq 1 ]]; then + log_info "Status endpoint test PASSED" + return 0 + else + log_error "Status endpoint test FAILED" + log_info "Status response: $status" + return 1 + fi +} + +# ============================================================================= +# Test 5: Cross-node data sync verification (KvStore + ProxyState) +# ============================================================================= +test_cross_node_data_sync() { + log_info "========== Test 5: Cross-node Data Sync ==========" + cleanup + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial connection + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register a client on node 1 via debug port + log_info "Registering client on node 1 via debug port..." + local register_response=$(debug_register_cvm $debug_port1 "testkey12345678901234567890123456789012345=" "app1" "inst1") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync (need at least 3 sync intervals of 5s for data to propagate) + log_info "Waiting for sync..." + sleep 20 + + # Check KvStore instance count on both nodes + local kv_instances1=$(get_n_instances $debug_port1) + local kv_instances2=$(get_n_instances $debug_port2) + + # Check ProxyState instance count on both nodes + local ps_instances1=$(get_n_proxy_state_instances $debug_port1) + local ps_instances2=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1: KvStore=$kv_instances1, ProxyState=$ps_instances1" + log_info "Node 2: KvStore=$kv_instances2, ProxyState=$ps_instances2" + + local test_passed=true + + # Verify KvStore sync + if [[ "$kv_instances1" -lt 1 ]] || [[ "$kv_instances2" -lt 1 ]]; then + log_error "KvStore sync failed: kv_instances1=$kv_instances1, kv_instances2=$kv_instances2" + test_passed=false + fi + + # Verify ProxyState sync (node 2 should have loaded instance from KvStore) + if [[ "$ps_instances1" -lt 1 ]] || [[ "$ps_instances2" -lt 1 ]]; then + log_error "ProxyState sync failed: ps_instances1=$ps_instances1, ps_instances2=$ps_instances2" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv_instances1" -ne "$ps_instances1" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv_instances1, ProxyState=$ps_instances1" + test_passed=false + fi + if [[ "$kv_instances2" -ne "$ps_instances2" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv_instances2, ProxyState=$ps_instances2" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Cross-node data sync test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 6: prpc DebugRegisterCvm endpoint (on separate debug port) +# ============================================================================= +test_prpc_register() { + log_info "========== Test 6: prpc DebugRegisterCvm ==========" + cleanup + + generate_config 1 + start_node 1 + + local debug_port=13015 + + # Verify debug service is available first + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + log_info "Debug service is available" + + # Register via debug port + local register_response=$(debug_register_cvm $debug_port "prpctest12345678901234567890123456789012=" "deadbeef" "cafebabe") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "prpc DebugRegisterCvm test FAILED" + return 1 + fi + + log_info "DebugRegisterCvm success: client_ip=$client_ip" + log_info "prpc DebugRegisterCvm test PASSED" + return 0 +} + +# ============================================================================= +# Test 7: prpc Info endpoint +# ============================================================================= +test_prpc_info() { + log_info "========== Test 7: prpc Info ==========" + cleanup + + generate_config 1 + start_node 1 + + local port=13012 + + # Call Info via prpc + # Note: trim: "Tproxy." removes "Tproxy.Gateway." prefix, so endpoint is just /prpc/Info + local info_response=$(curl -sk --cacert "$CA_CERT" \ + -X POST "https://localhost:${port}/prpc/Info" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null) + + log_info "Info response: $info_response" + + # Verify response has expected fields and no error + echo "$info_response" | python3 -c " +import sys, json +d = json.load(sys.stdin) +if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) +assert 'base_domain' in d, 'missing base_domain' +assert 'external_port' in d, 'missing external_port' +print('prpc Info check passed') +" && { + log_info "prpc Info test PASSED" + return 0 + } || { + log_error "prpc Info test FAILED" + return 1 + } +} + +# ============================================================================= +# Test 8: Client registration and data persistence +# ============================================================================= +test_client_registration_persistence() { + log_info "========== Test 8: Client Registration Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client via debug port + log_info "Registering client..." + local register_response=$(debug_register_cvm $debug_port "persisttest1234567890123456789012345678901=" "persist_app" "persist_inst") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before restart: $keys_before" + + # Restart node + stop_node 1 + start_node 1 + + # Check keys after restart + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]] && [[ "$keys_before" -gt 2 ]]; then + log_info "Client registration persistence test PASSED" + return 0 + else + log_error "Client registration persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 9: Stress test - multiple writes +# ============================================================================= +test_stress_writes() { + log_info "========== Test 9: Stress Test ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + local num_clients=10 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + log_info "Registering $num_clients clients via debug port..." + for i in $(seq 1 $num_clients); do + local key=$(printf "stresstest%02d12345678901234567890123456=" "$i") + local app_id=$(printf "stressapp%02d" "$i") + local inst_id=$(printf "stressinst%02d" "$i") + local response=$(debug_register_cvm $debug_port "$key" "$app_id" "$inst_id") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + + log_info "Successfully registered $success_count/$num_clients clients" + + sleep 2 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after stress test: $keys_after" + + # We expect successful registrations to create keys + if [[ "$success_count" -eq "$num_clients" ]] && [[ "$keys_after" -gt 2 ]]; then + log_info "Stress test PASSED" + return 0 + else + log_error "Stress test FAILED: success_count=$success_count, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 10: Network partition simulation (KvStore + ProxyState consistency) +# ============================================================================= +test_network_partition() { + log_info "========== Test 10: Network Partition Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Let them sync initially + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Stop node 2 (simulate partition) + log_info "Simulating network partition - stopping node 2..." + stop_node 2 + + # Register clients on node 1 while node 2 is down + log_info "Registering clients on node 1 during partition..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "partition%02d123456789012345678901234567=" "$i") + local response=$(debug_register_cvm $debug_port1 "$key" "partition_app$i" "partition_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients during partition" + + local kv1_during=$(get_n_instances $debug_port1) + local ps1_during=$(get_n_proxy_state_instances $debug_port1) + log_info "Node 1 during partition: KvStore=$kv1_during, ProxyState=$ps1_during" + + # Restore node 2 + log_info "Healing partition - restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 15 + + # Check KvStore and ProxyState on both nodes after recovery + local kv1_after=$(get_n_instances $debug_port1) + local kv2_after=$(get_n_instances $debug_port2) + local ps1_after=$(get_n_proxy_state_instances $debug_port1) + local ps2_after=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1 after recovery: KvStore=$kv1_after, ProxyState=$ps1_after" + log_info "Node 2 after recovery: KvStore=$kv2_after, ProxyState=$ps2_after" + + local test_passed=true + + # Verify basic sync + if [[ "$success_count" -ne 3 ]] || [[ "$kv1_during" -lt 3 ]]; then + log_error "Registration or KvStore write failed during partition" + test_passed=false + fi + + # Verify node 2 synced KvStore + if [[ "$kv2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 KvStore sync failed: kv2_after=$kv2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify node 2 ProxyState sync + if [[ "$ps2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 ProxyState sync failed: ps2_after=$ps2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1_after" -ne "$ps1_after" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv1_after, ProxyState=$ps1_after" + test_passed=false + fi + if [[ "$kv2_after" -ne "$ps2_after" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv2_after, ProxyState=$ps2_after" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Network partition recovery test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 11: Three-node cluster (KvStore + ProxyState consistency) +# ============================================================================= +test_three_node_cluster() { + log_info "========== Test 11: Three-node Cluster ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + generate_config 3 + + start_node 1 + start_node 2 + start_node 3 + + # Register peers so all nodes can discover each other + setup_peers 1 2 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form + sleep 10 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register client on node 1 + log_info "Registering client on node 1..." + local response=$(debug_register_cvm $debug_port1 "threenode12345678901234567890123456789=" "threenode_app" "threenode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes (need at least 2 sync intervals of 5s) + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1: KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node cluster test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "KvStore from node 3: $(debug_get_sync_data $debug_port3)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + log_info "ProxyState from node 3: $(debug_get_proxy_state $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 12: WAL file integrity +# ============================================================================= +test_wal_integrity() { + log_info "========== Test 12: WAL File Integrity ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients via debug port + for i in $(seq 1 5); do + local key=$(printf "waltest%02d1234567890123456789012345678901=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "wal_app$i" "wal_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/5 clients" + + if [[ "$success_count" -ne 5 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + sleep 2 + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ -f "$wal_file" ]]; then + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size: $wal_size bytes" + + if [[ "$wal_size" -gt 100 ]]; then + log_info "WAL file integrity test PASSED" + return 0 + else + log_error "WAL file integrity test FAILED: WAL file too small ($wal_size bytes)" + return 1 + fi + else + log_error "WAL file not found: $wal_file" + return 1 + fi +} + +# ============================================================================= +# Test 13: Three-node cluster with bootnode (no dynamic peer setup) +# ============================================================================= +test_three_node_bootnode() { + log_info "========== Test 13: Three-node Cluster with Bootnode ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + # Node 1 is the bootnode (no bootnode config) + # Node 2 and 3 use node 1 as bootnode + local bootnode_url="https://localhost:13012" + + generate_config 1 "" + generate_config 2 "$bootnode_url" + generate_config 3 "$bootnode_url" + + # Start node 1 first (bootnode) + start_node 1 + sleep 2 + + # Start node 2 and 3, they will discover each other via bootnode + start_node 2 + start_node 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form via bootnode discovery + log_info "Waiting for nodes to discover each other via bootnode..." + sleep 15 + + # Verify debug service is available on all nodes + for port in $debug_port1 $debug_port2 $debug_port3; do + if ! check_debug_service $port; then + log_error "Debug service not available on port $port" + return 1 + fi + done + + # Check peer discovery - each node should know about the others + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local peer_addrs3=$(get_n_peer_addrs $debug_port3) + + log_info "Peer addresses: node1=$peer_addrs1, node2=$peer_addrs2, node3=$peer_addrs3" + + # Register client on node 2 (not the bootnode) + log_info "Registering client on node 2..." + local response=$(debug_register_cvm $debug_port2 "bootnode12345678901234567890123456789=" "bootnode_app" "bootnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1 (bootnode): KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify peer discovery worked (each node should have at least 2 peer addresses) + if [[ "$peer_addrs1" -lt 2 ]] || [[ "$peer_addrs2" -lt 2 ]] || [[ "$peer_addrs3" -lt 2 ]]; then + log_error "Peer discovery via bootnode failed: peer_addrs1=$peer_addrs1, peer_addrs2=$peer_addrs2, peer_addrs3=$peer_addrs3" + test_passed=false + fi + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node bootnode cluster test PASSED" + return 0 + else + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + log_info "Sync data from node 3: $(debug_get_sync_data $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 14: Node ID reuse rejection +# ============================================================================= +test_node_id_reuse_rejected() { + log_info "========== Test 14: Node ID Reuse Rejected ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + # Start node 1 and node 2, let them sync + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + local admin_port1=13016 + + # Wait for initial sync + log_info "Waiting for initial sync between node 1 and node 2..." + sleep 10 + + # Verify both nodes have synced + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + return 1 + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + return 1 + fi + log_info "Initial sync completed successfully" + + # Get initial key count on node 1 + local keys_before=$(get_n_keys $admin_port1) + log_info "Keys on node 1 before node 2 restart: $keys_before" + + # Stop node 2 and delete its data (simulating a fresh node trying to reuse the ID) + log_info "Stopping node 2 and deleting its data..." + stop_node 2 + rm -rf "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node2.json" + + # Restart node 2 - it will have a new UUID but same node_id + log_info "Restarting node 2 with fresh data (new UUID, same node_id)..." + start_node 2 + + # Re-register peers + setup_peers 1 2 + + # Wait for sync attempt + sleep 15 + + # Check node 2's log for UUID mismatch error + local log_file="${LOG_DIR}/${CURRENT_TEST}_node2.log" + if grep -q "UUID mismatch" "$log_file" 2>/dev/null; then + log_info "Found UUID mismatch error in node 2 log (expected)" + else + log_warn "UUID mismatch error not found in log (may still be rejected)" + fi + + # Node 1 should have rejected sync from new node 2 + # Check if node 1's data is still intact (keys should not decrease) + local keys_after=$(get_n_keys $admin_port1) + log_info "Keys on node 1 after node 2 restart: $keys_after" + + # The new node 2 should NOT have received data from node 1 + # because node 1 should reject sync due to UUID mismatch + local kv2=$(get_n_instances $debug_port2) + log_info "Node 2 instances after restart: $kv2" + + # Verify node 1's data is intact + if [[ "$keys_after" -lt "$keys_before" ]]; then + log_error "Node 1 lost data after node 2 restart with reused ID" + return 1 + fi + + # The test passes if: + # 1. Node 1's data is intact + # 2. Either UUID mismatch was logged OR node 2 didn't get full sync + log_info "Node ID reuse rejection test PASSED" + return 0 +} + +# ============================================================================= +# Test 15: Periodic persistence +# ============================================================================= +test_periodic_persistence() { + log_info "========== Test 15: Periodic Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients to create data + log_info "Registering clients to create data..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "persist%02d123456789012345678901234567890=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "persist_app$i" "persist_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients" + + if [[ "$success_count" -ne 3 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before waiting for persist: $keys_before" + + # Wait for periodic persistence (persist_interval is 5s in test config) + log_info "Waiting for periodic persistence (8s)..." + sleep 8 + + # Check log for periodic persist message + local log_file="${LOG_DIR}/${CURRENT_TEST}_node1.log" + if grep -q "periodic persist completed" "$log_file" 2>/dev/null; then + log_info "Found periodic persist message in log" + else + log_error "Periodic persist message not found in log - test FAILED" + return 1 + fi + + # Stop node + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ ! -f "$wal_file" ]]; then + log_error "WAL file not found: $wal_file" + return 1 + fi + + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size after periodic persist: $wal_size bytes" + + # Restart node and verify data is recovered + log_info "Restarting node to verify persistence..." + start_node 1 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]]; then + log_info "Periodic persistence test PASSED" + return 0 + else + log_error "Periodic persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Admin RPC helper functions +# ============================================================================= + +# Call Admin.SetNodeUrl RPC +# Usage: admin_set_node_url +admin_set_node_url() { + local admin_port=$1 + local node_id=$2 + local url=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeUrl" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"url\": \"$url\"}" 2>/dev/null +} + +# Register peers between nodes via Admin RPC +# This is needed since we removed peer_node_ids/peer_urls from config +# Usage: setup_peers +# Example: setup_peers 1 2 3 # Sets up peers between nodes 1, 2, and 3 +setup_peers() { + local node_ids=("$@") + + for src_node in "${node_ids[@]}"; do + local src_admin_port=$((13000 + src_node * 10 + 6)) + + for dst_node in "${node_ids[@]}"; do + if [[ "$src_node" != "$dst_node" ]]; then + local dst_rpc_port=$((13000 + dst_node * 10 + 2)) + local dst_url="https://localhost:${dst_rpc_port}" + admin_set_node_url "$src_admin_port" "$dst_node" "$dst_url" + fi + done + done + + # Wait for peers to be registered + sleep 1 +} + +# Call Admin.SetNodeStatus RPC +# Usage: admin_set_node_status +# status: "up" or "down" +admin_set_node_status() { + local admin_port=$1 + local node_id=$2 + local status=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeStatus" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"status\": \"$status\"}" 2>/dev/null +} + +# Call Admin.Status RPC to get all nodes +# Usage: admin_get_status +admin_get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.Status" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Get peer URL from sync data +# Usage: get_peer_url +get_peer_url_from_sync() { + local debug_port=$1 + local node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + for pa in d.get('peer_addrs', []): + if pa.get('node_id') == $node_id: + print(pa.get('url', '')) + sys.exit(0) + print('') +except: + print('') +" 2>/dev/null +} + +# ============================================================================= +# Test 16: Admin.SetNodeUrl RPC +# ============================================================================= +test_admin_set_node_url() { + log_info "========== Test 16: Admin.SetNodeUrl RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Set URL for a new node (node 2) via Admin RPC + local new_url="https://new-node2.example.com:8011" + log_info "Setting node 2 URL via Admin.SetNodeUrl..." + local response=$(admin_set_node_url $admin_port 2 "$new_url") + log_info "SetNodeUrl response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeUrl returned error: $response" + return 1 + fi + + # Wait for data to be written + sleep 2 + + # Verify the URL was stored in KvStore + local stored_url=$(get_peer_url_from_sync $debug_port 2) + log_info "Stored URL for node 2: $stored_url" + + if [[ "$stored_url" == "$new_url" ]]; then + log_info "Admin.SetNodeUrl test PASSED" + return 0 + else + log_error "Admin.SetNodeUrl test FAILED: expected '$new_url', got '$stored_url'" + log_info "Sync data: $(debug_get_sync_data $debug_port)" + return 1 + fi +} + +# ============================================================================= +# Test 17: Admin.SetNodeStatus RPC +# ============================================================================= +test_admin_set_node_status() { + log_info "========== Test 17: Admin.SetNodeStatus RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # First set a URL for node 2 so we have a peer + admin_set_node_url $admin_port 2 "https://node2.example.com:8011" + sleep 1 + + # Set node 2 status to "down" + log_info "Setting node 2 status to 'down' via Admin.SetNodeStatus..." + local response=$(admin_set_node_status $admin_port 2 "down") + log_info "SetNodeStatus response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + sleep 1 + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up' via Admin.SetNodeStatus..." + response=$(admin_set_node_status $admin_port 2 "up") + log_info "SetNodeStatus response: $response" + + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + # Test invalid status + log_info "Testing invalid status..." + response=$(admin_set_node_status $admin_port 2 "invalid") + if echo "$response" | grep -q '"error"'; then + log_info "Invalid status correctly rejected" + else + log_warn "Invalid status was not rejected (may be acceptable)" + fi + + log_info "Admin.SetNodeStatus test PASSED" + return 0 +} + +# ============================================================================= +# Test 18: Node down excluded from RegisterCvm response +# ============================================================================= +test_node_status_register_exclude() { + log_info "========== Test 18: Node Down Excluded from Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local admin_port1=13016 + local admin_port2=13026 + local debug_port1=13015 + + # Wait for sync + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Set node 2 status to "down" via node 1's admin API + log_info "Setting node 2 status to 'down'..." + admin_set_node_status $admin_port1 2 "down" + sleep 2 + + # Register a client on node 1 + log_info "Registering client on node 1 (node 2 is down)..." + local response=$(debug_register_cvm $debug_port1 "downtest12345678901234567890123456789012=" "downtest_app" "downtest_inst") + log_info "Register response: $response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Check gateways list in response - should NOT include node 2 + local has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "yes" ]]; then + log_error "Node 2 (down) was included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (down) correctly excluded from registration response" + fi + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up'..." + admin_set_node_status $admin_port1 2 "up" + sleep 2 + + # Register another client + log_info "Registering client on node 1 (node 2 is now up)..." + response=$(debug_register_cvm $debug_port1 "uptest123456789012345678901234567890123=" "uptest_app" "uptest_inst2") + + # Check gateways list - should now include node 2 + has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "no" ]]; then + log_error "Node 2 (up) was NOT included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (up) correctly included in registration response" + fi + + log_info "Node down excluded from registration test PASSED" + return 0 +} + +# ============================================================================= +# Test 19: Node down rejects RegisterCvm requests +# ============================================================================= +test_node_status_register_reject() { + log_info "========== Test 19: Node Down Rejects Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client when node is up (should succeed) + log_info "Registering client when node 1 is up..." + local response=$(debug_register_cvm $debug_port "upnode123456789012345678901234567890123=" "upnode_app" "upnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was up" + return 1 + fi + log_info "Registration succeeded when node was up (IP: $client_ip)" + + # Set node 1 status to "down" (marking itself as down) + log_info "Setting node 1 status to 'down'..." + admin_set_node_status $admin_port 1 "down" + sleep 2 + + # Try to register a client when node is down (should fail) + log_info "Attempting to register client when node 1 is down..." + response=$(debug_register_cvm $debug_port "downnode12345678901234567890123456789012=" "downnode_app" "downnode_inst") + log_info "Register response: $response" + + # Check if response contains error about node being down + if echo "$response" | grep -qi "error"; then + log_info "Registration correctly rejected when node is down" + if echo "$response" | grep -qi "marked as down"; then + log_info "Error message mentions 'marked as down' (correct)" + fi + else + log_error "Registration was NOT rejected when node is down" + log_info "Response: $response" + return 1 + fi + + # Set node 1 status back to "up" + log_info "Setting node 1 status to 'up'..." + admin_set_node_status $admin_port 1 "up" + sleep 2 + + # Register a client again (should succeed) + log_info "Registering client when node 1 is back up..." + response=$(debug_register_cvm $debug_port "backup123456789012345678901234567890123=" "backup_app" "backup_inst") + client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was back up" + return 1 + fi + log_info "Registration succeeded when node was back up (IP: $client_ip)" + + log_info "Node down rejects registration test PASSED" + return 0 +} + +# ============================================================================= +# Clean command - remove all generated files +# ============================================================================= +clean() { + log_info "Cleaning up generated files..." + + # Kill only test gateway processes (matching our specific config path) + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + + # Remove WireGuard interfaces (only our test interfaces need sudo) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + + # Remove run directory (contains all generated files including certs) + rm -rf "$RUN_DIR" + + log_info "Cleanup complete" +} + +# ============================================================================= +# Ensure all certificates exist (CA + RPC + proxy) +# ============================================================================= +ensure_certs() { + # Create directories + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate if not exists + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]] || [[ ! -f "$CERTS_DIR/gateway-ca.cert" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=WaveKV Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA if not exists + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]] || [[ ! -f "$CERTS_DIR/gateway-rpc.cert" ]]; then + log_info "Creating RPC certificate signed by CA..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + # Create certificate with SAN for localhost + cat >"$CERTS_DIR/ext.cnf" </dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates (for TLS termination) + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]] || [[ ! -f "$proxy_cert_dir/key.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi +} + +# ============================================================================= +# Main +# ============================================================================= +main() { + # Handle clean command + if [[ "${1:-}" == "clean" ]]; then + clean + exit 0 + fi + + # Handle cfg command - generate node configuration + if [[ "${1:-}" == "cfg" ]]; then + local node_id="${2:-}" + if [[ -z "$node_id" ]]; then + log_error "Usage: $0 cfg " + log_info "Example: $0 cfg 1" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Generate config for the specified node + generate_config "$node_id" + log_info "Configuration generated: $RUN_DIR/node${node_id}.toml" + exit 0 + fi + + # Handle ls command - list all test cases + if [[ "${1:-}" == "ls" ]]; then + echo "Available test cases:" + echo "" + echo "Quick tests:" + echo " test_persistence - Single node persistence" + echo " test_status_endpoint - Status endpoint structure" + echo " test_prpc_register - prpc DebugRegisterCvm endpoint" + echo " test_prpc_info - prpc Info endpoint" + echo " test_wal_integrity - WAL file integrity" + echo "" + echo "Sync tests:" + echo " test_multi_node_sync - Multi-node sync" + echo " test_node_recovery - Node recovery after disconnect" + echo " test_cross_node_data_sync - Cross-node data sync verification" + echo "" + echo "Advanced tests:" + echo " test_client_registration_persistence - Client registration and persistence" + echo " test_stress_writes - Stress test - multiple writes" + echo " test_network_partition - Network partition simulation" + echo " test_three_node_cluster - Three-node cluster" + echo " test_three_node_bootnode - Three-node cluster with bootnode" + echo " test_node_id_reuse_rejected - Node ID reuse rejection" + echo " test_periodic_persistence - Periodic persistence" + echo "" + echo "Admin RPC tests:" + echo " test_admin_set_node_url - Admin.SetNodeUrl RPC" + echo " test_admin_set_node_status - Admin.SetNodeStatus RPC" + echo " test_node_status_register_exclude - Node down excluded from registration" + echo " test_node_status_register_reject - Node down rejects registration" + echo "" + echo "Usage:" + echo " $0 - Run all tests" + echo " $0 quick - Run quick tests only" + echo " $0 sync - Run sync tests only" + echo " $0 advanced - Run advanced tests only" + echo " $0 admin - Run admin RPC tests only" + echo " $0 case - Run specific test case" + echo " $0 ls - List all test cases" + echo " $0 clean - Clean up generated files" + exit 0 + fi + + # Handle case command - run specific test case + if [[ "${1:-}" == "case" ]]; then + local test_case="${2:-}" + if [[ -z "$test_case" ]]; then + log_error "Usage: $0 case " + log_info "Run '$0 ls' to see all available test cases" + exit 1 + fi + + # Check if gateway binary exists + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Check if test function exists + if ! declare -f "$test_case" >/dev/null; then + log_error "Test case not found: $test_case" + log_info "Use '$0 case' to see available test cases" + exit 1 + fi + + # Run the specific test + log_info "Running test case: $test_case" + CURRENT_TEST="$test_case" + if $test_case; then + log_info "Test PASSED: $test_case" + cleanup + exit 0 + else + log_error "Test FAILED: $test_case" + cleanup + exit 1 + fi + fi + + log_info "Starting WaveKV integration tests..." + + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure all certificates exist (RPC + proxy) + ensure_certs + + local failed=0 + local passed=0 + local failed_tests=() + + run_test() { + local test_name=$1 + CURRENT_TEST="$test_name" + if $test_name; then + ((passed++)) + else + ((failed++)) + failed_tests+=("$test_name") + fi + cleanup + } + + # Run selected test or all tests + local test_filter="${1:-all}" + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "quick" ]]; then + run_test test_persistence + run_test test_status_endpoint + run_test test_prpc_register + run_test test_prpc_info + run_test test_wal_integrity + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "sync" ]]; then + run_test test_multi_node_sync + run_test test_node_recovery + run_test test_cross_node_data_sync + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "advanced" ]]; then + run_test test_client_registration_persistence + run_test test_stress_writes + run_test test_network_partition + run_test test_three_node_cluster + run_test test_three_node_bootnode + run_test test_node_id_reuse_rejected + run_test test_periodic_persistence + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "admin" ]]; then + run_test test_admin_set_node_url + run_test test_admin_set_node_status + run_test test_node_status_register_exclude + run_test test_node_status_register_reject + fi + + echo "" + log_info "==========================================" + log_info "Tests passed: $passed" + if [[ $failed -gt 0 ]]; then + log_error "Tests failed: $failed" + echo "" + log_error "Failed test cases:" + for test_name in "${failed_tests[@]}"; do + log_error " - $test_name" + done + echo "" + log_info "To rerun a failed test:" + log_info " $0 case " + log_info "Example:" + if [[ ${#failed_tests[@]} -gt 0 ]]; then + log_info " $0 case ${failed_tests[0]}" + fi + fi + log_info "==========================================" + + return $failed +} + +# Run if executed directly +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + main "$@" +fi diff --git a/guest-agent/src/main.rs b/guest-agent/src/main.rs index 3183e61f..de4fb3c5 100644 --- a/guest-agent/src/main.rs +++ b/guest-agent/src/main.rs @@ -205,7 +205,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); let figment = config::load_config_figment(args.config.as_deref()); diff --git a/kms/src/main.rs b/kms/src/main.rs index 8584eec9..eddfbdc9 100644 --- a/kms/src/main.rs +++ b/kms/src/main.rs @@ -82,7 +82,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); diff --git a/sdk/rust/tests/test_tappd_client.rs b/sdk/rust/tests/test_tappd_client.rs index 0d363b94..b5afb571 100644 --- a/sdk/rust/tests/test_tappd_client.rs +++ b/sdk/rust/tests/test_tappd_client.rs @@ -11,27 +11,18 @@ use std::env; async fn test_tappd_client_creation() { // Test client creation with default endpoint let _client = TappdClient::new(None); - - // This should succeed without panicking - assert!(true); } #[tokio::test] async fn test_tappd_client_with_custom_endpoint() { // Test client creation with custom endpoint let _client = TappdClient::new(Some("/custom/path/tappd.sock")); - - // This should succeed without panicking - assert!(true); } #[tokio::test] async fn test_tappd_client_with_http_endpoint() { // Test client creation with HTTP endpoint let _client = TappdClient::new(Some("http://localhost:8080")); - - // This should succeed without panicking - assert!(true); } // Integration tests that require a running tappd service diff --git a/supervisor/client/src/main.rs b/supervisor/client/src/main.rs index b993984c..c3b13abd 100644 --- a/supervisor/client/src/main.rs +++ b/supervisor/client/src/main.rs @@ -50,7 +50,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let cli = Cli::parse(); diff --git a/supervisor/src/main.rs b/supervisor/src/main.rs index 752b511c..291ef320 100644 --- a/supervisor/src/main.rs +++ b/supervisor/src/main.rs @@ -90,10 +90,12 @@ fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .with_writer(file) + .with_ansi(false) .init(); } else { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) + .with_ansi(false) .init(); } #[cfg(unix)] diff --git a/tdx-attest/src/dummy.rs b/tdx-attest/src/dummy.rs index 65314a4f..77991e37 100644 --- a/tdx-attest/src/dummy.rs +++ b/tdx-attest/src/dummy.rs @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -use cc_eventlog::TdxEventLog; use num_enum::FromPrimitive; use thiserror::Error; @@ -48,10 +47,7 @@ pub fn extend_rtmr(_index: u32, _event_type: u32, _digest: [u8; 48]) -> Result<( pub fn get_report(_report_data: &TdxReportData) -> Result { Err(TdxAttestError::NotSupported) } -pub fn get_quote( - _report_data: &TdxReportData, - _att_key_id_list: Option<&[TdxUuid]>, -) -> Result<(TdxUuid, Vec)> { +pub fn get_quote(_report_data: &TdxReportData) -> Result> { let _ = _report_data; Err(TdxAttestError::NotSupported) } diff --git a/tools/mock-cf-dns-api/Dockerfile b/tools/mock-cf-dns-api/Dockerfile new file mode 100644 index 00000000..081b9625 --- /dev/null +++ b/tools/mock-cf-dns-api/Dockerfile @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.12-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY app.py . + +# Environment variables +ENV PORT=8080 +ENV DEBUG=false + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Run with gunicorn for production +CMD ["gunicorn", "--bind", "0.0.0.0:8080", "--workers", "2", "--threads", "4", "app:app"] diff --git a/tools/mock-cf-dns-api/app.py b/tools/mock-cf-dns-api/app.py new file mode 100644 index 00000000..5703db79 --- /dev/null +++ b/tools/mock-cf-dns-api/app.py @@ -0,0 +1,848 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Mock Cloudflare DNS API Server + +A mock server that simulates Cloudflare's DNS API for testing purposes. +Supports the following endpoints used by certbot: +- POST /client/v4/zones/{zone_id}/dns_records - Create DNS record +- GET /client/v4/zones/{zone_id}/dns_records - List DNS records +- DELETE /client/v4/zones/{zone_id}/dns_records/{record_id} - Delete DNS record +""" + +import os +import uuid +import time +import json +from datetime import datetime +from flask import Flask, request, jsonify, render_template_string +from functools import wraps + +app = Flask(__name__) + +# In-memory storage for DNS records +# Structure: {zone_id: {record_id: record_data}} +dns_records = {} + +# Request/Response logs for debugging +request_logs = [] +MAX_LOGS = 100 + +# Valid API tokens (for testing, accept any non-empty token or use env var) +VALID_TOKENS = os.environ.get("CF_API_TOKENS", "").split(",") if os.environ.get("CF_API_TOKENS") else None + + +def log_request(zone_id, method, path, req_data, resp_data, status_code): + """Log API requests for the management UI.""" + log_entry = { + "timestamp": datetime.now().isoformat(), + "zone_id": zone_id, + "method": method, + "path": path, + "request": req_data, + "response": resp_data, + "status_code": status_code, + } + request_logs.insert(0, log_entry) + if len(request_logs) > MAX_LOGS: + request_logs.pop() + + +def generate_record_id(): + """Generate a Cloudflare-style record ID.""" + return uuid.uuid4().hex[:32] + + +def get_current_time(): + """Get current time in Cloudflare format.""" + return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000000Z") + + +def verify_auth(f): + """Decorator to verify Bearer token authentication.""" + @wraps(f) + def decorated(*args, **kwargs): + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return jsonify({ + "success": False, + "errors": [{"code": 10000, "message": "Authentication error"}], + "messages": [], + "result": None + }), 401 + + token = auth_header[7:] # Remove "Bearer " prefix + + # If VALID_TOKENS is set, validate against it; otherwise accept any token + if VALID_TOKENS and token not in VALID_TOKENS: + return jsonify({ + "success": False, + "errors": [{"code": 10000, "message": "Invalid API token"}], + "messages": [], + "result": None + }), 403 + + return f(*args, **kwargs) + return decorated + + +def cf_response(result, success=True, errors=None, messages=None): + """Create a Cloudflare-style API response.""" + return { + "success": success, + "errors": errors or [], + "messages": messages or [], + "result": result + } + + +def cf_error(message, code=1000): + """Create a Cloudflare-style error response.""" + return cf_response(None, success=False, errors=[{"code": code, "message": message}]) + + +# ==================== DNS Record Endpoints ==================== + +@app.route("/client/v4/zones//dns_records", methods=["POST"]) +@verify_auth +def create_dns_record(zone_id): + """Create a new DNS record.""" + data = request.get_json() + + if not data: + resp = cf_error("Invalid request body") + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", None, resp, 400) + return jsonify(resp), 400 + + record_type = data.get("type") + name = data.get("name") + + if not record_type or not name: + resp = cf_error("Missing required fields: type, name") + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 400) + return jsonify(resp), 400 + + # Initialize zone if not exists + if zone_id not in dns_records: + dns_records[zone_id] = {} + + record_id = generate_record_id() + now = get_current_time() + + # Build record based on type + record = { + "id": record_id, + "zone_id": zone_id, + "zone_name": f"zone-{zone_id[:8]}.example.com", + "name": name, + "type": record_type, + "ttl": data.get("ttl", 1), + "proxied": data.get("proxied", False), + "proxiable": False, + "locked": False, + "created_on": now, + "modified_on": now, + "meta": { + "auto_added": False, + "managed_by_apps": False, + "managed_by_argo_tunnel": False + } + } + + # Handle different record types + if record_type == "TXT": + record["content"] = data.get("content", "") + elif record_type == "CAA": + caa_data = data.get("data", {}) + record["data"] = caa_data + # Format content as Cloudflare does + flags = caa_data.get("flags", 0) + tag = caa_data.get("tag", "") + value = caa_data.get("value", "") + record["content"] = f'{flags} {tag} "{value}"' + elif record_type == "A": + record["content"] = data.get("content", "") + elif record_type == "AAAA": + record["content"] = data.get("content", "") + elif record_type == "CNAME": + record["content"] = data.get("content", "") + else: + record["content"] = data.get("content", "") + + dns_records[zone_id][record_id] = record + + resp = cf_response(record) + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 200) + + print(f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}") + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records", methods=["GET"]) +@verify_auth +def list_dns_records(zone_id): + """List DNS records for a zone.""" + zone_records = dns_records.get(zone_id, {}) + records_list = list(zone_records.values()) + + # Filter by type if specified + record_type = request.args.get("type") + if record_type: + records_list = [r for r in records_list if r["type"] == record_type] + + # Filter by name if specified + name = request.args.get("name") + if name: + records_list = [r for r in records_list if r["name"] == name] + + # Get pagination params + page = int(request.args.get("page", 1)) + per_page = int(request.args.get("per_page", 100)) + + # Pagination + total_count = len(records_list) + total_pages = max(1, (total_count + per_page - 1) // per_page) + start_idx = (page - 1) * per_page + end_idx = start_idx + per_page + page_records = records_list[start_idx:end_idx] + + resp = { + "success": True, + "errors": [], + "messages": [], + "result": page_records, + "result_info": { + "page": page, + "per_page": per_page, + "count": len(page_records), + "total_count": total_count, + "total_pages": total_pages + } + } + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200) + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["GET"]) +@verify_auth +def get_dns_record(zone_id, record_id): + """Get a specific DNS record.""" + zone_records = dns_records.get(zone_id, {}) + record = zone_records.get(record_id) + + if not record: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + resp = cf_response(record) + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["PUT"]) +@verify_auth +def update_dns_record(zone_id, record_id): + """Update a DNS record.""" + zone_records = dns_records.get(zone_id, {}) + record = zone_records.get(record_id) + + if not record: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + data = request.get_json() + if not data: + resp = cf_error("Invalid request body") + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400) + return jsonify(resp), 400 + + # Update allowed fields + for field in ["name", "type", "content", "ttl", "proxied", "data"]: + if field in data: + record[field] = data[field] + + record["modified_on"] = get_current_time() + + resp = cf_response(record) + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200) + + print(f"[UPDATE] Zone: {zone_id}, Record: {record_id}") + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["DELETE"]) +@verify_auth +def delete_dns_record(zone_id, record_id): + """Delete a DNS record.""" + zone_records = dns_records.get(zone_id, {}) + + if record_id not in zone_records: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + del zone_records[record_id] + + resp = cf_response({"id": record_id}) + log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + + print(f"[DELETE] Zone: {zone_id}, Record: {record_id}") + + return jsonify(resp), 200 + + +# ==================== Zone Endpoints ==================== + +# Pre-configured zones for testing +# Can be configured via MOCK_ZONES environment variable (JSON format) +# Example: MOCK_ZONES='[{"id":"zone123","name":"example.com"},{"id":"zone456","name":"test.local"}]' +DEFAULT_ZONES = [ + {"id": "mock-zone-test-local", "name": "test.local"}, + {"id": "mock-zone-example-com", "name": "example.com"}, + {"id": "mock-zone-test0-local", "name": "test0.local"}, + {"id": "mock-zone-test1-local", "name": "test1.local"}, + {"id": "mock-zone-test2-local", "name": "test2.local"}, +] + + +def get_configured_zones(): + """Get zones from environment or use defaults.""" + zones_json = os.environ.get("MOCK_ZONES") + if zones_json: + try: + return json.loads(zones_json) + except json.JSONDecodeError: + print(f"Warning: Invalid MOCK_ZONES JSON, using defaults") + return DEFAULT_ZONES + + +@app.route("/client/v4/zones", methods=["GET"]) +@verify_auth +def list_zones(): + """List all zones (paginated).""" + page = int(request.args.get("page", 1)) + per_page = int(request.args.get("per_page", 50)) + name_filter = request.args.get("name") + + zones = get_configured_zones() + + # Filter by name if specified + if name_filter: + zones = [z for z in zones if z["name"] == name_filter] + + # Build full zone objects + full_zones = [] + for z in zones: + full_zones.append({ + "id": z["id"], + "name": z["name"], + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": [ + "ns1.mock-cloudflare.com", + "ns2.mock-cloudflare.com" + ], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + }) + + # Pagination + total_count = len(full_zones) + total_pages = max(1, (total_count + per_page - 1) // per_page) + start_idx = (page - 1) * per_page + end_idx = start_idx + per_page + page_zones = full_zones[start_idx:end_idx] + + result = { + "success": True, + "errors": [], + "messages": [], + "result": page_zones, + "result_info": { + "page": page, + "per_page": per_page, + "count": len(page_zones), + "total_count": total_count, + "total_pages": total_pages + } + } + + log_request("*", "GET", "/zones", dict(request.args), result, 200) + print(f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}") + + return jsonify(result), 200 + + +@app.route("/client/v4/zones/", methods=["GET"]) +@verify_auth +def get_zone(zone_id): + """Get zone details (mock).""" + # Try to find zone in configured zones + zones = get_configured_zones() + zone_info = next((z for z in zones if z["id"] == zone_id), None) + + if zone_info: + zone_name = zone_info["name"] + else: + # Fallback for unknown zone IDs + zone_name = f"zone-{zone_id[:8]}.example.com" + + zone = { + "id": zone_id, + "name": zone_name, + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": [ + "ns1.mock-cloudflare.com", + "ns2.mock-cloudflare.com" + ], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + } + + resp = cf_response(zone) + log_request(zone_id, "GET", f"/zones/{zone_id}", None, resp, 200) + + return jsonify(resp), 200 + + +# ==================== Management UI ==================== + +MANAGEMENT_HTML = """ + + + + + + Mock Cloudflare DNS API - Management + + + +
+
+

Mock Cloudflare DNS API

+

Testing server for ACME DNS-01 challenges

+
+ +
+
+

{{ zone_count }}

+

Zones

+
+
+

{{ record_count }}

+

DNS Records

+
+
+

{{ request_count }}

+

API Requests

+
+
+ +
+
+

DNS Records

+ +
+
+ {% if records %} + + + + + + + + + + + + + {% for record in records %} + + + + + + + + + {% endfor %} + +
Zone IDTypeNameContentCreatedActions
{{ record.zone_id[:12] }}...{{ record.type }}{{ record.name }}{{ record.content }}{{ record.created_on[:19] }} + +
+ {% else %} +
+

No DNS records yet. Records created via API will appear here.

+
+ {% endif %} +
+
+ +
+
+

Recent API Requests

+ +
+
+ {% if logs %} + {% for log in logs %} +
+ {{ log.timestamp }} + {{ log.method }} + {{ log.path }} + + ({{ log.status_code }}) + + {% if log.request %} +
+ Request/Response +
Request: {{ log.request | tojson(indent=2) }}
+
Response: {{ log.response | tojson(indent=2) }}
+
+ {% endif %} +
+ {% endfor %} + {% else %} +
+

No API requests yet.

+
+ {% endif %} +
+
+ +
+
+

API Usage

+
+
+

Base URL:

+

+            
+
+
+ + + + + + +""" + + +@app.route("/") +def management_ui(): + """Render the management UI.""" + all_records = [] + for zone_id, records in dns_records.items(): + all_records.extend(records.values()) + + # Sort by created time, newest first + all_records.sort(key=lambda r: r.get("created_on", ""), reverse=True) + + return render_template_string( + MANAGEMENT_HTML, + zone_count=len(dns_records), + record_count=sum(len(r) for r in dns_records.values()), + request_count=len(request_logs), + records=all_records, + logs=request_logs[:20], + port=os.environ.get("PORT", 8080) + ) + + +# ==================== Management API ==================== + +@app.route("/api/records", methods=["DELETE"]) +def clear_all_records(): + """Clear all DNS records.""" + dns_records.clear() + return jsonify({"success": True}) + + +@app.route("/api/records//", methods=["DELETE"]) +def delete_record_ui(zone_id, record_id): + """Delete a specific record from UI.""" + if zone_id in dns_records and record_id in dns_records[zone_id]: + del dns_records[zone_id][record_id] + return jsonify({"success": True}) + + +@app.route("/api/logs", methods=["DELETE"]) +def clear_logs(): + """Clear request logs.""" + request_logs.clear() + return jsonify({"success": True}) + + +@app.route("/api/records", methods=["GET"]) +def get_all_records(): + """Get all records as JSON.""" + all_records = [] + for zone_id, records in dns_records.items(): + all_records.extend(records.values()) + return jsonify(all_records) + + +@app.route("/health") +def health(): + """Health check endpoint.""" + return jsonify({"status": "healthy", "records": sum(len(r) for r in dns_records.values())}) + + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 8080)) + debug = os.environ.get("DEBUG", "false").lower() == "true" + + print(f""" + ╔═══════════════════════════════════════════════════════════════╗ + ║ Mock Cloudflare DNS API Server ║ + ╠═══════════════════════════════════════════════════════════════╣ + ║ Management UI: http://localhost:{port}/ ║ + ║ API Base URL: http://localhost:{port}/client/v4 ║ + ╚═══════════════════════════════════════════════════════════════╝ + """) + + app.run(host="0.0.0.0", port=port, debug=debug) diff --git a/tools/mock-cf-dns-api/docker-compose.yml b/tools/mock-cf-dns-api/docker-compose.yml new file mode 100644 index 00000000..951bd68b --- /dev/null +++ b/tools/mock-cf-dns-api/docker-compose.yml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +version: "3.8" + +services: + mock-cf-dns-api: + image: kvin/mock-cf-dns-api:latest + ports: + - "8080:8080" + environment: + - PORT=8080 + - DEBUG=false + # Optional: comma-separated list of valid API tokens + # - CF_API_TOKENS=token1,token2 + restart: unless-stopped diff --git a/tools/mock-cf-dns-api/requirements.txt b/tools/mock-cf-dns-api/requirements.txt new file mode 100644 index 00000000..a36229c7 --- /dev/null +++ b/tools/mock-cf-dns-api/requirements.txt @@ -0,0 +1,2 @@ +flask>=3.0.0 +gunicorn>=21.0.0 diff --git a/verifier/src/main.rs b/verifier/src/main.rs index d832a6ba..ea9cabd5 100644 --- a/verifier/src/main.rs +++ b/verifier/src/main.rs @@ -159,7 +159,11 @@ async fn run_oneshot(file_path: &str, config: &Config) -> anyhow::Result<()> { #[rocket::main] async fn main() -> Result<()> { - tracing_subscriber::fmt::try_init().ok(); + { + use tracing_subscriber::{fmt, EnvFilter}; + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + fmt().with_env_filter(filter).with_ansi(false).init(); + } let cli = Cli::parse(); diff --git a/vmm/src/main.rs b/vmm/src/main.rs index bb1f7873..40033ed5 100644 --- a/vmm/src/main.rs +++ b/vmm/src/main.rs @@ -159,7 +159,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse();