diff --git a/AGENTS.md b/AGENTS.md index fe417baad..0123615c0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,6 +37,7 @@ These pipelines connect skills into end-to-end workflows. Individual skill files | `crates/openshell-bootstrap/` | Gateway metadata | Gateway registration metadata, auth token storage, mTLS bundle storage | | `crates/openshell-ocsf/` | OCSF logging | OCSF v1.7.0 event types, builders, shorthand/JSONL formatters, tracing layers | | `crates/openshell-core/` | Shared core | Common types, configuration, error handling | +| `crates/openshell-sdk/` | Shared client SDK | Async Rust gateway client (gRPC transport, TLS, OIDC refresh, edge tunnel); consumed by CLI, TUI, and `@openshell/sdk` | | `crates/openshell-providers/` | Provider management | Credential provider backends | | `crates/openshell-tui/` | Terminal UI | Ratatui-based dashboard for monitoring | | `crates/openshell-driver-kubernetes/` | Kubernetes compute driver | In-process `ComputeDriver` backend for K8s sandbox pods | diff --git a/Cargo.lock b/Cargo.lock index 366f001a6..e28bafe6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3346,6 +3346,7 @@ dependencies = [ "openshell-policy", "openshell-prover", "openshell-providers", + "openshell-sdk", "openshell-tui", "owo-colors", "prost-types", @@ -3361,11 +3362,8 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", - "tokio-rustls", "tokio-stream", - "tokio-tungstenite 0.26.2", "tonic", - "tower 0.5.3", "tracing", "tracing-subscriber", "url", @@ -3615,6 +3613,31 @@ dependencies = [ "webpki-roots 1.0.7", ] +[[package]] +name = "openshell-sdk" +version = "0.0.0" +dependencies = [ + "async-trait", + "futures", + "hyper", + "hyper-util", + "miette", + "oauth2", + "openshell-core", + "reqwest 0.12.28", + "rustls", + "rustls-pemfile", + "serde", + "thiserror 2.0.18", + "tokio", + "tokio-rustls", + "tokio-stream", + "tokio-tungstenite 0.26.2", + "tonic", + "tower 0.5.3", + "tracing", +] + [[package]] name = "openshell-server" version = "0.0.0" @@ -3706,6 +3729,7 @@ dependencies = [ "openshell-core", "openshell-policy", "openshell-providers", + "openshell-sdk", "owo-colors", "ratatui", "serde", diff --git a/crates/openshell-cli/Cargo.toml b/crates/openshell-cli/Cargo.toml index 577fb73b9..d9ff84f11 100644 --- a/crates/openshell-cli/Cargo.toml +++ b/crates/openshell-cli/Cargo.toml @@ -20,6 +20,7 @@ openshell-core = { path = "../openshell-core", default-features = false } openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } openshell-prover = { path = "../openshell-prover" } +openshell-sdk = { path = "../openshell-sdk" } openshell-tui = { path = "../openshell-tui" } serde = { workspace = true } serde_json = { workspace = true } @@ -49,8 +50,6 @@ hyper-util = { workspace = true } hyper-rustls = { version = "0.27", default-features = false, features = ["native-tokio", "http1", "http2", "tls12", "logging", "ring", "webpki-tokio"] } rustls = { workspace = true } rustls-pemfile = { workspace = true } -tokio-rustls = { workspace = true } -tower = { workspace = true } reqwest = { workspace = true } # Error handling @@ -66,9 +65,6 @@ tempfile = "3" oauth2 = "5" base64 = { workspace = true } -# WebSocket (Cloudflare tunnel proxy) -tokio-tungstenite = { workspace = true } - # Streams futures = { workspace = true } tokio-stream = { workspace = true } diff --git a/crates/openshell-cli/src/completers.rs b/crates/openshell-cli/src/completers.rs index a421b418a..ff8713dcb 100644 --- a/crates/openshell-cli/src/completers.rs +++ b/crates/openshell-cli/src/completers.rs @@ -9,9 +9,9 @@ use openshell_bootstrap::edge_token::load_edge_token; use openshell_bootstrap::oidc_token::{is_token_expired, load_oidc_token, store_oidc_token}; use openshell_bootstrap::{list_gateways, load_active_gateway, load_gateway_metadata}; use openshell_core::ObjectName; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ListProvidersRequest, ListSandboxesRequest}; +use openshell_sdk::EdgeAuthInterceptor; use tonic::service::interceptor::InterceptedService; use tonic::transport::Channel; diff --git a/crates/openshell-cli/src/lib.rs b/crates/openshell-cli/src/lib.rs index 156668951..84999b42d 100644 --- a/crates/openshell-cli/src/lib.rs +++ b/crates/openshell-cli/src/lib.rs @@ -10,7 +10,6 @@ pub(crate) static TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(() pub mod auth; pub mod completers; -pub mod edge_tunnel; pub mod oidc_auth; pub mod output; pub(crate) mod policy_update; diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 109b67596..d56887be9 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -2970,7 +2970,7 @@ async fn main() -> Result<()> { let mut tls = tls.with_gateway_name(&ctx.name); apply_auth(&mut tls, &ctx.name); let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?; - let interceptor = openshell_core::auth::EdgeAuthInterceptor::new( + let interceptor = openshell_sdk::EdgeAuthInterceptor::new( tls.oidc_token.as_deref(), tls.edge_token.as_deref(), )?; diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index 379a53112..bdc30e902 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -17,10 +17,10 @@ use miette::{IntoDiagnostic, Result}; use oauth2::basic::BasicClient; use oauth2::{ AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, - RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, + RedirectUrl, Scope, TokenResponse, TokenUrl, }; use openshell_bootstrap::oidc_token::OidcTokenBundle; -use serde::Deserialize; +use openshell_sdk::oidc::{RefreshTokenInput, discover, http_client, refresh_token}; use std::convert::Infallible; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -30,50 +30,6 @@ use tracing::debug; const AUTH_TIMEOUT: Duration = Duration::from_secs(120); -/// OIDC discovery document (subset of fields we need). -#[derive(Debug, Deserialize)] -struct OidcDiscovery { - issuer: String, - authorization_endpoint: String, - token_endpoint: String, -} - -/// Discover OIDC endpoints from the issuer's well-known configuration. -/// -/// Validates that the discovery document's `issuer` field matches the -/// configured issuer URL to prevent SSRF or misdirection. -async fn discover(issuer: &str, insecure: bool) -> Result { - let normalized_issuer = issuer.trim_end_matches('/'); - let url = format!("{normalized_issuer}/.well-known/openid-configuration"); - let client = http_client(insecure); - let resp: OidcDiscovery = client - .get(&url) - .send() - .await - .into_diagnostic()? - .json() - .await - .into_diagnostic()?; - - let discovered_issuer = resp.issuer.trim_end_matches('/'); - if discovered_issuer != normalized_issuer { - return Err(miette::miette!( - "OIDC discovery issuer mismatch: expected '{}', got '{}'", - normalized_issuer, - discovered_issuer - )); - } - Ok(resp) -} - -fn http_client(insecure: bool) -> reqwest::Client { - let mut builder = reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); - if insecure { - builder = builder.danger_accept_invalid_certs(true); - } - builder.build().expect("failed to build HTTP client") -} - fn build_scopes(scopes: Option<&str>) -> Vec { let mut result = vec![Scope::new("openid".to_string())]; if let Some(s) = scopes { @@ -227,36 +183,33 @@ pub async fn oidc_client_credentials_flow( /// Refresh an OIDC token using the `refresh_token` grant. /// -/// Preserves the existing refresh token if the server does not return a new -/// one (per OAuth 2.0 spec, the refresh response may omit `refresh_token`). +/// Wraps [`openshell_sdk::oidc::refresh_token`] with the CLI's +/// [`OidcTokenBundle`] storage shape. Preserves the existing refresh +/// token when the server omits one (per OAuth 2.0 the refresh response +/// is allowed to leave `refresh_token` out). pub async fn oidc_refresh_token( bundle: &OidcTokenBundle, insecure: bool, ) -> Result { - let refresh_token = bundle.refresh_token.as_deref().ok_or_else(|| { + let refresh = bundle.refresh_token.as_deref().ok_or_else(|| { miette::miette!( "no refresh token available — re-authenticate with: openshell gateway login" ) })?; - let discovery = discover(&bundle.issuer, insecure).await?; - - let client = BasicClient::new(ClientId::new(bundle.client_id.clone())) - .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?); - - let http = http_client(insecure); - let token_response = client - .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) - .request_async(&http) - .await - .map_err(|e| miette::miette!("token refresh failed: {e}"))?; - - let mut refreshed = - bundle_from_oauth2_response(&token_response, &bundle.issuer, &bundle.client_id); - if refreshed.refresh_token.is_none() { - refreshed.refresh_token.clone_from(&bundle.refresh_token); - } - Ok(refreshed) + let input = + RefreshTokenInput::new(refresh, &bundle.issuer, &bundle.client_id).with_insecure(insecure); + let output = refresh_token(&input).await.into_diagnostic()?; + + Ok(OidcTokenBundle { + access_token: output.access_token, + refresh_token: output + .refresh_token + .or_else(|| bundle.refresh_token.clone()), + expires_at: output.expires_at, + issuer: bundle.issuer.clone(), + client_id: bundle.client_id.clone(), + }) } /// Ensure we have a valid OIDC token for the given gateway, refreshing if needed. diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index 10df401a5..89097b0de 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -2,25 +2,26 @@ // SPDX-License-Identifier: Apache-2.0 use miette::{IntoDiagnostic, Result, WrapErr}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::inference_client::InferenceClient; use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_sdk::EdgeAuthInterceptor; use rustls::{ RootCertStore, - client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, - pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, + pki_types::{CertificateDer, PrivateKeyDer}, }; -use std::collections::HashMap; -use std::future::Future; use std::io::Cursor; -use std::net::SocketAddr; use std::path::PathBuf; -use std::sync::OnceLock; use std::time::Duration; -use tokio::sync::Mutex; use tonic::service::interceptor::InterceptedService; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; -use tracing::debug; + +// `build_insecure_rustls_config` lives in the SDK (used by the SDK's +// transport stack and by CLI's HTTP health check). The other former +// `tls.rs` helpers (`build_rustls_config`, `build_tonic_tls_config`, +// `load_private_key`, `TlsMaterials`) were tied to mTLS and now live +// below as CLI-private legacy code — they will go away when mTLS is +// retired as an auth method. +pub use openshell_sdk::transport::build_insecure_rustls_config; /// Concrete gRPC client type used by all commands. pub type GrpcClient = OpenShellClient>; @@ -104,12 +105,44 @@ impl TlsOptions { pub fn is_bearer_auth(&self) -> bool { self.edge_token.is_some() || self.oidc_token.is_some() } -} -pub struct TlsMaterials { - ca: Vec, - cert: Vec, - key: Vec, + /// Returns `true` when this `TlsOptions` carries a full mTLS client + /// identity (cert + key on disk). Used by [`build_channel`] to route + /// mTLS-authenticated gateways through the legacy inline path. + pub fn has_mtls_identity(&self, server: &str) -> bool { + let resolved = self.with_default_paths(server); + resolved.cert.as_ref().is_some_and(|p| p.exists()) + && resolved.key.as_ref().is_some_and(|p| p.exists()) + } + + /// Convert this CLI-side `TlsOptions` into an SDK [`openshell_sdk::ClientConfig`] + /// for non-mTLS gateways. + /// + /// Reads the CA cert from disk if a path resolves; a missing file is + /// non-fatal and falls back to system roots (matches today's OIDC + /// fallback behavior). Maps tokens to [`openshell_sdk::AuthConfig`] + /// with OIDC taking precedence over `EdgeJwt` when both are set. + /// + /// mTLS materials are intentionally not carried through; gateways + /// requiring client certificates are dispatched to the legacy inline + /// path in [`build_channel`] before this conversion is reached. + pub fn to_client_config(&self, server: &str) -> openshell_sdk::ClientConfig { + let resolved = self.with_default_paths(server); + let ca_cert = resolved + .ca + .as_ref() + .and_then(|ca_path| std::fs::read(ca_path).ok()); + let auth = match (&resolved.oidc_token, &resolved.edge_token) { + (Some(token), _) => Some(openshell_sdk::AuthConfig::Oidc(token.clone())), + (None, Some(token)) => Some(openshell_sdk::AuthConfig::EdgeJwt(token.clone())), + (None, None) => None, + }; + let mut config = openshell_sdk::ClientConfig::new(server); + config.ca_cert = ca_cert; + config.auth = auth; + config.insecure_skip_verify = resolved.gateway_insecure; + config + } } /// Resolve the TLS cert directory for a known gateway name. @@ -163,6 +196,20 @@ fn xdg_config_dir() -> Result { openshell_core::paths::xdg_config_dir() } +// ── Legacy mTLS path ───────────────────────────────────────────────── +// Everything in this section supports gateways that authenticate clients +// with an mTLS certificate. mTLS is being retired as an auth method, and +// the SDK does not speak it. Until product removes mTLS support, these +// helpers stay in CLI for the `else { full mTLS }` branch of +// `build_channel` and the matching branch of `http_health_check`. + +/// In-memory mTLS materials read from disk by [`require_tls_materials`]. +pub struct TlsMaterials { + pub ca: Vec, + pub cert: Vec, + pub key: Vec, +} + pub fn require_tls_materials(server: &str, tls: &TlsOptions) -> Result { let resolved = tls.with_default_paths(server); let default_hint = default_tls_dir(server).map_or_else(String::new, |dir| { @@ -192,6 +239,7 @@ pub fn require_tls_materials(server: &str, tls: &TlsOptions) -> Result Result> { let mut cursor = Cursor::new(pem); let key = rustls_pemfile::private_key(&mut cursor) @@ -200,11 +248,12 @@ fn load_private_key(pem: &[u8]) -> Result> { Ok(key) } +/// Build a `rustls` mTLS client config (used by `http_health_check`). pub fn build_rustls_config(materials: &TlsMaterials) -> Result { let mut roots = RootCertStore::empty(); let mut ca_cursor = Cursor::new(&materials.ca); let ca_certs = rustls_pemfile::certs(&mut ca_cursor) - .collect::>, _>>() + .collect::>, _>>() .into_diagnostic()?; for cert in ca_certs { roots.add(cert).into_diagnostic()?; @@ -212,7 +261,7 @@ pub fn build_rustls_config(materials: &TlsMaterials) -> Result>, _>>() + .collect::>, _>>() .into_diagnostic()?; let key = load_private_key(&materials.key)?; @@ -222,6 +271,8 @@ pub fn build_rustls_config(materials: &TlsMaterials) -> Result ClientTlsConfig { let ca_cert = Certificate::from_pem(materials.ca.clone()); let identity = Identity::from_pem(materials.cert.clone(), materials.key.clone()); @@ -230,202 +281,48 @@ pub fn build_tonic_tls_config(materials: &TlsMaterials) -> ClientTlsConfig { .identity(identity) } -#[derive(Debug)] -struct InsecureServerCertVerifier; - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - rustls::crypto::ring::default_provider() - .signature_verification_algorithms - .supported_schemes() - } -} - -#[derive(Clone)] -struct InsecureTlsConnector { - tls_connector: tokio_rustls::TlsConnector, -} - -impl tower::Service for InsecureTlsConnector { - type Response = hyper_util::rt::TokioIo>; - type Error = Box; - type Future = - std::pin::Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::task::Poll::Ready(Ok(())) - } +// ── Channel construction (legacy mTLS dispatcher) ──────────────────── +// `build_channel` is a thin dispatcher: gateways that authenticate +// clients with mTLS take the inline `build_legacy_mtls_channel` path +// below; everything else converts to a `ClientConfig` and delegates to +// `openshell_sdk::transport::build_channel`. When mTLS retires as an +// auth method, `needs_legacy_mtls` and `build_legacy_mtls_channel` go +// with it. - fn call(&mut self, uri: hyper::Uri) -> Self::Future { - let tls_connector = self.tls_connector.clone(); - Box::pin(async move { - let host = uri.host().unwrap_or("localhost").to_string(); - let port = uri.port_u16().unwrap_or(443); - let addr = format!("{host}:{port}"); - let tcp = tokio::net::TcpStream::connect(addr).await?; - let server_name = ServerName::try_from(host)?; - let tls_stream = tls_connector.connect(server_name, tcp).await?; - Ok(hyper_util::rt::TokioIo::new(tls_stream)) - }) +pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { + if needs_legacy_mtls(tls, server) { + return build_legacy_mtls_channel(server, tls).await; } + let config = tls.to_client_config(server); + Ok(openshell_sdk::transport::build_channel(&config).await?) } -pub fn build_insecure_rustls_config() -> Result { - let config = rustls::ClientConfig::builder() - .dangerous() - .with_custom_certificate_verifier(std::sync::Arc::new(InsecureServerCertVerifier)) - .with_no_client_auth(); - Ok(config) -} - -/// Tunnel proxy addresses keyed by upstream endpoint + token. -/// -/// Each distinct edge-authenticated gateway gets its own local proxy instead of -/// reusing the first gateway touched in the current process. -static EDGE_TUNNEL_ADDRS: OnceLock>> = OnceLock::new(); - -async fn edge_tunnel_addr(server: &str, token: &str) -> Result { - let key = (server.to_string(), token.to_string()); - let registry = EDGE_TUNNEL_ADDRS.get_or_init(|| Mutex::new(HashMap::new())); - - { - let addrs = registry.lock().await; - if let Some(addr) = addrs.get(&key).copied() { - return Ok(addr); - } - } - - let proxy = crate::edge_tunnel::start_tunnel_proxy(server, token).await?; - debug!( - local_addr = %proxy.local_addr, - server, - "edge tunnel proxy started, routing gRPC through local proxy" - ); - - let mut addrs = registry.lock().await; - Ok(*addrs.entry(key).or_insert(proxy.local_addr)) +/// Returns `true` when this connection should run through the CLI's +/// inline mTLS path: HTTPS, no insecure-skip, no edge tunnel, and either +/// no OIDC token or OIDC paired with mTLS materials on disk. The combined +/// mTLS+OIDC case preserves the documented "mTLS as transport trust +/// boundary, Bearer for full scope" deployment model. +fn needs_legacy_mtls(tls: &TlsOptions, server: &str) -> bool { + server.starts_with("https://") + && !tls.gateway_insecure + && tls.edge_token.is_none() + && (tls.oidc_token.is_none() || tls.has_mtls_identity(server)) } -pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { - if server.starts_with("http://") { - let endpoint = Endpoint::from_shared(server.to_string()) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_adaptive_window(true) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint.connect().await.into_diagnostic(); - } - - // When Cloudflare edge bearer auth is active and the server is HTTPS, - // route traffic through a local WebSocket tunnel proxy instead. - // OIDC tokens bypass the tunnel — they connect directly. - if tls.edge_token.is_some() && server.starts_with("https://") { - let token = tls - .edge_token - .as_deref() - .ok_or_else(|| miette::miette!("edge token required for tunnel"))?; - let local_addr = edge_tunnel_addr(server, token).await?; - - // Connect to the local tunnel proxy over plaintext HTTP/2. - let local_url = format!("http://{local_addr}"); - let endpoint = Endpoint::from_shared(local_url) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_adaptive_window(true) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint.connect().await.into_diagnostic(); - } - - if tls.gateway_insecure && server.starts_with("https://") { - tracing::warn!("TLS certificate verification is disabled — do not use in production"); - let rustls_config = build_insecure_rustls_config()?; - let tls_connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)); - let connector = InsecureTlsConnector { tls_connector }; - // Use http:// so tonic does not layer its own TLS on top — our - // connector performs TLS with the insecure config. - let http_uri = server.replacen("https://", "http://", 1); - let endpoint = Endpoint::from_shared(http_uri) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint - .connect_with_connector(connector) - .await - .into_diagnostic(); - } - - let mut endpoint = Endpoint::from_shared(server.to_string()) +/// Inline mTLS channel construction for gateways that require client +/// certificates as the transport-level trust boundary. Goes away when +/// mTLS is retired as an auth method. +async fn build_legacy_mtls_channel(server: &str, tls: &TlsOptions) -> Result { + let materials = require_tls_materials(server, tls)?; + let tls_config = build_tonic_tls_config(&materials); + let endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - - let tls_config = if tls.oidc_token.is_some() { - // Bearer auth over HTTPS: use mTLS certs for the transport layer when - // available (server may still require client certs), and layer the - // Bearer token on top via the interceptor. - require_tls_materials(server, tls).map_or_else( - |_| { - let resolved = tls.with_default_paths(server); - resolved - .ca - .as_ref() - .and_then(|ca_path| std::fs::read(ca_path).ok()) - .map_or_else( - || ClientTlsConfig::new().with_enabled_roots(), - |ca_pem| { - ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_pem)) - }, - ) - }, - |materials| build_tonic_tls_config(&materials), - ) - } else if tls.edge_token.is_some() { - // Edge bearer mode — routed through tunnel above; if we reach here - // the server is not HTTPS so connect plaintext. - return endpoint.connect().await.into_diagnostic(); - } else { - // Standard mTLS: private CA + client cert. - let materials = require_tls_materials(server, tls)?; - build_tonic_tls_config(&materials) - }; - endpoint = endpoint.tls_config(tls_config).into_diagnostic()?; + .keep_alive_while_idle(true) + .tls_config(tls_config) + .into_diagnostic()?; endpoint.connect().await.into_diagnostic() } @@ -441,7 +338,10 @@ pub async fn grpc_client(server: &str, tls: &TlsOptions) -> Result { } fn interceptor_from_tls(tls: &TlsOptions) -> Result { - EdgeAuthInterceptor::new(tls.oidc_token.as_deref(), tls.edge_token.as_deref()) + Ok(EdgeAuthInterceptor::new( + tls.oidc_token.as_deref(), + tls.edge_token.as_deref(), + )?) } pub async fn grpc_inference_client(server: &str, tls: &TlsOptions) -> Result { diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index ceec0d617..7d2dc1fba 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -9,7 +9,6 @@ //! - Common error types //! - Build version metadata -pub mod auth; pub mod config; pub mod driver_mounts; pub mod driver_utils; diff --git a/crates/openshell-sdk/Cargo.toml b/crates/openshell-sdk/Cargo.toml new file mode 100644 index 000000000..420948e67 --- /dev/null +++ b/crates/openshell-sdk/Cargo.toml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-sdk" +description = "Shared async Rust client for OpenShell gateways" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core" } +async-trait = "0.1" +futures = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } +miette = { workspace = true } +oauth2 = "5" +reqwest = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tokio-rustls = { workspace = true } +tokio-stream = { workspace = true } +tokio-tungstenite = { workspace = true } +tonic = { workspace = true, features = ["tls-native-roots"] } +tower = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + +[lints] +workspace = true diff --git a/crates/openshell-sdk/README.md b/crates/openshell-sdk/README.md new file mode 100644 index 000000000..67b27c270 --- /dev/null +++ b/crates/openshell-sdk/README.md @@ -0,0 +1,82 @@ +# openshell-sdk + +`openshell-sdk` is the shared async Rust client for OpenShell gateways. It owns +the gRPC transport and auth stack so the CLI, the TUI, and language bindings +share one implementation of channel setup, TLS, OIDC refresh, and the +Cloudflare Access tunnel. + +Designed in [RFC 0005](../../rfc/0005-shared-sdk-core-and-ts-binding/README.md). + +## Two layers + +- `OpenShellClient` — the curated, sandbox-focused surface: health, sandbox + CRUD, readiness/deletion waits, and non-streaming exec. +- `raw` — direct access to the generated tonic clients for RPCs the curated + surface doesn't yet cover (inference, providers, policy, logs, settings, SSH, + forwarding). + +## Responsibilities + +- Construct the gRPC channel and select the transport (plaintext vs TLS). +- Load TLS material and set up mTLS channels. +- Attach edge-auth bearer tokens and refresh OIDC tokens, with single-flight + coalescing so only one refresh is in flight at a time. +- Proxy connections through the Cloudflare Access tunnel for hosted gateways. +- Map transport and gateway failures to a typed `SdkError` with a discriminable + kind. + +## Non-responsibilities + +- Gateway-name resolution, default config-path lookups, and the OIDC browser + flow. These are user-facing concerns owned by `openshell-cli`; the SDK + consumes a `Refresh` trait the CLI implements. +- Reading tokens from disk. Callers pass an explicit token; the SDK performs no + filesystem access. +- Defining the gRPC contract. The protos and generated types are owned by + `openshell-core`. + +## Transport and auth modes + +Mirrors the modes the CLI exercises in production, so callers can move to the +SDK without losing connectivity options: + +- Plaintext (local development) +- mTLS (self-deployed gateways with client certs) +- OIDC bearer over HTTPS (gateways behind an OAuth2/OIDC IdP) +- Cloudflare Access tunnel (hosted gateways) +- Insecure TLS (development/debug; certificate verification disabled) + +## Public surface + +`OpenShellClient::connect(ClientConfig)` returns a connected client exposing +`health`, `create_sandbox`, `get_sandbox`, `list_sandboxes`, `delete_sandbox`, +`wait_ready`, `wait_deleted`, and `exec`. Curated types (`SandboxSpec`, +`SandboxRef`, `Health`, `ListOptions`, `ExecOptions`, `SandboxPhase`) use +SDK-shaped enums rather than raw proto integers. + +## Modules + +| Module | Purpose | +|---|---| +| `client` | High-level `OpenShellClient` and the curated sandbox surface. | +| `config` | `ClientConfig`, `AuthConfig`. | +| `transport` | Channel construction, TLS resolution, request interceptors. | +| `auth` | `EdgeAuthInterceptor` for bearer-token attachment. | +| `oidc` | OIDC token handling at the transport layer. | +| `refresh` | `Refresh` trait and single-flight refresh coalescing. | +| `edge_tunnel` | Cloudflare Access tunnel dialer. | +| `error` | `SdkError` taxonomy. | +| `types` | Curated request/response types and proto conversions. | +| `raw` | Escape hatch re-exporting the generated tonic clients. | + +## Consumers + +`openshell-cli`, `openshell-tui`, and `openshell-sdk-node` (published as +`@openshell/sdk`). + +## Notes + +- Async-only. Tonic is async-native; callers needing a blocking call can wrap + with their own runtime. +- The surface is alpha and will grow as more RPCs graduate from `raw` into the + curated client. diff --git a/crates/openshell-core/src/auth.rs b/crates/openshell-sdk/src/auth.rs similarity index 78% rename from crates/openshell-core/src/auth.rs rename to crates/openshell-sdk/src/auth.rs index 16d513346..79e6a1fc0 100644 --- a/crates/openshell-core/src/auth.rs +++ b/crates/openshell-sdk/src/auth.rs @@ -1,15 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! gRPC authentication interceptor shared by CLI and TUI. +//! Bearer-token authentication interceptor for outgoing gRPC requests. -use miette::Result; +use crate::error::{Result, SdkError}; /// Interceptor that injects authentication headers into every outgoing gRPC request. /// -/// Supports application-layer Bearer tokens (standard `authorization` -/// header) and Cloudflare Access tokens (custom headers). When no token is -/// set, acts as a no-op. OIDC takes precedence over edge tokens. +/// Supports OIDC Bearer tokens (standard `authorization` header) and +/// Cloudflare Access tokens (custom headers). When no token is set, acts +/// as a no-op. OIDC takes precedence over edge tokens. #[derive(Clone)] #[allow(clippy::struct_field_names)] pub struct EdgeAuthInterceptor { @@ -21,14 +21,14 @@ pub struct EdgeAuthInterceptor { impl EdgeAuthInterceptor { /// Create an interceptor from optional token strings. /// - /// OIDC bearer tokens take precedence over edge tokens. Returns a no-op - /// interceptor when no token is provided. + /// OIDC bearer token takes precedence over edge token. Returns a no-op + /// interceptor when neither token is provided. pub fn new(oidc_token: Option<&str>, edge_token: Option<&str>) -> Result { if let Some(token) = oidc_token { let bearer: tonic::metadata::MetadataValue = format!("Bearer {token}") .parse() - .map_err(|_| miette::miette!("invalid bearer token value"))?; + .map_err(|_| SdkError::auth("invalid OIDC token value"))?; return Ok(Self { bearer_value: Some(bearer), header_value: None, @@ -40,11 +40,11 @@ impl EdgeAuthInterceptor { Some(t) => { let hv: tonic::metadata::MetadataValue = t .parse() - .map_err(|_| miette::miette!("invalid edge token value"))?; + .map_err(|_| SdkError::auth("invalid edge token value"))?; let cv: tonic::metadata::MetadataValue = format!("CF_Authorization={t}") .parse() - .map_err(|_| miette::miette!("invalid edge token value for cookie"))?; + .map_err(|_| SdkError::auth("invalid edge token value for cookie"))?; (Some(hv), Some(cv)) } None => (None, None), diff --git a/crates/openshell-sdk/src/client.rs b/crates/openshell-sdk/src/client.rs new file mode 100644 index 000000000..fb3209cc9 --- /dev/null +++ b/crates/openshell-sdk/src/client.rs @@ -0,0 +1,331 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! High-level async client over the gateway gRPC surface. +//! +//! Covers the sandbox-focused MVP slice: health, sandbox CRUD, readiness / +//! deletion waits, and non-streaming exec. Other RPCs (inference, providers, +//! policy, logs, settings, SSH, forwarding) are reachable via +//! [`OpenShellClient::raw_grpc`] / [`OpenShellClient::raw_inference`]. + +use crate::auth::EdgeAuthInterceptor; +use crate::config::{AuthConfig, ClientConfig}; +use crate::error::{Result, SdkError}; +use crate::raw::{AuthedGrpcClient, AuthedInferenceClient}; +use crate::transport; +use crate::types::{ + ExecOptions, ExecResult, Health, ListOptions, SandboxPhase, SandboxRef, SandboxSpec, +}; +use futures::StreamExt; +use openshell_core::proto; +use std::time::{Duration, Instant}; +use tonic::transport::Channel; + +/// Async client for a single `OpenShell` gateway. +/// +/// Cheap to clone — the underlying tonic [`Channel`] multiplexes RPCs over a +/// shared HTTP/2 connection. Construct one per logical gateway and share it +/// across tasks; do not call [`OpenShellClient::connect`] per request. +#[derive(Clone)] +pub struct OpenShellClient { + channel: Channel, + interceptor: EdgeAuthInterceptor, +} + +impl OpenShellClient { + /// Open a connection to the gateway described by `config`. + /// + /// Performs the gRPC channel handshake immediately; subsequent RPCs reuse + /// the connection. + pub async fn connect(config: ClientConfig) -> Result { + let channel = transport::build_channel(&config).await?; + let interceptor = interceptor_from_config(&config)?; + Ok(Self { + channel, + interceptor, + }) + } + + /// Construct from an already-built [`Channel`] and interceptor. + /// + /// Use when the caller needs to customize channel construction beyond + /// what [`ClientConfig`] exposes. + pub fn from_parts(channel: Channel, interceptor: EdgeAuthInterceptor) -> Self { + Self { + channel, + interceptor, + } + } + + /// Underlying tonic [`Channel`]. + pub fn channel(&self) -> Channel { + self.channel.clone() + } + + /// Authenticated gRPC client for the main `OpenShell` service. + /// + /// Use this when the curated surface below doesn't expose the RPC or + /// field you need. + pub fn raw_grpc(&self) -> AuthedGrpcClient { + proto::open_shell_client::OpenShellClient::with_interceptor( + self.channel.clone(), + self.interceptor.clone(), + ) + } + + /// Authenticated gRPC client for the inference service. + pub fn raw_inference(&self) -> AuthedInferenceClient { + proto::inference_client::InferenceClient::with_interceptor( + self.channel.clone(), + self.interceptor.clone(), + ) + } + + /// Gateway health snapshot. + pub async fn health(&self) -> Result { + let mut grpc = self.raw_grpc(); + let resp = grpc + .health(proto::HealthRequest {}) + .await + .map_err(map_status)? + .into_inner(); + Ok(Health { + status: resp.status.into(), + version: resp.version, + }) + } + + /// Create a new sandbox from a curated [`SandboxSpec`]. + pub async fn create_sandbox(&self, spec: SandboxSpec) -> Result { + let request = create_sandbox_request(spec); + let mut grpc = self.raw_grpc(); + let response = grpc + .create_sandbox(request) + .await + .map_err(map_status)? + .into_inner(); + sandbox_from_response(response.sandbox) + } + + /// Fetch a sandbox by name. + pub async fn get_sandbox(&self, name: &str) -> Result { + let mut grpc = self.raw_grpc(); + let response = grpc + .get_sandbox(proto::GetSandboxRequest { + name: name.to_string(), + }) + .await + .map_err(map_status)? + .into_inner(); + sandbox_from_response(response.sandbox) + } + + /// List sandboxes. + pub async fn list_sandboxes(&self, opts: ListOptions) -> Result> { + let mut grpc = self.raw_grpc(); + let response = grpc + .list_sandboxes(proto::ListSandboxesRequest { + limit: opts.limit, + offset: opts.offset, + label_selector: opts.label_selector.unwrap_or_default(), + }) + .await + .map_err(map_status)? + .into_inner(); + Ok(response + .sandboxes + .into_iter() + .map(SandboxRef::from_proto) + .collect()) + } + + /// Delete a sandbox by name. + /// + /// Returns `true` when the gateway acknowledges the deletion, `false` + /// when it was already absent. The sandbox may still be in + /// [`SandboxPhase::Deleting`] when this returns — pair with + /// [`OpenShellClient::wait_deleted`] when you need a terminal guarantee. + pub async fn delete_sandbox(&self, name: &str) -> Result { + let mut grpc = self.raw_grpc(); + let response = grpc + .delete_sandbox(proto::DeleteSandboxRequest { + name: name.to_string(), + }) + .await + .map_err(map_status)? + .into_inner(); + Ok(response.deleted) + } + + /// Poll [`OpenShellClient::get_sandbox`] until the sandbox reaches + /// [`SandboxPhase::Ready`] or the `timeout` elapses. + /// + /// Returns the terminal sandbox snapshot on success. Returns an + /// [`SdkError::Connect`] when the timeout expires, or whatever error + /// the gateway returns if the sandbox transitions into + /// [`SandboxPhase::Error`]. + pub async fn wait_ready(&self, name: &str, timeout: Duration) -> Result { + self.wait_for(name, timeout, |phase| match phase { + SandboxPhase::Ready => Some(Ok(())), + SandboxPhase::Error => Some(Err(SdkError::connect(format!( + "sandbox '{name}' entered error phase" + )))), + _ => None, + }) + .await + } + + /// Poll until the sandbox is gone (gRPC `NotFound`) or the `timeout` + /// elapses. + pub async fn wait_deleted(&self, name: &str, timeout: Duration) -> Result<()> { + let deadline = Instant::now() + timeout; + let mut delay = Duration::from_millis(250); + loop { + match self.get_sandbox(name).await { + Err(SdkError::NotFound { .. }) => return Ok(()), + Err(other) => return Err(other), + Ok(snapshot) if snapshot.phase == SandboxPhase::Deleting => {} + Ok(_) => {} + } + if Instant::now() >= deadline { + return Err(SdkError::connect(format!( + "timed out waiting for sandbox '{name}' to delete" + ))); + } + tokio::time::sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(2)); + } + } + + /// Run a command inside a sandbox and buffer stdout/stderr to the end. + /// + /// For streaming output, drop down to [`OpenShellClient::raw_grpc`] and + /// call `exec_sandbox` directly. + pub async fn exec(&self, name: &str, cmd: &[String], opts: ExecOptions) -> Result { + let sandbox = self.get_sandbox(name).await?; + let request = proto::ExecSandboxRequest { + sandbox_id: sandbox.id, + command: cmd.to_vec(), + workdir: opts.workdir.unwrap_or_default(), + environment: opts.environment, + timeout_seconds: opts + .timeout + .map_or(0, |d| u32::try_from(d.as_secs()).unwrap_or(u32::MAX)), + stdin: opts.stdin.unwrap_or_default(), + tty: false, + cols: 0, + rows: 0, + }; + + let mut grpc = self.raw_grpc(); + let mut stream = grpc + .exec_sandbox(request) + .await + .map_err(map_status)? + .into_inner(); + + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + let mut exit_code: Option = None; + + while let Some(event) = stream.next().await { + let event = event.map_err(map_status)?; + match event.payload { + Some(proto::exec_sandbox_event::Payload::Stdout(chunk)) => { + stdout.extend_from_slice(&chunk.data); + } + Some(proto::exec_sandbox_event::Payload::Stderr(chunk)) => { + stderr.extend_from_slice(&chunk.data); + } + Some(proto::exec_sandbox_event::Payload::Exit(exit)) => { + exit_code = Some(exit.exit_code); + } + None => {} + } + } + + Ok(ExecResult { + exit_code: exit_code.unwrap_or(-1), + stdout, + stderr, + }) + } + + async fn wait_for(&self, name: &str, timeout: Duration, mut decide: F) -> Result + where + F: FnMut(SandboxPhase) -> Option>, + { + let deadline = Instant::now() + timeout; + let mut delay = Duration::from_millis(250); + loop { + let snapshot = self.get_sandbox(name).await?; + if let Some(verdict) = decide(snapshot.phase) { + verdict?; + return Ok(snapshot); + } + if Instant::now() >= deadline { + return Err(SdkError::connect(format!( + "timed out waiting for sandbox '{name}'" + ))); + } + tokio::time::sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(2)); + } + } +} + +fn interceptor_from_config(config: &ClientConfig) -> Result { + match &config.auth { + None => Ok(EdgeAuthInterceptor::noop()), + Some(AuthConfig::Oidc(token)) => EdgeAuthInterceptor::new(Some(token), None), + Some(AuthConfig::EdgeJwt(token)) => EdgeAuthInterceptor::new(None, Some(token)), + } +} + +fn create_sandbox_request(spec: SandboxSpec) -> proto::CreateSandboxRequest { + let SandboxSpec { + name, + image, + labels, + environment, + providers, + gpu, + gpu_device, + } = spec; + let template = image.map(|image| proto::SandboxTemplate { + image, + ..proto::SandboxTemplate::default() + }); + proto::CreateSandboxRequest { + spec: Some(proto::SandboxSpec { + environment, + template, + providers, + gpu, + gpu_device: gpu_device.unwrap_or_default(), + ..proto::SandboxSpec::default() + }), + name: name.unwrap_or_default(), + labels, + } +} + +fn sandbox_from_response(sandbox: Option) -> Result { + sandbox + .map(SandboxRef::from_proto) + .ok_or_else(|| SdkError::invalid_config("sandbox missing from gateway response")) +} + +fn map_status(status: tonic::Status) -> SdkError { + let message = status.message().to_string(); + match status.code() { + tonic::Code::NotFound => SdkError::NotFound { message }, + tonic::Code::AlreadyExists => SdkError::AlreadyExists { message }, + tonic::Code::InvalidArgument => SdkError::invalid_config(message), + tonic::Code::Unauthenticated | tonic::Code::PermissionDenied => SdkError::auth(message), + _ => SdkError::Rpc { + code: status.code() as i32, + message, + }, + } +} diff --git a/crates/openshell-sdk/src/config.rs b/crates/openshell-sdk/src/config.rs new file mode 100644 index 000000000..f54cac5e7 --- /dev/null +++ b/crates/openshell-sdk/src/config.rs @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Public input types for the SDK: how callers describe a gateway and the +//! credentials used to talk to it. +//! +//! The CLI keeps its own filesystem-aware `TlsOptions` for plumbing; it +//! converts to a `ClientConfig` at the moment of dialing the gateway. + +/// Authentication mode for outgoing gRPC requests. +/// +/// The two variants are functionally distinct in the transport layer: +/// `EdgeJwt` routes through a local WebSocket tunnel (the only way to get a +/// browser-flow JWT past Cloudflare Access on POST/HTTP2), while `Oidc` +/// connects directly over HTTPS and adds an `authorization: Bearer ...` +/// header. +// +// `#[non_exhaustive]` keeps phase 2 additive: when we promote `Oidc(String)` +// to `Oidc { token, refresh: Option> }` or add a third +// variant, downstream `match` arms aren't forced to break. +#[derive(Clone)] +#[non_exhaustive] +pub enum AuthConfig { + /// Cloudflare Access JWT — routes through the edge WebSocket tunnel. + EdgeJwt(String), + /// OIDC bearer token — direct HTTPS, `authorization` header. + Oidc(String), +} + +/// Configuration for opening a gRPC channel to an `OpenShell` gateway. +/// +/// Consumed by `openshell_sdk::transport::grpc_client` and the +/// inference-client equivalent. One `ClientConfig` per logical connection; +/// callers that want connection pooling cache the resulting `tonic::Channel`. +// +// NOTE: +// - `gateway` is a full URL (`http://...` or `https://...`) so the scheme +// tells the transport layer whether to use plaintext or TLS. Matches +// today's CLI convention; matches the RFC's `pub gateway: String`. +// - `ca_cert` pins a private-CA certificate (PEM-encoded). `None` falls +// back to the platform's system roots. +// - This SDK does not speak mTLS. Gateways requiring client certificates +// are handled by `openshell-cli`'s legacy mTLS path until product +// retires that auth method. +// - `insecure_skip_verify` is a separate flag rather than a third +// `AuthConfig` variant because it's a transport concern (cert +// verification) that's orthogonal to auth. +// - No `timeout` field yet. The RFC mentions one but today's behavior is +// `connect_timeout(10s)` hard-coded; introducing a configurable timeout +// here would be a behavior change. Phase 2 territory. +// - No `Debug` derive: `auth` carries secrets; `ca_cert` is fine but we +// redact the whole struct for safety. If callers want ergonomic printing +// we can implement `Debug` manually with a redacted token field. +// - `#[non_exhaustive]` + `Default` lets phase 2 add fields (timeout, retry +// policy, `Refresh` trait) without breaking literal-construct callers. +// Idiom is `ClientConfig { gateway: g, ..Default::default() }`. +#[derive(Clone, Default)] +#[non_exhaustive] +pub struct ClientConfig { + /// Gateway URL, e.g. `http://127.0.0.1:8080` or `https://gw.example.com`. + pub gateway: String, + /// CA certificate (PEM) for private-CA gateways. `None` uses system + /// roots. Ignored for plaintext gateways and when + /// `insecure_skip_verify` is enabled. + pub ca_cert: Option>, + /// Bearer-token auth mode. `None` = anonymous TLS over HTTPS, or + /// plaintext when `gateway` is `http://`. + pub auth: Option, + /// Disable TLS certificate verification (development/debug only). + /// Ignored for plaintext gateways. **Do not enable in production.** + pub insecure_skip_verify: bool, +} + +impl ClientConfig { + pub fn new(gateway: impl Into) -> Self { + Self { + gateway: gateway.into(), + ..Default::default() + } + } +} diff --git a/crates/openshell-cli/src/edge_tunnel.rs b/crates/openshell-sdk/src/edge_tunnel.rs similarity index 92% rename from crates/openshell-cli/src/edge_tunnel.rs rename to crates/openshell-sdk/src/edge_tunnel.rs index 814e245f3..5ced5fc35 100644 --- a/crates/openshell-cli/src/edge_tunnel.rs +++ b/crates/openshell-sdk/src/edge_tunnel.rs @@ -19,13 +19,13 @@ //! 3. Bidirectionally pipe bytes between the local TCP stream and the //! WebSocket. //! -//! The gRPC [`Channel`] then connects to `http://127.0.0.1:` +//! The gRPC `Channel` then connects to `http://127.0.0.1:` //! (plaintext) — the edge handles TLS, and the WebSocket carries the raw //! bytes to the origin. +use crate::error::{Result, SdkError}; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; -use miette::{IntoDiagnostic, Result}; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -63,8 +63,8 @@ pub async fn start_tunnel_proxy( gateway_endpoint: &str, edge_token: &str, ) -> Result { - let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?; - let local_addr = listener.local_addr().into_diagnostic()?; + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; // Convert the gateway endpoint to a WebSocket URL. // https://foo.com -> wss://foo.com @@ -88,7 +88,6 @@ pub async fn start_tunnel_proxy( "starting edge tunnel proxy" ); - // Spawn the accept loop. tokio::spawn(accept_loop(listener, config)); Ok(EdgeTunnelProxy { local_addr }) @@ -149,13 +148,15 @@ async fn handle_connection(tcp_stream: TcpStream, config: &TunnelConfig) -> Resu /// Open a WebSocket connection to the edge proxy. async fn open_ws(config: &TunnelConfig) -> Result>> { - let mut request = (&config.ws_url).into_client_request().into_diagnostic()?; + let mut request = (&config.ws_url) + .into_client_request() + .map_err(|e| SdkError::invalid_config(format!("invalid tunnel URL: {e}")))?; // Inject the bearer token via multiple headers for compatibility with // Cloudflare Access (which checks `Cf-Access-Token`, the // `CF_Authorization` cookie, and the `Cf-Access-Jwt-Assertion` header). let token_val = HeaderValue::from_str(&config.edge_token) - .map_err(|e| miette::miette!("invalid edge token header value: {e}"))?; + .map_err(|e| SdkError::auth(format!("invalid edge token header value: {e}")))?; request .headers_mut() .insert("Cf-Access-Token", token_val.clone()); @@ -165,14 +166,14 @@ async fn open_ws(config: &TunnelConfig) -> Result = std::result::Result; + +/// Errors produced by `openshell-sdk`. +/// +/// CLI consumers convert these to `miette::Report` at the call boundary; +/// future TS/Python bindings will map them to language-native exceptions +/// via the [`SdkError::code`] accessor. +#[derive(Debug, Error, Diagnostic)] +pub enum SdkError { + /// Caller-supplied configuration is invalid (URL parse, missing field, + /// illegal token characters). + #[error("invalid configuration: {message}")] + #[diagnostic(code(openshell::sdk::invalid_config))] + InvalidConfig { + /// Error message. + message: String, + }, + + /// TLS material parse or rustls config build failure. + #[error("TLS error: {message}")] + #[diagnostic(code(openshell::sdk::tls))] + Tls { + /// Error message. + message: String, + }, + + /// Failed to establish a connection to the gateway (TCP, TLS handshake, + /// HTTP/2, WebSocket upgrade). + #[error("connect error: {message}")] + #[diagnostic(code(openshell::sdk::connect))] + Connect { + /// Error message. + message: String, + }, + + /// Auth-related failure: OIDC discovery / refresh, token format invalid + /// for header injection. + #[error("auth error: {message}")] + #[diagnostic(code(openshell::sdk::auth))] + Auth { + /// Error message. + message: String, + }, + + /// Local IO failure (file read, listener bind, socket). + #[error("I/O error: {source}")] + #[diagnostic(code(openshell::sdk::io))] + Io { + /// Underlying I/O error. + #[from] + source: std::io::Error, + }, + + /// Gateway reported the requested object does not exist (gRPC `NotFound`). + #[error("not found: {message}")] + #[diagnostic(code(openshell::sdk::not_found))] + NotFound { + /// Error message. + message: String, + }, + + /// Gateway reported the requested object already exists (gRPC `AlreadyExists`). + #[error("already exists: {message}")] + #[diagnostic(code(openshell::sdk::already_exists))] + AlreadyExists { + /// Error message. + message: String, + }, + + /// Catch-all for gRPC errors not mapped to a more specific variant. + #[error("gateway error ({code}): {message}")] + #[diagnostic(code(openshell::sdk::rpc))] + Rpc { + /// Numeric gRPC status code (see [`tonic::Code`]). + code: i32, + /// Error message. + message: String, + }, +} + +impl SdkError { + /// Create an `InvalidConfig` error. + pub fn invalid_config(message: impl Into) -> Self { + Self::InvalidConfig { + message: message.into(), + } + } + + /// Create a `Tls` error. + pub fn tls(message: impl Into) -> Self { + Self::Tls { + message: message.into(), + } + } + + /// Create a `Connect` error. + pub fn connect(message: impl Into) -> Self { + Self::Connect { + message: message.into(), + } + } + + /// Create an `Auth` error. + pub fn auth(message: impl Into) -> Self { + Self::Auth { + message: message.into(), + } + } + + /// Stable string code for cross-language binding consumers. + /// + /// Returns one of: `invalid_config`, `tls`, `connect`, `auth`, `io`, + /// `not_found`, `already_exists`, `rpc`. Phase 3 (napi binding) will + /// surface this as the JS error's `code` field for discriminated-union + /// ergonomics. + pub const fn code(&self) -> &'static str { + match self { + Self::InvalidConfig { .. } => "invalid_config", + Self::Tls { .. } => "tls", + Self::Connect { .. } => "connect", + Self::Auth { .. } => "auth", + Self::Io { .. } => "io", + Self::NotFound { .. } => "not_found", + Self::AlreadyExists { .. } => "already_exists", + Self::Rpc { .. } => "rpc", + } + } +} diff --git a/crates/openshell-sdk/src/lib.rs b/crates/openshell-sdk/src/lib.rs new file mode 100644 index 000000000..53bcb336a --- /dev/null +++ b/crates/openshell-sdk/src/lib.rs @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared async Rust client for `OpenShell` gateways. +//! +//! Two layers: +//! +//! - [`OpenShellClient`] — the high-level sandbox-focused MVP surface: +//! health, sandbox CRUD, readiness/deletion waits, non-streaming exec. +//! - [`raw`] — direct access to the generated tonic clients for RPCs the +//! curated surface doesn't yet cover (inference, providers, policy, logs, +//! settings, SSH, forwarding). +//! +//! Owns the gRPC transport stack — channel construction, TLS material +//! handling, request interceptors, OIDC token refresh, and the Cloudflare +//! Access tunnel proxy. Consumed by `openshell-cli`, `openshell-tui`, and +//! the napi-rs wrapper that ships as `@openshell/sdk`. +//! +//! # Quick start +//! +//! ```ignore +//! use openshell_sdk::{ClientConfig, ListOptions, OpenShellClient}; +//! +//! # async fn run() -> Result<(), openshell_sdk::SdkError> { +//! let client = OpenShellClient::connect(ClientConfig::new("http://127.0.0.1:8080")).await?; +//! let health = client.health().await?; +//! let sandboxes = client.list_sandboxes(ListOptions::default()).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod auth; +pub mod client; +pub mod config; +pub mod edge_tunnel; +pub mod error; +pub mod oidc; +pub mod raw; +pub mod refresh; +pub mod transport; +pub mod types; + +pub use auth::EdgeAuthInterceptor; +pub use client::OpenShellClient; +pub use config::{AuthConfig, ClientConfig}; +pub use error::SdkError; +pub use refresh::{Refresh, RefreshError, RefreshedToken, TokenSource}; +pub use types::{ + ExecOptions, ExecResult, Health, ListOptions, SandboxPhase, SandboxRef, SandboxSpec, + ServiceStatus, +}; diff --git a/crates/openshell-sdk/src/oidc.rs b/crates/openshell-sdk/src/oidc.rs new file mode 100644 index 000000000..6a26678bb --- /dev/null +++ b/crates/openshell-sdk/src/oidc.rs @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! OIDC discovery and refresh-token flow (non-interactive). +//! +//! Browser-based authorization flows live in `openshell-cli` since they +//! require a local callback HTTP server and an OS browser launcher. + +use crate::error::{Result, SdkError}; +use oauth2::basic::BasicClient; +use oauth2::{ClientId, RefreshToken, TokenResponse, TokenUrl}; +use serde::Deserialize; + +/// OIDC discovery document (subset of fields callers consume). +#[derive(Debug, Deserialize)] +#[non_exhaustive] +pub struct OidcDiscovery { + pub issuer: String, + pub authorization_endpoint: String, + pub token_endpoint: String, +} + +/// Input to [`refresh_token`]. +/// +/// Constructed by the caller from whatever bundle / storage shape they +/// use — the SDK does not assume any particular persistence model. +#[derive(Clone)] +#[non_exhaustive] +pub struct RefreshTokenInput { + pub refresh_token: String, + pub issuer: String, + pub client_id: String, + pub insecure: bool, +} + +impl RefreshTokenInput { + pub fn new( + refresh_token: impl Into, + issuer: impl Into, + client_id: impl Into, + ) -> Self { + Self { + refresh_token: refresh_token.into(), + issuer: issuer.into(), + client_id: client_id.into(), + insecure: false, + } + } + + #[must_use] + pub fn with_insecure(mut self, insecure: bool) -> Self { + self.insecure = insecure; + self + } +} + +/// Output from [`refresh_token`]. +/// +/// `refresh_token` is `None` when the OIDC server did not return a new +/// refresh token; per OAuth 2.0, callers should preserve the previous +/// refresh token in that case. `expires_at` is a Unix timestamp (seconds +/// since epoch); `None` when the server omits `expires_in`. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct RefreshTokenOutput { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, +} + +/// Discover OIDC endpoints from the issuer's well-known configuration. +/// +/// Validates that the discovery document's `issuer` field matches the +/// configured issuer URL to prevent SSRF or misdirection. When `insecure` +/// is true, TLS certificate verification is disabled (intended for +/// development against self-signed gateways). +pub async fn discover(issuer: &str, insecure: bool) -> Result { + let normalized_issuer = issuer.trim_end_matches('/'); + let url = format!("{normalized_issuer}/.well-known/openid-configuration"); + let client = http_client(insecure); + let resp: OidcDiscovery = client + .get(&url) + .send() + .await + .map_err(|e| SdkError::auth(format!("OIDC discovery request failed: {e}")))? + .json() + .await + .map_err(|e| SdkError::auth(format!("OIDC discovery JSON parse failed: {e}")))?; + + let discovered_issuer = resp.issuer.trim_end_matches('/'); + if discovered_issuer != normalized_issuer { + return Err(SdkError::auth(format!( + "OIDC discovery issuer mismatch: expected '{normalized_issuer}', got '{discovered_issuer}'" + ))); + } + Ok(resp) +} + +/// Build an HTTP client suitable for OIDC token-endpoint requests. +/// +/// Disables redirects so token-endpoint responses aren't accidentally +/// followed; OIDC providers should not redirect on the token endpoint. +/// When `insecure` is true, TLS certificate verification is disabled. +pub fn http_client(insecure: bool) -> reqwest::Client { + let mut builder = reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); + if insecure { + builder = builder.danger_accept_invalid_certs(true); + } + builder.build().expect("failed to build HTTP client") +} + +/// Refresh an OIDC access token using the `refresh_token` grant. +/// +/// The caller is responsible for preserving the prior refresh token when +/// the output's `refresh_token` is `None` — per OAuth 2.0 the server may +/// omit it from the refresh response. +pub async fn refresh_token(input: &RefreshTokenInput) -> Result { + let discovery = discover(&input.issuer, input.insecure).await?; + + let client = BasicClient::new(ClientId::new(input.client_id.clone())).set_token_uri( + TokenUrl::new(discovery.token_endpoint) + .map_err(|e| SdkError::auth(format!("invalid token endpoint URL: {e}")))?, + ); + + let http = http_client(input.insecure); + let token_response = client + .exchange_refresh_token(&RefreshToken::new(input.refresh_token.clone())) + .request_async(&http) + .await + .map_err(|e| SdkError::auth(format!("token refresh failed: {e}")))?; + + Ok(output_from_oauth2_response(&token_response)) +} + +fn output_from_oauth2_response(resp: &oauth2::basic::BasicTokenResponse) -> RefreshTokenOutput { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + RefreshTokenOutput { + access_token: resp.access_token().secret().clone(), + refresh_token: resp.refresh_token().map(|rt| rt.secret().clone()), + expires_at: resp.expires_in().map(|ei| now + ei.as_secs()), + } +} diff --git a/crates/openshell-sdk/src/raw.rs b/crates/openshell-sdk/src/raw.rs new file mode 100644 index 000000000..0b3b18a04 --- /dev/null +++ b/crates/openshell-sdk/src/raw.rs @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Escape hatch — direct access to the generated tonic clients and protobuf +//! types. +//! +//! Use this module when the curated high-level surface in +//! [`crate::client::OpenShellClient`] doesn't expose the RPC or field you +//! need. The high-level surface is sandbox-focused for MVP; inference, +//! providers, policy, logs, settings, SSH, and forwarding all live here. +//! +//! ```ignore +//! use openshell_sdk::{ClientConfig, OpenShellClient}; +//! use openshell_sdk::raw::ListProvidersRequest; +//! +//! let client = OpenShellClient::connect(ClientConfig::new("http://127.0.0.1:8080")).await?; +//! let mut grpc = client.raw_grpc(); +//! let providers = grpc.list_providers(ListProvidersRequest::default()).await?; +//! ``` + +pub use openshell_core::proto; +pub use openshell_core::proto::inference_client::InferenceClient; +pub use openshell_core::proto::open_shell_client::OpenShellClient as GrpcClient; +pub use openshell_core::proto::{ + CreateSandboxRequest, DeleteSandboxRequest, ExecSandboxRequest, GetSandboxRequest, + HealthRequest, ListProvidersRequest, ListSandboxesRequest, Sandbox, + SandboxPhase as ProtoSandboxPhase, SandboxSpec as ProtoSandboxSpec, SandboxTemplate, + ServiceStatus as ProtoServiceStatus, +}; + +/// Type alias for the gRPC client wrapped in the SDK's auth interceptor. +pub type AuthedGrpcClient = GrpcClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + crate::EdgeAuthInterceptor, + >, +>; + +/// Type alias for the inference client wrapped in the SDK's auth interceptor. +pub type AuthedInferenceClient = InferenceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + crate::EdgeAuthInterceptor, + >, +>; diff --git a/crates/openshell-sdk/src/refresh.rs b/crates/openshell-sdk/src/refresh.rs new file mode 100644 index 000000000..46809700d --- /dev/null +++ b/crates/openshell-sdk/src/refresh.rs @@ -0,0 +1,301 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! OIDC bearer-token refresh contract. +//! +//! The SDK never talks to a browser or any specific `IdP`. Callers that need +//! the SDK to rotate an OIDC bearer mid-session implement [`Refresh`] and +//! construct a [`TokenSource`] around it. Implementations live where the +//! browser flow / token store / FFI callback belongs — in `openshell-cli` +//! for the desktop browser flow, in `openshell-sdk-node` for a JS callback. +//! +//! The trait is intentionally minimal. Single-flight coalescing (one refresh +//! in flight at a time, with all waiters sharing the result) is the SDK's +//! responsibility, not the implementer's; see [`TokenSource`]. +//! +//! TODO(rfc-0004): plumb [`TokenSource`] into the gRPC auth interceptor so +//! refreshes happen automatically before each request. Today the napi +//! binding exposes [`TokenSource::refresh_now`] / [`TokenSource::current`] +//! directly to JS callers, which can rotate the token by calling +//! `set_oidc_token` on a future iteration of the SDK client. + +use crate::error::{Result, SdkError}; +use std::fmt; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::{Mutex, RwLock}; + +/// Errors a refresher can return. +/// +/// Domain-specific, deliberately not coupled to `tonic`, `napi`, or any +/// FFI-facing error type. The SDK maps these into [`SdkError::Auth`] before +/// surfacing to callers. +#[derive(Debug)] +#[non_exhaustive] +pub enum RefreshError { + /// Refresh failed but a retry might succeed (network blip, transient + /// `IdP` error). + Transient(String), + /// Refresh cannot succeed without user interaction (refresh token + /// expired, `IdP` revoked the session). Callers should not retry; they + /// should re-authenticate. + Terminal(String), +} + +impl fmt::Display for RefreshError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Transient(msg) => write!(f, "transient refresh error: {msg}"), + Self::Terminal(msg) => write!(f, "terminal refresh error: {msg}"), + } + } +} + +impl std::error::Error for RefreshError {} + +impl From for SdkError { + fn from(value: RefreshError) -> Self { + Self::auth(value.to_string()) + } +} + +/// A freshly minted access token + its absolute expiry. +/// +/// `expires_at` is seconds since the Unix epoch. `None` means the token's +/// expiry was not advertised — the SDK will not refresh it proactively but +/// may refresh on demand if [`Refresh::refresh`] is called. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct RefreshedToken { + pub access_token: String, + pub expires_at: Option, +} + +impl RefreshedToken { + pub fn new(access_token: impl Into) -> Self { + Self { + access_token: access_token.into(), + expires_at: None, + } + } + + #[must_use] + pub fn with_expires_at(mut self, expires_at: u64) -> Self { + self.expires_at = Some(expires_at); + self + } +} + +/// Pluggable OIDC refresher. +/// +/// Implementations should be cheap to clone and safe to call from any tokio +/// task. They MUST NOT do their own single-flight coalescing — that's the +/// SDK's job (see [`TokenSource`]). +#[async_trait::async_trait] +pub trait Refresh: Send + Sync + 'static { + /// Mint a fresh access token. Called by the SDK when it determines the + /// current token is near expiry (or has been explicitly invalidated). + async fn refresh(&self) -> std::result::Result; +} + +/// Mutable token state shared between the auth interceptor and the +/// background refresh task. +#[derive(Debug)] +struct TokenState { + token: String, + expires_at: Option, +} + +/// A bearer-token source with single-flight refresh coalescing. +/// +/// Wraps a [`Refresh`] implementation and tracks the current token + its +/// advertised expiry. Phase 3 of the RFC plumbs this into the auth path; for +/// now language bindings hand it out directly so JS/Python code can drive +/// refreshes externally. +#[derive(Clone)] +pub struct TokenSource { + state: Arc>, + refresher: Arc, + in_flight: Arc>, + /// Refresh `skew` seconds before the advertised `expires_at`. Tokens + /// without `expires_at` are not auto-refreshed. + skew: Duration, +} + +impl TokenSource { + /// Construct a token source backed by `refresher`. Use this when wiring + /// an FFI callback or browser flow into the SDK. + pub fn new(initial: RefreshedToken, refresher: Arc) -> Self { + Self { + state: Arc::new(RwLock::new(TokenState { + token: initial.access_token, + expires_at: initial.expires_at, + })), + refresher, + in_flight: Arc::new(Mutex::new(())), + skew: Duration::from_secs(60), + } + } + + /// Current token without checking expiry. Used by the sync gRPC + /// interceptor, which can't await. + pub fn snapshot(&self) -> String { + self.state + .try_read() + .map(|s| s.token.clone()) + .unwrap_or_default() + } + + /// Async-fetch the current token, refreshing if it's within `skew` of + /// expiry. Single-flight: concurrent callers share one refresh. + pub async fn current(&self) -> Result { + if !self.needs_refresh().await { + return Ok(self.state.read().await.token.clone()); + } + self.refresh_now().await + } + + /// Force a refresh regardless of expiry. Used on `Unauthenticated` + /// responses from the gateway. + pub async fn refresh_now(&self) -> Result { + // Single-flight: only one refresh in flight at a time. Other waiters + // block here and then see the updated state on re-check. + let _guard = self.in_flight.lock().await; + + // Re-check inside the critical section: another caller may have just + // refreshed while we were waiting on the lock. + if !self.needs_refresh().await { + return Ok(self.state.read().await.token.clone()); + } + + let refreshed = self.refresher.refresh().await?; + let mut state = self.state.write().await; + state.token.clone_from(&refreshed.access_token); + state.expires_at = refreshed.expires_at; + Ok(refreshed.access_token) + } + + async fn needs_refresh(&self) -> bool { + let state = self.state.read().await; + let Some(expires_at) = state.expires_at else { + return false; + }; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + now + self.skew.as_secs() >= expires_at + } + + /// Replace the current token without invoking the refresher. + /// + /// Used by callers that manage refresh externally (e.g. the napi + /// binding's JS-side timer) or for testing. + pub async fn replace(&self, token: RefreshedToken) { + let mut state = self.state.write().await; + state.token = token.access_token; + state.expires_at = token.expires_at; + } +} + +impl fmt::Debug for TokenSource { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TokenSource") + .field("skew", &self.skew) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct CountingRefresher { + calls: Arc, + delay: Duration, + } + + #[async_trait::async_trait] + impl Refresh for CountingRefresher { + async fn refresh(&self) -> std::result::Result { + tokio::time::sleep(self.delay).await; + let n = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + Ok(RefreshedToken::new(format!("token-{n}")).with_expires_at( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600, + )) + } + } + + #[tokio::test] + async fn refresh_now_coalesces_concurrent_callers() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(50), + }); + let source = TokenSource::new(RefreshedToken::new("initial").with_expires_at(0), refresher); + + let tasks = (0..5).map(|_| { + let src = source.clone(); + tokio::spawn(async move { src.refresh_now().await }) + }); + for t in tasks { + t.await.unwrap().unwrap(); + } + + assert_eq!( + calls.load(Ordering::SeqCst), + 1, + "single-flight should have collapsed 5 concurrent calls into 1 refresh" + ); + } + + #[tokio::test] + async fn current_returns_cached_when_not_near_expiry() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(0), + }); + let future = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600; + let source = TokenSource::new( + RefreshedToken::new("fresh").with_expires_at(future), + refresher, + ); + + let token = source.current().await.unwrap(); + assert_eq!(token, "fresh"); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn current_refreshes_when_within_skew() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(0), + }); + let near = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 5; + let source = TokenSource::new( + RefreshedToken::new("stale").with_expires_at(near), + refresher, + ); + + let token = source.current().await.unwrap(); + assert_eq!(token, "token-1"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } +} diff --git a/crates/openshell-sdk/src/transport.rs b/crates/openshell-sdk/src/transport.rs new file mode 100644 index 000000000..d930aee5e --- /dev/null +++ b/crates/openshell-sdk/src/transport.rs @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! gRPC transport stack: channel construction and the insecure-TLS connector. +//! +//! mTLS is intentionally out of scope here. Gateways that require client +//! certificates are handled by `openshell-cli`'s legacy path until the auth +//! method is retired. + +use crate::config::{AuthConfig, ClientConfig}; +use crate::edge_tunnel; +use crate::error::{Result, SdkError}; +use rustls::{ + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::sync::OnceLock; +use std::time::Duration; +use tokio::sync::Mutex; +use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint}; +use tracing::debug; + +/// Standard endpoint settings used by every dialed connection. +/// +/// Centralizes timeouts and HTTP/2 keepalive so behavior is consistent across +/// transport branches. Returns an `Endpoint` ready for `connect()` / +/// `connect_with_connector()`. +fn standard_endpoint(uri: String) -> Result { + Endpoint::from_shared(uri) + .map(|ep| { + ep.connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) + .http2_keep_alive_interval(Duration::from_secs(10)) + .keep_alive_while_idle(true) + }) + .map_err(|e| SdkError::invalid_config(format!("invalid gateway URL: {e}"))) +} + +// ── Edge tunnel registry ───────────────────────────────────────────── +// Each distinct edge-authenticated gateway gets its own local proxy +// instead of reusing the first gateway touched in the current process. +static EDGE_TUNNEL_ADDRS: OnceLock>> = OnceLock::new(); + +/// Look up (or start) the local tunnel proxy for an edge-authenticated +/// gateway. Subsequent calls with the same `(server, token)` reuse the +/// existing proxy. +async fn edge_tunnel_addr(server: &str, token: &str) -> Result { + let key = (server.to_string(), token.to_string()); + let registry = EDGE_TUNNEL_ADDRS.get_or_init(|| Mutex::new(HashMap::new())); + + { + let addrs = registry.lock().await; + if let Some(addr) = addrs.get(&key).copied() { + return Ok(addr); + } + } + + let proxy = edge_tunnel::start_tunnel_proxy(server, token).await?; + debug!( + local_addr = %proxy.local_addr, + server, + "edge tunnel proxy started, routing gRPC through local proxy" + ); + + let mut addrs = registry.lock().await; + Ok(*addrs.entry(key).or_insert(proxy.local_addr)) +} + +// ── Channel construction ───────────────────────────────────────────── + +/// Open a gRPC channel to the gateway described by `config`. +/// +/// Routing is determined by `gateway` scheme + `auth` variant + +/// `insecure_skip_verify`. Reference today's CLI implementation in +/// `openshell-cli/src/tls.rs::build_channel` (lines 219–308) for behavior +/// the SDK needs to preserve. +/// +/// **Branch table:** +/// +/// | `gateway` scheme | `auth` | `insecure_skip_verify` | TLS handling | +/// |------------------|--------|------------------------|-------------| +/// | `http://` | (any) | (any) | plaintext, ignore tls | +/// | `https://` | `Some(EdgeJwt)` | (any) | tunnel proxy + plaintext to local proxy | +/// | `https://` | (any) | `true` | `InsecureTlsConnector`, no verification | +/// | `https://` | `Some(Oidc)` or `None` | `false` | tonic TLS, pin `ca_cert` if set, system roots otherwise | +pub async fn build_channel(config: &ClientConfig) -> Result { + let gateway = &config.gateway; + + // Branch 1 — plaintext. + // Reference: cli/tls.rs:220-228 (http:// branch). + if gateway.starts_with("http://") { + return standard_endpoint(gateway.clone())? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + if !gateway.starts_with("https://") { + return Err(SdkError::invalid_config(format!( + "gateway URL must start with http:// or https://: {gateway}" + ))); + } + + // Branch 2 — Cloudflare Access edge JWT: tunnel proxy + plaintext-to-local. + // Reference: cli/tls.rs:233-249 (https:// + edge_token branch). Use + // `edge_tunnel_addr(gateway, token).await?` to get the local proxy + // address, then `standard_endpoint(format!("http://{local_addr}"))?.connect()`. + if let Some(AuthConfig::EdgeJwt(token)) = &config.auth { + let local_addr = edge_tunnel_addr(gateway, token).await?; + return standard_endpoint(format!("http://{local_addr}"))? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + // Branch 3 — insecure TLS (skip cert verification). + // Reference: cli/tls.rs:251-268 (gateway_insecure branch). Build the + // insecure rustls config, wrap it in `InsecureTlsConnector`, swap the + // gateway scheme to http:// (so tonic doesn't double-layer TLS), and + // call `endpoint.connect_with_connector(connector)`. + if config.insecure_skip_verify { + tracing::warn!("TLS certificate verification is disabled — do not use in production"); + let rustls_config = build_insecure_rustls_config()?; + let tls_connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)); + let connector = InsecureTlsConnector { tls_connector }; + let http_uri = gateway.replacen("https://", "http://", 1); + return standard_endpoint(http_uri)? + .connect_with_connector(connector) + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + // Branch 4 — anonymous TLS or OIDC bearer over HTTPS. + // Reference: cli/tls.rs:270-307 (the `oidc_token` and final mTLS + // branches collapsed). Build a `ClientTlsConfig`: + // - if `config.ca_cert` is `Some(pem)`, pin it via `.ca_certificate(...)` + // - else fall back to `.with_enabled_roots()` (system roots) + // Then `endpoint.tls_config(tls_config)?.connect()`. + // + // The OIDC bearer header is added by the gRPC interceptor at request + // time, not here — `build_channel` only owns the TLS layer. + + let tls_config = config.ca_cert.as_ref().map_or_else( + || ClientTlsConfig::new().with_enabled_roots(), + |ca_cert| ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_cert)), + ); + standard_endpoint(gateway.clone())? + .tls_config(tls_config) + .map_err(|e| SdkError::tls(format!("{e}")))? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))) +} + +/// rustls verifier that accepts any server certificate. +/// +/// Used only when the caller explicitly opts into +/// [`ClientConfig::insecure_skip_verify`]. Do not use in production. +/// +/// [`ClientConfig::insecure_skip_verify`]: crate::config::ClientConfig::insecure_skip_verify +#[derive(Debug)] +pub struct InsecureServerCertVerifier; + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> std::result::Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +/// rustls client config that disables server certificate verification. +/// +/// Pairs with [`InsecureTlsConnector`] for transports that need to skip +/// verification (development, debug). Returns `Result` for symmetry with +/// future verifying variants; the current implementation cannot fail. +pub fn build_insecure_rustls_config() -> Result { + Ok(rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(std::sync::Arc::new(InsecureServerCertVerifier)) + .with_no_client_auth()) +} + +/// `tower::Service` connector that performs TLS using the supplied rustls +/// connector, bypassing tonic's built-in TLS layering. +/// +/// Used to plumb [`InsecureServerCertVerifier`]-backed configs into a tonic +/// `Endpoint` via `connect_with_connector`. +#[derive(Clone)] +pub struct InsecureTlsConnector { + /// Inner rustls connector configured by the caller. + pub tls_connector: tokio_rustls::TlsConnector, +} + +impl tower::Service for InsecureTlsConnector { + type Response = hyper_util::rt::TokioIo>; + type Error = Box; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, uri: hyper::Uri) -> Self::Future { + let tls_connector = self.tls_connector.clone(); + Box::pin(async move { + let host = uri.host().unwrap_or("localhost").to_string(); + let port = uri.port_u16().unwrap_or(443); + let addr = format!("{host}:{port}"); + let tcp = tokio::net::TcpStream::connect(addr).await?; + let server_name = ServerName::try_from(host)?; + let tls_stream = tls_connector.connect(server_name, tcp).await?; + Ok(hyper_util::rt::TokioIo::new(tls_stream)) + }) + } +} diff --git a/crates/openshell-sdk/src/types.rs b/crates/openshell-sdk/src/types.rs new file mode 100644 index 000000000..059b00d1f --- /dev/null +++ b/crates/openshell-sdk/src/types.rs @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Curated public types for the high-level SDK surface. +//! +//! These types intentionally diverge from the raw protobuf shapes so future +//! language bindings (TypeScript via napi, Python via `PyO3`) can render them +//! idiomatically. In particular, enum-valued fields use Rust enums that map +//! to string literals in TypeScript rather than numeric proto enums; nested +//! `Option<...>` chains from proto are flattened where one of the wrappers +//! is structurally meaningless. +//! +//! The raw proto clients are still accessible via [`crate::raw`] as an +//! escape hatch for callers who need fields not exposed here. + +use openshell_core::proto; +use std::collections::HashMap; +use std::time::Duration; + +/// Gateway health snapshot. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct Health { + pub status: ServiceStatus, + pub version: String, +} + +/// Coarse gateway service status. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum ServiceStatus { + Unspecified, + Healthy, + Degraded, + Unhealthy, +} + +impl From for ServiceStatus { + fn from(value: proto::ServiceStatus) -> Self { + match value { + proto::ServiceStatus::Healthy => Self::Healthy, + proto::ServiceStatus::Degraded => Self::Degraded, + proto::ServiceStatus::Unhealthy => Self::Unhealthy, + proto::ServiceStatus::Unspecified => Self::Unspecified, + } + } +} + +impl From for ServiceStatus { + fn from(value: i32) -> Self { + proto::ServiceStatus::try_from(value).map_or(Self::Unspecified, Self::from) + } +} + +/// High-level sandbox lifecycle phase. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum SandboxPhase { + Unspecified, + Provisioning, + Ready, + Error, + Deleting, + Unknown, +} + +impl From for SandboxPhase { + fn from(value: proto::SandboxPhase) -> Self { + match value { + proto::SandboxPhase::Unspecified => Self::Unspecified, + proto::SandboxPhase::Provisioning => Self::Provisioning, + proto::SandboxPhase::Ready => Self::Ready, + proto::SandboxPhase::Error => Self::Error, + proto::SandboxPhase::Deleting => Self::Deleting, + proto::SandboxPhase::Unknown => Self::Unknown, + } + } +} + +impl From for SandboxPhase { + fn from(value: i32) -> Self { + proto::SandboxPhase::try_from(value).map_or(Self::Unspecified, Self::from) + } +} + +/// Caller intent for a new sandbox. +/// +/// Only the most commonly used fields are exposed. Callers that need the +/// full proto surface (volume claim templates, runtime classes, struct +/// resources, etc.) should drop down to [`crate::raw`]. +#[derive(Clone, Debug, Default)] +pub struct SandboxSpec { + /// Optional user-supplied sandbox name. When empty the server generates one. + pub name: Option, + /// Container image reference (e.g. `ghcr.io/nvidia/openshell-community/sandboxes/python:latest`). + pub image: Option, + /// Labels attached to the sandbox. + pub labels: HashMap, + /// Environment variables injected into the sandbox runtime. + pub environment: HashMap, + /// Provider names to attach. + pub providers: Vec, + /// Request a GPU. + pub gpu: bool, + /// Driver-specific GPU device selector (CDI ID for Docker/Podman; BDF or + /// index for VM). Only meaningful when `gpu` is true; empty defers to the + /// driver's default selection. + pub gpu_device: Option, +} + +/// Reference to a sandbox owned by the gateway. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct SandboxRef { + pub id: String, + pub name: String, + pub phase: SandboxPhase, + pub labels: HashMap, + pub resource_version: u64, +} + +impl SandboxRef { + pub(crate) fn from_proto(sandbox: proto::Sandbox) -> Self { + let phase = sandbox.phase().into(); + let meta = sandbox.metadata.unwrap_or_default(); + Self { + id: meta.id, + name: meta.name, + phase, + labels: meta.labels, + resource_version: meta.resource_version, + } + } +} + +/// Options for listing sandboxes. +#[derive(Clone, Debug, Default)] +pub struct ListOptions { + /// Maximum sandboxes to return. `0` defers to the server default. + pub limit: u32, + /// Offset into the result list. + pub offset: u32, + /// Optional Kubernetes-style label selector (e.g. `env=prod,team=core`). + pub label_selector: Option, +} + +/// Options for [`crate::client::OpenShellClient::exec`]. +#[derive(Clone, Debug, Default)] +pub struct ExecOptions { + /// Working directory inside the sandbox. + pub workdir: Option, + /// Environment overrides for the exec. + pub environment: HashMap, + /// Optional command timeout. `None` lets the gateway choose. + pub timeout: Option, + /// Optional stdin payload. + pub stdin: Option>, +} + +/// Result of a non-streaming exec call. +/// +/// `stdout` and `stderr` are buffered to the end of the command. Use the +/// raw streaming RPC ([`crate::raw`]) for long-running output. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct ExecResult { + pub exit_code: i32, + pub stdout: Vec, + pub stderr: Vec, +} diff --git a/crates/openshell-sdk/tests/client_mock.rs b/crates/openshell-sdk/tests/client_mock.rs new file mode 100644 index 000000000..12c3aaa06 --- /dev/null +++ b/crates/openshell-sdk/tests/client_mock.rs @@ -0,0 +1,765 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! High-level [`OpenShellClient`] tests against an in-process mock gateway. +//! +//! The mock binds an ephemeral plaintext TCP listener and serves the +//! `OpenShell` gRPC service. Tests dial it via `http://127.0.0.1:` so +//! TLS and auth code paths are skipped — those are exercised by the CLI's +//! `mtls_integration` and OIDC tests. + +use openshell_core::proto; +use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; +use openshell_sdk::{ + ClientConfig, ExecOptions, ListOptions, OpenShellClient, SandboxPhase, SandboxSpec, + ServiceStatus as SdkServiceStatus, +}; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::{Response, Status}; + +/// Captured fixture state — what the mock observed and the canned replies it +/// returned. One per test so assertions are scoped. +#[derive(Default)] +struct MockState { + last_get_name: Mutex>, + last_create: Mutex>, + last_delete_name: Mutex>, + last_list_request: Mutex>, + last_exec_request: Mutex>, + get_calls: AtomicU32, + phase_sequence: Vec, + get_returns_not_found: bool, + not_found_after: Option, +} + +#[derive(Clone)] +struct TestOpenShell { + state: Arc, +} + +fn sandbox_with_phase(name: &str, phase: proto::SandboxPhase) -> proto::Sandbox { + proto::Sandbox { + metadata: Some(proto::datamodel::v1::ObjectMeta { + id: format!("id-{name}"), + name: name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 1, + }), + spec: None, + status: Some(proto::SandboxStatus { + phase: phase.into(), + ..Default::default() + }), + } +} + +#[tonic::async_trait] +impl OpenShell for TestOpenShell { + async fn health( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::HealthResponse { + status: proto::ServiceStatus::Healthy.into(), + version: "test-1.2.3".to_string(), + })) + } + + async fn create_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let req = request.into_inner(); + let name = if req.name.is_empty() { + "generated".to_string() + } else { + req.name.clone() + }; + *self.state.last_create.lock().await = Some(req); + Ok(Response::new(proto::SandboxResponse { + sandbox: Some(sandbox_with_phase(&name, proto::SandboxPhase::Provisioning)), + })) + } + + async fn get_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + *self.state.last_get_name.lock().await = Some(name.clone()); + let count = self.state.get_calls.fetch_add(1, Ordering::SeqCst); + + if self.state.get_returns_not_found { + return Err(Status::not_found(format!("sandbox '{name}' not found"))); + } + if let Some(threshold) = self.state.not_found_after + && count >= threshold + { + return Err(Status::not_found(format!("sandbox '{name}' not found"))); + } + + let phase = self + .state + .phase_sequence + .get(count as usize) + .copied() + .or_else(|| self.state.phase_sequence.last().copied()) + .unwrap_or(proto::SandboxPhase::Ready); + + Ok(Response::new(proto::SandboxResponse { + sandbox: Some(sandbox_with_phase(&name, phase)), + })) + } + + async fn list_sandboxes( + &self, + request: tonic::Request, + ) -> Result, Status> { + *self.state.last_list_request.lock().await = Some(request.into_inner()); + Ok(Response::new(proto::ListSandboxesResponse { + sandboxes: vec![ + sandbox_with_phase("alpha", proto::SandboxPhase::Ready), + sandbox_with_phase("beta", proto::SandboxPhase::Provisioning), + ], + })) + } + + async fn list_sandbox_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::DetachSandboxProviderResponse::default(), + )) + } + + async fn delete_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + *self.state.last_delete_name.lock().await = Some(name); + Ok(Response::new(proto::DeleteSandboxResponse { + deleted: true, + })) + } + + async fn create_ssh_session( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::CreateSshSessionResponse::default())) + } + + async fn expose_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ServiceEndpointResponse::default())) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn revoke_ssh_session( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::RevokeSshSessionResponse::default())) + } + + type ExecSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn exec_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + *self.state.last_exec_request.lock().await = Some(request.into_inner()); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Stdout( + proto::ExecSandboxStdout { + data: b"hello\n".to_vec(), + }, + )), + })) + .await; + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Stderr( + proto::ExecSandboxStderr { + data: b"warn\n".to_vec(), + }, + )), + })) + .await; + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Exit( + proto::ExecSandboxExit { exit_code: 7 }, + )), + })) + .await; + }); + Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new( + rx, + ))) + } + + type ForwardTcpStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn forward_tcp( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn exec_sandbox_interactive( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn create_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn get_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn list_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ListProvidersResponse::default())) + } + + async fn list_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_provider_profile( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn import_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn lint_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn update_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::DeleteProviderResponse::default())) + } + + async fn delete_provider_profile( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::GetSandboxConfigResponse::default())) + } + + async fn get_gateway_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::GetGatewayConfigResponse::default())) + } + + async fn update_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_policy_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::GetSandboxPolicyStatusResponse::default(), + )) + } + + async fn list_sandbox_policies( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn report_policy_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_provider_environment( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::GetSandboxProviderEnvironmentResponse::default(), + )) + } + + async fn get_sandbox_logs( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn push_sandbox_logs( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type WatchSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn watch_sandbox( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn submit_policy_analysis( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_draft_policy( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn approve_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn approve_all_draft_chunks( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn reject_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn edit_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn undo_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn clear_draft_chunks( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_draft_history( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn connect_supervisor( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type RelayStreamStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn relay_stream( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn issue_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn refresh_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } +} + +/// Spin up the mock gateway, return its endpoint URL. +async fn start_mock(state: Arc) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let endpoint = format!("http://{addr}"); + let stream = TcpListenerStream::new(listener); + let svc = OpenShellServer::new(TestOpenShell { state }); + tokio::spawn(async move { + let _ = tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming(stream) + .await; + }); + endpoint +} + +async fn connect(endpoint: &str) -> OpenShellClient { + OpenShellClient::connect(ClientConfig::new(endpoint)) + .await + .expect("connect should succeed against local mock") +} + +#[tokio::test] +async fn health_returns_curated_snapshot() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let h = client.health().await.unwrap(); + assert_eq!(h.status, SdkServiceStatus::Healthy); + assert_eq!(h.version, "test-1.2.3"); +} + +#[tokio::test] +async fn create_sandbox_passes_spec_through() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let mut labels = HashMap::new(); + labels.insert("team".to_string(), "core".to_string()); + + let spec = SandboxSpec { + name: Some("my-box".to_string()), + image: Some("ghcr.io/foo:bar".to_string()), + labels: labels.clone(), + gpu: true, + ..Default::default() + }; + + let result = client.create_sandbox(spec).await.unwrap(); + assert_eq!(result.name, "my-box"); + assert_eq!(result.phase, SandboxPhase::Provisioning); + + let observed = state.last_create.lock().await.clone().unwrap(); + assert_eq!(observed.name, "my-box"); + assert_eq!(observed.labels, labels); + let observed_spec = observed.spec.unwrap(); + assert!(observed_spec.gpu); + assert_eq!( + observed_spec.template.as_ref().unwrap().image, + "ghcr.io/foo:bar" + ); +} + +#[tokio::test] +async fn get_sandbox_sends_name_and_maps_phase() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Ready], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let sandbox = client.get_sandbox("my-box").await.unwrap(); + assert_eq!(sandbox.name, "my-box"); + assert_eq!(sandbox.id, "id-my-box"); + assert_eq!(sandbox.phase, SandboxPhase::Ready); + + let observed = state.last_get_name.lock().await.clone(); + assert_eq!(observed.as_deref(), Some("my-box")); +} + +#[tokio::test] +async fn list_sandboxes_propagates_filters() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let opts = ListOptions { + limit: 25, + offset: 5, + label_selector: Some("team=core".to_string()), + }; + let items = client.list_sandboxes(opts).await.unwrap(); + assert_eq!(items.len(), 2); + assert_eq!(items[0].name, "alpha"); + assert_eq!(items[0].phase, SandboxPhase::Ready); + assert_eq!(items[1].phase, SandboxPhase::Provisioning); + + let observed = state.last_list_request.lock().await.clone().unwrap(); + assert_eq!(observed.limit, 25); + assert_eq!(observed.offset, 5); + assert_eq!(observed.label_selector, "team=core"); +} + +#[tokio::test] +async fn delete_sandbox_returns_server_ack() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let deleted = client.delete_sandbox("doomed").await.unwrap(); + assert!(deleted); + + let observed = state.last_delete_name.lock().await.clone(); + assert_eq!(observed.as_deref(), Some("doomed")); +} + +#[tokio::test] +async fn wait_ready_transitions_through_phases() { + let state = Arc::new(MockState { + phase_sequence: vec![ + proto::SandboxPhase::Provisioning, + proto::SandboxPhase::Provisioning, + proto::SandboxPhase::Ready, + ], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let sandbox = client + .wait_ready("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap(); + assert_eq!(sandbox.phase, SandboxPhase::Ready); + assert!(state.get_calls.load(Ordering::SeqCst) >= 3); +} + +#[tokio::test] +async fn wait_ready_surfaces_error_phase() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Error], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let err = client + .wait_ready("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap_err(); + assert_eq!(err.code(), "connect"); +} + +#[tokio::test] +async fn wait_deleted_returns_when_get_reports_not_found() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Deleting], + not_found_after: Some(2), + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + client + .wait_deleted("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap(); + assert!(state.get_calls.load(Ordering::SeqCst) >= 3); +} + +#[tokio::test] +async fn get_sandbox_not_found_maps_to_typed_error() { + let state = Arc::new(MockState { + get_returns_not_found: true, + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let err = client.get_sandbox("missing").await.unwrap_err(); + assert_eq!(err.code(), "not_found"); +} + +#[tokio::test] +async fn exec_buffers_stdout_stderr_and_exit() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Ready], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let result = client + .exec( + "my-box", + &["echo".to_string(), "hello".to_string()], + ExecOptions { + workdir: Some("/work".to_string()), + timeout: Some(std::time::Duration::from_secs(10)), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(result.exit_code, 7); + assert_eq!(result.stdout, b"hello\n"); + assert_eq!(result.stderr, b"warn\n"); + + let observed = state.last_exec_request.lock().await.clone().unwrap(); + assert_eq!(observed.sandbox_id, "id-my-box"); + assert_eq!( + observed.command, + vec!["echo".to_string(), "hello".to_string()] + ); + assert_eq!(observed.workdir, "/work"); + assert_eq!(observed.timeout_seconds, 10); +} diff --git a/crates/openshell-tui/Cargo.toml b/crates/openshell-tui/Cargo.toml index 238166136..92b993057 100644 --- a/crates/openshell-tui/Cargo.toml +++ b/crates/openshell-tui/Cargo.toml @@ -15,6 +15,7 @@ openshell-core = { path = "../openshell-core", default-features = false } openshell-bootstrap = { path = "../openshell-bootstrap" } openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } +openshell-sdk = { path = "../openshell-sdk" } base64 = { workspace = true } ratatui = { workspace = true } diff --git a/crates/openshell-tui/src/app.rs b/crates/openshell-tui/src/app.rs index cb02c8c24..2a985f99f 100644 --- a/crates/openshell-tui/src/app.rs +++ b/crates/openshell-tui/src/app.rs @@ -5,10 +5,10 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::setting_value; use openshell_core::settings::{self, SettingValueKind}; +use openshell_sdk::EdgeAuthInterceptor; use tonic::service::interceptor::InterceptedService; use tonic::transport::Channel; diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index b7e03c6ac..d504f325b 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -18,10 +18,10 @@ use crossterm::terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, }; use miette::{IntoDiagnostic, Result}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::metadata::{ObjectId, ObjectLabels, ObjectName}; use openshell_core::proto::SandboxPhase; use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_sdk::EdgeAuthInterceptor; use ratatui::Terminal; use ratatui::backend::CrosstermBackend; use tokio::sync::mpsc;