Skip to content

Commit

Permalink
✨ zb: Add (tokio) support for unixexec transport
Browse files Browse the repository at this point in the history
  • Loading branch information
vially committed Sep 5, 2024
1 parent 23e6a2e commit 9fd4f64
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
13 changes: 13 additions & 0 deletions zbus/src/address/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use tokio::net::TcpStream;
use tokio_vsock::VsockStream;
#[cfg(windows)]
use uds_windows::UnixStream;
#[cfg(all(unix, feature = "tokio"))]
use unixexec::UnixExec;
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
use vsock::VsockStream;

Expand Down Expand Up @@ -52,6 +54,8 @@ use std::os::linux::net::SocketAddrExt;
feature = "tokio-vsock"
))]
pub use vsock_transport::Vsock;
#[cfg(all(unix, feature = "tokio"))]
mod unixexec;

/// The transport properties of a D-Bus address.
#[derive(Clone, Debug, PartialEq, Eq)]
Expand All @@ -77,6 +81,9 @@ pub enum Transport {
/// The type of `stream` is `vsock::VsockStream` with the `vsock` feature and
/// `tokio_vsock::VsockStream` with the `tokio-vsock` feature.
Vsock(Vsock),
/// A `unixexec` address.
#[cfg(all(unix, feature = "tokio"))]
UnixExec(UnixExec),
}

impl Transport {
Expand Down Expand Up @@ -136,6 +143,8 @@ impl Transport {
}
}
}
#[cfg(all(unix, feature = "tokio"))]
Transport::UnixExec(unixexec) => unixexec.connect().await.map(Stream::Unix),
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
Transport::Vsock(addr) => {
let stream = VsockStream::connect_with_cid_port(addr.cid(), addr.port())?;
Expand Down Expand Up @@ -211,6 +220,8 @@ impl Transport {
pub(super) fn from_options(transport: &str, options: HashMap<&str, &str>) -> Result<Self> {
match transport {
"unix" => Unix::from_options(options).map(Self::Unix),
#[cfg(all(unix, feature = "tokio"))]
"unixexec" => UnixExec::from_options(options).map(Self::UnixExec),
"tcp" => Tcp::from_options(options, false).map(Self::Tcp),
"nonce-tcp" => Tcp::from_options(options, true).map(Self::Tcp),
#[cfg(any(
Expand Down Expand Up @@ -334,6 +345,8 @@ impl Display for Transport {
match self {
Self::Tcp(tcp) => write!(f, "{}", tcp)?,
Self::Unix(unix) => write!(f, "{}", unix)?,
#[cfg(all(unix, feature = "tokio"))]
Self::UnixExec(unix) => write!(f, "{}", unix)?,
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
Expand Down
107 changes: 107 additions & 0 deletions zbus/src/address/transport/unixexec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use std::ffi::OsString;
use std::fmt::Display;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::process::CommandExt;
use std::path::PathBuf;
use std::process::Stdio;

use tokio::net::UnixStream;
use tracing::warn;

use crate::Error;

use super::encode_percents;

/// A unixexec domain socket transport in a D-Bus address.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UnixExec {
pub(super) path: PathBuf,
pub(super) arg0: Option<OsString>,
pub(super) args: Vec<String>,
}

impl UnixExec {
/// Create a new unixexec transport with the given path and arguments.
pub fn new(path: PathBuf, arg0: Option<OsString>, args: Vec<String>) -> Self {
Self { path, arg0, args }
}

pub(super) fn from_options(opts: std::collections::HashMap<&str, &str>) -> crate::Result<Self> {
let Some(path) = opts.get("path") else {
return Err(crate::Error::Address(
"unixexec address is missing `path`".to_owned(),
));
};

let arg0 = opts.get("argv0").map(OsString::from);

let mut args: Vec<String> = Vec::new();
let mut arg_index = 1;
while let Some(arg) = opts.get(format!("argv{arg_index}").as_str()) {
args.push(arg.to_string());
arg_index += 1;
}

Ok(Self::new(PathBuf::from(path), arg0, args))
}

pub(super) async fn connect(self) -> crate::Result<UnixStream> {
let mut child = tokio::process::Command::from(self)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;

let stdin = child
.stdin
.take()
.ok_or(Error::Failure("child stdin not found".into()))?;

let stdout = child
.stdout
.take()
.ok_or(Error::Failure("child stdout not found".into()))?;

let exec_stdio_stream = tokio::io::join(stdout, stdin);

let (transport_stream, unix_pipe_stream) = tokio::net::UnixStream::pair()?;

tokio::task::spawn(async move {
let mut unix_pipe_stream = unix_pipe_stream;
let mut exec_stdio_stream = exec_stdio_stream;
if let Err(e) =
tokio::io::copy_bidirectional(&mut unix_pipe_stream, &mut exec_stdio_stream).await
{
warn!("Error occurred while copying bidirectional streams: {}", e);
}
});

Ok(transport_stream)
}
}

impl Display for UnixExec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("unixexec:")?;
encode_percents(f, self.path.as_os_str().as_bytes())
}
}

impl From<UnixExec> for std::process::Command {
fn from(unixexec: UnixExec) -> Self {
let mut command = std::process::Command::new(unixexec.path);
command.args(unixexec.args);

if let Some(arg0) = unixexec.arg0.as_ref() {
command.arg0(arg0);
}

command
}
}

impl From<UnixExec> for tokio::process::Command {
fn from(unixexec: UnixExec) -> Self {
std::process::Command::from(unixexec).into()
}
}

0 comments on commit 9fd4f64

Please sign in to comment.