diff --git a/Cargo.lock b/Cargo.lock index 366f001a6..aa2683cc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3701,6 +3701,7 @@ version = "0.0.0" dependencies = [ "base64 0.22.1", "crossterm 0.28.1", + "futures", "miette", "openshell-bootstrap", "openshell-core", @@ -3711,6 +3712,7 @@ dependencies = [ "serde", "terminal-colorsaurus", "tokio", + "tokio-tungstenite 0.26.2", "tonic", "tracing", "url", diff --git a/crates/openshell-tui/Cargo.toml b/crates/openshell-tui/Cargo.toml index 238166136..aaeacf2f2 100644 --- a/crates/openshell-tui/Cargo.toml +++ b/crates/openshell-tui/Cargo.toml @@ -27,6 +27,8 @@ owo-colors = { workspace = true } serde = { workspace = true } tracing = { workspace = true } url = { workspace = true } +tokio-tungstenite = { workspace = true } +futures = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-tui/src/edge_tunnel.rs b/crates/openshell-tui/src/edge_tunnel.rs new file mode 100644 index 000000000..4ebc67e0a --- /dev/null +++ b/crates/openshell-tui/src/edge_tunnel.rs @@ -0,0 +1,173 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Edge-authenticated WebSocket tunnel proxy for TUI gateway switching. + +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use miette::{IntoDiagnostic, Result}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::HeaderValue; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tracing::{debug, error, warn}; + +pub struct EdgeTunnelProxy { + pub local_addr: SocketAddr, +} + +#[derive(Clone)] +struct TunnelConfig { + ws_url: String, + edge_token: String, +} + +pub async fn start_tunnel_proxy( + gateway_endpoint: &str, + edge_token: &str, +) -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?; + let local_addr = listener.local_addr().into_diagnostic()?; + let ws_url = format!( + "{}/_ws_tunnel", + gateway_endpoint + .replacen("https://", "wss://", 1) + .replacen("http://", "ws://", 1) + .trim_end_matches('/') + ); + let config = Arc::new(TunnelConfig { + ws_url, + edge_token: edge_token.to_string(), + }); + + debug!( + local_addr = %local_addr, + gateway = %gateway_endpoint, + "starting TUI edge tunnel proxy" + ); + tokio::spawn(accept_loop(listener, config)); + Ok(EdgeTunnelProxy { local_addr }) +} + +async fn accept_loop(listener: TcpListener, config: Arc) { + loop { + match listener.accept().await { + Ok((stream, peer)) => { + let config = Arc::clone(&config); + tokio::spawn(async move { + if let Err(err) = handle_connection(stream, &config).await { + warn!(peer = %peer, error = %err, "TUI edge tunnel connection failed"); + } + }); + } + Err(err) => { + error!(error = %err, "failed to accept TUI edge tunnel connection"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + } + } +} + +async fn handle_connection(tcp_stream: TcpStream, config: &TunnelConfig) -> Result<()> { + let ws_stream = open_ws(config).await?; + let (ws_sink, ws_source) = ws_stream.split(); + let (tcp_read, tcp_write) = tokio::io::split(tcp_stream); + + let mut tcp_to_ws = tokio::spawn(copy_tcp_to_ws(tcp_read, ws_sink)); + let mut ws_to_tcp = tokio::spawn(copy_ws_to_tcp(ws_source, tcp_write)); + + tokio::select! { + res = &mut tcp_to_ws => { + if let Err(err) = res { + debug!(error = %err, "TUI tcp->ws task panicked"); + } + ws_to_tcp.abort(); + } + res = &mut ws_to_tcp => { + if let Err(err) = res { + debug!(error = %err, "TUI ws->tcp task panicked"); + } + tcp_to_ws.abort(); + } + } + + Ok(()) +} + +async fn open_ws(config: &TunnelConfig) -> Result>> { + let mut request = (&config.ws_url).into_client_request().into_diagnostic()?; + let token_val = HeaderValue::from_str(&config.edge_token) + .map_err(|err| miette::miette!("invalid edge token header value: {err}"))?; + request + .headers_mut() + .insert("Cf-Access-Token", token_val.clone()); + request + .headers_mut() + .insert("Cf-Access-Jwt-Assertion", token_val); + request.headers_mut().insert( + "Cookie", + HeaderValue::from_str(&format!("CF_Authorization={}", config.edge_token)) + .map_err(|err| miette::miette!("invalid edge token cookie value: {err}"))?, + ); + + let (ws_stream, response) = tokio_tungstenite::connect_async(request) + .await + .map_err(|err| miette::miette!("WebSocket connect failed: {err}"))?; + debug!(status = %response.status(), "TUI edge WebSocket connected"); + Ok(ws_stream) +} + +async fn copy_tcp_to_ws( + mut tcp_read: tokio::io::ReadHalf, + mut ws_sink: SplitSink>, Message>, +) { + let mut buf = vec![0_u8; 32 * 1024]; + loop { + match tcp_read.read(&mut buf).await { + Ok(0) => { + let _ = ws_sink.close().await; + break; + } + Ok(n) => { + if ws_sink + .send(Message::Binary(buf[..n].to_vec().into())) + .await + .is_err() + { + break; + } + } + Err(err) => { + debug!(error = %err, "TUI tcp read error"); + let _ = ws_sink.close().await; + break; + } + } + } +} + +async fn copy_ws_to_tcp( + mut ws_source: SplitStream>>, + mut tcp_write: tokio::io::WriteHalf, +) { + while let Some(msg) = ws_source.next().await { + match msg { + Ok(Message::Binary(data)) => { + if tcp_write.write_all(&data).await.is_err() { + break; + } + } + Ok(Message::Close(_)) => break, + Ok(Message::Ping(_) | Message::Pong(_) | Message::Text(_) | Message::Frame(_)) => {} + Err(err) => { + debug!(error = %err, "TUI WebSocket read error"); + break; + } + } + } + let _ = tcp_write.shutdown().await; +} diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index b7e03c6ac..95b35225b 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -3,6 +3,7 @@ mod app; mod clipboard; +mod edge_tunnel; mod event; pub mod theme; mod ui; @@ -18,6 +19,7 @@ use crossterm::terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, }; use miette::{IntoDiagnostic, Result}; +use openshell_bootstrap::{GatewayMetadata, edge_token::load_edge_token}; use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::metadata::{ObjectId, ObjectLabels, ObjectName}; use openshell_core::proto::SandboxPhase; @@ -492,6 +494,24 @@ async fn handle_gateway_switch(app: &mut App) { } } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum GatewayChannelMode { + Oidc, + Edge, + Plaintext, + Mtls, +} + +fn gateway_channel_mode(meta: Option<&GatewayMetadata>, endpoint: &str) -> GatewayChannelMode { + match meta.and_then(|m| m.auth_mode.as_deref()) { + Some("oidc") => GatewayChannelMode::Oidc, + Some("cloudflare_jwt") => GatewayChannelMode::Edge, + Some("plaintext") => GatewayChannelMode::Plaintext, + _ if endpoint.starts_with("http://") => GatewayChannelMode::Plaintext, + _ => GatewayChannelMode::Mtls, + } +} + /// Build a gRPC channel and auth interceptor for a gateway. /// /// Checks gateway metadata for the auth mode and loads the appropriate @@ -499,26 +519,64 @@ async fn handle_gateway_switch(app: &mut App) { async fn connect_to_gateway(name: &str, endpoint: &str) -> Result<(Channel, EdgeAuthInterceptor)> { let meta = openshell_bootstrap::get_gateway_metadata(name); - if meta.as_ref().and_then(|m| m.auth_mode.as_deref()) == Some("oidc") { - let bundle = openshell_bootstrap::oidc_token::load_oidc_token(name).ok_or_else(|| { - miette::miette!( - "No OIDC token for gateway '{name}'.\n\ + match gateway_channel_mode(meta.as_ref(), endpoint) { + GatewayChannelMode::Oidc => { + let bundle = + openshell_bootstrap::oidc_token::load_oidc_token(name).ok_or_else(|| { + miette::miette!( + "No OIDC token for gateway '{name}'.\n\ Authenticate with: openshell gateway login" - ) - })?; - if openshell_bootstrap::oidc_token::is_token_expired(&bundle) { - miette::bail!( - "OIDC token for gateway '{name}' has expired.\n\ - Re-authenticate with: openshell gateway login" - ); - } - let interceptor = EdgeAuthInterceptor::new(Some(&bundle.access_token), None)?; - let channel = build_oidc_channel(name, endpoint).await?; - Ok((channel, interceptor)) - } else { - let channel = build_mtls_channel(name, endpoint).await?; - Ok((channel, EdgeAuthInterceptor::noop())) + ) + })?; + if openshell_bootstrap::oidc_token::is_token_expired(&bundle) { + miette::bail!( + "OIDC token for gateway '{name}' has expired.\n\ + Re-authenticate with: openshell gateway login" + ); + } + let interceptor = EdgeAuthInterceptor::new(Some(&bundle.access_token), None)?; + let channel = build_oidc_channel(name, endpoint).await?; + Ok((channel, interceptor)) + } + GatewayChannelMode::Edge => { + let token = load_edge_token(name).ok_or_else(|| { + miette::miette!( + "No edge token for gateway '{name}'.\n\ + Authenticate with: openshell gateway login" + ) + })?; + let interceptor = EdgeAuthInterceptor::new(None, Some(&token))?; + let channel = build_edge_channel(endpoint, &token).await?; + Ok((channel, interceptor)) + } + GatewayChannelMode::Plaintext => { + let channel = build_plaintext_channel(endpoint).await?; + Ok((channel, EdgeAuthInterceptor::noop())) + } + GatewayChannelMode::Mtls => { + let channel = build_mtls_channel(name, endpoint).await?; + Ok((channel, EdgeAuthInterceptor::noop())) + } + } +} + +async fn build_edge_channel(endpoint: &str, token: &str) -> Result { + if endpoint.starts_with("https://") { + let proxy = edge_tunnel::start_tunnel_proxy(endpoint, token).await?; + return build_plaintext_channel(&format!("http://{}", proxy.local_addr)).await; } + build_plaintext_channel(endpoint).await +} + +async fn build_plaintext_channel(endpoint: &str) -> Result { + Endpoint::from_shared(endpoint.to_string()) + .into_diagnostic()? + .connect_timeout(Duration::from_secs(10)) + .http2_keep_alive_interval(Duration::from_secs(10)) + .keep_alive_while_idle(true) + .connect() + .await + .into_diagnostic() } /// Build an HTTPS channel for OIDC-authenticated gateways. @@ -2508,3 +2566,53 @@ fn days_to_ymd(days: i64) -> (i64, i64, i64) { let y = if m <= 2 { y + 1 } else { y }; (y, m, d) } + +#[cfg(test)] +mod tests { + use super::{GatewayChannelMode, gateway_channel_mode}; + use openshell_bootstrap::GatewayMetadata; + + #[test] + fn gateway_channel_mode_uses_plaintext_auth_mode() { + let meta = GatewayMetadata { + auth_mode: Some("plaintext".to_string()), + ..Default::default() + }; + assert_eq!( + gateway_channel_mode(Some(&meta), "https://gateway.example.com"), + GatewayChannelMode::Plaintext + ); + } + + #[test] + fn gateway_channel_mode_prefers_edge_metadata() { + let meta = GatewayMetadata { + auth_mode: Some("cloudflare_jwt".to_string()), + ..Default::default() + }; + assert_eq!( + gateway_channel_mode(Some(&meta), "https://gateway.example.com"), + GatewayChannelMode::Edge + ); + } + + #[test] + fn gateway_channel_mode_uses_http_endpoint_fallback() { + assert_eq!( + gateway_channel_mode(None, "http://127.0.0.1:17670"), + GatewayChannelMode::Plaintext + ); + } + + #[test] + fn gateway_channel_mode_prefers_oidc_metadata() { + let meta = GatewayMetadata { + auth_mode: Some("oidc".to_string()), + ..Default::default() + }; + assert_eq!( + gateway_channel_mode(Some(&meta), "https://gateway.example.com"), + GatewayChannelMode::Oidc + ); + } +}