diff --git a/src/cmdext.rs b/src/cmdext.rs index 2fabe33..e53ac6d 100644 --- a/src/cmdext.rs +++ b/src/cmdext.rs @@ -4,6 +4,7 @@ //! //! - File descriptor passing //! - Changing to a file-descriptor relative directory +//! - Systemd socket activation fd passing use cap_std::fs::Dir; use cap_std::io_lifetimes; @@ -11,17 +12,182 @@ use cap_tempfile::cap_std; use io_lifetimes::OwnedFd; use rustix::fd::{AsFd, FromRawFd, IntoRawFd}; use rustix::io::FdFlags; +use std::collections::BTreeSet; +use std::ffi::CString; use std::os::fd::AsRawFd; use std::os::unix::process::CommandExt; use std::sync::Arc; +/// The file descriptor number at which systemd passes the first socket. +/// See `sd_listen_fds(3)`. +const SD_LISTEN_FDS_START: i32 = 3; + +/// A validated name for a systemd socket-activation file descriptor. +/// +/// Names appear in the `LISTEN_FDNAMES` environment variable as +/// colon-separated values, so they must not contain `:`. +/// The constructor validates this at construction time. +/// +/// ``` +/// use cap_std_ext::cmdext::SystemdFdName; +/// let name = SystemdFdName::new("varlink"); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct SystemdFdName<'a>(&'a str); + +impl<'a> SystemdFdName<'a> { + /// Create a new `SystemdFdName`, panicking if `name` contains `:`. + pub fn new(name: &'a str) -> Self { + assert!(!name.contains(':'), "systemd fd name must not contain ':'"); + Self(name) + } + + /// Return the name as a string slice. + pub fn as_str(&self) -> &'a str { + self.0 + } +} + +/// File descriptor allocator for child processes. +/// +/// Collects fd assignments and optional systemd socket-activation +/// configuration, then applies them all at once via +/// [`CapStdExtCommandExt::take_fds`]. +/// +/// - [`new_systemd_fds`](Self::new_systemd_fds) creates an allocator +/// with systemd socket-activation fds at 3, 4, … (`SD_LISTEN_FDS_START`). +/// - [`take_fd`](Self::take_fd) auto-assigns the next fd above all +/// previously assigned ones (minimum 3). +/// - [`take_fd_n`](Self::take_fd_n) places an fd at an explicit number, +/// panicking on overlap. +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use cap_std_ext::cmdext::{CmdFds, CapStdExtCommandExt, SystemdFdName}; +/// # let varlink_fd: Arc = todo!(); +/// # let extra_fd: Arc = todo!(); +/// let mut cmd = std::process::Command::new("myservice"); +/// let mut fds = CmdFds::new_systemd_fds([(varlink_fd, SystemdFdName::new("varlink"))]); +/// let extra_n = fds.take_fd(extra_fd); +/// cmd.take_fds(fds); +/// ``` +#[derive(Debug)] +pub struct CmdFds { + taken: BTreeSet, + fds: Vec<(i32, Arc)>, + /// Pre-built CStrings for the systemd env vars, set by new_systemd_fds. + systemd_env: Option<(CString, CString)>, +} + +impl Default for CmdFds { + fn default() -> Self { + Self::new() + } +} + +impl CmdFds { + /// Create a new fd allocator. + pub fn new() -> Self { + Self { + taken: BTreeSet::new(), + fds: Vec::new(), + systemd_env: None, + } + } + + /// Create a new fd allocator with systemd socket-activation fds. + /// + /// Each `(fd, name)` pair is assigned a consecutive fd number starting + /// at `SD_LISTEN_FDS_START` (3). The `LISTEN_PID`, `LISTEN_FDS`, and + /// `LISTEN_FDNAMES` environment variables will be set in the child + /// when [`CapStdExtCommandExt::take_fds`] is called. + /// + /// Additional (non-systemd) fds can be registered afterwards via + /// [`take_fd`](Self::take_fd) or [`take_fd_n`](Self::take_fd_n). + /// + /// [sd_listen_fds]: https://www.freedesktop.org/software/systemd/man/latest/sd_listen_fds.html + pub fn new_systemd_fds<'a>( + fds: impl IntoIterator, SystemdFdName<'a>)>, + ) -> Self { + let mut this = Self::new(); + this.register_systemd_fds(fds); + this + } + + /// Compute the next fd number above everything already taken + /// (minimum `SD_LISTEN_FDS_START`). + fn next_fd(&self) -> i32 { + self.taken + .last() + .map(|n| n.checked_add(1).expect("fd number overflow")) + .unwrap_or(SD_LISTEN_FDS_START) + } + + fn insert_fd(&mut self, n: i32) { + let inserted = self.taken.insert(n); + assert!(inserted, "fd {n} is already assigned"); + } + + /// Register a file descriptor at the next available fd number. + /// + /// Returns the fd number that will be assigned in the child. + /// Call [`CapStdExtCommandExt::take_fds`] to apply. + pub fn take_fd(&mut self, fd: Arc) -> i32 { + let n = self.next_fd(); + self.insert_fd(n); + self.fds.push((n, fd)); + n + } + + /// Register a file descriptor at a specific fd number. + /// + /// Call [`CapStdExtCommandExt::take_fds`] to apply. + /// + /// # Panics + /// + /// Panics if `target` has already been assigned. + pub fn take_fd_n(&mut self, fd: Arc, target: i32) -> &mut Self { + self.insert_fd(target); + self.fds.push((target, fd)); + self + } + + fn register_systemd_fds<'a>( + &mut self, + fds: impl IntoIterator, SystemdFdName<'a>)>, + ) { + let mut n_fds: i32 = 0; + let mut names = Vec::new(); + for (fd, name) in fds { + let target = SD_LISTEN_FDS_START + .checked_add(n_fds) + .expect("too many fds"); + self.insert_fd(target); + self.fds.push((target, fd)); + names.push(name.as_str()); + n_fds = n_fds.checked_add(1).expect("too many fds"); + } + + let fd_count = CString::new(n_fds.to_string()).unwrap(); + // SAFETY: SystemdFdName guarantees no NUL bytes. + let fd_names = CString::new(names.join(":")).unwrap(); + self.systemd_env = Some((fd_count, fd_names)); + } +} + /// Extension trait for [`std::process::Command`]. /// /// [`cap_std::fs::Dir`]: https://docs.rs/cap-std/latest/cap_std/fs/struct.Dir.html pub trait CapStdExtCommandExt { - /// Pass a file descriptor into the target process. + /// Pass a file descriptor into the target process at a specific fd number. + #[deprecated = "Use CmdFds with take_fds() instead"] fn take_fd_n(&mut self, fd: Arc, target: i32) -> &mut Self; + /// Apply a [`CmdFds`] to this command, passing all registered file + /// descriptors and (if configured) setting up the systemd + /// socket-activation environment. + fn take_fds(&mut self, fds: CmdFds) -> &mut Self; + /// Use the given directory as the current working directory for the process. fn cwd_dir(&mut self, dir: Dir) -> &mut Self; @@ -39,7 +205,24 @@ pub trait CapStdExtCommandExt { fn lifecycle_bind_to_parent_thread(&mut self) -> &mut Self; } +/// Wrapper around `libc::setenv` that checks the return value. +/// +/// # Safety +/// +/// Must only be called in a single-threaded context (e.g. after `fork()` +/// and before `exec()`). #[allow(unsafe_code)] +unsafe fn check_setenv(key: *const i8, val: *const i8) -> std::io::Result<()> { + // SAFETY: Caller guarantees we are in a single-threaded context + // with valid nul-terminated C strings. + if unsafe { libc::setenv(key, val, 1) } != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) +} + +#[allow(unsafe_code)] +#[allow(deprecated)] impl CapStdExtCommandExt for std::process::Command { fn take_fd_n(&mut self, fd: Arc, target: i32) -> &mut Self { unsafe { @@ -62,6 +245,32 @@ impl CapStdExtCommandExt for std::process::Command { self } + fn take_fds(&mut self, fds: CmdFds) -> &mut Self { + for (target, fd) in fds.fds { + self.take_fd_n(fd, target); + } + if let Some((fd_count, fd_names)) = fds.systemd_env { + // Set LISTEN_PID/FDS/FDNAMES in the forked child via setenv(3). + // We cannot use Command::env() because it causes Rust to build + // an envp array that replaces environ after our pre_exec setenv + // calls. + unsafe { + self.pre_exec(move || { + let pid = rustix::process::getpid(); + let pid_dec = rustix::path::DecInt::new(pid.as_raw_nonzero().get()); + // SAFETY: After fork() and before exec(), the child is + // single-threaded, so setenv (which is not thread-safe) + // is safe to call here. + check_setenv(c"LISTEN_PID".as_ptr(), pid_dec.as_c_str().as_ptr())?; + check_setenv(c"LISTEN_FDS".as_ptr(), fd_count.as_ptr())?; + check_setenv(c"LISTEN_FDNAMES".as_ptr(), fd_names.as_ptr())?; + Ok(()) + }); + } + } + self + } + fn cwd_dir(&mut self, dir: Dir) -> &mut Self { unsafe { self.pre_exec(move || { @@ -92,6 +301,7 @@ mod tests { use super::*; use std::sync::Arc; + #[allow(deprecated)] #[test] fn test_take_fdn() -> anyhow::Result<()> { // Pass srcfd == destfd and srcfd != destfd diff --git a/src/lib.rs b/src/lib.rs index e78618f..1ef22b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,7 @@ pub(crate) fn escape_attempt() -> io::Error { /// Prelude, intended for glob import. pub mod prelude { #[cfg(not(windows))] - pub use super::cmdext::CapStdExtCommandExt; + pub use super::cmdext::{CapStdExtCommandExt, CmdFds, SystemdFdName}; pub use super::dirext::CapStdExtDirExt; #[cfg(feature = "fs_utf8")] pub use super::dirext::CapStdExtDirExtUtf8; diff --git a/tests/it/main.rs b/tests/it/main.rs index 5589650..462e224 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -5,7 +5,7 @@ use cap_std::fs::PermissionsExt; use cap_std::fs::{Dir, File, Permissions}; use cap_std_ext::cap_std; #[cfg(not(windows))] -use cap_std_ext::cmdext::CapStdExtCommandExt; +use cap_std_ext::cmdext::{CapStdExtCommandExt, CmdFds, SystemdFdName}; use cap_std_ext::dirext::{CapStdExtDirExt, WalkConfiguration}; #[cfg(any(target_os = "android", target_os = "linux"))] use cap_std_ext::RootDir; @@ -21,6 +21,7 @@ use std::{process::Command, sync::Arc}; #[test] #[cfg(not(windows))] +#[allow(deprecated)] fn take_fd() -> Result<()> { let mut c = Command::new("/bin/bash"); c.arg("-c"); @@ -860,3 +861,146 @@ fn test_lifecycle_bind_to_parent_thread() -> Result<()> { Ok(()) } + +#[test] +#[cfg(not(windows))] +fn test_pass_systemd_fds() -> Result<()> { + // Verify the child sees the correct LISTEN_* env vars and can read from the fd. + let (r, w) = rustix::pipe::pipe()?; + let r = Arc::new(r); + let mut w: cap_std::fs::File = w.into(); + write!(w, "sd-activate-test")?; + drop(w); + + // The child: verify LISTEN_PID matches $$, print the other vars, read fd 3. + let script = r#" +test "$LISTEN_PID" = "$$" || { echo "LISTEN_PID=$LISTEN_PID but $$=$$" >&2; exit 1; } +printf '%s\n' "$LISTEN_FDS" "$LISTEN_FDNAMES" +cat <&3 +"#; + let mut c = Command::new("/bin/bash"); + c.arg("-c").arg(script); + c.stdout(std::process::Stdio::piped()); + let fds = CmdFds::new_systemd_fds([(r, SystemdFdName::new("myproto"))]); + c.take_fds(fds); + let out = c.output()?; + assert!( + out.status.success(), + "child failed: {}", + String::from_utf8_lossy(&out.stderr) + ); + let stdout = String::from_utf8_lossy(&out.stdout); + let lines: Vec<&str> = stdout.lines().collect(); + assert_eq!(lines.len(), 3, "unexpected output: {stdout}"); + assert_eq!(lines[0], "1"); + assert_eq!(lines[1], "myproto"); + assert_eq!(lines[2], "sd-activate-test"); + + Ok(()) +} + +#[test] +#[cfg(not(windows))] +fn test_systemd_fds_then_take_fd() -> Result<()> { + // Systemd fds at 3 and 4, then auto-assigned take_fd gets 5. + let (r1, w1) = rustix::pipe::pipe()?; + let (r2, w2) = rustix::pipe::pipe()?; + let (r_extra, w_extra) = rustix::pipe::pipe()?; + let r1 = Arc::new(r1); + let r2 = Arc::new(r2); + let r_extra = Arc::new(r_extra); + let mut w1: cap_std::fs::File = w1.into(); + let mut w2: cap_std::fs::File = w2.into(); + let mut w_extra: cap_std::fs::File = w_extra.into(); + write!(w1, "first")?; + write!(w2, "second")?; + write!(w_extra, "extra")?; + drop(w1); + drop(w2); + drop(w_extra); + + let script = r#" +printf '%s\n' "$LISTEN_FDS" "$LISTEN_FDNAMES" +cat <&3 +printf '\n' +cat <&4 +printf '\n' +cat <&5 +"#; + let mut c = Command::new("/bin/bash"); + c.arg("-c").arg(script); + c.stdout(std::process::Stdio::piped()); + let mut fds = CmdFds::new_systemd_fds([ + (r1, SystemdFdName::new("alpha")), + (r2, SystemdFdName::new("beta")), + ]); + let extra_n = fds.take_fd(r_extra); + assert_eq!(extra_n, 5); + c.take_fds(fds); + let out = c.output()?; + assert!(out.status.success(), "child failed: {:?}", out); + let stdout = String::from_utf8_lossy(&out.stdout); + let lines: Vec<&str> = stdout.lines().collect(); + assert_eq!(lines.len(), 5, "unexpected output: {stdout}"); + assert_eq!(lines[0], "2"); + assert_eq!(lines[1], "alpha:beta"); + assert_eq!(lines[2], "first"); + assert_eq!(lines[3], "second"); + assert_eq!(lines[4], "extra"); + + Ok(()) +} + +#[test] +#[cfg(not(windows))] +fn test_cmd_fds_take_fd_n_then_systemd() -> Result<()> { + // Reserve fd 10 explicitly, then systemd fds at 3 — no conflict. + let (r_explicit, w_explicit) = rustix::pipe::pipe()?; + let (r_sd, w_sd) = rustix::pipe::pipe()?; + let r_explicit = Arc::new(r_explicit); + let r_sd = Arc::new(r_sd); + let mut w_explicit: cap_std::fs::File = w_explicit.into(); + let mut w_sd: cap_std::fs::File = w_sd.into(); + write!(w_explicit, "explicit")?; + write!(w_sd, "systemd")?; + drop(w_explicit); + drop(w_sd); + + let script = r#" +cat <&10 +printf '\n' +printf '%s\n' "$LISTEN_FDS" "$LISTEN_FDNAMES" +cat <&3 +"#; + let mut c = Command::new("/bin/bash"); + c.arg("-c").arg(script); + c.stdout(std::process::Stdio::piped()); + let mut fds = CmdFds::new_systemd_fds([(r_sd, SystemdFdName::new("varlink"))]); + fds.take_fd_n(r_explicit, 10); + c.take_fds(fds); + let out = c.output()?; + assert!(out.status.success(), "child failed: {:?}", out); + let stdout = String::from_utf8_lossy(&out.stdout); + let lines: Vec<&str> = stdout.lines().collect(); + assert_eq!(lines.len(), 4, "unexpected output: {stdout}"); + assert_eq!(lines[0], "explicit"); + assert_eq!(lines[1], "1"); + assert_eq!(lines[2], "varlink"); + assert_eq!(lines[3], "systemd"); + + Ok(()) +} + +#[test] +#[cfg(not(windows))] +#[should_panic(expected = "fd 3 is already assigned")] +fn test_cmd_fds_overlap_panics() { + let (r1, _w1) = rustix::pipe::pipe().unwrap(); + let (r2, _w2) = rustix::pipe::pipe().unwrap(); + let r1 = Arc::new(r1); + let r2 = Arc::new(r2); + + let mut fds = CmdFds::new_systemd_fds([(r1, SystemdFdName::new("x"))]); + // This should panic: fd 3 is already taken by systemd. + fds.take_fd_n(r2, 3); +}