diff --git a/zbus/src/blocking/connection/builder.rs b/zbus/src/blocking/connection/builder.rs index 496645772..4efd47cbf 100644 --- a/zbus/src/blocking/connection/builder.rs +++ b/zbus/src/blocking/connection/builder.rs @@ -15,8 +15,9 @@ use zvariant::ObjectPath; #[cfg(feature = "p2p")] use crate::Guid; use crate::{ - address::ToAddresses, blocking::Connection, connection::socket::BoxedSplit, - names::WellKnownName, object_server::Interface, utils::block_on, AuthMechanism, Error, Result, + address::ToAddresses, blocking::Connection, conn::AuthMechanism, + connection::socket::BoxedSplit, names::WellKnownName, object_server::Interface, + utils::block_on, Error, Result, }; /// A builder for [`zbus::blocking::Connection`]. diff --git a/zbus/src/connection/builder.rs b/zbus/src/connection/builder.rs index dd81bf100..a20dc1ddd 100644 --- a/zbus/src/connection/builder.rs +++ b/zbus/src/connection/builder.rs @@ -7,7 +7,7 @@ use std::net::TcpStream; #[cfg(all(unix, not(feature = "tokio")))] use std::os::unix::net::UnixStream; use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, HashSet}, vec, }; #[cfg(feature = "tokio")] @@ -68,7 +68,7 @@ pub struct Builder<'a> { internal_executor: bool, interfaces: Interfaces<'a>, names: HashSet>, - auth_mechanisms: Option>, + auth_mechanism: Option, #[cfg(feature = "bus-impl")] unique_name: Option>, } @@ -190,7 +190,7 @@ impl<'a> Builder<'a> { /// Specify the mechanism to use during authentication. pub fn auth_mechanism(mut self, auth_mechanism: AuthMechanism) -> Self { - self.auth_mechanisms = Some(VecDeque::from(vec![auth_mechanism])); + self.auth_mechanism = Some(auth_mechanism); self } @@ -381,7 +381,7 @@ impl<'a> Builder<'a> { match self.guid { None => { // SASL Handshake - Authenticated::client(stream, server_guid, self.auth_mechanisms, is_bus_conn) + Authenticated::client(stream, server_guid, self.auth_mechanism, is_bus_conn) .await? } Some(guid) => { @@ -402,7 +402,7 @@ impl<'a> Builder<'a> { client_uid, #[cfg(windows)] client_sid, - self.auth_mechanisms, + self.auth_mechanism, unique_name, ) .await? @@ -410,7 +410,7 @@ impl<'a> Builder<'a> { } #[cfg(not(feature = "p2p"))] - Authenticated::client(stream, server_guid, self.auth_mechanisms, is_bus_conn).await? + Authenticated::client(stream, server_guid, self.auth_mechanism, is_bus_conn).await? }; // SAFETY: `Authenticated` is always built with these fields set to `Some`. @@ -468,7 +468,7 @@ impl<'a> Builder<'a> { internal_executor: true, interfaces: HashMap::new(), names: HashSet::new(), - auth_mechanisms: None, + auth_mechanism: None, #[cfg(feature = "bus-impl")] unique_name: None, } diff --git a/zbus/src/connection/handshake/client.rs b/zbus/src/connection/handshake/client.rs index c85010df2..6cf3269a9 100644 --- a/zbus/src/connection/handshake/client.rs +++ b/zbus/src/connection/handshake/client.rs @@ -1,6 +1,5 @@ use async_trait::async_trait; -use std::collections::VecDeque; -use tracing::{debug, instrument, trace, warn}; +use tracing::{instrument, trace, warn}; use crate::{conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName, Message}; @@ -24,19 +23,14 @@ impl Client { /// Start a handshake on this client socket pub fn new( socket: BoxedSplit, - mechanisms: Option>, + mechanism: Option, server_guid: Option, bus: bool, ) -> Client { - let mechanisms = mechanisms.unwrap_or_else(|| { - let mut mechanisms = VecDeque::new(); - mechanisms.push_back(AuthMechanism::External); - mechanisms.push_back(AuthMechanism::Anonymous); - mechanisms - }); + let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism()); Client { - common: Common::new(socket, mechanisms), + common: Common::new(socket, mechanism), server_guid, bus, } @@ -84,32 +78,37 @@ impl Client { /// Perform the authentication handshake with the server. #[instrument(skip(self))] async fn authenticate(&mut self) -> Result<()> { - loop { - let mechanism = self.common.next_mechanism()?; - trace!("Trying {mechanism} mechanism"); - let auth_cmd = match mechanism { - AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())), - AuthMechanism::External => { - Command::Auth(Some(mechanism), Some(sasl_auth_id()?.into_bytes())) - } - }; - self.common.write_command(auth_cmd).await?; + let mechanism = self.common.mechanism(); + trace!("Trying {mechanism} mechanism"); + let auth_cmd = match mechanism { + AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())), + AuthMechanism::External => { + Command::Auth(Some(mechanism), Some(sasl_auth_id()?.into_bytes())) + } + }; + self.common.write_command(auth_cmd).await?; - match self.common.read_command().await? { - Command::Ok(guid) => { - trace!("Received OK from server"); - self.set_guid(guid)?; + match self.common.read_command().await? { + Command::Ok(guid) => { + trace!("Received OK from server"); + self.set_guid(guid)?; - return Ok(()); - } - Command::Rejected(_) => debug!("{mechanism} rejected by the server"), - Command::Error(e) => debug!("Received error from server: {e}"), - cmd => { - return Err(Error::Handshake(format!( - "Unexpected command from server: {cmd}" - ))) - } + Ok(()) + } + Command::Rejected(accepted) => { + let list = accepted + .iter() + .map(|m| m.to_string()) + .collect::>() + .join(", "); + Err(Error::Handshake(format!( + "{mechanism} rejected by the server. Accepted mechanisms: [{list}]" + ))) } + Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))), + cmd => Err(Error::Handshake(format!( + "Unexpected command from server: {cmd}" + ))), } } diff --git a/zbus/src/connection/handshake/command.rs b/zbus/src/connection/handshake/command.rs index be2252461..08b3a9d1e 100644 --- a/zbus/src/connection/handshake/command.rs +++ b/zbus/src/connection/handshake/command.rs @@ -1,6 +1,6 @@ use std::{fmt, str::FromStr}; -use crate::{AuthMechanism, Error, Guid, OwnedGuid, Result}; +use crate::{conn::AuthMechanism, Error, Guid, OwnedGuid, Result}; // The plain-text SASL profile authentication protocol described here: // diff --git a/zbus/src/connection/handshake/common.rs b/zbus/src/connection/handshake/common.rs index 9c5a4d327..d87f15037 100644 --- a/zbus/src/connection/handshake/common.rs +++ b/zbus/src/connection/handshake/common.rs @@ -1,4 +1,3 @@ -use std::collections::VecDeque; use tracing::{instrument, trace}; use super::{AuthMechanism, BoxedSplit, Command}; @@ -12,21 +11,20 @@ pub(super) struct Common { #[cfg(unix)] received_fds: Vec, cap_unix_fd: bool, - // the current AUTH mechanism is front, ordered by priority - mechanisms: VecDeque, + mechanism: AuthMechanism, first_command: bool, } impl Common { /// Start a handshake on this client socket - pub fn new(socket: BoxedSplit, mechanisms: VecDeque) -> Self { + pub fn new(socket: BoxedSplit, mechanism: AuthMechanism) -> Self { Self { socket, recv_buffer: Vec::new(), #[cfg(unix)] received_fds: Vec::new(), cap_unix_fd: false, - mechanisms, + mechanism, first_command: true, } } @@ -44,9 +42,8 @@ impl Common { self.cap_unix_fd = cap_unix_fd; } - #[cfg(feature = "p2p")] - pub fn mechanisms(&self) -> &VecDeque { - &self.mechanisms + pub fn mechanism(&self) -> AuthMechanism { + self.mechanism } pub fn into_components(self) -> IntoComponentsReturn { @@ -56,7 +53,7 @@ impl Common { #[cfg(unix)] self.received_fds, self.cap_unix_fd, - self.mechanisms, + self.mechanism, ) } @@ -175,12 +172,6 @@ impl Common { Ok(commands) } - - pub fn next_mechanism(&mut self) -> Result { - self.mechanisms - .pop_front() - .ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms".into())) - } } #[cfg(unix)] @@ -189,7 +180,7 @@ type IntoComponentsReturn = ( Vec, Vec, bool, - VecDeque, + AuthMechanism, ); #[cfg(not(unix))] -type IntoComponentsReturn = (BoxedSplit, Vec, bool, VecDeque); +type IntoComponentsReturn = (BoxedSplit, Vec, bool, AuthMechanism); diff --git a/zbus/src/connection/handshake/mod.rs b/zbus/src/connection/handshake/mod.rs index 56a665cbd..b54d4d021 100644 --- a/zbus/src/connection/handshake/mod.rs +++ b/zbus/src/connection/handshake/mod.rs @@ -8,7 +8,7 @@ mod server; use async_trait::async_trait; #[cfg(unix)] use nix::unistd::Uid; -use std::{collections::VecDeque, fmt::Debug}; +use std::fmt::Debug; use zbus_names::OwnedUniqueName; #[cfg(windows)] @@ -51,10 +51,10 @@ impl Authenticated { pub async fn client( socket: BoxedSplit, server_guid: Option, - mechanisms: Option>, + mechanism: Option, bus: bool, ) -> Result { - Client::new(socket, mechanisms, server_guid, bus) + Client::new(socket, mechanism, server_guid, bus) .perform() .await } @@ -68,7 +68,7 @@ impl Authenticated { guid: OwnedGuid, #[cfg(unix)] client_uid: Option, #[cfg(windows)] client_sid: Option, - auth_mechanisms: Option>, + auth_mechanism: Option, unique_name: Option, ) -> Result { Server::new( @@ -78,7 +78,7 @@ impl Authenticated { client_uid, #[cfg(windows)] client_sid, - auth_mechanisms, + auth_mechanism, unique_name, )? .perform() @@ -250,7 +250,7 @@ mod tests { p1.into(), Guid::generate().into(), Some(Uid::effective().into()), - Some(vec![AuthMechanism::Anonymous].into()), + Some(AuthMechanism::Anonymous), None, ) .unwrap(); @@ -267,7 +267,7 @@ mod tests { p1.into(), Guid::generate().into(), Some(Uid::effective().into()), - Some(vec![AuthMechanism::Anonymous].into()), + Some(AuthMechanism::Anonymous), None, ) .unwrap(); diff --git a/zbus/src/connection/handshake/server.rs b/zbus/src/connection/handshake/server.rs index 518a556f7..3480311af 100644 --- a/zbus/src/connection/handshake/server.rs +++ b/zbus/src/connection/handshake/server.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use std::collections::VecDeque; use tracing::{instrument, trace}; use crate::names::OwnedUniqueName; @@ -41,21 +40,13 @@ impl Server { guid: OwnedGuid, #[cfg(unix)] client_uid: Option, #[cfg(windows)] client_sid: Option, - mechanisms: Option>, + mechanism: Option, unique_name: Option, ) -> Result { - let mechanisms = match mechanisms { - Some(mechanisms) => mechanisms, - None => { - let mut mechanisms = VecDeque::new(); - mechanisms.push_back(AuthMechanism::External); - - mechanisms - } - }; + let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism()); Ok(Server { - common: Common::new(socket, mechanisms), + common: Common::new(socket, mechanism), step: ServerHandshakeStep::WaitingForAuth, #[cfg(unix)] client_uid, @@ -111,8 +102,7 @@ impl Server { #[instrument(skip(self))] async fn rejected_error(&mut self) -> Result<()> { - let mechanisms = self.common.mechanisms().iter().cloned().collect(); - let cmd = Command::Rejected(mechanisms); + let cmd = Command::Rejected(vec![self.common.mechanism()]); trace!("Sending authentication error"); self.common.write_command(cmd).await?; self.step = ServerHandshakeStep::WaitingForAuth; @@ -141,22 +131,24 @@ impl Server { trace!("Waiting for authentication"); let reply = self.common.read_command().await?; match reply { - Command::Auth(mech, resp) => { - let mech = mech.filter(|m| self.common.mechanisms().contains(m)); + Command::Auth(requested_mech, resp) => { + let mech = self.common.mechanism(); + if requested_mech != Some(mech) { + self.rejected_error().await?; - match (mech, &resp) { - (Some(mech), None) => { + return Ok(()); + } + + match &resp { + None => { trace!("Sending data request"); self.common.write_command(Command::Data(None)).await?; self.step = ServerHandshakeStep::WaitingForData(mech); } - (Some(AuthMechanism::Anonymous), Some(_)) => { - self.auth_ok().await?; - } - (Some(AuthMechanism::External), Some(sasl_id)) => { - self.check_external_auth(sasl_id).await?; - } - _ => self.rejected_error().await?, + Some(sasl_id) => match mech { + AuthMechanism::Anonymous => self.auth_ok().await?, + AuthMechanism::External => self.check_external_auth(sasl_id).await?, + }, } } Command::Cancel | Command::Error(_) => { diff --git a/zbus/src/connection/mod.rs b/zbus/src/connection/mod.rs index a76e9bc88..c1f34a5c7 100644 --- a/zbus/src/connection/mod.rs +++ b/zbus/src/connection/mod.rs @@ -41,6 +41,7 @@ mod socket_reader; use socket_reader::SocketReader; pub(crate) mod handshake; +pub use handshake::AuthMechanism; use handshake::Authenticated; mod connect; @@ -1528,7 +1529,7 @@ mod p2p_tests { use test_log::test; use zvariant::{Endian, NATIVE_ENDIAN}; - use crate::{AuthMechanism, Guid}; + use crate::{conn::AuthMechanism, Guid}; use super::*; diff --git a/zbus/src/connection/socket/channel.rs b/zbus/src/connection/socket/channel.rs index acb86f483..b0b46ef7a 100644 --- a/zbus/src/connection/socket/channel.rs +++ b/zbus/src/connection/socket/channel.rs @@ -2,7 +2,7 @@ use std::io; use async_broadcast::{broadcast, Receiver, Sender}; -use crate::{fdo::ConnectionCredentials, Message}; +use crate::{conn::AuthMechanism, fdo::ConnectionCredentials, Message}; /// An in-process channel-based socket. /// @@ -72,6 +72,10 @@ impl super::ReadHalf for Reader { async fn peer_credentials(&mut self) -> io::Result { self_credentials().await } + + fn auth_mechanism(&self) -> AuthMechanism { + AuthMechanism::Anonymous + } } /// The writer half of a [`Channel`]. diff --git a/zbus/src/connection/socket/mod.rs b/zbus/src/connection/socket/mod.rs index 9d9062c0a..348058bfc 100644 --- a/zbus/src/connection/socket/mod.rs +++ b/zbus/src/connection/socket/mod.rs @@ -18,6 +18,7 @@ use std::{io, mem}; use tracing::trace; use crate::{ + conn::AuthMechanism, fdo::ConnectionCredentials, message::{ header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE}, @@ -237,6 +238,13 @@ pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static { async fn peer_credentials(&mut self) -> io::Result { Ok(ConnectionCredentials::default()) } + + /// The authentication mechanism to use for this socket on the target OS. + /// + /// Default is `AuthMechanism::External`. + fn auth_mechanism(&self) -> AuthMechanism { + AuthMechanism::External + } } /// The write half of a socket. @@ -354,6 +362,10 @@ impl ReadHalf for Box { async fn peer_credentials(&mut self) -> io::Result { (**self).peer_credentials().await } + + fn auth_mechanism(&self) -> AuthMechanism { + (**self).auth_mechanism() + } } #[async_trait::async_trait] diff --git a/zbus/src/connection/socket/tcp.rs b/zbus/src/connection/socket/tcp.rs index d66db8108..80c62a0da 100644 --- a/zbus/src/connection/socket/tcp.rs +++ b/zbus/src/connection/socket/tcp.rs @@ -48,6 +48,11 @@ impl ReadHalf for Arc> { ) .await } + + #[cfg(not(windows))] + fn auth_mechanism(&self) -> crate::conn::AuthMechanism { + crate::conn::AuthMechanism::Anonymous + } } #[cfg(not(feature = "tokio"))] @@ -120,6 +125,11 @@ impl ReadHalf for tokio::net::tcp::OwnedReadHalf { ) .await } + + #[cfg(not(windows))] + fn auth_mechanism(&self) -> crate::conn::AuthMechanism { + crate::conn::AuthMechanism::Anonymous + } } #[cfg(feature = "tokio")] diff --git a/zbus/src/connection/socket/vsock.rs b/zbus/src/connection/socket/vsock.rs index 4889167c2..803f85703 100644 --- a/zbus/src/connection/socket/vsock.rs +++ b/zbus/src/connection/socket/vsock.rs @@ -26,6 +26,10 @@ impl super::ReadHalf for std::sync::Arc> { } } } + + fn auth_mechanism(&self) -> crate::AuthMechanism { + crate::AuthMechanism::Anonymous + } } #[cfg(all(feature = "vsock", not(feature = "tokio")))] @@ -86,6 +90,10 @@ impl super::ReadHalf for tokio_vsock::ReadHalf { ret }) } + + fn auth_mechanism(&self) -> crate::conn::AuthMechanism { + crate::conn::AuthMechanism::Anonymous + } } #[cfg(feature = "tokio-vsock")] diff --git a/zbus/src/lib.rs b/zbus/src/lib.rs index b4abf950f..96517f855 100644 --- a/zbus/src/lib.rs +++ b/zbus/src/lib.rs @@ -58,7 +58,12 @@ pub use message::Message; pub mod connection; /// Alias for `connection` module, for convenience. pub use connection as conn; -pub use connection::{handshake::AuthMechanism, Connection}; +#[deprecated( + since = "5.0.0", + note = "Please use `connection::AuthMechanism` instead" +)] +pub use connection::handshake::AuthMechanism; +pub use connection::Connection; mod message_stream; pub use message_stream::*;