diff --git a/.cirrus.yml b/.cirrus.yml index 9c7b696..ca8fd19 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -49,7 +49,29 @@ task: - cargo check --features mio_08 - cargo check --features tokio test_script: - - cargo test --all-features --no-fail-fast -- --test-threads=1 --nocapture + - cargo test --features mio_08,tokio --no-fail-fast -- --test-threads=1 --nocapture + before_cache_script: + - rm -rf $HOME/.cargo/registry/index + +task: + name: Linux amd64 1.75 + container: + image: rust:1.75.0 + cpu: 1 + memory: 1G # OOMs with 512MB + allow_failures: false + env: + RUST_BACKTRACE: 1 + cargo_cache: + folder: $HOME/.cargo/registry + fingerprint_script: cat Cargo.lock 2> /dev/null || true + target_cache: + folder: target + fingerprint_script: cat Cargo.lock 2> /dev/null || true + check_script: + - cargo check --features tokio,async_trait + test_script: + - cargo test --features tokio,async_trait --no-fail-fast -- --test-threads=1 --nocapture before_cache_script: - rm -rf $HOME/.cargo/registry/index diff --git a/Cargo.toml b/Cargo.toml index f4458a3..809fd49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,11 @@ categories = ["os::unix-apis", "asynchronous"] edition = "2021" exclude = ["tests", "src/bin", ".vscode"] +[features] +mio_08 = [ "dep:mio_08" ] +tokio = [ "dep:tokio" ] +async_trait = [] + [target."cfg(unix)".dependencies] libc = "0.2.90" # peer credentials for DragonFly BSD and NetBSD, SO_PEERSEC on all Linux architectures # enabling this feature implements the extension traits for mio 0.8's unix socket types @@ -26,4 +31,4 @@ tokio = { version = "1.28", features = ["net"], optional=true } tokio = { version = "1.28", features = ["io-util", "macros", "rt", 'rt-multi-thread'] } [package.metadata.docs.rs] -features = ["mio_08", "tokio"] +features = ["mio_08", "tokio", "async_trait"] diff --git a/README.md b/README.md index adf846d..7801a27 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,16 @@ To enable it, add this to Cargo.toml: uds = {version="0.4.2", features=["tokio"]} ``` +### async + +The combination of `tokio` and `async_trait` add async `send` and `recv` methods to the extension trait, in addition to everything already added by the `tokio` feature. +This increases the minimum Rust version to 1.75. + +```toml +[dependencies] +uds = {version="0.4.2", features=["tokio", "async_trait"]} +``` + ## Minimum Rust version The minimum Rust version is 1.63. diff --git a/src/tokio/traits.rs b/src/tokio/traits.rs index 3aa3d86..da7021a 100644 --- a/src/tokio/traits.rs +++ b/src/tokio/traits.rs @@ -11,6 +11,11 @@ use crate::addr::UnixSocketAddr; use crate::helpers::*; use crate::credentials::*; +#[cfg(feature = "async_trait")] +use crate::ancillary::*; +#[cfg(feature = "async_trait")] +use std::os::fd::RawFd; + mod private { use super::*; pub trait Sealed {} @@ -20,9 +25,6 @@ mod private { } /// Extension trait for `tokio::net::UnixStream`. -/// -/// Doesn't have `send_fds()` or `recv_fds()`, -/// because they would be `async` which isn't supported in traits yet. pub trait UnixStreamExt: AsRawFd + private::Sealed { /// Get the address of this socket, as a type that fully supports abstract addresses. fn local_unix_addr(&self) -> Result { @@ -56,6 +58,22 @@ pub trait UnixStreamExt: AsRawFd + private::Sealed { fn initial_peer_selinux_context(&self, buffer: &mut[u8]) -> Result { selinux_context(self.as_raw_fd(), buffer) } + + /// Sends file descriptors in addition to bytes. + #[cfg(feature = "async_trait")] + fn send_fds( + &self, + bytes: &[u8], + fds: &[RawFd], + ) -> (impl Send + std::future::Future>); + + /// Receives file descriptors in addition to bytes. + #[cfg(feature = "async_trait")] + fn recv_fds( + &self, + buf: &mut [u8], + fd_buf: &mut [RawFd], + ) -> (impl Send + std::future::Future>); } impl UnixStreamExt for UnixStream { @@ -71,6 +89,29 @@ impl UnixStreamExt for UnixStream { set_unix_addr(socket.as_raw_fd(), SetAddr::PEER, to)?; UnixStream::from_std(unsafe { stdUnixStream::from_raw_fd(socket.into_raw_fd()) }) } + + /// Sends file descriptors in addition to bytes. + #[cfg(feature = "async_trait")] + async fn send_fds(&self, bytes: &[u8], fds: &[RawFd]) -> Result { + self.async_io(tokio_crate::io::Interest::WRITABLE, || { + send_ancillary(self.as_raw_fd(), None, 0, &[io::IoSlice::new(bytes)], fds, None) + }) + .await + } + + /// Receives file descriptors in addition to bytes. + #[cfg(feature = "async_trait")] + async fn recv_fds( + &self, + buf: &mut [u8], + fd_buf: &mut [RawFd], + ) -> Result<(usize, usize), io::Error> { + self.async_io(tokio_crate::io::Interest::READABLE, || { + recv_fds(self.as_raw_fd(), None, &mut [io::IoSliceMut::new(buf)], fd_buf) + .map(|(bytes, _, fds)| (bytes, fds)) + }) + .await + } } /// Extension trait for using [`UnixSocketAddr`](struct.UnixSocketAddr.html) with `tokio::net::UnixListener`. diff --git a/tests/tokio_traits.rs b/tests/tokio_traits.rs index 5341aef..7a9fc58 100644 --- a/tests/tokio_traits.rs +++ b/tests/tokio_traits.rs @@ -103,3 +103,59 @@ async fn initial_pair_credentials() { let me = a.initial_pair_credentials().expect("get peer credentials"); assert_eq!(me, b.initial_pair_credentials().unwrap()); } + +#[cfg(feature="async_trait")] +mod async_trait { + use super::*; + + use std::os::unix::io::IntoRawFd; + + use libc::close; + + #[cfg_attr(not(any(target_os = "illumos", target_os = "solaris")), tokio::test)] + async fn many_fds() { + let (a, b) = std::os::unix::net::UnixStream::pair().expect("create stream socket pair"); + + a.set_nonblocking(true).expect("set a to nonblocking"); + let a = UnixStream::from_std(a).expect("convert to tokio unix stream"); + + // only odd numbers cause difference between CMSG_SPACE() and CMSG_LEN(), and only on 64bit. + let mut fds = [-1; 99]; + for (i, fd) in fds.iter_mut().enumerate() { + match b.try_clone() { + Ok(clone) => *fd = clone.into_raw_fd(), + Err(e) => panic!("failed to clone the {}nt fd: {}", i + 1, e), + } + } + a.send_fds(&b"99"[..], &fds).await.expect("send 99 fds"); + + let mut recv_buf = [0; 10]; + let mut recv_fds = [-1; 200]; + + b.set_nonblocking(true).expect("set b to nonblocking"); + let b = UnixStream::from_std(b).expect("convert to tokio unix stream"); + b.recv_fds(&mut recv_buf, &mut recv_fds) + .await + .expect("receive 99 of 200 fds"); + for (i, &received) in recv_fds[..fds.len()].iter().enumerate() { + assert_ne!(received, -1, "rerceived fd {} is not -1", i + 1); + assert_eq!( + unsafe { close(received) }, + 0, + "close(received fd {}) failed: {}", + i + 1, + std::io::Error::last_os_error(), + ); + } + + for (i, &sent) in fds.iter().enumerate() { + assert_eq!( + unsafe { close(sent) }, + 0, + "close(sent fd {}) failed: {}", + i + 1, + std::io::Error::last_os_error(), + ); + } + } +}