diff --git a/netlink-proto/examples/listen_uevents.rs b/netlink-proto/examples/listen_uevents.rs new file mode 100644 index 00000000..2e90ebd1 --- /dev/null +++ b/netlink-proto/examples/listen_uevents.rs @@ -0,0 +1,59 @@ +use futures::StreamExt; +use netlink_proto::{new_connection, sys::{protocols::NETLINK_KOBJECT_UEVENT, SocketAddr}}; + +use netlink_packet_core::{NetlinkDeserializable, NetlinkHeader, NetlinkSerializable}; + +#[derive(Debug, PartialEq, Eq, Clone)] +enum UEvent { + Add, +} + +impl NetlinkSerializable for UEvent { + fn message_type(&self) -> u16 { + todo!() + } + + fn buffer_len(&self) -> usize { + todo!() + } + + fn serialize(&self, buffer: &mut [u8]) { + todo!() + } +} + +impl NetlinkDeserializable for UEvent { + type Error = std::io::Error; + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { + let s = String::from_utf8_lossy(payload); + println!("{}", s); + + Ok(UEvent::Add) + } +} + +#[tokio::main] +async fn main() -> Result<(), String> { + env_logger::init(); + // Create the netlink socket. + let (mut conn, mut _handle, mut messages) = new_connection::(NETLINK_KOBJECT_UEVENT) + .map_err(|e| format!("Failed to create a new netlink connection: {}", e))?; + + let sa = SocketAddr::new(std::process::id(), 1); + + conn.socket_mut().bind(&sa).unwrap(); + + // Spawn the `Connection` in the background + tokio::spawn(conn); + + // Print all the messages received in response + loop { + if let Some(packet) = messages.next().await { + println!("<<< {:?}", packet); + } else { + break; + } + } + + Ok(()) +} diff --git a/netlink-proto/src/codecs.rs b/netlink-proto/src/codecs.rs index d5c1af48..3e6d370c 100644 --- a/netlink-proto/src/codecs.rs +++ b/netlink-proto/src/codecs.rs @@ -2,27 +2,45 @@ use std::{fmt::Debug, io, marker::PhantomData}; use bytes::{BufMut, BytesMut}; use netlink_packet_core::{ - NetlinkBuffer, - NetlinkDeserializable, - NetlinkMessage, + NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable, }; + +use crate::sys::protocols::{NETLINK_AUDIT, NETLINK_GENERIC, NETLINK_KOBJECT_UEVENT}; + use tokio_util::codec::{Decoder, Encoder}; +#[derive(Eq, PartialEq)] +enum CodecType { + /// Normal Netlink packet with header + Packet, + /// Audit packets lenght is unreliable + AuditPacket, + /// kobject_uevent packets do not have headers at all + UEventPacket, +} + pub struct NetlinkCodec { phantom: PhantomData, + ty: CodecType, } impl Default for NetlinkCodec { fn default() -> Self { - Self::new() + Self::new(NETLINK_GENERIC) } } impl NetlinkCodec { - pub fn new() -> Self { + pub fn new(protocol: isize) -> Self { + let ty = match protocol { + NETLINK_AUDIT => CodecType::AuditPacket, + NETLINK_KOBJECT_UEVENT => CodecType::UEventPacket, + _ => CodecType::Packet, + }; NetlinkCodec { phantom: PhantomData, + ty, } } } @@ -47,54 +65,59 @@ where return Ok(None); } - // This is a bit hacky because we don't want to keep `src` - // borrowed, since we need to mutate it later. - let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { - #[cfg(not(feature = "workaround-audit-bug"))] - Ok(buf) => Ok(buf.length() as usize), - #[cfg(feature = "workaround-audit-bug")] - Ok(buf) => { - if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { - // The audit messages are sometimes truncated, - // because the length specified in the header, - // does not take the header itself into - // account. To workaround this, we tweak the - // length. We've noticed two occurences of - // truncated packets: - // - // - the length of the header is not included (see also: - // https://github.com/mozilla/libaudit-go/issues/24) - // - some rule message have some padding for alignment (see - // https://github.com/linux-audit/audit-userspace/issues/78) which is not - // taken into account in the buffer length. - warn!("found what looks like a truncated audit packet"); - Ok(src.as_ref().len()) - } else { - Ok(buf.length() as usize) + // the uevent packets do not have any header + let len = if self.ty == CodecType::UEventPacket { + src.len() + } else { + // This is a bit hacky because we don't want to keep `src` + // borrowed, since we need to mutate it later. + let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { + Ok(buf) => { + if self.ty == CodecType::Packet { + Ok(buf.length() as usize) + } else { + if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { + // The audit messages are sometimes truncated, + // because the length specified in the header, + // does not take the header itself into + // account. To workaround this, we tweak the + // length. We've noticed two occurences of + // truncated packets: + // + // - the length of the header is not included (see also: + // https://github.com/mozilla/libaudit-go/issues/24) + // - some rule message have some padding for alignment (see + // https://github.com/linux-audit/audit-userspace/issues/78) which is not + // taken into account in the buffer length. + warn!("found what looks like a truncated audit packet"); + Ok(src.as_ref().len()) + } else { + Ok(buf.length() as usize) + } + } } - } - Err(e) => { - // We either received a truncated packet, or the - // packet if malformed (invalid length field). In - // both case, we can't decode the datagram, and we - // cannot find the start of the next one (if - // any). The only solution is to clear the buffer - // and potentially lose some datagrams. - error!("failed to decode datagram: {:?}: {:#x?}.", e, src.as_ref()); - Err(()) - } - }; + Err(e) => { + // We either received a truncated packet, or the + // packet if malformed (invalid length field). In + // both case, we can't decode the datagram, and we + // cannot find the start of the next one (if + // any). The only solution is to clear the buffer + // and potentially lose some datagrams. + error!("failed to decode datagram: {:?}: {:#x?}.", e, src.as_ref()); + Err(()) + } + }; - if len_res.is_err() { - error!("clearing the whole socket buffer. Datagrams may have been lost"); - src.clear(); - return Ok(None); - } + if len_res.is_err() { + error!("clearing the whole socket buffer. Datagrams may have been lost"); + src.clear(); + return Ok(None); + } - let len = len_res.unwrap(); + len_res.unwrap() + }; - #[cfg(feature = "workaround-audit-bug")] - let bytes = { + let bytes = if self.ty == CodecType::AuditPacket { let mut bytes = src.split_to(len); { let mut buf = NetlinkBuffer::new(bytes.as_mut()); @@ -120,21 +143,38 @@ where } } bytes + } else { + src.split_to(len) }; - #[cfg(not(feature = "workaround-audit-bug"))] - let bytes = src.split_to(len); - - let parsed = NetlinkMessage::::deserialize(&bytes); - match parsed { - Ok(packet) => { - trace!("<<< {:?}", packet); - return Ok(Some(packet)); + + if self.ty == CodecType::UEventPacket { + // dummy header, unused + let header = NetlinkHeader::default(); + match T::deserialize(&header, &bytes) { + Ok(packet) => { + trace!("<<< {:?}", packet); + return Ok(Some(NetlinkMessage::new( + header, + NetlinkPayload::InnerMessage(packet), + ))); + } + Err(e) => { + error!("failed to decode packet {:#x?}: {}", &bytes, e); + } } - Err(e) => { - error!("failed to decode packet {:#x?}: {}", &bytes, e); - // continue looping, there may be more datagrams in the buffer + } else { + let parsed = NetlinkMessage::::deserialize(&bytes); + match parsed { + Ok(packet) => { + trace!("<<< {:?}", packet); + return Ok(Some(packet)); + } + Err(e) => { + error!("failed to decode packet {:#x?}: {}", &bytes, e); + // continue looping, there may be more datagrams in the buffer + } } - } + }; } } } diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index 7406c198..4f7a0ac3 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -7,25 +7,18 @@ use std::{ use futures::{ channel::mpsc::{UnboundedReceiver, UnboundedSender}, - Future, - Sink, - Stream, + Future, Sink, Stream, }; use log::{error, warn}; use netlink_packet_core::{ - NetlinkDeserializable, - NetlinkMessage, - NetlinkPayload, - NetlinkSerializable, + NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, }; use crate::{ codecs::NetlinkCodec, framed::NetlinkFramed, sys::{Socket, SocketAddr}, - Protocol, - Request, - Response, + Protocol, Request, Response, }; /// Connection to a Netlink socket, running in the background. @@ -61,7 +54,7 @@ where ) -> io::Result { let socket = Socket::new(protocol)?; Ok(Connection { - socket: NetlinkFramed::new(socket, NetlinkCodec::>::new()), + socket: NetlinkFramed::new(socket, NetlinkCodec::>::new(protocol)), protocol: Protocol::new(), requests_rx: Some(requests_rx), unsolicited_messages_tx: Some(unsolicited_messages_tx), diff --git a/netlink-sys/examples/listen_uevents.rs b/netlink-sys/examples/listen_uevents.rs new file mode 100644 index 00000000..9f69cd95 --- /dev/null +++ b/netlink-sys/examples/listen_uevents.rs @@ -0,0 +1,33 @@ +// Build: +// +// ``` +// cd netlink-sys +// cargo run --example listen_uevents +// +// ``` +// +// Run *as root*: +// +// ``` +// find /sys -name uevent -exec sh -c 'echo add >"{}"' '; +// ``` +// +// To generate events. + +use std::process; + +use netlink_sys::{protocols::NETLINK_KOBJECT_UEVENT, Socket, SocketAddr}; + +fn main() { + let mut socket = Socket::new(NETLINK_KOBJECT_UEVENT).unwrap(); + let sa = SocketAddr::new(process::id(), 1); + let mut buf = vec![0; 1024 * 8]; + + socket.bind(&sa); + + loop { + let n = socket.recv(&mut buf, 0).unwrap(); + let s = String::from_utf8_lossy(&buf[..n]); + println!(">> {}", s); + } +}