Skip to content

Commit

Permalink
refactor windows module
Browse files Browse the repository at this point in the history
  • Loading branch information
GyulyVGC committed Mar 13, 2024
1 parent 02bae43 commit 418fb16
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 128 deletions.
32 changes: 16 additions & 16 deletions src/platform/linux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ pub fn get_all() -> crate::Result<HashSet<Listener>> {
Ok(listeners)
}

#[cfg(test)]
mod tests {
#[test]
fn test_get_all() {
let listeners = crate::get_all().unwrap();
assert!(!listeners.is_empty());

// let out = std::process::Command::new("netstat")
// .args(["-plnt"])
// .output()
// .unwrap();
// for l in String::from_utf8(out.stdout).unwrap().lines() {
// println!("{}", l);
// }
}
}
// #[cfg(test)]
// mod tests {
// #[test]
// fn test_get_all() {
// let listeners = crate::get_all().unwrap();
// assert!(!listeners.is_empty());
//
// // let out = std::process::Command::new("netstat")
// // .args(["-plnt"])
// // .output()
// // .unwrap();
// // for l in String::from_utf8(out.stdout).unwrap().lines() {
// // println!("{}", l);
// // }
// }
// }
32 changes: 16 additions & 16 deletions src/platform/macos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ pub fn get_all() -> crate::Result<HashSet<Listener>> {
Ok(listeners)
}

#[cfg(test)]
mod tests {
#[test]
fn test_get_all() {
let listeners = crate::get_all().unwrap();
assert!(!listeners.is_empty());

// let out = std::process::Command::new("netstat")
// .args(["-p", "tcp", "-van"])
// .output()
// .unwrap();
// for l in String::from_utf8(out.stdout).unwrap().lines().filter(|l| l.contains("LISTEN")) {
// println!("{}", l);
// }
}
}
// #[cfg(test)]
// mod tests {
// #[test]
// fn test_get_all() {
// let listeners = crate::get_all().unwrap();
// assert!(!listeners.is_empty());
//
// // let out = std::process::Command::new("netstat")
// // .args(["-p", "tcp", "-van"])
// // .output()
// // .unwrap();
// // for l in String::from_utf8(out.stdout).unwrap().lines().filter(|l| l.contains("LISTEN")) {
// // println!("{}", l);
// // }
// }
// }
24 changes: 4 additions & 20 deletions src/platform/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::collections::HashSet;
use crate::Listener;
use crate::platform::windows::socket_table::SocketTable;
use crate::platform::windows::tcp6_table::Tcp6Table;
use crate::platform::windows::tcp_listener::TcpListener;
use crate::platform::windows::tcp_table::TcpTable;
use socket_table::SocketTable;
use std::collections::HashSet;
use tcp_listener::TcpListener;

mod c_iphlpapi;
mod socket_table;
Expand All @@ -15,25 +13,11 @@ mod tcp_table;
pub fn get_all() -> crate::Result<HashSet<Listener>> {
let mut listeners = HashSet::new();

let tcp_listeners = entries::<TcpTable>();
let tcp6_listeners = entries::<Tcp6Table>();

for tcp_listener in tcp_listeners.iter().flatten().chain(tcp6_listeners.iter().flatten()) {
for tcp_listener in TcpListener::get_all() {
if let Some(listener) = tcp_listener.to_listener() {
listeners.insert(listener);
}
}

Ok(listeners)
}

fn entries<Table: SocketTable>() -> crate::Result<Vec<TcpListener>> {
let mut tcp_listeners = Vec::new();
let table = Table::get_table()?;
for i in 0..Table::get_rows_count(&table) {
if let Some(tcp_listener) = Table::get_tcp_listener(&table, i) {
tcp_listeners.push(tcp_listener);
}
}
Ok(tcp_listeners)
}
78 changes: 39 additions & 39 deletions src/platform/windows/socket_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,12 @@ use crate::platform::windows::tcp_table::{TcpRow, TcpTable};
use std::ffi::{c_ulong, c_void};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

pub trait SocketTable {
pub(super) trait SocketTable {
fn get_table() -> crate::Result<Vec<u8>>;
fn get_rows_count(table: &[u8]) -> usize;
fn get_tcp_listener(table: &[u8], index: usize) -> Option<TcpListener>;
}

fn get_tcp_table(address_family: c_ulong) -> crate::Result<Vec<u8>> {
let mut table_size: c_ulong = 0;
let mut err_code = unsafe {
GetExtendedTcpTable(
std::ptr::null_mut(),
&mut table_size,
FALSE,
address_family,
TCP_TABLE_OWNER_PID_ALL,
0,
)
};
let mut table = Vec::<u8>::new();
let mut iterations = 0;
while err_code == ERROR_INSUFFICIENT_BUFFER {
table = Vec::<u8>::with_capacity(table_size as usize);
err_code = unsafe {
GetExtendedTcpTable(
table.as_mut_ptr() as *mut c_void,
&mut table_size,
FALSE,
address_family,
TCP_TABLE_OWNER_PID_ALL,
0,
)
};
iterations += 1;
if iterations > 100 {
return Err("Failed to allocate buffer".into());
}
}
if err_code == NO_ERROR {
Ok(table)
} else {
Err("Failed to get TCP table".into())
}
}

impl SocketTable for TcpTable {
fn get_table() -> crate::Result<Vec<u8>> {
get_tcp_table(AF_INET)
Expand Down Expand Up @@ -104,3 +66,41 @@ impl SocketTable for Tcp6Table {
}
}
}

fn get_tcp_table(address_family: c_ulong) -> crate::Result<Vec<u8>> {
let mut table_size: c_ulong = 0;
let mut err_code = unsafe {
GetExtendedTcpTable(
std::ptr::null_mut(),
&mut table_size,
FALSE,
address_family,
TCP_TABLE_OWNER_PID_ALL,
0,
)
};
let mut table = Vec::<u8>::new();
let mut iterations = 0;
while err_code == ERROR_INSUFFICIENT_BUFFER {
table = Vec::<u8>::with_capacity(table_size as usize);
err_code = unsafe {
GetExtendedTcpTable(
table.as_mut_ptr() as *mut c_void,
&mut table_size,
FALSE,
address_family,
TCP_TABLE_OWNER_PID_ALL,
0,
)
};
iterations += 1;
if iterations > 100 {
return Err("Failed to allocate buffer".into());
}
}
if err_code == NO_ERROR {
Ok(table)
} else {
Err("Failed to get TCP table".into())
}
}
24 changes: 12 additions & 12 deletions src/platform/windows/tcp6_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ use std::os::raw::c_ulong;

#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct Tcp6Table {
pub rows_count: c_ulong,
pub rows: [Tcp6Row; 1],
pub(super) struct Tcp6Table {
pub(super) rows_count: c_ulong,
pub(super) rows: [Tcp6Row; 1],
}

#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct Tcp6Row {
pub local_addr: [c_uchar; 16],
pub local_scope_id: c_ulong,
pub local_port: c_ulong,
pub remote_addr: [c_uchar; 16],
pub remote_scope_id: c_ulong,
pub remote_port: c_ulong,
pub state: c_ulong,
pub owning_pid: c_ulong,
pub(super) struct Tcp6Row {
pub(super) local_addr: [c_uchar; 16],
local_scope_id: c_ulong,
pub(super) local_port: c_ulong,
remote_addr: [c_uchar; 16],
remote_scope_id: c_ulong,
remote_port: c_ulong,
pub(super) state: c_ulong,
pub(super) owning_pid: c_ulong,
}
41 changes: 26 additions & 15 deletions src/platform/windows/tcp_listener.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::net::{IpAddr, SocketAddr};
use crate::platform::windows::socket_table::SocketTable;
use crate::platform::windows::tcp6_table::Tcp6Table;
use crate::platform::windows::tcp_table::TcpTable;
use crate::Listener;
use std::net::{IpAddr, SocketAddr};

#[derive(Debug)]
pub(super) struct TcpListener {
Expand All @@ -9,6 +12,25 @@ pub(super) struct TcpListener {
}

impl TcpListener {
pub(super) fn get_all() -> Vec<TcpListener> {
Self::table_entries::<TcpTable>()
.into_iter()
.flatten()
.chain(Self::table_entries::<Tcp6Table>().into_iter().flatten())
.collect()
}

fn table_entries<Table: SocketTable>() -> crate::Result<Vec<Self>> {
let mut tcp_listeners = Vec::new();
let table = Table::get_table()?;
for i in 0..Table::get_rows_count(&table) {
if let Some(tcp_listener) = Table::get_tcp_listener(&table, i) {
tcp_listeners.push(tcp_listener);
}
}
Ok(tcp_listeners)
}

pub(super) fn new(local_addr: IpAddr, local_port: u16, pid: u32) -> Self {
Self {
local_addr,
Expand All @@ -17,24 +39,13 @@ impl TcpListener {
}
}

pub(super) fn local_addr(&self) -> IpAddr {
self.local_addr
}

pub(super) fn local_port(&self) -> u16 {
self.local_port
}

pub(super) fn pid(&self) -> u32 {
self.pid
}

pub(super) fn pname(&self) -> Option<String> {
fn pname(&self) -> Option<String> {
use std::mem::size_of;
use std::mem::zeroed;
use windows::Win32::Foundation::CloseHandle;
use windows::Win32::System::Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32, TH32CS_SNAPPROCESS,
CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32,
TH32CS_SNAPPROCESS,
};

let pid = self.pid;
Expand Down
20 changes: 10 additions & 10 deletions src/platform/windows/tcp_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ use std::ffi::c_ulong;

#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct TcpTable {
pub rows_count: c_ulong,
pub rows: [TcpRow; 1],
pub(super) struct TcpTable {
pub(super) rows_count: c_ulong,
pub(super) rows: [TcpRow; 1],
}

#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct TcpRow {
pub state: c_ulong,
pub local_addr: c_ulong,
pub local_port: c_ulong,
pub remote_addr: c_ulong,
pub remote_port: c_ulong,
pub owning_pid: c_ulong,
pub(super) struct TcpRow {
pub(super) state: c_ulong,
pub(super) local_addr: c_ulong,
pub(super) local_port: c_ulong,
remote_addr: c_ulong,
remote_port: c_ulong,
pub(super) owning_pid: c_ulong,
}

0 comments on commit 418fb16

Please sign in to comment.