Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(websocket_server sink): add ACK support for message buffering #22540

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add ACK support to message buffering feature of `websocket_server` sink, allowing this component to cache latest received messages per client.

authors: esensar Quad9DNS
135 changes: 129 additions & 6 deletions src/sinks/websocket_server/buffering.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::{collections::VecDeque, num::NonZeroUsize};
use crate::serde::default_decoding;
use std::{collections::VecDeque, net::SocketAddr, num::NonZeroUsize};

use bytes::Bytes;
use derivative::Derivative;
use tokio_tungstenite::tungstenite::{handshake::server::Request, Message};
use url::Url;
use uuid::Uuid;
use vector_config::configurable_component;
use vector_lib::{
codecs::decoding::{format::Deserializer as _, DeserializerConfig},
event::{Event, MaybeAsLogMut},
lookup::lookup_v2::ConfigValuePath,
};
use vrl::prelude::VrlValueConvert;

/// Configuration for message buffering which enables message replay for clients that connect later.
#[configurable_component]
Expand All @@ -27,6 +31,53 @@ pub struct MessageBufferingConfig {
/// clients can request replay starting from the message ID of their choosing.
#[serde(default, skip_serializing_if = "crate::serde::is_default")]
pub message_id_path: Option<ConfigValuePath>,

#[configurable(derived)]
pub client_ack_support: Option<BufferingAckConfig>,
}

/// Configuration for ACK support for message buffering.
/// Enabling ACK support makes it possible to replay messages for clients without requiring query
/// parameters at connection time. It moves the burden of tracking latest received messages from
/// clients to this component. It requires clients to respond to received messages with an ACK.
#[configurable_component]
#[derive(Clone, Debug, Derivative)]
pub struct BufferingAckConfig {
#[configurable(derived)]
#[derivative(Default(value = "default_decoding()"))]
#[serde(default = "default_decoding")]
pub ack_decoding: DeserializerConfig,

/// Name of the field that contains the ACKed message ID. Use "." if message ID is the root of
/// the message.
pub message_id_path: ConfigValuePath,

#[configurable(derived)]
#[serde(default = "default_client_key_config")]
pub client_key: ClientKeyConfig,
}

/// Configuration for client key used for tracking ACKed message for message buffering.
#[configurable_component]
#[derive(Clone, Debug)]
pub enum ClientKeyConfig {
/// Use client IP address as the unique key for that client
IpAddress {
/// Set to true if port should be included with the ip address.
///
/// By default port is not included
#[serde(default = "crate::serde::default_false")]
with_port: bool,
},
/// Use the value of a header on connection request as the unique key for that client
Header {
/// Name of the header to use as value
name: String,
},
}

const fn default_client_key_config() -> ClientKeyConfig {
ClientKeyConfig::IpAddress { with_port: false }
}

const fn default_max_events() -> NonZeroUsize {
Expand All @@ -50,7 +101,7 @@ impl BufferReplayRequest {
replay_from: None,
};

const fn with_replay_from(replay_from: Uuid) -> Self {
pub const fn with_replay_from(replay_from: Uuid) -> Self {
Self {
should_replay: true,
replay_from: Some(replay_from),
Expand All @@ -74,40 +125,77 @@ impl BufferReplayRequest {
pub trait WsMessageBufferConfig {
/// Returns true if this configuration enables buffering.
fn should_buffer(&self) -> bool;
/// Generates key for a client based on connection request and address.
/// This key should be used for storing client checkpoints (last ACKed message).
fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String>;
/// Returns configured size of the buffer.
fn buffer_capacity(&self) -> usize;
/// Extracts buffer replay request from the given connection request, based on configuration.
fn extract_message_replay_request(&self, request: &Request) -> BufferReplayRequest;
fn extract_message_replay_request(
&self,
request: &Request,
client_checkpoint: Option<Uuid>,
) -> BufferReplayRequest;
/// Adds a message ID that can be used for requesting replay into the event.
/// Created ID is returned to be stored in the buffer.
fn add_replay_message_id_to_event(&self, event: &mut Event) -> Uuid;
/// Handles ACK request and returns message ID, if available.
fn handle_ack_request(&self, request: Message) -> Option<Uuid>;
}

impl WsMessageBufferConfig for Option<MessageBufferingConfig> {
fn should_buffer(&self) -> bool {
self.is_some()
}

fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String> {
self.as_ref()
.and_then(|mb| mb.client_ack_support.as_ref())
.and_then(|ack| match &ack.client_key {
ClientKeyConfig::IpAddress { with_port } => Some(if *with_port {
client_address.to_string()
} else {
client_address.ip().to_string()
}),
ClientKeyConfig::Header { name } => request
.headers()
.get(name)
.and_then(|h| h.to_str().ok())
.map(ToString::to_string),
})
}

fn buffer_capacity(&self) -> usize {
self.as_ref().map_or(0, |mb| mb.max_events.get())
}

fn extract_message_replay_request(&self, request: &Request) -> BufferReplayRequest {
fn extract_message_replay_request(
&self,
request: &Request,
client_checkpoint: Option<Uuid>,
) -> BufferReplayRequest {
// Early return if buffering is disabled
if self.is_none() {
return BufferReplayRequest::NO_REPLAY;
}

let default_request = client_checkpoint
.map(BufferReplayRequest::with_replay_from)
// If we don't have ACK support, or don't have an ACK stored for the client,
// default to no replay
.unwrap_or(BufferReplayRequest::NO_REPLAY);

// Early return if query params are missing
let Some(query_params) = request.uri().query() else {
return BufferReplayRequest::NO_REPLAY;
return default_request;
};

// Early return if there is no query param for replay
if !query_params.contains(LAST_RECEIVED_QUERY_PARAM_NAME) {
return BufferReplayRequest::NO_REPLAY;
return default_request;
}

// Even if we have an ACK stored, query param should override the cached state
let base_url = Url::parse("ws://localhost").ok();
match Url::options()
.base_url(base_url.as_ref())
Expand Down Expand Up @@ -154,4 +242,39 @@ impl WsMessageBufferConfig for Option<MessageBufferingConfig> {
}
message_id
}

fn handle_ack_request(&self, request: Message) -> Option<Uuid> {
let ack_config = self
.as_ref()
.and_then(|mb| mb.client_ack_support.as_ref())?;

let parsed_message = ack_config
.ack_decoding
.build()
.unwrap()
.parse(request.into_data().into(), Default::default())
.inspect_err(|err| {
debug!(message = "Parsing ACK request failed.", %err);
})
.ok()?;

let Some(message_id_field) = parsed_message
.first()?
.maybe_as_log()?
.value()
.get(&ack_config.message_id_path)
else {
debug!("Couldn't find message ID in ACK request.");
return None;
};

message_id_field
.try_bytes_utf8_lossy()
.map_err(|_| "Message ID is not a valid string.")
.and_then(|id| {
Uuid::parse_str(id.trim()).map_err(|_| "Message ID is not a valid UUID.")
})
.inspect_err(|err| debug!(message = "Parsing message ID in ACK request failed.", %err))
.ok()
}
}
58 changes: 48 additions & 10 deletions src/sinks/websocket_server/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use async_trait::async_trait;
use bytes::BytesMut;
use futures::{
channel::mpsc::{unbounded, UnboundedSender},
pin_mut,
future, pin_mut,
stream::BoxStream,
StreamExt,
StreamExt, TryStreamExt,
};
use http::StatusCode;
use tokio::net::TcpStream;
Expand Down Expand Up @@ -48,7 +48,6 @@ use crate::{
use super::{buffering::MessageBufferingConfig, WebSocketListenerSinkConfig};

pub struct WebSocketListenerSink {
peers: Arc<Mutex<HashMap<SocketAddr, UnboundedSender<Message>>>>,
tls: MaybeTlsSettings,
transformer: Transformer,
encoder: Encoder<()>,
Expand All @@ -67,8 +66,8 @@ impl WebSocketListenerSink {
.auth
.map(|config| config.build(&cx.enrichment_tables))
.transpose()?;

Ok(Self {
peers: Arc::new(Mutex::new(HashMap::new())),
tls,
address: config.address,
transformer,
Expand All @@ -93,6 +92,7 @@ impl WebSocketListenerSink {
auth: Option<HttpServerAuthMatcher>,
message_buffering: Option<MessageBufferingConfig>,
peers: Arc<Mutex<HashMap<SocketAddr, UnboundedSender<Message>>>>,
client_checkpoints: Arc<Mutex<HashMap<String, Uuid>>>,
buffer: Arc<Mutex<VecDeque<(Uuid, Message)>>>,
mut listener: MaybeTlsListener,
) {
Expand All @@ -104,6 +104,7 @@ impl WebSocketListenerSink {
auth.clone(),
message_buffering.clone(),
Arc::clone(&peers),
Arc::clone(&client_checkpoints),
Arc::clone(&buffer),
stream,
open_gauge.clone(),
Expand All @@ -117,6 +118,7 @@ impl WebSocketListenerSink {
auth: Option<HttpServerAuthMatcher>,
message_buffering: Option<MessageBufferingConfig>,
peers: Arc<Mutex<HashMap<SocketAddr, UnboundedSender<Message>>>>,
client_checkpoints: Arc<Mutex<HashMap<String, Uuid>>>,
buffer: Arc<Mutex<VecDeque<(Uuid, Message)>>>,
stream: MaybeTlsIncomingStream<TcpStream>,
open_gauge: OpenGauge,
Expand All @@ -126,8 +128,20 @@ impl WebSocketListenerSink {
debug!("Incoming TCP connection from: {}", addr);

let mut buffer_replay = BufferReplayRequest::NO_REPLAY;
let mut client_checkpoint_key = None;

let header_callback = |req: &Request, response: Response| {
buffer_replay = message_buffering.extract_message_replay_request(req);
client_checkpoint_key = message_buffering.client_key(req, &addr);
buffer_replay = message_buffering.extract_message_replay_request(
req,
client_checkpoint_key.clone().and_then(|key| {
client_checkpoints
.lock()
.expect("mutex poisoned")
.get(&key)
.cloned()
}),
);
let Some(auth) = auth else {
return Ok(response);
};
Expand Down Expand Up @@ -173,12 +187,30 @@ impl WebSocketListenerSink {
});
}

let (outgoing, _incoming) = ws_stream.split();
let (outgoing, incoming) = ws_stream.split();

let incoming_data_handler = incoming.try_for_each(|msg| {
let ip = addr.ip();
debug!("Received a message from {}: {}", ip, msg.to_text().unwrap());
if let Some(client_key) = &client_checkpoint_key {
if let Some(checkpoint) = message_buffering.handle_ack_request(msg) {
debug!(
"Inserting checkpoint for {}({}): {}",
client_key, ip, checkpoint
);
client_checkpoints
.lock()
.unwrap()
.insert(client_key.clone(), checkpoint);
}
}

future::ok(())
});
let forward_data_to_client = rx.map(Ok).forward(outgoing);

pin_mut!(forward_data_to_client);
let _ = forward_data_to_client.await;
pin_mut!(forward_data_to_client, incoming_data_handler);
future::select(forward_data_to_client, incoming_data_handler).await;

{
let mut peers = peers.lock().unwrap();
Expand All @@ -205,14 +237,18 @@ impl StreamSink<Event> for WebSocketListenerSink {

let listener = self.tls.bind(&self.address).await.map_err(|_| ())?;

let peers = Arc::new(Mutex::new(HashMap::default()));
let message_buffer = Arc::new(Mutex::new(VecDeque::with_capacity(
self.message_buffering.buffer_capacity(),
)));
let client_checkpoints = Arc::new(Mutex::new(HashMap::default()));

tokio::spawn(
Self::handle_connections(
self.auth,
self.message_buffering.clone(),
Arc::clone(&self.peers),
Arc::clone(&peers),
Arc::clone(&client_checkpoints),
Arc::clone(&message_buffer),
listener,
)
Expand Down Expand Up @@ -251,7 +287,7 @@ impl StreamSink<Event> for WebSocketListenerSink {
buffer.push_back((message_id, message.clone()));
}

let peers = self.peers.lock().unwrap();
let peers = peers.lock().unwrap();
let broadcast_recipients = peers.iter().map(|(_, ws_sink)| ws_sink);
for recp in broadcast_recipients {
if let Err(error) = recp.unbounded_send(message.clone()) {
Expand Down Expand Up @@ -412,6 +448,7 @@ mod tests {
message_buffering: Some(MessageBufferingConfig {
max_events: NonZeroUsize::new(1).unwrap(),
message_id_path: None,
client_ack_support: None,
}),
..Default::default()
},
Expand Down Expand Up @@ -455,6 +492,7 @@ mod tests {
message_buffering: Some(MessageBufferingConfig {
max_events: NonZeroUsize::new(1).unwrap(),
message_id_path: None,
client_ack_support: None,
}),
..Default::default()
},
Expand Down