Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions crates/openshell-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ pub struct TlsConfig {
/// When `false`, client certificates are accepted but not required.
#[serde(default)]
pub require_client_auth: bool,

/// Interval in seconds for polling TLS certificate files for changes.
/// When `0`, certificate reload is disabled (default).
#[serde(default)]
pub reload_interval_secs: u64,
}

/// OIDC (`OpenID` Connect) configuration for JWT-based authentication.
Expand Down
1 change: 1 addition & 0 deletions crates/openshell-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ ipnet = "2"
tempfile = "3"
rustix = { workspace = true }
x509-parser = "0.16"
arc-swap = "1"

[features]
bundled-z3 = ["openshell-prover/bundled-z3"]
Expand Down
7 changes: 7 additions & 0 deletions crates/openshell-server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,18 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
let key_path = args.tls_key.clone().ok_or_else(|| {
miette::miette!("--tls-key is required when TLS is enabled (use --disable-tls to skip)")
})?;

let reload_interval_secs = file
.as_ref()
.and_then(|f| f.openshell.gateway.tls.as_ref())
.map_or(0, |t| t.reload_interval_secs);

Some(openshell_core::TlsConfig {
cert_path,
key_path,
require_client_auth: has_client_ca && !has_oidc,
client_ca_path: args.tls_client_ca.clone(),
reload_interval_secs,
})
};

Expand Down
59 changes: 23 additions & 36 deletions crates/openshell-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ mod ssh_sessions;
pub mod supervisor_session;
mod telemetry;
mod tls;
#[cfg(test)]
pub(crate) mod tls_test_utils;
pub mod tracing_bus;
mod ws_tunnel;

Expand Down Expand Up @@ -413,20 +415,31 @@ pub async fn run_server(
info!("Metrics server disabled");
}

let (shutdown_tx, shutdown_rx) = watch::channel(false);

// Build TLS acceptor when TLS is configured; otherwise serve plaintext.
let tls_acceptor = if let Some(tls) = &config.tls {
Some(TlsAcceptor::from_files(
let acceptor = TlsAcceptor::from_files(
&tls.cert_path,
&tls.key_path,
tls.client_ca_path.as_deref(),
tls.require_client_auth,
)?)
)?;

// Spawn TLS certificate reload worker if enabled
if tls.reload_interval_secs > 0 {
acceptor.spawn_reload_worker(
Duration::from_secs(tls.reload_interval_secs),
shutdown_rx.clone(),
);
}

Some(acceptor)
} else {
info!("TLS disabled — accepting plaintext connections");
None
};

let (shutdown_tx, shutdown_rx) = watch::channel(false);
let mut listener_tasks = Vec::with_capacity(gateway_listeners.len());
let enable_loopback_service_http = config.service_routing.enable_loopback_service_http;
for (listener, listen_addr) in gateway_listeners {
Expand Down Expand Up @@ -615,7 +628,10 @@ fn spawn_gateway_connection(
warn!(client = %addr, listen = %listen_addr, "Rejected plaintext HTTP on non-loopback gateway listener");
}
Ok(ConnectionProtocol::Tls | ConnectionProtocol::Unknown) => {
match acceptor.inner().accept(stream).await {
// acceptor.acceptor() snapshots the current TLS config;
// the returned acceptor owns an Arc that stays alive for
// the full duration of the handshake.
match acceptor.acceptor().accept(stream).await {
Ok(tls_stream) => {
let peer_identity = multiplex::extract_peer_identity(&tls_stream);
if let Err(e) = service
Expand Down Expand Up @@ -908,8 +924,7 @@ mod tests {
ComputeDriverKind, Config,
proto::{HealthRequest, open_shell_client::OpenShellClient},
};
use rcgen::{CertificateParams, IsCa, KeyPair};
use std::io::{Error, ErrorKind, Write};
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -918,41 +933,13 @@ mod tests {
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;

fn install_rustls_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
use crate::tls_test_utils::{generate_test_certs_with_ca, install_rustls_provider};

fn test_tls_acceptor() -> (TempDir, TlsAcceptor) {
install_rustls_provider();

let mut ca_params =
CertificateParams::new(Vec::<String>::new()).expect("failed to create CA params");
ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
ca_params
.distinguished_name
.push(rcgen::DnType::CommonName, "test-ca");
let ca_key = KeyPair::generate().expect("failed to generate CA key");
let ca_cert = ca_params
.self_signed(&ca_key)
.expect("failed to sign CA cert");

let server_params = CertificateParams::new(vec!["localhost".to_string()])
.expect("failed to create server params");
let server_key = KeyPair::generate().expect("failed to generate server key");
let server_cert = server_params
.signed_by(&server_key, &ca_cert, &ca_key)
.expect("failed to sign server cert");

let dir = tempdir().expect("failed to create tempdir");
let write_file = |name: &str, data: &[u8]| {
let path = dir.path().join(name);
std::fs::File::create(&path)
.and_then(|mut file| file.write_all(data))
.expect("failed to write tls test file");
};
write_file("ca.pem", ca_cert.pem().as_bytes());
write_file("server-cert.pem", server_cert.pem().as_bytes());
write_file("server-key.pem", server_key.serialize_pem().as_bytes());
generate_test_certs_with_ca(dir.path());

let acceptor = TlsAcceptor::from_files(
&dir.path().join("server-cert.pem"),
Expand Down
1 change: 1 addition & 0 deletions crates/openshell-server/src/service_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ mod tests {
key_path: "server.key".into(),
client_ca_path: Some("ca.crt".into()),
require_client_auth: false,
reload_interval_secs: 0,
}
}

Expand Down
Loading
Loading