Skip to content

Commit

Permalink
fix(host_metrics source): add defensive check to prevent panics (#22604)
Browse files Browse the repository at this point in the history
* fix(host_metrics source): avoid panic when reading from buffer

* tweaks

* add docs because these functions are complex

* changelog
  • Loading branch information
pront authored Mar 7, 2025
1 parent 559d9f2 commit d17c099
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 32 deletions.
3 changes: 3 additions & 0 deletions changelog.d/host_metrics_tcp_panic.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix potential panic in the `host_metrics` source when collecting TCP metrics.

authors: pront
121 changes: 89 additions & 32 deletions src/sources/host_metrics/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ enum TcpError {
InvalidTcpState { state: u8 },
#[snafu(display("Received an error message from netlink; code: {code}"))]
NetlinkMsgError { code: i32 },
#[snafu(display("Invalid message length: {length}"))]
InvalidLength { length: usize },
}

#[repr(u8)]
Expand Down Expand Up @@ -134,7 +136,84 @@ struct TcpStats {
tx_queued_bytes: f64,
}

async fn fetch_nl_inet_hdrs(addr_family: u8) -> Result<Vec<InetResponseHeader>, TcpError> {
/// Parses Netlink messages from a buffer, extracting [`InetResponseHeader`]s.
///
/// # Arguments
/// * `buffer` - Raw byte slice containing Netlink messages.
/// * `headers` - Mutable vector to store parsed [`InetResponseHeader`]s.
///
/// # Returns
/// * `Ok(true)` if parsing is complete (Done message received).
/// * `Ok(false)` if more data is expected. In this case, this function can be called again.
/// * `Err(TcpError)` on invalid length, deserialization failure, or Netlink error.
///
/// # Errors
/// Returns [`TcpError`] variants for invalid message lengths or Netlink errors.
fn parse_netlink_messages(
buffer: &[u8],
headers: &mut Vec<InetResponseHeader>,
) -> Result<bool, TcpError> {
let mut offset = 0;
let mut done = false;

while offset < buffer.len() {
let remaining_bytes = &buffer[offset..];
if remaining_bytes.len() < 4 {
// Still treat this as an error since we can't even read the length
return Err(TcpError::InvalidLength {
length: remaining_bytes.len(),
});
}
// This function panics if the buffer length is less than 4.
let length = NativeEndian::read_u32(&remaining_bytes[0..4]) as usize;
if length == 0 {
break;
}
if length > remaining_bytes.len() {
return Err(TcpError::InvalidLength { length });
}

let msg_bytes = &remaining_bytes[..length];
let rx_packet =
<NetlinkMessage<SockDiagMessage>>::deserialize(msg_bytes).context(NetlinkParseSnafu)?;

match rx_packet.payload {
NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(response)) => {
headers.push(response.header);
}
NetlinkPayload::Done(_) => {
done = true;
break;
}
NetlinkPayload::Error(error) => {
if let Some(code) = error.code {
return Err(TcpError::NetlinkMsgError { code: code.get() });
}
}
_ => {}
}

offset += length;
}

Ok(done)
}

/// Fetches [`InetResponseHeader`]s for TCP sockets using Netlink.
///
/// # Arguments
/// * `addr_family` - Address family (`AF_INET` for IPv4, `AF_INET6` for IPv6).
///
/// # Returns
/// * `Ok(Vec<InetResponseHeader>)` containing headers for active TCP sockets.
/// * `Err(TcpError)` on socket creation, send, receive, or parsing errors.
///
/// # Errors
/// Returns [`TcpError`] for socket-related or message parsing failures.
///
/// # Notes
/// Asynchronously queries the kernel via a Netlink socket for TCP socket info.
async fn fetch_netlink_inet_headers(addr_family: u8) -> Result<Vec<InetResponseHeader>, TcpError> {
let unicast_socket: SocketAddr = SocketAddr::new(0, 0);
let mut socket = TokioSocket::new(NETLINK_SOCK_DIAG).context(NetlinkSocketSnafu)?;

Expand Down Expand Up @@ -163,34 +242,12 @@ async fn fetch_nl_inet_hdrs(addr_family: u8) -> Result<Vec<InetResponseHeader>,
.context(NetlinkSendSnafu)?;

let mut receive_buffer = vec![0; 4096];
let mut inet_resp_hdrs: Vec<InetResponseHeader> = Vec::new();
'outer: while let Ok(()) = socket.recv(&mut &mut receive_buffer[..]).await {
let mut offset = 0;
'inner: loop {
let bytes = &receive_buffer[offset..];
let length = NativeEndian::read_u32(&bytes[0..4]) as usize;
if length == 0 {
break 'inner;
}
let rx_packet =
<NetlinkMessage<SockDiagMessage>>::deserialize(bytes).context(NetlinkParseSnafu)?;

match rx_packet.payload {
NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(response)) => {
inet_resp_hdrs.push(response.header);
}
NetlinkPayload::Done(_) => {
break 'outer;
}
NetlinkPayload::Error(error) => {
if let Some(code) = error.code {
return Err(TcpError::NetlinkMsgError { code: code.get() });
}
}
_ => {}
}
let mut inet_resp_hdrs = Vec::with_capacity(32); // Pre-allocate with an estimate

offset += rx_packet.header.length as usize;
while let Ok(()) = socket.recv(&mut &mut receive_buffer[..]).await {
let done = parse_netlink_messages(&receive_buffer, &mut inet_resp_hdrs)?;
if done {
break;
}
}

Expand All @@ -214,11 +271,11 @@ fn parse_nl_inet_hdrs(

async fn build_tcp_stats() -> Result<TcpStats, TcpError> {
let mut tcp_stats = TcpStats::default();
let resp = fetch_nl_inet_hdrs(AF_INET).await?;
let resp = fetch_netlink_inet_headers(AF_INET).await?;
parse_nl_inet_hdrs(resp, &mut tcp_stats)?;

if is_ipv6_enabled() {
let resp = fetch_nl_inet_hdrs(AF_INET6).await?;
let resp = fetch_netlink_inet_headers(AF_INET6).await?;
parse_nl_inet_hdrs(resp, &mut tcp_stats)?;
}

Expand All @@ -239,7 +296,7 @@ mod tests {
};

use super::{
fetch_nl_inet_hdrs, parse_nl_inet_hdrs, TcpStats, STATE, TCP_CONNS_TOTAL,
fetch_netlink_inet_headers, parse_nl_inet_hdrs, TcpStats, STATE, TCP_CONNS_TOTAL,
TCP_RX_QUEUED_BYTES_TOTAL, TCP_TX_QUEUED_BYTES_TOTAL,
};
use crate::sources::host_metrics::{HostMetrics, HostMetricsConfig, MetricsBuffer};
Expand Down Expand Up @@ -296,7 +353,7 @@ mod tests {
// initiate a connection
let _stream = TcpStream::connect(addr).await.unwrap();

let hdrs = fetch_nl_inet_hdrs(AF_INET).await.unwrap();
let hdrs = fetch_netlink_inet_headers(AF_INET).await.unwrap();
// there should be at least two connections, one for the server and one for the client
assert!(hdrs.len() >= 2);

Expand Down

0 comments on commit d17c099

Please sign in to comment.