Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kl committed Nov 14, 2024
1 parent a2fb9ec commit 8ec822a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 67 deletions.
83 changes: 51 additions & 32 deletions talpid-wireguard/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
ffi::CString,
net::{Ipv4Addr, Ipv6Addr},
};
use talpid_types::net::wireguard::{PeerConfig, PrivateKey};
use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions};

/// Name to use for the tunnel device
Expand Down Expand Up @@ -121,38 +122,11 @@ impl Config {
/// Returns a CString with the appropriate config for WireGuard-go
// TODO: Consider outputting both overriding and additive configs
pub fn to_userspace_format(&self) -> CString {
// the order of insertion matters, public key entry denotes a new peer entry
let mut wg_conf = WgConfigBuffer::new();
wg_conf
.add::<&[u8]>("private_key", self.tunnel.private_key.to_bytes().as_ref())
.add("listen_port", "0");

#[cfg(target_os = "linux")]
if let Some(fwmark) = &self.fwmark {
wg_conf.add("fwmark", fwmark.to_string().as_str());
}

wg_conf.add("replace_peers", "true");

for peer in self.peers() {
wg_conf
.add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref())
.add("endpoint", peer.endpoint.to_string().as_str())
.add("replace_allowed_ips", "true");
if let Some(ref psk) = peer.psk {
wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref());
}
for addr in &peer.allowed_ips {
wg_conf.add("allowed_ip", addr.to_string().as_str());
}
#[cfg(daita)]
if peer.constant_packet_size {
wg_conf.add("constant_packet_size", "true");
}
}

let bytes = wg_conf.into_config();
CString::new(bytes).expect("null bytes inside config")
userspace_format(
&self.tunnel.private_key,
&self.entry_peer,
self.exit_peer.as_ref(),
)
}

/// Return whether the config connects to an exit peer from another remote peer.
Expand Down Expand Up @@ -242,3 +216,48 @@ impl WgConfigBuffer {
self.buf
}
}

/// Returns a CString with the appropriate config for WireGuard-go
pub fn userspace_format(
private_key: &PrivateKey,
entry_peer: &PeerConfig,
exit_peer: Option<&PeerConfig>,
) -> CString {
// the order of insertion matters, public key entry denotes a new peer entry
let mut wg_conf = WgConfigBuffer::new();
wg_conf
.add::<&[u8]>("private_key", private_key.to_bytes().as_ref())
.add("listen_port", "0");

#[cfg(target_os = "linux")]
if let Some(fwmark) = &self.fwmark {
wg_conf.add("fwmark", fwmark.to_string().as_str());
}

wg_conf.add("replace_peers", "true");

let peers = exit_peer
.as_ref()
.into_iter()
.chain(std::iter::once(&entry_peer));

for peer in peers {
wg_conf
.add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref())
.add("endpoint", peer.endpoint.to_string().as_str())
.add("replace_allowed_ips", "true");
if let Some(ref psk) = peer.psk {
wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref());
}
for addr in &peer.allowed_ips {
wg_conf.add("allowed_ip", addr.to_string().as_str());
}
#[cfg(daita)]
if peer.constant_packet_size {
wg_conf.add("constant_packet_size", "true");
}
}

let bytes = wg_conf.into_config();
CString::new(bytes).expect("null bytes inside config")
}
72 changes: 37 additions & 35 deletions talpid-wireguard/src/wireguard_go/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{
config,
stats::{Stats, StatsMap},
Config, Tunnel, TunnelError,
};
Expand Down Expand Up @@ -35,11 +36,20 @@ const DAITA_ACTIONS_CAPACITY: u32 = 1000;

type Result<T> = std::result::Result<T, TunnelError>;

struct LoggingContext(u64);
struct LoggingContext {
ordinal: u64,
path: Option<PathBuf>,
}

impl LoggingContext {
fn new(ordinal: u64, path: Option<PathBuf>) -> Self {
LoggingContext { ordinal, path }
}
}

impl Drop for LoggingContext {
fn drop(&mut self) {
clean_up_logging(self.0);
clean_up_logging(self.ordinal);
}
}

Expand Down Expand Up @@ -92,7 +102,7 @@ impl WgGoTunnel {

pub fn better_set_config(self, config: &Config) -> Result<Self> {
let state = self.as_state();
let log_path = state._log_path.clone();
let log_path = state.logging_context.path.clone();
let tun_provider = Arc::clone(&state.tun_provider);
let routes = config.get_tunnel_destinations();
#[cfg(daita)]
Expand Down Expand Up @@ -142,11 +152,8 @@ pub(crate) struct WgGoTunnelState {
// holding on to the tunnel device and the log file ensures that the associated file handles
// live long enough and get closed when the tunnel is stopped
_tunnel_device: Tun,
// HACK: Don't use this. Only sometimes. ;-)
#[cfg(target_os = "android")]
_log_path: Option<PathBuf>,
// context that maps to fs::File instance, used with logging callback
_logging_context: LoggingContext,
// context that maps to fs::File instance and stores the file path, used with logging callback
logging_context: LoggingContext,
#[cfg(target_os = "android")]
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(daita)]
Expand Down Expand Up @@ -223,7 +230,7 @@ impl WgGoTunnel {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
_logging_context: logging_context,
logging_context,
#[cfg(daita)]
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
Expand Down Expand Up @@ -279,22 +286,21 @@ impl WgGoTunnel {
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
) -> Result<Self> {
let tun_provider_clone = tun_provider.clone();

let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?;
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;

let interface_name: String = tunnel_device.interface_name().to_string();
let wg_config_str = config.to_userspace_format();
let _log_path = log_path;
let logging_context = initialize_logging(log_path)
.map(LoggingContext)
.map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
.map_err(TunnelError::LoggingError)?;

let wg_config_str = config.to_userspace_format();

let handle = wireguard_go_rs::Tunnel::turn_on(
&wg_config_str,
tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.0,
logging_context.ordinal,
)
.map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;

Expand All @@ -305,9 +311,8 @@ impl WgGoTunnel {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
_logging_context: logging_context,
_log_path: _log_path.map(|log_path| log_path.to_owned()),
tun_provider: tun_provider_clone,
logging_context,
tun_provider,
#[cfg(daita)]
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
Expand All @@ -323,24 +328,22 @@ impl WgGoTunnel {
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
) -> Result<Self> {
let tun_provider_clone = tun_provider.clone();

let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?;
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;

let interface_name: String = tunnel_device.interface_name().to_string();
let _log_path = log_path;
let logging_context = initialize_logging(log_path)
.map(LoggingContext)
.map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
.map_err(TunnelError::LoggingError)?;

let mut entry_config = config.clone();
entry_config.exit_peer = None;

let mut exit_config = config.clone();
exit_config.entry_peer = exit_peer;
let entry_config_str =
config::userspace_format(&config.tunnel.private_key, &config.entry_peer, None);

let entry_config_str = entry_config.to_userspace_format();
let exit_config_str = exit_config.to_userspace_format();
let exit_config_str = config::userspace_format(
&config.tunnel.private_key,
&exit_peer,
config.exit_peer.as_ref(),
);

let private_ip = config
.tunnel
Expand All @@ -356,7 +359,7 @@ impl WgGoTunnel {
&private_ip,
tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.0,
logging_context.ordinal,
)
.map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;

Expand All @@ -367,9 +370,8 @@ impl WgGoTunnel {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
_logging_context: logging_context,
_log_path: _log_path.map(|log_path| log_path.to_owned()),
tun_provider: tun_provider_clone,
logging_context,
tun_provider,
#[cfg(daita)]
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
Expand Down

0 comments on commit 8ec822a

Please sign in to comment.