Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into auth_support
Browse files Browse the repository at this point in the history
  • Loading branch information
ting-ms committed May 29, 2024
2 parents d5e5a06 + a1282cd commit 81728b5
Show file tree
Hide file tree
Showing 18 changed files with 275 additions and 223 deletions.
9 changes: 8 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion rumqttc/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `size()` method on `Packet` calculates size once serialized.
* `read()` and `write()` methods on `Packet`.
* `ConnectionAborted` variant on `StateError` type to denote abrupt end to a connection
* `AUTH` packet support for enhanced authentication.
* `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`.
* `Auth` packet as per MQTT5 standards
* Allow configuring the `nodelay` property of underlying TCP client with the `tcp_nodelay` field in `NetworkOptions`
* `MqttOptions::set_auth_manager` that allows users to set their own authentication manager that implements the `AuthManager` trait.
* `Client::reauth` that enables users to send `AUTH` packet for re-authentication purposes.

Expand All @@ -22,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* rename `N` as `AsyncReadWrite` to describe usage.
* use `Framed` to encode/decode MQTT packets.
* use `Login` to store credentials
* Made `DisconnectProperties` struct public.
* Replace `Vec<Option<u16>>` with `FixedBitSet` for managing packet ids of released QoS 2 publishes and incoming QoS 2 publishes in `MqttState`.

### Deprecated

Expand All @@ -32,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Validate filters while creating subscription requests.
* Make v4::Connect::write return correct value
* Ordering of `State.events` related to `QoS > 0` publishes
* Filter PUBACK in pending save requests to fix unexpected PUBACK sent to reconnected broker.
* Resume session only if broker sends `CONNACK` with `session_present == 1`.

### Security

Expand Down
1 change: 1 addition & 0 deletions rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ url = { version = "2", default-features = false, optional = true }
# proxy
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true }
tokio-stream = "0.1.15"
fixedbitset = "0.5.7"
#auth
scram = { version = "0.6.0", optional = true }

Expand Down
6 changes: 3 additions & 3 deletions rumqttc/examples/async_auth_oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ async fn main() -> Result<(), Box<dyn Error>> {

// Re-authentication test.
let props = AuthProperties {
authentication_method: Some("OAUTH2-JWT".to_string()),
authentication_data: Some(pubsub_access_token.into()),
reason_string: None,
method: Some("OAUTH2-JWT".to_string()),
data: Some(pubsub_access_token.into()),
reason: None,
user_properties: Vec::new(),
};

Expand Down
6 changes: 3 additions & 3 deletions rumqttc/examples/async_auth_scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Reauthenticate using SCRAM-SHA-256
let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap();
let properties = AuthProperties {
authentication_method: Some("SCRAM-SHA-256".to_string()),
authentication_data: client_first,
reason_string: None,
method: Some("SCRAM-SHA-256".to_string()),
data: client_first,
reason: None,
user_properties: Vec::new(),
};
client.reauth(Some(properties)).await.unwrap();
Expand Down
1 change: 1 addition & 0 deletions rumqttc/examples/async_manual_acks_v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fn create_conn() -> (AsyncClient, EventLoop) {
let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1884);
mqttoptions
.set_keep_alive(Duration::from_secs(5))
.set_session_expiry_interval(u32::MAX.into())
.set_manual_acks(true)
.set_clean_start(false);

Expand Down
6 changes: 3 additions & 3 deletions rumqttc/examples/sync_auth_scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ fn main() -> Result<(), Box<dyn Error>> {
// Reauthenticate using SCRAM-SHA-256
let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap();
let properties = AuthProperties {
authentication_method: Some("SCRAM-SHA-256".to_string()),
authentication_data: client_first,
reason_string: None,
method: Some("SCRAM-SHA-256".to_string()),
data: client_first,
reason: None,
user_properties: Vec::new(),
};
client.reauth(Some(properties)).unwrap();
Expand Down
32 changes: 23 additions & 9 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ impl EventLoop {
self.pending.extend(self.state.clean());

// drain requests from channel which weren't yet received
let requests_in_channel = self.requests_rx.drain();
let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();

requests_in_channel.retain(|request| {
match request {
Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack
_ => true,
}
});

self.pending.extend(requests_in_channel);
}

Expand All @@ -149,18 +157,24 @@ impl EventLoop {
Ok(inner) => inner?,
Err(_) => return Err(ConnectionError::NetworkTimeout),
};
// Last session might contain packets which aren't acked. If it's a new session, clear the pending packets.
if !connack.session_present {
self.pending.clear();
}
self.network = Some(network);

if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
}

return Ok(Event::Incoming(connack));
return Ok(Event::Incoming(Packet::ConnAck(connack)));
}

match self.select().await {
Ok(v) => Ok(v),
Err(e) => {
// MQTT requires that packets pending acknowledgement should be republished on session resume.
// Move pending messages from state to eventloop.
self.clean();
Err(e)
}
Expand Down Expand Up @@ -294,14 +308,14 @@ impl EventLoop {
async fn connect(
mqtt_options: &MqttOptions,
network_options: NetworkOptions,
) -> Result<(Network, Incoming), ConnectionError> {
) -> Result<(Network, ConnAck), ConnectionError> {
// connect to the broker
let mut network = network_connect(mqtt_options, network_options).await?;

// make MQTT connection request (which internally awaits for ack)
let packet = mqtt_connect(mqtt_options, &mut network).await?;
let connack = mqtt_connect(mqtt_options, &mut network).await?;

Ok((network, packet))
Ok((network, connack))
}

pub(crate) async fn socket_connect(
Expand All @@ -317,6 +331,8 @@ pub(crate) async fn socket_connect(
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};

socket.set_nodelay(network_options.tcp_nodelay)?;

if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
socket.set_send_buffer_size(send_buff_size).unwrap();
}
Expand Down Expand Up @@ -469,7 +485,7 @@ async fn network_connect(
async fn mqtt_connect(
options: &MqttOptions,
network: &mut Network,
) -> Result<Incoming, ConnectionError> {
) -> Result<ConnAck, ConnectionError> {
let keep_alive = options.keep_alive().as_secs() as u16;
let clean_session = options.clean_session();
let last_will = options.last_will();
Expand All @@ -485,9 +501,7 @@ async fn mqtt_connect(

// validate connack
match network.read().await? {
Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
Ok(Packet::ConnAck(connack))
}
Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => Ok(connack),
Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
packet => Err(ConnectionError::NotConnAck(packet)),
}
Expand Down
6 changes: 6 additions & 0 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ impl From<ClientConfig> for TlsConfiguration {
pub struct NetworkOptions {
tcp_send_buffer_size: Option<u32>,
tcp_recv_buffer_size: Option<u32>,
tcp_nodelay: bool,
conn_timeout: u64,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
bind_device: Option<String>,
Expand All @@ -381,12 +382,17 @@ impl NetworkOptions {
NetworkOptions {
tcp_send_buffer_size: None,
tcp_recv_buffer_size: None,
tcp_nodelay: false,
conn_timeout: 5,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
bind_device: None,
}
}

pub fn set_tcp_nodelay(&mut self, nodelay: bool) {
self.tcp_nodelay = nodelay;
}

pub fn set_tcp_send_buffer_size(&mut self, size: u32) {
self.tcp_send_buffer_size = Some(size);
}
Expand Down
56 changes: 21 additions & 35 deletions rumqttc/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{Event, Incoming, Outgoing, Request};

use crate::mqttbytes::v4::*;
use crate::mqttbytes::{self, *};
use fixedbitset::FixedBitSet;
use std::collections::VecDeque;
use std::{io, time::Instant};

Expand Down Expand Up @@ -62,9 +63,9 @@ pub struct MqttState {
/// Outgoing QoS 1, 2 publishes which aren't acked yet
pub(crate) outgoing_pub: Vec<Option<Publish>>,
/// Packet ids of released QoS 2 publishes
pub(crate) outgoing_rel: Vec<Option<u16>>,
pub(crate) outgoing_rel: FixedBitSet,
/// Packet ids on incoming QoS 2 publishes
pub(crate) incoming_pub: Vec<Option<u16>>,
pub(crate) incoming_pub: FixedBitSet,
/// Last collision due to broker not acking in order
pub collision: Option<Publish>,
/// Buffered incoming packets
Expand All @@ -89,8 +90,8 @@ impl MqttState {
max_inflight,
// index 0 is wasted as 0 is not a valid packet id
outgoing_pub: vec![None; max_inflight as usize + 1],
outgoing_rel: vec![None; max_inflight as usize + 1],
incoming_pub: vec![None; u16::MAX as usize + 1],
outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1),
incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
collision: None,
// TODO: Optimize these sizes later
events: VecDeque::with_capacity(100),
Expand All @@ -113,17 +114,14 @@ impl MqttState {
}

// remove and collect pending releases
for rel in self.outgoing_rel.iter_mut() {
if let Some(pkid) = rel.take() {
let request = Request::PubRel(PubRel::new(pkid));
pending.push(request);
}
for pkid in self.outgoing_rel.ones() {
let request = Request::PubRel(PubRel::new(pkid as u16));
pending.push(request);
}
self.outgoing_rel.clear();

// remove packed ids of incoming qos2 publishes
for id in self.incoming_pub.iter_mut() {
id.take();
}
// remove packet ids of incoming qos2 publishes
self.incoming_pub.clear();

self.await_pingresp = false;
self.collision_ping_count = 0;
Expand Down Expand Up @@ -210,7 +208,7 @@ impl MqttState {
}
QoS::ExactlyOnce => {
let pkid = publish.pkid;
self.incoming_pub[pkid as usize] = Some(pkid);
self.incoming_pub.insert(pkid as usize);

if !self.manual_acks {
let pubrec = PubRec::new(pkid);
Expand Down Expand Up @@ -261,7 +259,7 @@ impl MqttState {
}

// NOTE: Inflight - 1 for qos2 in comp
self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid);
self.outgoing_rel.insert(pubrec.pkid as usize);
let pubrel = PubRel { pkid: pubrec.pkid };
let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
self.events.push_back(event);
Expand All @@ -270,16 +268,12 @@ impl MqttState {
}

fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
let publish = self
.incoming_pub
.get_mut(pubrel.pkid as usize)
.ok_or(StateError::Unsolicited(pubrel.pkid))?;

if publish.take().is_none() {
if !self.incoming_pub.contains(pubrel.pkid as usize) {
error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
return Err(StateError::Unsolicited(pubrel.pkid));
}

self.incoming_pub.set(pubrel.pkid as usize, false);
let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
let pubcomp = PubComp { pkid: pubrel.pkid };
self.events.push_back(event);
Expand All @@ -288,17 +282,12 @@ impl MqttState {
}

fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
if self
.outgoing_rel
.get_mut(pubcomp.pkid as usize)
.ok_or(StateError::Unsolicited(pubcomp.pkid))?
.take()
.is_none()
{
if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
return Err(StateError::Unsolicited(pubcomp.pkid));
}

self.outgoing_rel.set(pubcomp.pkid as usize, false);
self.inflight -= 1;
let packet = self.check_collision(pubcomp.pkid).map(|publish| {
let event = Event::Outgoing(Outgoing::Publish(publish.pkid));
Expand Down Expand Up @@ -486,7 +475,7 @@ impl MqttState {
_ => pubrel,
};

self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid);
self.outgoing_rel.insert(pubrel.pkid as usize);
self.inflight += 1;
Ok(pubrel)
}
Expand Down Expand Up @@ -610,10 +599,8 @@ mod test {
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

let pkid = mqtt.incoming_pub[3].unwrap();

// only qos2 publish should be add to queue
assert_eq!(pkid, 3);
assert!(mqtt.incoming_pub.contains(3));
}

#[test]
Expand Down Expand Up @@ -656,8 +643,7 @@ mod test {
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

let pkid = mqtt.incoming_pub[3].unwrap();
assert_eq!(pkid, 3);
assert!(mqtt.incoming_pub.contains(3));

assert!(mqtt.events.is_empty());
}
Expand Down Expand Up @@ -725,7 +711,7 @@ mod test {
assert_eq!(backup.unwrap().pkid, 1);

// check if the qos2 element's release pkid is 2
assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2);
assert!(mqtt.outgoing_rel.contains(2));
}

#[test]
Expand Down
Loading

0 comments on commit 81728b5

Please sign in to comment.