Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion src/cmdext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,190 @@
//!
//! - File descriptor passing
//! - Changing to a file-descriptor relative directory
//! - Systemd socket activation fd passing

use cap_std::fs::Dir;
use cap_std::io_lifetimes;
use cap_tempfile::cap_std;
use io_lifetimes::OwnedFd;
use rustix::fd::{AsFd, FromRawFd, IntoRawFd};
use rustix::io::FdFlags;
use std::collections::BTreeSet;
use std::ffi::CString;
use std::os::fd::AsRawFd;
use std::os::unix::process::CommandExt;
use std::sync::Arc;

/// The file descriptor number at which systemd passes the first socket.
/// See `sd_listen_fds(3)`.
const SD_LISTEN_FDS_START: i32 = 3;

/// A validated name for a systemd socket-activation file descriptor.
///
/// Names appear in the `LISTEN_FDNAMES` environment variable as
/// colon-separated values, so they must not contain `:`.
/// The constructor validates this at construction time.
///
/// ```
/// use cap_std_ext::cmdext::SystemdFdName;
/// let name = SystemdFdName::new("varlink");
/// ```
#[derive(Debug, Clone, Copy)]
pub struct SystemdFdName<'a>(&'a str);

impl<'a> SystemdFdName<'a> {
/// Create a new `SystemdFdName`, panicking if `name` contains `:`.
pub fn new(name: &'a str) -> Self {
assert!(!name.contains(':'), "systemd fd name must not contain ':'");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/systemd/systemd/blob/1bbf72d7700a0b93264c9ff105e2aee8b7701541/src/basic/fd-util.c#L545 does a bunch more validation. We don't have to adopt all of those, but the max length check and space check seem trivial to add.

Self(name)
}

/// Return the name as a string slice.
pub fn as_str(&self) -> &'a str {
self.0
}
}

/// File descriptor allocator for child processes.
///
/// Collects fd assignments and optional systemd socket-activation
/// configuration, then applies them all at once via
/// [`CapStdExtCommandExt::take_fds`].
///
/// - [`new_systemd_fds`](Self::new_systemd_fds) creates an allocator
/// with systemd socket-activation fds at 3, 4, … (`SD_LISTEN_FDS_START`).
/// - [`take_fd`](Self::take_fd) auto-assigns the next fd above all
/// previously assigned ones (minimum 3).
/// - [`take_fd_n`](Self::take_fd_n) places an fd at an explicit number,
/// panicking on overlap.
///
/// ```no_run
/// # use std::sync::Arc;
/// # use cap_std_ext::cmdext::{CmdFds, CapStdExtCommandExt, SystemdFdName};
/// # let varlink_fd: Arc<rustix::fd::OwnedFd> = todo!();
/// # let extra_fd: Arc<rustix::fd::OwnedFd> = todo!();
/// let mut cmd = std::process::Command::new("myservice");
/// let mut fds = CmdFds::new_systemd_fds([(varlink_fd, SystemdFdName::new("varlink"))]);
/// let extra_n = fds.take_fd(extra_fd);
/// cmd.take_fds(fds);
/// ```
#[derive(Debug)]
pub struct CmdFds {
taken: BTreeSet<i32>,
fds: Vec<(i32, Arc<OwnedFd>)>,
/// Pre-built CStrings for the systemd env vars, set by new_systemd_fds.
systemd_env: Option<(CString, CString)>,
}

impl Default for CmdFds {
fn default() -> Self {
Self::new()
}
}

impl CmdFds {
/// Create a new fd allocator.
pub fn new() -> Self {
Self {
taken: BTreeSet::new(),
fds: Vec::new(),
systemd_env: None,
}
}

/// Create a new fd allocator with systemd socket-activation fds.
///
/// Each `(fd, name)` pair is assigned a consecutive fd number starting
/// at `SD_LISTEN_FDS_START` (3). The `LISTEN_PID`, `LISTEN_FDS`, and
/// `LISTEN_FDNAMES` environment variables will be set in the child
/// when [`CapStdExtCommandExt::take_fds`] is called.
///
/// Additional (non-systemd) fds can be registered afterwards via
/// [`take_fd`](Self::take_fd) or [`take_fd_n`](Self::take_fd_n).
///
/// [sd_listen_fds]: https://www.freedesktop.org/software/systemd/man/latest/sd_listen_fds.html
pub fn new_systemd_fds<'a>(
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) -> Self {
let mut this = Self::new();
this.register_systemd_fds(fds);
this
}

/// Compute the next fd number above everything already taken
/// (minimum `SD_LISTEN_FDS_START`).
fn next_fd(&self) -> i32 {
self.taken
.last()
.map(|n| n.checked_add(1).expect("fd number overflow"))
.unwrap_or(SD_LISTEN_FDS_START)
}

fn insert_fd(&mut self, n: i32) {
let inserted = self.taken.insert(n);
assert!(inserted, "fd {n} is already assigned");
}

/// Register a file descriptor at the next available fd number.
///
/// Returns the fd number that will be assigned in the child.
/// Call [`CapStdExtCommandExt::take_fds`] to apply.
pub fn take_fd(&mut self, fd: Arc<OwnedFd>) -> i32 {
let n = self.next_fd();
self.insert_fd(n);
self.fds.push((n, fd));
n
}

/// Register a file descriptor at a specific fd number.
///
/// Call [`CapStdExtCommandExt::take_fds`] to apply.
///
/// # Panics
///
/// Panics if `target` has already been assigned.
pub fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
self.insert_fd(target);
self.fds.push((target, fd));
self
}

fn register_systemd_fds<'a>(
&mut self,
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) {
let mut n_fds: i32 = 0;
let mut names = Vec::new();
for (fd, name) in fds {
let target = SD_LISTEN_FDS_START
.checked_add(n_fds)
.expect("too many fds");
self.insert_fd(target);
self.fds.push((target, fd));
names.push(name.as_str());
n_fds = n_fds.checked_add(1).expect("too many fds");
}

let fd_count = CString::new(n_fds.to_string()).unwrap();
// SAFETY: SystemdFdName guarantees no NUL bytes.
let fd_names = CString::new(names.join(":")).unwrap();
self.systemd_env = Some((fd_count, fd_names));
}
}

/// Extension trait for [`std::process::Command`].
///
/// [`cap_std::fs::Dir`]: https://docs.rs/cap-std/latest/cap_std/fs/struct.Dir.html
pub trait CapStdExtCommandExt {
/// Pass a file descriptor into the target process.
/// Pass a file descriptor into the target process at a specific fd number.
#[deprecated = "Use CmdFds with take_fds() instead"]
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self;

/// Apply a [`CmdFds`] to this command, passing all registered file
/// descriptors and (if configured) setting up the systemd
/// socket-activation environment.
fn take_fds(&mut self, fds: CmdFds) -> &mut Self;

/// Use the given directory as the current working directory for the process.
fn cwd_dir(&mut self, dir: Dir) -> &mut Self;

Expand All @@ -39,7 +205,24 @@ pub trait CapStdExtCommandExt {
fn lifecycle_bind_to_parent_thread(&mut self) -> &mut Self;
}

/// Wrapper around `libc::setenv` that checks the return value.
///
/// # Safety
///
/// Must only be called in a single-threaded context (e.g. after `fork()`
/// and before `exec()`).
#[allow(unsafe_code)]
unsafe fn check_setenv(key: *const i8, val: *const i8) -> std::io::Result<()> {
// SAFETY: Caller guarantees we are in a single-threaded context
// with valid nul-terminated C strings.
if unsafe { libc::setenv(key, val, 1) } != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}

#[allow(unsafe_code)]
#[allow(deprecated)]
impl CapStdExtCommandExt for std::process::Command {
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
unsafe {
Expand All @@ -62,6 +245,32 @@ impl CapStdExtCommandExt for std::process::Command {
self
}

fn take_fds(&mut self, fds: CmdFds) -> &mut Self {
for (target, fd) in fds.fds {
self.take_fd_n(fd, target);
}
if let Some((fd_count, fd_names)) = fds.systemd_env {
// Set LISTEN_PID/FDS/FDNAMES in the forked child via setenv(3).
// We cannot use Command::env() because it causes Rust to build
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like something to document on the API itself since anyone could be using Command::env() and lose the envs passed here IIUC.

// an envp array that replaces environ after our pre_exec setenv
// calls.
unsafe {
self.pre_exec(move || {
let pid = rustix::process::getpid();
let pid_dec = rustix::path::DecInt::new(pid.as_raw_nonzero().get());
// SAFETY: After fork() and before exec(), the child is
// single-threaded, so setenv (which is not thread-safe)
// is safe to call here.
check_setenv(c"LISTEN_PID".as_ptr(), pid_dec.as_c_str().as_ptr())?;
check_setenv(c"LISTEN_FDS".as_ptr(), fd_count.as_ptr())?;
check_setenv(c"LISTEN_FDNAMES".as_ptr(), fd_names.as_ptr())?;
Ok(())
});
}
}
self
}

fn cwd_dir(&mut self, dir: Dir) -> &mut Self {
unsafe {
self.pre_exec(move || {
Expand Down Expand Up @@ -92,6 +301,7 @@ mod tests {
use super::*;
use std::sync::Arc;

#[allow(deprecated)]
#[test]
fn test_take_fdn() -> anyhow::Result<()> {
// Pass srcfd == destfd and srcfd != destfd
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) fn escape_attempt() -> io::Error {
/// Prelude, intended for glob import.
pub mod prelude {
#[cfg(not(windows))]
pub use super::cmdext::CapStdExtCommandExt;
pub use super::cmdext::{CapStdExtCommandExt, CmdFds, SystemdFdName};
pub use super::dirext::CapStdExtDirExt;
#[cfg(feature = "fs_utf8")]
pub use super::dirext::CapStdExtDirExtUtf8;
Expand Down
Loading