Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 270 additions & 32 deletions crates/openshell-cli/src/ssh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ use crate::tls::{TlsOptions, grpc_client};
use miette::{IntoDiagnostic, Result, WrapErr};
#[cfg(unix)]
use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction};
#[cfg(unix)]
use nix::unistd::Pid;
use openshell_core::ObjectId;
#[cfg(unix)]
use openshell_core::forward::pid_matches_forward;
use openshell_core::forward::{
build_proxy_command, find_ssh_forward_pid, format_gateway_url, resolve_ssh_gateway,
shell_escape, validate_ssh_session_response, write_forward_pid,
ForwardSpec, build_proxy_command, find_ssh_forward_pid, format_gateway_url,
resolve_ssh_gateway, shell_escape, validate_ssh_session_response, write_forward_pid,
};
use openshell_core::proto::{
CreateSshSessionRequest, GetSandboxRequest, SshRelayTarget, TcpForwardFrame, TcpForwardInit,
Expand All @@ -25,10 +29,23 @@ use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::Command as TokioCommand;
use tokio::net::TcpStream;
use tokio::process::{Child, Command as TokioCommand};
use tokio_stream::wrappers::ReceiverStream;

const FOREGROUND_FORWARD_STARTUP_GRACE_PERIOD: Duration = Duration::from_secs(2);
/// Time budget for finding the OpenSSH background process after `ssh -f`
/// returns. PID discovery is separate from listener readiness so missing
/// process tracking still fails quickly.
const FORWARD_PID_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(2);
/// Time budget for the local listener to become reachable after `ssh` starts.
/// This is a user-visible readiness deadline for both foreground and background
/// forwards, not a soft cleanup grace period.
const FORWARD_LISTENER_READINESS_TIMEOUT: Duration = Duration::from_secs(10);
/// Delay between listener/PID probes within the configured timeout.
const FORWARD_LISTENER_PROBE_INTERVAL: Duration = Duration::from_millis(50);
/// Per-attempt connect timeout, so one hung probe cannot consume the whole
/// grace period.
const FORWARD_LISTENER_CONNECT_TIMEOUT: Duration = Duration::from_millis(200);

#[derive(Clone, Copy, Debug)]
pub enum Editor {
Expand Down Expand Up @@ -320,7 +337,7 @@ pub async fn sandbox_connect_editor(
pub async fn sandbox_forward(
server: &str,
name: &str,
spec: &openshell_core::forward::ForwardSpec,
spec: &ForwardSpec,
background: bool,
tls: &TlsOptions,
) -> Result<()> {
Expand Down Expand Up @@ -349,44 +366,181 @@ pub async fn sandbox_forward(

let port = spec.port;

let status = if background {
command.status().await.into_diagnostic()?
} else {
let mut child = command.spawn().into_diagnostic()?;
if let Ok(status) =
tokio::time::timeout(FOREGROUND_FORWARD_STARTUP_GRACE_PERIOD, child.wait()).await
if background {
let status = command.status().await.into_diagnostic()?;
if !status.success() {
return Err(miette::miette!("ssh exited with status {status}"));
}

// Background forwards must be both reachable and tracked. Without a PID
// file, `openshell forward stop/list` cannot manage the tunnel, so this
// path fails closed instead of printing a URL for an orphaned process.
let pid = wait_for_ssh_forward_pid(&session.sandbox_id, port)
.await
.ok_or_else(|| {
miette::miette!(
"could not discover backgrounded SSH process for sandbox {name} port {port}; \
refusing to report an untracked forward"
)
})?;

if let Err(err) = wait_for_forward_listener(spec, FORWARD_LISTENER_READINESS_TIMEOUT)
.await
.wrap_err("ssh process started but local forward listener was not reachable")
{
status.into_diagnostic()?
} else {
eprintln!("{}", foreground_forward_started_message(name, spec));
child.wait().await.into_diagnostic()?
terminate_forward_pid(pid, port, &session.sandbox_id);
return Err(err);
}

write_forward_pid(name, port, pid, &session.sandbox_id, &spec.bind_addr)?;
return Ok(());
}

let status = {
let mut child = command.spawn().into_diagnostic()?;
if let Err(err) = wait_for_foreground_forward_start(&mut child, spec).await {
let _ = child.kill().await;
return Err(err);
}
eprintln!("{}", foreground_forward_started_message(name, spec));
child.wait().await.into_diagnostic()?
};

if !status.success() {
return Err(miette::miette!("ssh exited with status {status}"));
}

if background {
// SSH has forked — find its PID and record it.
if let Some(pid) = find_ssh_forward_pid(&session.sandbox_id, port) {
write_forward_pid(name, port, pid, &session.sandbox_id, &spec.bind_addr)?;
} else {
eprintln!(
"{} Could not discover backgrounded SSH process; \
forward may be running but is not tracked",
"!".yellow(),
);
Ok(())
}

/// Wait for the local listener, racing the probe against the `ssh` child
/// exiting. An early exit (e.g. `ExitOnForwardFailure=yes` tearing down the
/// session) means forwarding never came up, so it errors instead of waiting
/// out the grace period.
async fn wait_for_foreground_forward_start(child: &mut Child, spec: &ForwardSpec) -> Result<()> {
let listener = wait_for_forward_listener(spec, FORWARD_LISTENER_READINESS_TIMEOUT);
tokio::pin!(listener);
tokio::select! {
result = &mut listener => result,
status = child.wait() => {
let status = status.into_diagnostic()?;
if status.success() {
Err(miette::miette!(
"ssh exited before local forward listener opened on {}:{}",
forward_probe_host(spec),
spec.port,
))
} else {
Err(miette::miette!(
"ssh exited with status {status} before local forward listener opened on {}:{}",
forward_probe_host(spec),
spec.port,
))
}
}
}
}

Ok(())
/// Poll for the backgrounded (`ssh -f`) forward's PID. `-f` forks after auth,
/// so the PID is unknown when `command.status()` returns and must be discovered
/// afterward. Returns `None` if it never appears within the grace period.
async fn wait_for_ssh_forward_pid(sandbox_id: &str, port: u16) -> Option<u32> {
let deadline = tokio::time::Instant::now() + FORWARD_PID_DISCOVERY_TIMEOUT;
loop {
if let Some(pid) = find_ssh_forward_pid(sandbox_id, port) {
return Some(pid);
}
if tokio::time::Instant::now() >= deadline {
return None;
}
tokio::time::sleep(FORWARD_LISTENER_PROBE_INTERVAL).await;
}
}

fn foreground_forward_started_message(
name: &str,
spec: &openshell_core::forward::ForwardSpec,
) -> String {
/// Poll the local endpoint until a connect succeeds or `wait_for` elapses. The
/// last probe error is folded into the timeout diagnostic, so a failure reports
/// why the listener never opened, not just that it timed out.
async fn wait_for_forward_listener(spec: &ForwardSpec, wait_for: Duration) -> Result<()> {
let deadline = tokio::time::Instant::now() + wait_for;
loop {
let probe_error = match probe_forward_listener(spec).await {
Ok(()) => return Ok(()),
Err(err) => err,
};

if tokio::time::Instant::now() >= deadline {
return Err(miette::miette!(
"local forward listener did not open on {}:{} within {}ms: last probe failed with {probe_error}",
forward_probe_host(spec),
spec.port,
wait_for.as_millis(),
));
}

tokio::time::sleep(FORWARD_LISTENER_PROBE_INTERVAL).await;
}
}

/// One bounded TCP connect to the forward endpoint. Returns a `String` error
/// rather than a `miette` diagnostic to stay cheap in the poll loop. The
/// connection only proves reachability and is dropped at once; SSH forwards
/// this throwaway connect to the sandbox-side target.
async fn probe_forward_listener(spec: &ForwardSpec) -> std::result::Result<(), String> {
match tokio::time::timeout(
FORWARD_LISTENER_CONNECT_TIMEOUT,
TcpStream::connect((forward_probe_host(spec), spec.port)),
)
.await
{
Ok(Ok(stream)) => {
drop(stream);
Ok(())
}
Ok(Err(err)) => Err(err.to_string()),
Err(_) => Err(format!(
"connect timed out after {}ms",
FORWARD_LISTENER_CONNECT_TIMEOUT.as_millis()
)),
}
}

/// Resolve the bind address to a connectable host. Wildcard binds (`0.0.0.0`,
/// `::`, empty) are "any-address" listeners, not valid connect targets, so they
/// map to the matching loopback. Specific addresses are probed as-is.
fn forward_probe_host(spec: &ForwardSpec) -> &str {
match spec.bind_addr.as_str() {
"" | "0.0.0.0" => "127.0.0.1",
"::" => "::1",
host => host,
}
}

/// Best-effort termination of a forward whose listener never came up. Failures
/// are ignored: the process may already be exiting, and the caller surfaces the
/// original listener error regardless.
#[cfg(unix)]
fn terminate_forward_pid(pid: u32, port: u16, sandbox_id: &str) {
let Ok(raw_pid) = i32::try_from(pid) else {
return;
};
if raw_pid <= 0 {
return;
}
if !pid_matches_forward(pid, port, Some(sandbox_id)) {
// The PID came from a process-table scan, not a file we own. Re-check
// immediately before signaling so a stale or spoofed match is left
// untouched instead of terminating an unrelated process.
return;
}

let _ = nix::sys::signal::kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
}

/// Non-Unix builds cannot manage OpenSSH process IDs with Unix signals.
#[cfg(not(unix))]
fn terminate_forward_pid(_pid: u32, _port: u16, _sandbox_id: &str) {}

fn foreground_forward_started_message(name: &str, spec: &ForwardSpec) -> String {
format!(
"{} Forwarding port {} to sandbox {name}\n Access at: {}\n Press Ctrl+C to stop\n {}",
"✓".green().bold(),
Expand Down Expand Up @@ -1590,7 +1744,7 @@ mod tests {

#[test]
fn foreground_forward_started_message_includes_port_and_stop_hint() {
let spec = openshell_core::forward::ForwardSpec::new(8080);
let spec = ForwardSpec::new(8080);
let message = foreground_forward_started_message("demo", &spec);
assert!(message.contains("Forwarding port 8080 to sandbox demo"));
assert!(message.contains("Access at: http://127.0.0.1:8080/"));
Expand All @@ -1603,12 +1757,96 @@ mod tests {

#[test]
fn foreground_forward_started_message_custom_bind_addr() {
let spec = openshell_core::forward::ForwardSpec::parse("0.0.0.0:3000").unwrap();
let spec = ForwardSpec::parse("0.0.0.0:3000").unwrap();
let message = foreground_forward_started_message("demo", &spec);
assert!(message.contains("Forwarding port 3000 to sandbox demo"));
assert!(message.contains("Access at: http://localhost:3000/"));
}

#[test]
fn forward_probe_host_uses_connectable_loopback_for_wildcard_binds() {
let ipv4 = ForwardSpec::parse("0.0.0.0:3000").unwrap();
let ipv6 = ForwardSpec::parse(":::3000").unwrap();
let loopback = ForwardSpec::new(3000);

assert_eq!(forward_probe_host(&ipv4), "127.0.0.1");
assert_eq!(forward_probe_host(&ipv6), "::1");
assert_eq!(forward_probe_host(&loopback), "127.0.0.1");
}

#[tokio::test]
async fn wait_for_forward_listener_accepts_ready_listener() {
let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0))
.await
.unwrap();
let port = listener.local_addr().unwrap().port();
let accept = tokio::spawn(async move {
let _ = listener.accept().await;
});
let spec = ForwardSpec::new(port);

wait_for_forward_listener(&spec, Duration::from_secs(1))
.await
.unwrap();
accept.await.unwrap();
}

#[tokio::test]
async fn wait_for_forward_listener_rejects_missing_listener() {
let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let spec = ForwardSpec::new(port);

let err = wait_for_forward_listener(&spec, Duration::from_millis(20))
.await
.unwrap_err();
let text = format!("{err:?}");
assert!(text.contains("local forward listener did not open"));
}

#[cfg(unix)]
#[test]
fn terminate_forward_pid_skips_process_that_no_longer_matches_forward() {
let dir = tempfile::tempdir().unwrap();
let terminated_path = dir.path().join("terminated");
let mut child = Command::new("python3")
.arg("-c")
.arg(
r#"
import pathlib
import signal
import sys
import time

terminated_path = pathlib.Path(sys.argv[1])

def stop(_signum, _frame):
terminated_path.write_text("terminated")
raise SystemExit(0)

signal.signal(signal.SIGTERM, stop)

while True:
time.sleep(1)
"#,
)
.arg(&terminated_path)
.spawn()
.unwrap();
std::thread::sleep(Duration::from_millis(100));

terminate_forward_pid(child.id(), 43210, "id-spoofed-forward");
std::thread::sleep(Duration::from_millis(200));

assert!(
!terminated_path.exists(),
"mismatched process should not receive SIGTERM"
);
let _ = child.kill();
let _ = child.wait();
}

#[test]
fn split_sandbox_path_separates_parent_and_basename() {
assert_eq!(
Expand Down
Loading
Loading