diff --git a/Cargo.lock b/Cargo.lock index c363f2b9..0c1c903e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2246,6 +2246,7 @@ dependencies = [ "hex", "host-api", "k256", + "listenfd", "load_config", "or-panic", "ra-rpc", @@ -4202,6 +4203,17 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "listenfd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b87bc54a4629b4294d0b3ef041b64c40c611097a677d9dc07b2c67739fe39dba" +dependencies = [ + "libc", + "uuid", + "winapi", +] + [[package]] name = "litemap" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index 3a75109b..218c01a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,6 +105,7 @@ tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } safe-write = "0.1.2" nix = "0.29.0" sd-notify = "0.4.5" +listenfd = "1.0" jemallocator = "0.5.4" # Serialization/Parsing diff --git a/basefiles/dstack-guest-agent.service b/basefiles/dstack-guest-agent.service index b9277cae..a88d395d 100644 --- a/basefiles/dstack-guest-agent.service +++ b/basefiles/dstack-guest-agent.service @@ -1,6 +1,7 @@ [Unit] Description=dstack Guest Agent Service -After=network.target tboot.service +Requires=dstack-guest-agent.socket +After=network.target tboot.service dstack-guest-agent.socket Before=docker.service [Service] diff --git a/basefiles/dstack-guest-agent.socket b/basefiles/dstack-guest-agent.socket new file mode 100644 index 00000000..6d9bf39c --- /dev/null +++ b/basefiles/dstack-guest-agent.socket @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Socket activation for dstack-guest-agent. +# Provides backward compatibility for containers that mount sockets directly. +# Socket order: dstack.sock (index 0), tappd.sock (index 1) + +[Unit] +Description=dstack guest agent sockets + +[Socket] +ListenStream=/run/dstack.sock +ListenStream=/run/tappd.sock +SocketMode=0777 + +[Install] +WantedBy=sockets.target diff --git a/basefiles/dstack-socket.service b/basefiles/dstack-socket.service deleted file mode 100644 index 903037c3..00000000 --- a/basefiles/dstack-socket.service +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: 2025 Phala Network -# -# SPDX-License-Identifier: Apache-2.0 - -# Proxy service that forwards connections from /var/run/dstack.sock to /var/run/dstack/dstack.sock. -# Used for backward compatibility with containers that mount the socket file directly. - -[Unit] -Description=dstack socket proxy for backward compatibility -Requires=dstack-socket.socket -After=dstack-guest-agent.service - -[Service] -ExecStart=/usr/lib/systemd/systemd-socket-proxyd /var/run/dstack/dstack.sock -Type=notify diff --git a/basefiles/dstack-socket.socket b/basefiles/dstack-socket.socket deleted file mode 100644 index 1a2af9e4..00000000 --- a/basefiles/dstack-socket.socket +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: 2025 Phala Network -# -# SPDX-License-Identifier: Apache-2.0 - -# Backward compatibility socket for containers that mount /var/run/dstack.sock directly. -# The socket is owned by systemd, so it survives service restarts without inode changes. - -[Unit] -Description=dstack backward compatibility socket (dstack.sock) - -[Socket] -ListenStream=/var/run/dstack.sock -SocketMode=0777 - -[Install] -WantedBy=sockets.target diff --git a/basefiles/tappd-socket.service b/basefiles/tappd-socket.service deleted file mode 100644 index c99293e6..00000000 --- a/basefiles/tappd-socket.service +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: 2025 Phala Network -# -# SPDX-License-Identifier: Apache-2.0 - -# Proxy service that forwards connections from /var/run/tappd.sock to /var/run/dstack/tappd.sock. -# Used for backward compatibility with containers that mount the socket file directly. - -[Unit] -Description=tappd socket proxy for backward compatibility -Requires=tappd-socket.socket -After=dstack-guest-agent.service - -[Service] -ExecStart=/usr/lib/systemd/systemd-socket-proxyd /var/run/dstack/tappd.sock -Type=notify diff --git a/basefiles/tappd-socket.socket b/basefiles/tappd-socket.socket deleted file mode 100644 index 9e109d5b..00000000 --- a/basefiles/tappd-socket.socket +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: 2025 Phala Network -# -# SPDX-License-Identifier: Apache-2.0 - -# Backward compatibility socket for containers that mount /var/run/tappd.sock directly. -# The socket is owned by systemd, so it survives service restarts without inode changes. - -[Unit] -Description=dstack backward compatibility socket (tappd.sock) - -[Socket] -ListenStream=/var/run/tappd.sock -SocketMode=0777 - -[Install] -WantedBy=sockets.target diff --git a/guest-agent/Cargo.toml b/guest-agent/Cargo.toml index 9bf291b9..30206868 100644 --- a/guest-agent/Cargo.toml +++ b/guest-agent/Cargo.toml @@ -53,3 +53,4 @@ tempfile.workspace = true rand.workspace = true or-panic.workspace = true cc-eventlog.workspace = true +listenfd.workspace = true diff --git a/guest-agent/src/main.rs b/guest-agent/src/main.rs index 3183e61f..6b615ffc 100644 --- a/guest-agent/src/main.rs +++ b/guest-agent/src/main.rs @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::{fs::Permissions, future::pending, os::unix::fs::PermissionsExt}; +use std::{future::pending, os::unix::net::UnixListener as StdUnixListener}; use anyhow::{anyhow, Context, Result}; use clap::Parser; @@ -16,6 +16,7 @@ use rocket::{ use rocket_vsock_listener::VsockListener; use rpc_service::{AppState, ExternalRpcHandler, InternalRpcHandler, InternalRpcHandlerV0}; use sd_notify::{notify as sd_notify, NotifyState}; +use socket_activation::{ActivatedSockets, ActivatedUnixListener}; use std::time::Duration; use tokio::sync::oneshot; use tracing::{error, info}; @@ -25,6 +26,7 @@ mod guest_api_service; mod http_routes; mod models; mod rpc_service; +mod socket_activation; const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); const GIT_REV: &str = git_version::git_version!( @@ -52,6 +54,7 @@ struct Args { async fn run_internal_v0( state: AppState, figment: Figment, + activated_socket: Option, sock_ready_tx: oneshot::Sender<()>, ) -> Result<()> { let rocket = rocket::custom(figment) @@ -64,26 +67,36 @@ async fn run_internal_v0( .ignite() .await .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - let endpoint = DefaultListener::bind_endpoint(&ignite) - .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; - if let Some(path) = endpoint.unix() { - // Allow any user to connect to the socket - fs_err::set_permissions(path, Permissions::from_mode(0o777))?; + + if let Some(std_listener) = activated_socket { + // Use systemd-activated socket + info!("Using systemd-activated socket for tappd.sock"); + let listener = ActivatedUnixListener::new(std_listener)?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; + } else { + // Fall back to binding our own socket + let endpoint = DefaultListener::bind_endpoint(&ignite) + .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; } - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; Ok(()) } async fn run_internal( state: AppState, figment: Figment, + activated_socket: Option, sock_ready_tx: oneshot::Sender<()>, ) -> Result<()> { let rocket = rocket::custom(figment) @@ -93,20 +106,29 @@ async fn run_internal( .ignite() .await .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - let endpoint = DefaultListener::bind_endpoint(&ignite) - .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; - if let Some(path) = endpoint.unix() { - // Allow any user to connect to the socket - fs_err::set_permissions(path, Permissions::from_mode(0o777))?; + + if let Some(std_listener) = activated_socket { + // Use systemd-activated socket + info!("Using systemd-activated socket for dstack.sock"); + let listener = ActivatedUnixListener::new(std_listener)?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; + } else { + // Fall back to binding our own socket + let endpoint = DefaultListener::bind_endpoint(&ignite) + .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; } - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; Ok(()) } @@ -219,11 +241,18 @@ async fn main() -> Result<()> { .extract() .context("Failed to extract bind address")?; let guest_api_figment = figment.select("guest-api"); + + // Get systemd-activated sockets if available + let activated = ActivatedSockets::from_env(); + if activated.any_activated() { + info!("Systemd socket activation detected"); + } + let (tappd_ready_tx, tappd_ready_rx) = oneshot::channel(); let (sock_ready_tx, sock_ready_rx) = oneshot::channel(); tokio::select!( - res = run_internal_v0(state.clone(), internal_v0_figment, tappd_ready_tx) => res?, - res = run_internal(state.clone(), internal_figment, sock_ready_tx) => res?, + res = run_internal_v0(state.clone(), internal_v0_figment, activated.tappd, tappd_ready_tx) => res?, + res = run_internal(state.clone(), internal_figment, activated.dstack, sock_ready_tx) => res?, res = run_external(state.clone(), external_figment) => res?, res = run_guest_api(state.clone(), guest_api_figment) => res?, _ = async { diff --git a/guest-agent/src/socket_activation.rs b/guest-agent/src/socket_activation.rs new file mode 100644 index 00000000..29f1faad --- /dev/null +++ b/guest-agent/src/socket_activation.rs @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: © 2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Systemd socket activation support for dstack-guest-agent. +//! +//! This module provides utilities for receiving pre-created sockets from systemd +//! via the LISTEN_FDS mechanism, allowing the service to use sockets that survive +//! service restarts. + +use std::{io, os::unix::net::UnixListener as StdUnixListener}; + +use listenfd::ListenFd; +use rocket::listener::{unix::UnixStream, Endpoint, Listener}; + +/// Socket indices for systemd socket activation. +/// Order matches ListenStream declarations in dstack-guest-agent.socket. +mod socket_index { + pub const DSTACK: usize = 0; + pub const TAPPD: usize = 1; +} + +/// Systemd-activated sockets passed via LISTEN_FDS. +pub struct ActivatedSockets { + pub dstack: Option, + pub tappd: Option, +} + +impl ActivatedSockets { + /// Retrieve activated sockets from systemd environment variables. + pub fn from_env() -> Self { + let mut listenfd = ListenFd::from_env(); + Self { + dstack: listenfd + .take_unix_listener(socket_index::DSTACK) + .ok() + .flatten(), + tappd: listenfd + .take_unix_listener(socket_index::TAPPD) + .ok() + .flatten(), + } + } + + /// Check if any sockets were activated. + pub fn any_activated(&self) -> bool { + self.dstack.is_some() || self.tappd.is_some() + } +} + +/// Wrapper for systemd-activated Unix socket that implements rocket's Listener trait. +pub struct ActivatedUnixListener { + listener: tokio::net::UnixListener, +} + +impl ActivatedUnixListener { + /// Create a new listener from a standard library UnixListener. + pub fn new(std_listener: StdUnixListener) -> io::Result { + std_listener.set_nonblocking(true)?; + let listener = tokio::net::UnixListener::from_std(std_listener)?; + Ok(Self { listener }) + } +} + +impl Listener for ActivatedUnixListener { + type Accept = UnixStream; + type Connection = UnixStream; + + async fn accept(&self) -> io::Result { + Ok(self.listener.accept().await?.0) + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + Ok(accept) + } + + fn endpoint(&self) -> io::Result { + self.listener.local_addr()?.try_into() + } +}