mod conf; mod key; mod setup; mod view; use std::{ borrow::Cow, net::{Ipv4Addr, Ipv6Addr}, }; use futures::{StreamExt, TryStreamExt}; use genetlink::{GenetlinkError, GenetlinkHandle}; use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST}; use netlink_packet_generic::GenlMessage; use netlink_packet_route::{link::LinkAttribute, route::RouteScope}; use netlink_packet_wireguard::{WireguardAttribute, WireguardCmd, WireguardMessage}; use rtnetlink::{Handle, LinkMessageBuilder, LinkSetRequest, LinkWireguard, RouteMessageBuilder}; pub use conf::*; pub use key::*; pub use setup::*; pub use view::*; pub use ipnet; #[doc(hidden)] pub const fn __decode_wg_key_const( encoded: &str, ) -> [u8; netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN] { key::decode_wg_key_const(encoded) } /// Creates a [`Key`] from a canonical WireGuard base64 literal at compile time. /// /// The literal must be exactly 44 characters of standard base64 and end with `=`. /// /// ``` /// use wireguard::{key, Key}; /// /// const PRIVATE_KEY: Key = key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); /// ``` /// /// ```compile_fail /// use wireguard::key; /// /// const _BAD: wireguard::Key = key!("not-a-wireguard-key"); /// ``` #[macro_export] macro_rules! key { ($value:literal) => { $crate::Key::new_unchecked_from($crate::__decode_wg_key_const($value)) }; } pub type Result = std::result::Result; #[derive(Debug)] pub struct Error { inner: Option>, message: Option>, } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match (self.message.as_ref(), self.inner.as_ref()) { (Some(message), Some(inner)) => write!(f, "{}: {}", message, inner), (Some(message), None) => write!(f, "{}", message), (None, Some(inner)) => write!(f, "{}", inner), (None, None) => write!(f, "Unknown error"), } } } impl std::error::Error for Error {} impl From for Error { fn from(inner: std::io::Error) -> Self { Self { inner: Some(Box::new(inner)), message: None, } } } impl From for Error { fn from(inner: GenetlinkError) -> Self { Self { inner: Some(Box::new(inner)), message: None, } } } impl From for Error { fn from(inner: rtnetlink::Error) -> Self { Self { inner: Some(Box::new(inner)), message: None, } } } impl Error { pub(crate) fn with_message(inner: E, message: impl Into>) -> Self where E: std::error::Error + Send + Sync + 'static, { Self { inner: Some(Box::new(inner)), message: Some(message.into()), } } pub(crate) fn message(message: impl Into>) -> Self { Self { inner: None, message: Some(message.into()), } } } struct Link { pub name: String, pub ifindex: u32, } pub struct WireGuard { rt_handle: Handle, gen_handle: GenetlinkHandle, } #[allow(clippy::await_holding_refcell_ref)] impl WireGuard { pub async fn new() -> Result { let (rt_connection, rt_handle, _) = rtnetlink::new_connection()?; tokio::spawn(rt_connection); let (gen_connection, gen_handle, _) = genetlink::new_connection()?; tokio::spawn(gen_connection); Ok(Self { rt_handle, gen_handle, }) } pub async fn create_device( &mut self, device_name: &str, descriptor: DeviceDescriptor, ) -> Result<()> { tracing::trace!("Creating device {}", device_name); self.link_create(device_name).await?; let link = self.link_get_by_name(device_name).await?; self.link_up(link.ifindex).await?; self.setup_device(device_name, descriptor).await?; tracing::trace!("Created device"); Ok(()) } pub async fn reload_device( &mut self, device_name: &str, descriptor: DeviceDescriptor, ) -> Result<()> { tracing::trace!("Reloading device {}", device_name); self.setup_device(device_name, descriptor).await?; tracing::trace!("Reloaded device"); Ok(()) } pub async fn remove_device(&self, device_name: &str) -> Result<()> { tracing::trace!("Removing device {}", device_name); let link = self.link_get_by_name(device_name).await?; self.link_down(link.ifindex).await?; self.link_delete(link.ifindex).await?; tracing::trace!("Removed device"); Ok(()) } pub async fn view_device(&mut self, device_name: &str) -> Result { let genlmsg: GenlMessage = GenlMessage::from_payload(WireguardMessage { cmd: WireguardCmd::GetDevice, attributes: vec![WireguardAttribute::IfName(device_name.to_string())], }); let mut nlmsg = NetlinkMessage::from(genlmsg); nlmsg.header.flags = NLM_F_REQUEST | NLM_F_DUMP; let mut resp = self.gen_handle.request(nlmsg).await?; while let Some(result) = resp.next().await { let rx_packet = result.map_err(|e| Error::with_message(e, "Error decoding packet"))?; match rx_packet.payload { NetlinkPayload::InnerMessage(genlmsg) => { return device_view_from_payload(genlmsg.payload); } NetlinkPayload::Error(e) => { return Err(Error::message(format!("Error: {:?}", e))); } _ => (), }; } unreachable!(); } pub async fn view_device_if_exists(&mut self, device_name: &str) -> Result> { let device_names = self.list_device_names().await?; if device_names.iter().any(|name| name == device_name) { Ok(Some(self.view_device(device_name).await?)) } else { Ok(None) } } pub async fn view_devices(&mut self) -> Result> { let device_names = self.list_device_names().await?; let mut devices = Vec::with_capacity(device_names.len()); for name in device_names { let device = self.view_device(&name).await?; devices.push(device); } Ok(devices) } pub async fn list_device_names(&self) -> Result> { Ok(self .link_list() .await? .into_iter() .map(|link| link.name) .collect()) } async fn setup_device( &mut self, device_name: &str, descriptor: DeviceDescriptor, ) -> Result<()> { tracing::trace!("Setting up device {}", device_name); let link = self.link_get_by_name(device_name).await?; for addr in descriptor.addresses.iter() { self.link_add_address(link.ifindex, *addr).await?; } let message = descriptor.into_wireguard(device_name.to_string()); let genlmsg: GenlMessage = GenlMessage::from_payload(message); let mut nlmsg = NetlinkMessage::from(genlmsg); nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK; let mut stream = self.gen_handle.request(nlmsg).await?; while (stream.next().await).is_some() {} tracing::trace!("Device setup"); Ok(()) } async fn link_create(&self, name: &str) -> Result<()> { self.rt_handle .link() .add(LinkMessageBuilder::::new(name).build()) .replace() .execute() .await?; Ok(()) } async fn link_delete(&self, ifindex: u32) -> Result<()> { self.rt_handle.link().del(ifindex).execute().await?; Ok(()) } async fn link_up(&self, ifindex: u32) -> Result<()> { tracing::trace!("Bringing up interface {}", ifindex); self.rt_handle .link() .set( LinkMessageBuilder::::default() .index(ifindex) .up() .build(), ) .execute() .await?; Ok(()) } async fn link_down(&self, ifindex: u32) -> Result<()> { tracing::trace!("Bringing down interface {}", ifindex); self.rt_handle .link() .set( LinkMessageBuilder::::default() .index(ifindex) .down() .build(), ) .execute() .await?; Ok(()) } async fn link_add_address(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> { tracing::trace!("Adding address {} to {}", net, ifindex); self.rt_handle .address() .add(ifindex, net.addr(), net.prefix_len()) .replace() .execute() .await?; Ok(()) } //TODO: return Result>? async fn link_get_by_name(&self, name: &str) -> Result { let link = self .link_list() .await? .into_iter() .find(|link| link.name == name) .ok_or_else(|| Error::message(format!("Link {} not found", name)))?; tracing::debug!("device {} has index {}", name, link.ifindex); Ok(link) } async fn link_list(&self) -> Result> { let mut links = Vec::new(); let mut link_stream = self.rt_handle.link().get().execute(); while let Some(link) = link_stream.try_next().await? { let mut is_wireguard = false; let mut link_name = None; for nla in link.attributes { match nla { LinkAttribute::IfName(name) => link_name = Some(name), LinkAttribute::LinkInfo(infos) => { for info in infos { if let netlink_packet_route::link::LinkInfo::Kind(kind) = info { if kind == netlink_packet_route::link::InfoKind::Wireguard { is_wireguard = true; break; } } } } _ => {} } if is_wireguard && link_name.is_some() { links.push(Link { name: link_name.unwrap(), ifindex: link.header.index, }); break; } } } Ok(links) } #[allow(unused)] async fn route_add(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> { tracing::trace!("Adding route {} to {}", net, ifindex); match net.addr() { std::net::IpAddr::V4(ip) => { self.rt_handle .route() .add( RouteMessageBuilder::::default() .scope(RouteScope::Link) .output_interface(ifindex) .destination_prefix(ip, net.prefix_len()) .build(), ) .replace() .execute() .await?; } std::net::IpAddr::V6(ip) => { self.rt_handle .route() .add( RouteMessageBuilder::::default() .scope(RouteScope::Link) .output_interface(ifindex) .destination_prefix(ip, net.prefix_len()) .build(), ) .replace() .execute() .await?; } }; Ok(()) } } pub async fn create_device( device_name: impl AsRef, device_descriptor: DeviceDescriptor, ) -> Result<()> { tracing::info!("creating device {}", device_name.as_ref()); tracing::debug!("device descriptor: {:#?}", device_descriptor); let mut wireguard = WireGuard::new().await?; wireguard .create_device(device_name.as_ref(), device_descriptor) .await } pub async fn reload_device( device_name: impl AsRef, device_descriptor: DeviceDescriptor, ) -> Result<()> { tracing::info!("reloading device {}", device_name.as_ref()); tracing::debug!("device descriptor: {:#?}", device_descriptor); let mut wireguard = WireGuard::new().await?; wireguard .reload_device(device_name.as_ref(), device_descriptor) .await } pub async fn device_exists(name: impl AsRef) -> Result { tracing::info!("checking if device {} exists", name.as_ref()); let mut wireguard = WireGuard::new().await?; wireguard .view_device_if_exists(name.as_ref()) .await .map(|x| x.is_some()) } pub async fn remove_device(name: impl AsRef) -> Result<()> { tracing::info!("removing device {}", name.as_ref()); let wireguard = WireGuard::new().await?; wireguard.remove_device(name.as_ref()).await } pub async fn view_device(name: impl AsRef) -> Result { tracing::info!("viewing device {}", name.as_ref()); let mut wireguard = WireGuard::new().await?; wireguard.view_device(name.as_ref()).await } pub async fn view_device_if_exists(name: impl AsRef) -> Result> { tracing::info!("viewing device {}", name.as_ref()); let mut wireguard = WireGuard::new().await?; wireguard.view_device_if_exists(name.as_ref()).await } pub async fn list_device_names() -> Result> { tracing::info!("listing device names"); let wireguard = WireGuard::new().await?; wireguard.list_device_names().await }