use std::net::{IpAddr, SocketAddr}; use ipnet::IpNet; use netlink_packet_wireguard::{ WireguardAddressFamily, WireguardAllowedIp, WireguardAllowedIpAttr, WireguardAttribute, WireguardCmd, WireguardMessage, WireguardPeer, WireguardPeerAttribute, }; use super::Key; #[allow(unused)] mod constants { // this is copy pasted from the netlink_packet_wireguard's constants module because for some reason // they stopped exposing constants in commit 3067a394fc7bc28fadbed5359c44cce95aac0f13 pub const WGDEVICE_F_REPLACE_PEERS: u32 = 1 << 0; pub const WGPEER_F_REMOVE_ME: u32 = 1 << 0; pub const WGPEER_F_REPLACE_ALLOWEDIPS: u32 = 1 << 1; pub const WGPEER_F_UPDATE_ONLY: u32 = 1 << 2; pub const WGPEER_A_UNSPEC: u16 = 0; pub const WGPEER_A_PUBLIC_KEY: u16 = 1; pub const WGPEER_A_PRESHARED_KEY: u16 = 2; pub const WGPEER_A_FLAGS: u16 = 3; pub const WGPEER_A_ENDPOINT: u16 = 4; pub const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5; pub const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6; pub const WGPEER_A_RX_BYTES: u16 = 7; pub const WGPEER_A_TX_BYTES: u16 = 8; pub const WGPEER_A_ALLOWEDIPS: u16 = 9; pub const WGPEER_A_PROTOCOL_VERSION: u16 = 10; pub const WGALLOWEDIP_A_UNSPEC: u16 = 0; pub const WGALLOWEDIP_A_FAMILY: u16 = 1; pub const WGALLOWEDIP_A_IPADDR: u16 = 2; pub const WGALLOWEDIP_A_CIDR_MASK: u16 = 3; pub const AF_INET6: u16 = 10; pub const AF_INET: u16 = 2; } #[allow(unused)] pub(crate) use constants::*; #[derive(Debug)] pub struct PeerDescriptor { pub(super) public_key: Key, pub(super) preshared_key: Option, pub(super) endpoint: Option, pub(super) keepalive: Option, pub(super) allowed_ips: Option>, } impl PeerDescriptor { pub fn new(public_key: Key) -> Self { Self { public_key, preshared_key: None, endpoint: None, keepalive: None, allowed_ips: None, } } pub fn preshared_key_optional(mut self, preshared_key: Option) -> Self { self.preshared_key = preshared_key; self } pub fn preshared_key(mut self, preshared_key: Key) -> Self { self.preshared_key = Some(preshared_key); self } pub fn endpoint_optional(mut self, endpoint: Option) -> Self { self.endpoint = endpoint; self } pub fn endpoint(mut self, endpoint: SocketAddr) -> Self { self.endpoint = Some(endpoint); self } pub fn keepalive_optional(mut self, keepalive: Option) -> Self { self.keepalive = keepalive; self } pub fn keepalive(mut self, keepalive: u16) -> Self { self.keepalive = Some(keepalive); self } pub fn allowed_ip_optional(self, allowed_ip: Option) -> Self { if let Some(allowed_ip) = allowed_ip { self.allowed_ip(allowed_ip) } else { self } } pub fn allowed_ip(mut self, allowed_ip: IpNet) -> Self { let mut allowed_ips = self.allowed_ips.take().unwrap_or_default(); allowed_ips.push(allowed_ip); self.allowed_ips = Some(allowed_ips); self } pub fn allowed_ips_optional(self, allowed_ips: Option>) -> Self { if let Some(allowed_ips) = allowed_ips { self.allowed_ips(allowed_ips) } else { self } } pub fn allowed_ips(mut self, allowed_ips: Vec) -> Self { self.allowed_ips = Some(allowed_ips); self } pub(super) fn into_wireguard(self) -> WireguardPeer { let mut attributes = Vec::new(); attributes.push(WireguardPeerAttribute::PublicKey( self.public_key.into_array(), )); attributes.extend( self.preshared_key .map(|key| WireguardPeerAttribute::PresharedKey(key.into_array())), ); attributes.extend(self.endpoint.map(WireguardPeerAttribute::Endpoint)); attributes.extend( self.keepalive .map(WireguardPeerAttribute::PersistentKeepalive), ); attributes.extend(self.allowed_ips.map(|allowed_ips| { WireguardPeerAttribute::AllowedIps(allowed_ips.into_iter().map(ipnet_to_wg).collect()) })); attributes.push(WireguardPeerAttribute::Flags(WGPEER_F_REPLACE_ALLOWEDIPS)); WireguardPeer(attributes) } } #[derive(Debug)] pub struct DeviceDescriptor { pub(super) addresses: Vec, pub(super) private_key: Option, pub(super) listen_port: Option, pub(super) fwmark: Option, pub(super) peers: Option>, } impl Default for DeviceDescriptor { fn default() -> Self { Self::new() } } impl DeviceDescriptor { pub fn new() -> Self { Self { addresses: Vec::default(), private_key: None, listen_port: None, fwmark: None, peers: None, } } pub fn address(mut self, address: IpNet) -> Self { self.addresses.push(address); self } pub fn addresses(mut self, addresses: impl IntoIterator) -> Self { self.addresses.extend(addresses); self } pub fn private_key(mut self, key: Key) -> Self { self.private_key = Some(key); self } pub fn listen_port(mut self, port: u16) -> Self { self.listen_port = Some(port); self } pub fn listen_port_optional(mut self, port: Option) -> Self { self.listen_port = port; self } pub fn fwmark(mut self, fwmark: u32) -> Self { self.fwmark = Some(fwmark); self } pub fn peer(mut self, peer: PeerDescriptor) -> Self { let mut p = self.peers.take().unwrap_or_default(); p.push(peer); self.peers = Some(p); self } pub fn peers(mut self, peers: impl IntoIterator) -> Self { let mut p = self.peers.take().unwrap_or_default(); p.extend(peers); self.peers = Some(p); self } pub(super) fn into_wireguard(self, device_name: String) -> WireguardMessage { let mut attributes = Vec::new(); attributes.push(WireguardAttribute::IfName(device_name)); attributes.extend( self.private_key .map(|key| WireguardAttribute::PrivateKey(key.into_array())), ); attributes.extend(self.listen_port.map(WireguardAttribute::ListenPort)); attributes.extend(self.fwmark.map(WireguardAttribute::Fwmark)); attributes.extend(self.peers.map(|peers| { WireguardAttribute::Peers( peers .into_iter() .map(PeerDescriptor::into_wireguard) .collect(), ) })); attributes.push(WireguardAttribute::Flags(WGDEVICE_F_REPLACE_PEERS)); WireguardMessage { cmd: WireguardCmd::SetDevice, attributes, } } } fn ipnet_to_wg(net: IpNet) -> WireguardAllowedIp { let mut attributes = Vec::default(); attributes.push(WireguardAllowedIpAttr::Cidr(net.prefix_len())); attributes.push(WireguardAllowedIpAttr::IpAddr(net.addr())); match net.addr() { IpAddr::V4(_) => { attributes.push(WireguardAllowedIpAttr::Family(WireguardAddressFamily::Ipv4)) } IpAddr::V6(_) => { attributes.push(WireguardAllowedIpAttr::Family(WireguardAddressFamily::Ipv6)) } } WireguardAllowedIp(attributes) }