From 75ccbd675c22fb3275c5763518c3b97819db4c53 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Fri, 18 Jul 2025 18:46:55 +0100 Subject: init --- src/conf.rs | 692 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/key.rs | 372 ++++++++++++++++++++++++++++++++ src/lib.rs | 401 ++++++++++++++++++++++++++++++++++ src/setup.rs | 212 ++++++++++++++++++ src/view.rs | 130 +++++++++++ 5 files changed, 1807 insertions(+) create mode 100644 src/conf.rs create mode 100644 src/key.rs create mode 100644 src/lib.rs create mode 100644 src/setup.rs create mode 100644 src/view.rs (limited to 'src') diff --git a/src/conf.rs b/src/conf.rs new file mode 100644 index 0000000..b6e49a0 --- /dev/null +++ b/src/conf.rs @@ -0,0 +1,692 @@ +use std::{fmt::Write, str::FromStr}; + +use ipnet::IpNet; + +use super::Key; + +const FIELD_PRIVATE_KEY: &str = "PrivateKey"; +const FIELD_LISTEN_PORT: &str = "ListenPort"; +const FIELD_FWMARK: &str = "FwMark"; +const FIELD_PUBLIC_KEY: &str = "PublicKey"; +const FIELD_PRE_SHARED_KEY: &str = "PresharedKey"; +const FIELD_ALLOWED_IPS: &str = "AllowedIPs"; +const FIELD_ENDPOINT: &str = "Endpoint"; +const FIELD_PERSISTENT_KEEPALIVE: &str = "PersistentKeepalive"; + +// wg-quick fields +const FIELD_ADDRESS: &str = "Address"; +const FIELD_DNS: &str = "DNS"; + +macro_rules! header { + ($dest:expr, $h:expr) => { + writeln!($dest, "[{}]", $h).unwrap(); + }; +} + +macro_rules! field { + ($dest:expr, $n:expr, $v:expr) => { + writeln!($dest, "{} = {}", $n, $v).unwrap(); + }; +} + +macro_rules! field_csv { + ($dest:expr, $n:expr, $v:expr) => { + if !$v.is_empty() { + write!($dest, "{} = ", $n).unwrap(); + let mut comma = false; + for e in $v.iter() { + if comma { + write!($dest, ", ").unwrap(); + } else { + comma = true; + } + write!($dest, "{}", e).unwrap(); + } + writeln!($dest).unwrap(); + } + }; +} + +macro_rules! field_opt { + ($dest:expr, $n:expr, $v:expr) => { + if let Some(ref v) = $v { + field!($dest, $n, v); + } + }; +} + +#[derive(Debug, Clone)] +pub struct WgInterface { + pub private_key: Key, + pub address: Vec, + pub listen_port: Option, + pub fw_mark: Option, + pub dns: Option, +} + +#[derive(Debug, Clone)] +pub struct WgPeer { + pub public_key: Key, + pub preshared_key: Option, + pub allowed_ips: Vec, + pub endpoint: Option, + pub keep_alive: Option, +} + +impl WgPeer { + pub fn builder(public_key: Key) -> WgPeerBuilder { + WgPeerBuilder::new(public_key) + } +} + +#[derive(Debug, Clone)] +pub struct WgConf { + pub interface: WgInterface, + pub peers: Vec, +} + +impl WgConf { + pub fn builder() -> WgConfBuilder { + WgConfBuilder::new() + } +} + +#[derive(Debug)] +pub struct WgPeerBuilder { + pub public_key: Key, + pub preshared_key: Option, + pub allowed_ips: Vec, + pub endpoint: Option, + pub keep_alive: Option, +} + +impl WgPeerBuilder { + pub fn new(public_key: Key) -> WgPeerBuilder { + WgPeerBuilder { + public_key, + preshared_key: None, + allowed_ips: Vec::new(), + endpoint: None, + keep_alive: None, + } + } + + pub fn preshared_key(mut self, preshared_key: Key) -> Self { + self.preshared_key = Some(preshared_key); + self + } + + pub fn allowed_ip(mut self, allowed_ip: IpNet) -> Self { + self.allowed_ips.push(allowed_ip); + self + } + + pub fn allowed_ips(mut self, allowed_ips: impl IntoIterator) -> Self { + self.allowed_ips.extend(allowed_ips); + self + } + + pub fn endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + pub fn endpoint_opt(mut self, endpoint: Option>) -> Self { + if let Some(endpoint) = endpoint { + self.endpoint = Some(endpoint.into()); + } + self + } + + pub fn keep_alive(mut self, keep_alive: u16) -> Self { + self.keep_alive = Some(keep_alive); + self + } + + pub fn build(self) -> WgPeer { + WgPeer { + public_key: self.public_key, + preshared_key: self.preshared_key, + allowed_ips: self.allowed_ips, + endpoint: self.endpoint, + keep_alive: self.keep_alive, + } + } +} + +#[derive(Debug)] +pub struct WgConfBuilder { + private_key: Option, + address: Vec, + listen_port: Option, + fw_mark: Option, + dns: Option, + peers: Vec, +} + +impl WgConfBuilder { + pub fn new() -> Self { + WgConfBuilder { + private_key: None, + address: Vec::new(), + listen_port: None, + fw_mark: None, + dns: None, + peers: Vec::new(), + } + } + + pub fn private_key(mut self, private_key: Key) -> Self { + self.private_key = Some(private_key); + self + } + + pub fn address(mut self, address: impl Into) -> Self { + self.address.push(address.into()); + self + } + + pub fn addresses(mut self, addresses: impl IntoIterator) -> Self { + self.address.extend(addresses); + self + } + + pub fn listen_port(mut self, listen_port: u16) -> Self { + self.listen_port = Some(listen_port); + self + } + + pub fn fw_mark(mut self, fw_mark: u32) -> Self { + self.fw_mark = Some(fw_mark); + self + } + + pub fn dns(mut self, dns: impl Into) -> Self { + self.dns = Some(dns.into()); + self + } + + pub fn dns_opt(mut self, dns: Option>) -> Self { + if let Some(dns) = dns { + self.dns = Some(dns.into()); + } + self + } + + pub fn peer(mut self, peer: WgPeer) -> Self { + self.peers.push(peer); + self + } + + pub fn peers(mut self, peers: impl IntoIterator) -> Self { + self.peers.extend(peers); + self + } + + pub fn build(self) -> WgConf { + WgConf { + interface: WgInterface { + private_key: self.private_key.unwrap_or_else(Key::generate_private), + address: self.address, + listen_port: self.listen_port, + fw_mark: self.fw_mark, + dns: self.dns, + }, + peers: self.peers, + } + } +} + +#[derive(Default)] +struct PartialConf { + interface: Option, + peers: Vec, +} + +pub fn parse_conf(conf: &str) -> anyhow::Result { + let mut iter = conf.lines().filter_map(|l| { + // remove whitespace on the sides + let l = l.trim(); + // remove the comment + let (l, _) = l.rsplit_once("#").unwrap_or((l, "")); + if l.is_empty() { + None + } else { + Some(l) + } + }); + + let mut partial = PartialConf::default(); + parse_partial(&mut partial, &mut iter)?; + + match partial.interface { + Some(interface) => Ok(WgConf { + interface, + peers: partial.peers, + }), + None => Err(anyhow::anyhow!("no interface found")), + } +} + +pub fn serialize_conf(conf: &WgConf) -> String { + let mut conf_str = String::new(); + header!(conf_str, "Interface"); + field!(conf_str, FIELD_PRIVATE_KEY, conf.interface.private_key); + field_csv!(conf_str, FIELD_ADDRESS, conf.interface.address); + field_opt!(conf_str, FIELD_LISTEN_PORT, conf.interface.listen_port); + field_opt!(conf_str, FIELD_FWMARK, conf.interface.fw_mark); + field_opt!(conf_str, FIELD_DNS, conf.interface.dns); + for peer in conf.peers.iter() { + writeln!(conf_str).unwrap(); + header!(conf_str, "Peer"); + field!(conf_str, FIELD_PUBLIC_KEY, peer.public_key); + field_opt!(conf_str, FIELD_PRE_SHARED_KEY, peer.preshared_key); + field_csv!(conf_str, FIELD_ALLOWED_IPS, peer.allowed_ips); + field_opt!(conf_str, FIELD_ENDPOINT, peer.endpoint); + field_opt!(conf_str, FIELD_PERSISTENT_KEEPALIVE, peer.keep_alive); + } + conf_str +} + +fn parse_partial<'s, I: Iterator>( + cfg: &mut PartialConf, + iter: &mut I, +) -> anyhow::Result<()> { + match iter.next() { + Some("[Interface]") => parse_interface(cfg, iter), + Some("[Peer]") => parse_peer(cfg, iter), + Some(line) => Err(anyhow::anyhow!("unexpected line: {}", line)), + None => Err(anyhow::anyhow!("unexpected end of file")), + } +} + +fn parse_interface<'s, I: Iterator>( + cfg: &mut PartialConf, + iter: &mut I, +) -> anyhow::Result<()> { + let mut private_key = None; + let mut address = Vec::new(); + let mut listen_port = None; + let mut fw_mark = None; + let mut dns = None; + let mut peer_next = false; + + if cfg.interface.is_some() { + anyhow::bail!("cannot have more than one interface"); + } + + while let Some(line) = iter.next() { + if line == "[Peer]" { + peer_next = true; + break; + } + + let (key, value) = parse_key_value(line)?; + match key { + FIELD_PRIVATE_KEY => private_key = Some(value.parse()?), + FIELD_LISTEN_PORT => listen_port = Some(value.parse()?), + FIELD_FWMARK => fw_mark = Some(value.parse()?), + FIELD_ADDRESS => address = parse_csv(value)?, + FIELD_DNS => dns = Some(value.to_string()), + _ => anyhow::bail!("unexpected key: {}", key), + } + } + + cfg.interface = Some(WgInterface { + private_key: private_key.ok_or_else(|| anyhow::anyhow!("interface missing private key"))?, + address, + listen_port, + fw_mark, + dns, + }); + + if peer_next { + parse_peer(cfg, iter) + } else { + Ok(()) + } +} + +fn parse_peer<'s, I: Iterator>( + cfg: &mut PartialConf, + iter: &mut I, +) -> anyhow::Result<()> { + let mut public_key = None; + let mut preshared_key = None; + let mut allowed_ips = Vec::new(); + let mut endpoint = None; + let mut keep_alive = None; + let mut interface_next = false; + let mut peer_next = false; + + while let Some(line) = iter.next() { + if line == "[Interface]" { + interface_next = true; + break; + } + if line == "[Peer]" { + peer_next = true; + break; + } + + let (key, value) = parse_key_value(line)?; + match key { + FIELD_PUBLIC_KEY => public_key = Some(value.parse()?), + FIELD_PRE_SHARED_KEY => preshared_key = Some(value.parse()?), + FIELD_ALLOWED_IPS => allowed_ips = parse_csv(value)?, + FIELD_ENDPOINT => endpoint = Some(value.to_string()), + FIELD_PERSISTENT_KEEPALIVE => keep_alive = Some(value.parse()?), + _ => anyhow::bail!("unexpected key: {}", key), + } + } + + cfg.peers.push(WgPeer { + public_key: public_key.ok_or_else(|| anyhow::anyhow!("peer missing public key"))?, + preshared_key, + allowed_ips, + endpoint, + keep_alive, + }); + + if interface_next { + parse_interface(cfg, iter) + } else if peer_next { + parse_peer(cfg, iter) + } else { + Ok(()) + } +} + +fn parse_key_value<'s>(line: &'s str) -> anyhow::Result<(&'s str, &'s str)> { + line.split_once("=") + .map(|(k, v)| (k.trim(), v.trim())) + .ok_or_else(|| anyhow::anyhow!("invalid line: {}", line)) +} + +fn parse_csv< + 'v, + T: FromStr, +>( + value: &'v str, +) -> anyhow::Result> { + let mut values = Vec::new(); + for v in value.split(',').map(str::trim) { + values.push(v.parse()?); + } + Ok(values) +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use ipnet::{IpNet, Ipv4Net}; + + use crate::Key; + + use super::{WgConfBuilder, WgPeerBuilder}; + + const TEST_CONF_1: &str = r#" + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 + + [Peer] + PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= + Endpoint = 192.95.5.67:1234 + AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + + [Peer] + PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= + Endpoint = [2607:5300:60:6b0::c05f:543]:2468 + AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 + + [Peer] + PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= + Endpoint = test.wireguard.com:18981 + + AllowedIPs = 10.10.10.230/32 + PersistentKeepalive = 54 +"#; + + const TEST_CONF_2: &str = r#" + [Peer] + PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= + Endpoint = 192.95.5.67:1234 + AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + + [Peer] + PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= + Endpoint = [2607:5300:60:6b0::c05f:543]:2468 + AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 + + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 + + [Peer] + PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= + Endpoint = test.wireguard.com:18981 + + AllowedIPs = 10.10.10.230/32 + PersistentKeepalive = 54 +"#; + + const TEST_CONF_3: &str = r#" + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 + + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51821 + + [Peer] + PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= + Endpoint = test.wireguard.com:18981 + AllowedIPs = 10.10.10.230/32 +"#; + + const TEST_CONF_4: &str = ""; + + const TEST_CONF_5: &str = r#" + PublicKey = 1 + + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 + + [Peer] + PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= + Endpoint = test.wireguard.com:18981 + AllowedIPs = 10.10.10.230/32 +"#; + + const TEST_CONF_6: &str = r#" + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 + Unknown = 1 + + [Peer] + PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= + Endpoint = test.wireguard.com:18981 + AllowedIPs = 10.10.10.230/32 +"#; + + const TEST_CONF_7: &str = r#" + [Interface] + PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= + ListenPort = 51820 +"#; + + #[test] + fn parse_config() { + parse_config_1_and_2(TEST_CONF_1); + } + + #[test] + fn parse_config_out_of_order_interface() { + parse_config_1_and_2(TEST_CONF_2); + } + + #[test] + #[should_panic] + fn parse_config_duplicate_interface() { + super::parse_conf(TEST_CONF_3).unwrap(); + } + + #[test] + #[should_panic] + fn parse_config_empty() { + super::parse_conf(TEST_CONF_4).unwrap(); + } + + #[test] + #[should_panic] + fn parse_config_out_of_order_field() { + super::parse_conf(TEST_CONF_5).unwrap(); + } + + #[test] + #[should_panic] + fn parse_config_unkown_field() { + super::parse_conf(TEST_CONF_6).unwrap(); + } + + #[test] + fn parse_config_no_peers() { + let cfg = super::parse_conf(TEST_CONF_7).unwrap(); + + assert_eq!( + "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", + cfg.interface.private_key.to_string(), + ); + assert_eq!(Some(51820), cfg.interface.listen_port); + assert_eq!(None, cfg.interface.fw_mark); + + assert_eq!(0, cfg.peers.len()); + } + + fn parse_config_1_and_2(conf_str: &str) { + let cfg = super::parse_conf(conf_str).unwrap(); + + assert_eq!( + "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", + cfg.interface.private_key.to_string() + ); + assert_eq!(Some(51820), cfg.interface.listen_port); + assert_eq!(None, cfg.interface.fw_mark); + + assert_eq!(3, cfg.peers.len()); + + let peer = &cfg.peers[0]; + assert_eq!( + "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", + peer.public_key.to_string() + ); + assert_eq!(None, peer.preshared_key); + assert_eq!(2, peer.allowed_ips.len()); + assert_eq!(Some("192.95.5.67:1234"), peer.endpoint.as_deref()); + assert_eq!(None, peer.keep_alive); + + let peer = &cfg.peers[1]; + assert_eq!( + "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", + peer.public_key.to_string() + ); + assert_eq!(None, peer.preshared_key); + assert_eq!(2, peer.allowed_ips.len()); + assert_eq!( + Some("[2607:5300:60:6b0::c05f:543]:2468"), + peer.endpoint.as_deref() + ); + assert_eq!(None, peer.keep_alive); + + let peer = &cfg.peers[2]; + assert_eq!( + "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", + peer.public_key.to_string() + ); + assert_eq!(None, peer.preshared_key); + assert_eq!(1, peer.allowed_ips.len()); + assert_eq!(Some("test.wireguard.com:18981"), peer.endpoint.as_deref()); + assert_eq!(Some(54), peer.keep_alive); + } + + #[test] + fn serialize_no_peers() { + let key = Key::decode("yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=").unwrap(); + let conf = WgConfBuilder::new() + .fw_mark(10) + .listen_port(6000) + .dns("dns.example.com") + .address(IpNet::V4( + Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 5), 24).unwrap(), + )) + .private_key(key) + .build(); + let serialized = super::serialize_conf(&conf); + + assert_eq!( + r#"[Interface] +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +Address = 10.0.0.5/24 +ListenPort = 6000 +FwMark = 10 +DNS = dns.example.com +"#, + serialized + ); + } + + #[test] + fn serialize_with_peers() { + let key1 = Key::decode("xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=").unwrap(); + let key2 = Key::decode("TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=").unwrap(); + let key3 = Key::decode("gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=").unwrap(); + + let conf = WgConfBuilder::new() + .private_key(key1) + .listen_port(51820) + .dns("dns.example.com") + .peer( + WgPeerBuilder::new(key2) + .keep_alive(10) + .endpoint("test.wireguard.com:18981") + .allowed_ip(ipnet::IpNet::V4( + Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 2), 24).unwrap(), + )) + .build(), + ) + .peer( + WgPeerBuilder::new(key3) + .allowed_ip(ipnet::IpNet::V4( + Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 3), 24).unwrap(), + )) + .build(), + ) + .build(); + + let serialized = super::serialize_conf(&conf); + + assert_eq!( + r#"[Interface] +PrivateKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +ListenPort = 51820 +DNS = dns.example.com + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +AllowedIPs = 10.0.0.2/24 +Endpoint = test.wireguard.com:18981 +PersistentKeepalive = 10 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +AllowedIPs = 10.0.0.3/24 +"#, + serialized + ); + } +} diff --git a/src/key.rs b/src/key.rs new file mode 100644 index 0000000..19bc127 --- /dev/null +++ b/src/key.rs @@ -0,0 +1,372 @@ +use base64::Engine; +use netlink_packet_wireguard::constants::WG_KEY_LEN; +use rand::{rngs::OsRng, RngCore}; + +// Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c + +type Fe = [i64; 16]; + +#[derive(Debug)] +pub struct KeyDecodeError; + +impl std::fmt::Display for KeyDecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Key decode error") + } +} + +impl std::error::Error for KeyDecodeError {} + +#[derive(Clone, Default, Copy, PartialEq, Eq, Hash)] +pub struct Key([u8; WG_KEY_LEN]); + +impl Key { + pub fn into_array(self) -> [u8; WG_KEY_LEN] { + self.0 + } + + pub fn as_array(&self) -> &[u8; WG_KEY_LEN] { + &self.0 + } + + pub fn as_slice(&self) -> &[u8] { + &self.0 + } + + pub fn encode(&self) -> String { + base64::engine::general_purpose::STANDARD.encode(&self.0) + } + + pub fn decode(encoded: &str) -> Result { + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|_| KeyDecodeError)?; + if decoded.len() != WG_KEY_LEN { + return Err(KeyDecodeError); + } + let mut key = [0u8; WG_KEY_LEN]; + key.copy_from_slice(&decoded); + Ok(Key(key)) + } + + pub fn generate_pub_priv() -> (Self, Self) { + let private_key = Self::generate_private(); + let public_key = Self::generate_public(&private_key); + (public_key, private_key) + } + + pub fn generate_public(private: &Key) -> Self { + let mut r: i32 = Default::default(); + let mut public_key: [u8; WG_KEY_LEN] = Default::default(); + let mut z: [u8; WG_KEY_LEN] = private.0; + let mut a = fe_new_one(1); + let mut b = fe_new_one(9); + let mut c = fe_new_one(0); + let mut d = fe_new_one(1); + let mut e = fe_new_default(); + let mut f = fe_new_default(); + + clamp_key(&mut z); + + for i in (0..=254i32).rev() { + r = ((z[(i >> 3) as usize] >> (i & 7)) & 1) as i32; + cswap(&mut a, &mut b, r); + cswap(&mut c, &mut d, r); + add(&mut e, &a, &c); + { + let a_clone = a; + subtract(&mut a, &a_clone, &c); + } + add(&mut c, &b, &d); + { + let b_clone = b; + subtract(&mut b, &b_clone, &d); + } + multmod(&mut d, &e, &e); + multmod(&mut f, &a, &a); + { + let a_clone = a; + multmod(&mut a, &c, &a_clone); + } + multmod(&mut c, &b, &e); + add(&mut e, &a, &c); + { + let a_clone = a; + subtract(&mut a, &a_clone, &c); + } + multmod(&mut b, &a, &a); + subtract(&mut c, &d, &f); + //multmod(&mut a, &c, (const fe){ 0xdb41, 1 }); + multmod(&mut a, &c, &fe_new_two(0xdb41, 1)); + { + let a_clone = a; + add(&mut a, &a_clone, &d); + } + { + let c_clone = c; + multmod(&mut c, &c_clone, &a); + } + multmod(&mut a, &d, &f); + multmod(&mut d, &b, &fe_new_one(9)); + multmod(&mut b, &e, &e); + cswap(&mut a, &mut b, r); + cswap(&mut c, &mut d, r); + } + { + let c_clone = c; + invert(&mut c, &c_clone); + } + { + let a_clone = a; + multmod(&mut a, &a_clone, &c); + } + pack(&mut public_key, &a); + + memzero_explicit(&mut r); + memzero_explicit(&mut z); + memzero_explicit(&mut a); + memzero_explicit(&mut b); + memzero_explicit(&mut c); + memzero_explicit(&mut d); + memzero_explicit(&mut e); + memzero_explicit(&mut f); + + Self(public_key) + } + + pub fn generate_private() -> Self { + let mut preshared = Self::generate_preshared(); + clamp_key(&mut preshared.0); + preshared + } + + pub fn generate_preshared() -> Self { + let mut key = [0u8; WG_KEY_LEN]; + OsRng.fill_bytes(&mut key); + Self(key) + } +} + +impl From<&[u8; WG_KEY_LEN]> for Key { + fn from(k: &[u8; WG_KEY_LEN]) -> Self { + Self(*k) + } +} + +impl From<[u8; WG_KEY_LEN]> for Key { + fn from(k: [u8; WG_KEY_LEN]) -> Self { + Self(k) + } +} + +impl std::str::FromStr for Key { + type Err = KeyDecodeError; + + fn from_str(s: &str) -> Result { + Key::decode(s) + } +} + +impl std::fmt::Debug for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut buf = [0; WG_KEY_LEN * 2]; + let len = base64::engine::general_purpose::STANDARD + .encode_slice(&self.0, &mut buf) + .expect("base64 should encode"); + let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); + f.debug_tuple("Key").field(&b64).finish() + } +} + +impl std::fmt::Display for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut buf = [0; WG_KEY_LEN * 2]; + let len = base64::engine::general_purpose::STANDARD + .encode_slice(&self.0, &mut buf) + .expect("base64 should encode"); + let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); + f.write_str(b64) + } +} + +impl serde::Serialize for Key { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.encode()) + } +} + +impl<'de> serde::Deserialize<'de> for Key { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Key::decode(&s).map_err(serde::de::Error::custom) + } +} + +fn fe_new_default() -> Fe { + Default::default() +} + +fn fe_new_one(x: i64) -> Fe { + let mut fe = fe_new_default(); + fe[0] = x; + fe +} + +fn fe_new_two(x: i64, y: i64) -> Fe { + let mut fe = fe_new_default(); + fe[0] = x; + fe[1] = y; + fe +} + +fn clamp_key(key: &mut [u8]) { + key[31] = (key[31] & 127) | 64; + key[0] &= 248; +} + +fn carry(o: &mut Fe) { + for i in 0..16 { + let x = if i == 15 { 38 } else { 1 }; + o[(i + 1) % 16] += x * (o[i] >> 16); + o[i] &= 0xffff; + } +} + +fn cswap(p: &mut Fe, q: &mut Fe, mut b: i32) { + let mut t: i64 = 0; + let mut c: i64 = !i64::from(b).wrapping_sub(1); + + for i in 0..16 { + t = c & (p[i] ^ q[i]); + p[i] ^= t; + q[i] ^= t; + } + + memzero_explicit(&mut t); + memzero_explicit(&mut c); + memzero_explicit(&mut b); +} + +fn pack(o: &mut [u8; WG_KEY_LEN], n: &Fe) { + let mut b: i32 = 0; + let mut t: Fe = fe_new_default(); + let mut m: Fe = fe_new_default(); + + t.copy_from_slice(n); + carry(&mut t); + carry(&mut t); + carry(&mut t); + for _ in 0..2 { + m[0] = t[0] - 0xffed; + for i in 1..15 { + m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1); + m[i - 1] &= 0xffff; + } + m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1); + b = ((m[15] >> 16) & 1) as i32; + m[14] &= 0xffff; + cswap(&mut t, &mut m, 1 - b); + } + for i in 0..16 { + o[2 * i] = (t[i] & 0xff) as u8; + o[2 * i + 1] = (t[i] >> 8) as u8; + } + + memzero_explicit(&mut m); + memzero_explicit(&mut t); + memzero_explicit(&mut b); +} + +fn add(o: &mut Fe, a: &Fe, b: &Fe) { + for i in 0..16 { + o[i] = a[i] + b[i]; + } +} + +fn subtract(o: &mut Fe, a: &Fe, b: &Fe) { + for i in 0..16 { + o[i] = a[i] - b[i]; + } +} + +fn multmod(o: &mut Fe, a: &Fe, b: &Fe) { + let mut t: [i64; 31] = [0; 31]; + + for i in 0..16 { + for j in 0..16 { + t[i + j] += a[i] * b[j]; + } + } + for i in 0..15 { + t[i] += 38 * t[i + 16]; + } + o.copy_from_slice(&t[..16]); + carry(o); + carry(o); + + memzero_explicit(&mut t); +} + +fn invert(o: &mut Fe, i: &Fe) { + let mut c: Fe = fe_new_default(); + + c.copy_from_slice(i); + for a in (0..=253).rev() { + { + let c_clone = c; + multmod(&mut c, &c_clone, &c_clone); + } + if a != 2 && a != 4 { + { + let c_clone = c; + multmod(&mut c, &c_clone, i); + } + } + } + o.copy_from_slice(&c); + + memzero_explicit(&mut c); +} + +fn memzero_explicit(v: &mut T) { + unsafe { + let zeroed = std::mem::zeroed(); + std::ptr::write_volatile(v as *mut _, zeroed); + } +} + +#[cfg(test)] +mod tests { + use super::Key; + + #[test] + fn decode_encode_key() { + let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; + let key = super::Key::decode(key).unwrap(); + let key = key.encode(); + assert_eq!(key, "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); + } + + #[test] + fn generate_public_key() { + assert_eq!( + Key::decode("3D5lgnI9ztvnuyWDm7dlBDgm6xr0+WVWPoo6HIfzHRU=").unwrap(), + Key::generate_public( + &Key::decode("+Op7voRskU0Zm2fHFR/5tVE+PJtnwn6cbnme71jXt0E=").unwrap() + ) + ); + + assert_eq!( + Key::decode("//eq/raPUE4+sOlTlozx76XEE+W8L0bUqNfyg9IpX0Q=").unwrap(), + Key::generate_public( + &Key::decode("8OD8QPWH/a0D5LmbWVnb7bwFq4Fghy/QUEFkIhyL/EI=").unwrap() + ) + ); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c47d618 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,401 @@ +mod conf; +mod key; +mod setup; +mod view; + +use std::borrow::Cow; + +use futures::{StreamExt, TryStreamExt}; +use genetlink::{GenetlinkError, GenetlinkHandle}; +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST}; +use netlink_packet_generic::GenlMessage; +use netlink_packet_route::{ + link::{InfoKind, LinkAttribute, LinkInfo}, + route::RouteScope, +}; +use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd}; +use rtnetlink::Handle; + +pub use conf::*; +pub use key::*; +pub use setup::*; +pub use view::*; + +pub use ipnet; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub struct Error { + inner: Option>, + message: Option>, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match (self.message.as_ref(), self.inner.as_ref()) { + (Some(message), Some(inner)) => write!(f, "{}: {}", message, inner), + (Some(message), None) => write!(f, "{}", message), + (None, Some(inner)) => write!(f, "{}", inner), + (None, None) => write!(f, "Unknown error"), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(inner: std::io::Error) -> Self { + Self { + inner: Some(Box::new(inner)), + message: None, + } + } +} + +impl From for Error { + fn from(inner: GenetlinkError) -> Self { + Self { + inner: Some(Box::new(inner)), + message: None, + } + } +} + +impl From for Error { + fn from(inner: rtnetlink::Error) -> Self { + Self { + inner: Some(Box::new(inner)), + message: None, + } + } +} + +impl Error { + pub(crate) fn with_message(inner: E, message: impl Into>) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + inner: Some(Box::new(inner)), + message: Some(message.into()), + } + } + + pub(crate) fn message(message: impl Into>) -> Self { + Self { + inner: None, + message: Some(message.into()), + } + } +} + +struct Link { + pub name: String, + pub ifindex: u32, +} + +pub struct WireGuard { + rt_handle: Handle, + gen_handle: GenetlinkHandle, +} + +#[allow(clippy::await_holding_refcell_ref)] +impl WireGuard { + pub async fn new() -> Result { + let (rt_connection, rt_handle, _) = rtnetlink::new_connection()?; + tokio::spawn(rt_connection); + let (gen_connection, gen_handle, _) = genetlink::new_connection()?; + tokio::spawn(gen_connection); + + Ok(Self { + rt_handle, + gen_handle, + }) + } + + pub async fn create_device( + &mut self, + device_name: &str, + descriptor: DeviceDescriptor, + ) -> Result<()> { + tracing::trace!("Creating device {}", device_name); + self.link_create(device_name).await?; + let link = self.link_get_by_name(device_name).await?; + self.link_up(link.ifindex).await?; + self.setup_device(device_name, descriptor).await?; + tracing::trace!("Created device"); + Ok(()) + } + + pub async fn reload_device( + &mut self, + device_name: &str, + descriptor: DeviceDescriptor, + ) -> Result<()> { + tracing::trace!("Reloading device {}", device_name); + self.setup_device(device_name, descriptor).await?; + tracing::trace!("Reloaded device"); + Ok(()) + } + + pub async fn remove_device(&self, device_name: &str) -> Result<()> { + tracing::trace!("Removing device {}", device_name); + let link = self.link_get_by_name(device_name).await?; + self.link_down(link.ifindex).await?; + self.link_delete(link.ifindex).await?; + tracing::trace!("Removed device"); + Ok(()) + } + + pub async fn view_device(&mut self, device_name: &str) -> Result { + let genlmsg: GenlMessage = GenlMessage::from_payload(Wireguard { + cmd: WireguardCmd::GetDevice, + nlas: vec![WgDeviceAttrs::IfName(device_name.to_string())], + }); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST | NLM_F_DUMP; + + let mut resp = self.gen_handle.request(nlmsg).await?; + while let Some(result) = resp.next().await { + let rx_packet = result.map_err(|e| Error::with_message(e, "Error decoding packet"))?; + match rx_packet.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + return device_view_from_payload(genlmsg.payload); + } + NetlinkPayload::Error(e) => { + return Err(Error::message(format!("Error: {:?}", e))); + } + _ => (), + }; + } + unreachable!(); + } + + pub async fn view_device_if_exists(&mut self, device_name: &str) -> Result> { + let device_names = self.list_device_names().await?; + if device_names.iter().any(|name| name == device_name) { + Ok(Some(self.view_device(device_name).await?)) + } else { + Ok(None) + } + } + + pub async fn view_devices(&mut self) -> Result> { + let device_names = self.list_device_names().await?; + let mut devices = Vec::with_capacity(device_names.len()); + for name in device_names { + let device = self.view_device(&name).await?; + devices.push(device); + } + Ok(devices) + } + + pub async fn list_device_names(&self) -> Result> { + Ok(self + .link_list() + .await? + .into_iter() + .map(|link| link.name) + .collect()) + } + + async fn setup_device( + &mut self, + device_name: &str, + descriptor: DeviceDescriptor, + ) -> Result<()> { + tracing::trace!("Setting up device {}", device_name); + + let link = self.link_get_by_name(device_name).await?; + for addr in descriptor.addresses.iter() { + self.link_add_address(link.ifindex, *addr).await?; + } + + let message = descriptor.into_wireguard(device_name.to_string()); + let genlmsg: GenlMessage = GenlMessage::from_payload(message); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK; + + let mut stream = self.gen_handle.request(nlmsg).await?; + while (stream.next().await).is_some() {} + tracing::trace!("Device setup"); + + Ok(()) + } + + async fn link_create(&self, name: &str) -> Result<()> { + let mut msg = self.rt_handle.link().add().replace(); + msg.message_mut() + .attributes + .push(LinkAttribute::LinkInfo(vec![LinkInfo::Kind( + InfoKind::Wireguard, + )])); + msg.message_mut() + .attributes + .push(LinkAttribute::IfName(name.to_string())); + msg.execute().await?; + Ok(()) + } + + async fn link_delete(&self, ifindex: u32) -> Result<()> { + self.rt_handle.link().del(ifindex).execute().await?; + Ok(()) + } + + async fn link_up(&self, ifindex: u32) -> Result<()> { + tracing::trace!("Bringing up interface {}", ifindex); + self.rt_handle.link().set(ifindex).up().execute().await?; + Ok(()) + } + + async fn link_down(&self, ifindex: u32) -> Result<()> { + tracing::trace!("Bringing down interface {}", ifindex); + self.rt_handle.link().set(ifindex).down().execute().await?; + Ok(()) + } + + async fn link_add_address(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> { + tracing::trace!("Adding address {} to {}", net, ifindex); + self.rt_handle + .address() + .add(ifindex, net.addr(), net.prefix_len()) + .replace() + .execute() + .await?; + Ok(()) + } + + //TODO: return Result>? + async fn link_get_by_name(&self, name: &str) -> Result { + let link = self + .link_list() + .await? + .into_iter() + .find(|link| link.name == name) + .ok_or_else(|| Error::message(format!("Link {} not found", name)))?; + tracing::debug!("device {} has index {}", name, link.ifindex); + Ok(link) + } + + async fn link_list(&self) -> Result> { + let mut links = Vec::new(); + let mut link_stream = self.rt_handle.link().get().execute(); + while let Some(link) = link_stream.try_next().await? { + let mut is_wireguard = false; + let mut link_name = None; + for nla in link.attributes { + match nla { + LinkAttribute::IfName(name) => link_name = Some(name), + LinkAttribute::LinkInfo(infos) => { + for info in infos { + if let netlink_packet_route::link::LinkInfo::Kind(kind) = info { + if kind == netlink_packet_route::link::InfoKind::Wireguard { + is_wireguard = true; + break; + } + } + } + } + _ => {} + } + if is_wireguard && link_name.is_some() { + links.push(Link { + name: link_name.unwrap(), + ifindex: link.header.index, + }); + break; + } + } + } + Ok(links) + } + + #[allow(unused)] + async fn route_add(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> { + tracing::trace!("Adding route {} to {}", net, ifindex); + let request = self + .rt_handle + .route() + .add() + .scope(RouteScope::Link) + .output_interface(ifindex) + .replace(); + + match net.addr() { + std::net::IpAddr::V4(ip) => { + request + .v4() + .destination_prefix(ip, net.prefix_len()) + .execute() + .await + } + std::net::IpAddr::V6(ip) => { + request + .v6() + .destination_prefix(ip, net.prefix_len()) + .execute() + .await + } + }?; + + Ok(()) + } +} + +pub async fn create_device( + device_name: impl AsRef, + device_descriptor: DeviceDescriptor, +) -> Result<()> { + tracing::info!("creating device {}", device_name.as_ref()); + tracing::debug!("device descriptor: {:#?}", device_descriptor); + let mut wireguard = WireGuard::new().await?; + wireguard + .create_device(device_name.as_ref(), device_descriptor) + .await +} + +pub async fn reload_device( + device_name: impl AsRef, + device_descriptor: DeviceDescriptor, +) -> Result<()> { + tracing::info!("reloading device {}", device_name.as_ref()); + tracing::debug!("device descriptor: {:#?}", device_descriptor); + let mut wireguard = WireGuard::new().await?; + wireguard + .reload_device(device_name.as_ref(), device_descriptor) + .await +} + +pub async fn device_exists(name: impl AsRef) -> Result { + tracing::info!("checking if device {} exists", name.as_ref()); + let mut wireguard = WireGuard::new().await?; + wireguard + .view_device_if_exists(name.as_ref()) + .await + .map(|x| x.is_some()) +} + +pub async fn remove_device(name: impl AsRef) -> Result<()> { + tracing::info!("removing device {}", name.as_ref()); + let wireguard = WireGuard::new().await?; + wireguard.remove_device(name.as_ref()).await +} + +pub async fn view_device(name: impl AsRef) -> Result { + tracing::info!("viewing device {}", name.as_ref()); + let mut wireguard = WireGuard::new().await?; + wireguard.view_device(name.as_ref()).await +} + +pub async fn view_device_if_exists(name: impl AsRef) -> Result> { + tracing::info!("viewing device {}", name.as_ref()); + let mut wireguard = WireGuard::new().await?; + wireguard.view_device_if_exists(name.as_ref()).await +} + +pub async fn list_device_names() -> Result> { + tracing::info!("listing device names"); + let wireguard = WireGuard::new().await?; + wireguard.list_device_names().await +} diff --git a/src/setup.rs b/src/setup.rs new file mode 100644 index 0000000..e7d454c --- /dev/null +++ b/src/setup.rs @@ -0,0 +1,212 @@ +use std::net::{IpAddr, SocketAddr}; + +use ipnet::IpNet; +use netlink_packet_wireguard::{ + constants::{AF_INET, AF_INET6, WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REPLACE_ALLOWEDIPS}, + nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, + Wireguard, WireguardCmd, +}; + +use super::Key; + +#[derive(Debug)] +pub struct PeerDescriptor { + pub(super) public_key: Key, + pub(super) preshared_key: Option, + pub(super) endpoint: Option, + pub(super) keepalive: Option, + pub(super) allowed_ips: Option>, +} + +impl PeerDescriptor { + pub fn new(public_key: Key) -> Self { + Self { + public_key, + preshared_key: None, + endpoint: None, + keepalive: None, + allowed_ips: None, + } + } + + pub fn preshared_key_optional(mut self, preshared_key: Option) -> Self { + self.preshared_key = preshared_key; + self + } + + pub fn preshared_key(mut self, preshared_key: Key) -> Self { + self.preshared_key = Some(preshared_key); + self + } + + pub fn endpoint_optional(mut self, endpoint: Option) -> Self { + self.endpoint = endpoint; + self + } + + pub fn endpoint(mut self, endpoint: SocketAddr) -> Self { + self.endpoint = Some(endpoint); + self + } + + pub fn keepalive_optional(mut self, keepalive: Option) -> Self { + self.keepalive = keepalive; + self + } + + pub fn keepalive(mut self, keepalive: u16) -> Self { + self.keepalive = Some(keepalive); + self + } + + pub fn allowed_ip_optional(self, allowed_ip: Option) -> Self { + if let Some(allowed_ip) = allowed_ip { + self.allowed_ip(allowed_ip) + } else { + self + } + } + + pub fn allowed_ip(mut self, allowed_ip: IpNet) -> Self { + let mut allowed_ips = self.allowed_ips.take().unwrap_or_default(); + allowed_ips.push(allowed_ip); + self.allowed_ips = Some(allowed_ips); + self + } + + pub fn allowed_ips_optional(self, allowed_ips: Option>) -> Self { + if let Some(allowed_ips) = allowed_ips { + self.allowed_ips(allowed_ips) + } else { + self + } + } + + pub fn allowed_ips(mut self, allowed_ips: Vec) -> Self { + self.allowed_ips = Some(allowed_ips); + self + } + + pub(super) fn into_wireguard(self) -> WgPeer { + let mut nlas = Vec::new(); + nlas.push(WgPeerAttrs::PublicKey(self.public_key.into_array())); + nlas.extend( + self.preshared_key + .map(|key| WgPeerAttrs::PresharedKey(key.into_array())), + ); + nlas.extend(self.endpoint.map(WgPeerAttrs::Endpoint)); + nlas.extend(self.keepalive.map(WgPeerAttrs::PersistentKeepalive)); + nlas.extend(self.allowed_ips.map(|allowed_ips| { + WgPeerAttrs::AllowedIps(allowed_ips.into_iter().map(ipnet_to_wg).collect()) + })); + nlas.push(WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS)); + WgPeer(nlas) + } +} + +#[derive(Debug)] +pub struct DeviceDescriptor { + pub(super) addresses: Vec, + pub(super) private_key: Option, + pub(super) listen_port: Option, + pub(super) fwmark: Option, + pub(super) peers: Option>, +} + +impl Default for DeviceDescriptor { + fn default() -> Self { + Self::new() + } +} + +impl DeviceDescriptor { + pub fn new() -> Self { + Self { + addresses: Vec::default(), + private_key: None, + listen_port: None, + fwmark: None, + peers: None, + } + } + + pub fn address(mut self, address: IpNet) -> Self { + self.addresses.push(address); + self + } + + pub fn addresses(mut self, addresses: impl IntoIterator) -> Self { + self.addresses.extend(addresses); + self + } + + pub fn private_key(mut self, key: Key) -> Self { + self.private_key = Some(key); + self + } + + pub fn listen_port(mut self, port: u16) -> Self { + self.listen_port = Some(port); + self + } + + pub fn listen_port_optional(mut self, port: Option) -> Self { + self.listen_port = port; + self + } + + pub fn fwmark(mut self, fwmark: u32) -> Self { + self.fwmark = Some(fwmark); + self + } + + pub fn peer(mut self, peer: PeerDescriptor) -> Self { + let mut p = self.peers.take().unwrap_or_default(); + p.push(peer); + self.peers = Some(p); + self + } + + pub fn peers(mut self, peers: impl IntoIterator) -> Self { + let mut p = self.peers.take().unwrap_or_default(); + p.extend(peers); + self.peers = Some(p); + self + } + + pub(super) fn into_wireguard(self, device_name: String) -> Wireguard { + let mut nlas = Vec::new(); + nlas.push(WgDeviceAttrs::IfName(device_name)); + nlas.extend( + self.private_key + .map(|key| WgDeviceAttrs::PrivateKey(key.into_array())), + ); + nlas.extend(self.listen_port.map(WgDeviceAttrs::ListenPort)); + nlas.extend(self.fwmark.map(WgDeviceAttrs::Fwmark)); + nlas.extend(self.peers.map(|peers| { + WgDeviceAttrs::Peers( + peers + .into_iter() + .map(PeerDescriptor::into_wireguard) + .collect(), + ) + })); + nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); + + Wireguard { + cmd: WireguardCmd::SetDevice, + nlas, + } + } +} + +fn ipnet_to_wg(net: IpNet) -> WgAllowedIp { + let mut nlas = Vec::default(); + nlas.push(WgAllowedIpAttrs::Cidr(net.prefix_len())); + nlas.push(WgAllowedIpAttrs::IpAddr(net.addr())); + match net.addr() { + IpAddr::V4(_) => nlas.push(WgAllowedIpAttrs::Family(AF_INET)), + IpAddr::V6(_) => nlas.push(WgAllowedIpAttrs::Family(AF_INET6)), + } + WgAllowedIp(nlas) +} diff --git a/src/view.rs b/src/view.rs new file mode 100644 index 0000000..2858811 --- /dev/null +++ b/src/view.rs @@ -0,0 +1,130 @@ +use std::{net::SocketAddr, time::SystemTime}; + +use ipnet::IpNet; +use netlink_packet_wireguard::{ + nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, + Wireguard, +}; + +use super::{Error, Key, Result}; + +#[derive(Debug, Clone)] +pub struct DeviceView { + pub name: String, + pub ifindex: u32, + pub private_key: Option, + pub public_key: Option, + pub listen_port: u16, + pub fwmark: u32, + pub peers: Vec, +} + +#[derive(Debug, Clone)] +pub struct PeerView { + pub public_key: Key, + pub preshared_key: Option, + pub endpoint: Option, + pub persistent_keepalive: Option, + pub last_handshake: SystemTime, + pub rx_bytes: u64, + pub tx_bytes: u64, + pub allowed_ips: Vec, +} + +pub(super) fn device_view_from_payload(wg: Wireguard) -> Result { + let mut if_index = None; + let mut if_name = None; + let mut private_key = None; + let mut public_key = None; + let mut listen_port = None; + let mut fwmark = None; + let mut peers = None; + + for nla in wg.nlas { + match nla { + WgDeviceAttrs::IfIndex(v) => if_index = Some(v), + WgDeviceAttrs::IfName(v) => if_name = Some(v), + WgDeviceAttrs::PrivateKey(v) => private_key = Some(Key::from(v)), + WgDeviceAttrs::PublicKey(v) => public_key = Some(Key::from(v)), + WgDeviceAttrs::ListenPort(v) => listen_port = Some(v), + WgDeviceAttrs::Fwmark(v) => fwmark = Some(v), + WgDeviceAttrs::Peers(v) => peers = Some(peers_from_wg_peers(v)?), + _ => {} + } + } + + Ok(DeviceView { + name: if_name.ok_or_else(|| Error::message("missing if_name"))?, + ifindex: if_index.ok_or_else(|| Error::message("missing if_index"))?, + private_key, + public_key, + listen_port: listen_port.ok_or_else(|| Error::message("missing listen_port"))?, + fwmark: fwmark.ok_or_else(|| Error::message("missing fwmark"))?, + peers: peers.unwrap_or_default(), + }) +} + +fn peers_from_wg_peers(wg_peers: Vec) -> Result> { + let mut peers = Vec::with_capacity(wg_peers.len()); + for wg_peer in wg_peers { + peers.push(peer_from_wg_peer(wg_peer)?); + } + Ok(peers) +} + +fn peer_from_wg_peer(wg_peer: WgPeer) -> Result { + let mut public_key = None; + let mut preshared_key = None; + let mut endpoint = None; + let mut persistent_keepalive = None; + let mut last_handshake = None; + let mut rx_bytes = None; + let mut tx_bytes = None; + let mut allowed_ips = Vec::default(); + + for attr in wg_peer.iter() { + match attr { + WgPeerAttrs::PublicKey(v) => public_key = Some(Key::from(v)), + WgPeerAttrs::PresharedKey(v) => preshared_key = Some(Key::from(v)), + WgPeerAttrs::Endpoint(v) => endpoint = Some(*v), + WgPeerAttrs::PersistentKeepalive(v) => persistent_keepalive = Some(*v), + WgPeerAttrs::LastHandshake(v) => last_handshake = Some(*v), + WgPeerAttrs::RxBytes(v) => rx_bytes = Some(*v), + WgPeerAttrs::TxBytes(v) => tx_bytes = Some(*v), + WgPeerAttrs::AllowedIps(v) => { + for ip in v { + allowed_ips.push(ipnet_from_wg(ip)?); + } + } + _ => {} + } + } + + Ok(PeerView { + public_key: public_key.ok_or_else(|| Error::message("missing public_key"))?, + preshared_key, + endpoint, + persistent_keepalive, + last_handshake: last_handshake.ok_or_else(|| Error::message("missing last_handshake"))?, + rx_bytes: rx_bytes.ok_or_else(|| Error::message("missing rx_bytes"))?, + tx_bytes: tx_bytes.ok_or_else(|| Error::message("missing tx_bytes"))?, + allowed_ips, + }) +} + +fn ipnet_from_wg(wg: &WgAllowedIp) -> Result { + let mut ip = None; + let mut prefix = None; + for attr in wg.iter() { + match attr { + WgAllowedIpAttrs::IpAddr(v) => ip = Some(*v), + WgAllowedIpAttrs::Cidr(v) => prefix = Some(*v), + _ => {} + } + } + Ok(IpNet::new( + ip.ok_or_else(|| Error::message("missing ip"))?, + prefix.ok_or_else(|| Error::message("missing prefix"))?, + ) + .map_err(|e| Error::with_message(e, "invalid ipnet"))?) +} -- cgit