From ed0f17dea764fea93ea7c2b897fa4a9e31c9476e Mon Sep 17 00:00:00 2001 From: diogo464 Date: Tue, 17 Feb 2026 11:41:02 +0000 Subject: add const key macro and internal base64 codec --- Cargo.lock | 7 --- Cargo.toml | 1 - src/key.rs | 178 ++++++++++++++++++++++++++++++++++++++++++++++++++++--------- src/lib.rs | 29 ++++++++++ 4 files changed, 183 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6203091..66088b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,12 +8,6 @@ version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "bitflags" version = "2.11.0" @@ -699,7 +693,6 @@ name = "wireguard" version = "0.0.0" dependencies = [ "anyhow", - "base64", "futures", "genetlink", "ipnet", diff --git a/Cargo.toml b/Cargo.toml index e7ffd1b..aea7700 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" [dependencies] anyhow = "1.0.79" -base64 = "0.21.7" futures = "0.3.30" genetlink = "=0.2.6" ipnet = { version = "2.9.0", features = ["serde"] } diff --git a/src/key.rs b/src/key.rs index 49f9284..3fa7021 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,18 +1,40 @@ -use base64::Engine; use rand::{rngs::OsRng, RngCore}; const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN; +const WG_KEY_B64_LEN: usize = 44; +const BASE64_ALPHABET: [u8; 64] = + *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; // Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c type Fe = [i64; 16]; -#[derive(Debug)] -pub struct KeyDecodeError; +#[derive(Clone, Copy)] +pub struct KeyDecodeError(&'static str); + +impl KeyDecodeError { + pub(crate) const fn invalid_length() -> Self { + Self("invalid length") + } + + pub(crate) const fn invalid_padding() -> Self { + Self("invalid padding") + } + + pub(crate) const fn invalid_base64_character() -> Self { + Self("invalid base64 character") + } +} + +impl std::fmt::Debug for KeyDecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "KeyDecodeError: {}", self.0) + } +} impl std::fmt::Display for KeyDecodeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Key decode error") + write!(f, "Key decode error: {}", self.0) } } @@ -22,6 +44,10 @@ impl std::error::Error for KeyDecodeError {} pub struct Key([u8; WG_KEY_LEN]); impl Key { + pub const fn new_unchecked_from(bytes: [u8; WG_KEY_LEN]) -> Self { + Self(bytes) + } + pub fn into_array(self) -> [u8; WG_KEY_LEN] { self.0 } @@ -35,19 +61,12 @@ impl Key { } pub fn encode(&self) -> String { - base64::engine::general_purpose::STANDARD.encode(&self.0) + let encoded = encode_wg_key(&self.0); + String::from_utf8(encoded.to_vec()).expect("WireGuard key should encode to valid UTF-8") } pub fn decode(encoded: &str) -> Result { - let decoded = base64::engine::general_purpose::STANDARD - .decode(encoded) - .map_err(|_| KeyDecodeError)?; - if decoded.len() != WG_KEY_LEN { - return Err(KeyDecodeError); - } - let mut key = [0u8; WG_KEY_LEN]; - key.copy_from_slice(&decoded); - Ok(Key(key)) + decode_wg_key(encoded).map(Self) } pub fn generate_pub_priv() -> (Self, Self) { @@ -148,6 +167,114 @@ impl Key { } } +pub(crate) const fn decode_wg_key_const(encoded: &str) -> [u8; WG_KEY_LEN] { + match decode_wg_key(encoded) { + Ok(out) => out, + Err(_) => panic!("invalid WireGuard key literal"), + } +} + +pub(crate) const fn decode_wg_key(encoded: &str) -> Result<[u8; WG_KEY_LEN], KeyDecodeError> { + let bytes = encoded.as_bytes(); + if bytes.len() != WG_KEY_B64_LEN { + return Err(KeyDecodeError::invalid_length()); + } + + if bytes[WG_KEY_B64_LEN - 1] != b'=' { + return Err(KeyDecodeError::invalid_padding()); + } + + let mut out = [0u8; WG_KEY_LEN]; + let mut in_idx = 0; + let mut out_idx = 0; + + while in_idx < 40 { + let a = match decode_b64_char(bytes[in_idx]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + let b = match decode_b64_char(bytes[in_idx + 1]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + let c = match decode_b64_char(bytes[in_idx + 2]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + let d = match decode_b64_char(bytes[in_idx + 3]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + + out[out_idx] = (a << 2) | (b >> 4); + out[out_idx + 1] = (b << 4) | (c >> 2); + out[out_idx + 2] = (c << 6) | d; + + in_idx += 4; + out_idx += 3; + } + + let a = match decode_b64_char(bytes[40]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + let b = match decode_b64_char(bytes[41]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + let c = match decode_b64_char(bytes[42]) { + Ok(v) => v, + Err(err) => return Err(err), + }; + if (c & 0b0000_0011) != 0 { + return Err(KeyDecodeError::invalid_padding()); + } + + out[30] = (a << 2) | (b >> 4); + out[31] = (b << 4) | (c >> 2); + + Ok(out) +} + +pub(crate) const fn encode_wg_key(key: &[u8; WG_KEY_LEN]) -> [u8; WG_KEY_B64_LEN] { + let mut out = [b'='; WG_KEY_B64_LEN]; + let mut in_idx = 0; + let mut out_idx = 0; + + while in_idx < 30 { + let b0 = key[in_idx]; + let b1 = key[in_idx + 1]; + let b2 = key[in_idx + 2]; + + out[out_idx] = BASE64_ALPHABET[(b0 >> 2) as usize]; + out[out_idx + 1] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; + out[out_idx + 2] = BASE64_ALPHABET[(((b1 & 0b0000_1111) << 2) | (b2 >> 6)) as usize]; + out[out_idx + 3] = BASE64_ALPHABET[(b2 & 0b0011_1111) as usize]; + + in_idx += 3; + out_idx += 4; + } + + let b0 = key[30]; + let b1 = key[31]; + out[40] = BASE64_ALPHABET[(b0 >> 2) as usize]; + out[41] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; + out[42] = BASE64_ALPHABET[((b1 & 0b0000_1111) << 2) as usize]; + + out +} + +const fn decode_b64_char(c: u8) -> Result { + match c { + b'A'..=b'Z' => Ok(c - b'A'), + b'a'..=b'z' => Ok(c - b'a' + 26), + b'0'..=b'9' => Ok(c - b'0' + 52), + b'+' => Ok(62), + b'/' => Ok(63), + _ => Err(KeyDecodeError::invalid_base64_character()), + } +} + impl From<&[u8; WG_KEY_LEN]> for Key { fn from(k: &[u8; WG_KEY_LEN]) -> Self { Self(*k) @@ -170,22 +297,16 @@ impl std::str::FromStr for Key { impl std::fmt::Debug for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut buf = [0; WG_KEY_LEN * 2]; - let len = base64::engine::general_purpose::STANDARD - .encode_slice(&self.0, &mut buf) - .expect("base64 should encode"); - let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); + let buf = encode_wg_key(&self.0); + let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); f.debug_tuple("Key").field(&b64).finish() } } impl std::fmt::Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut buf = [0; WG_KEY_LEN * 2]; - let len = base64::engine::general_purpose::STANDARD - .encode_slice(&self.0, &mut buf) - .expect("base64 should encode"); - let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); + let buf = encode_wg_key(&self.0); + let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); f.write_str(b64) } } @@ -346,6 +467,8 @@ fn memzero_explicit(v: &mut T) { mod tests { use super::Key; + const CONST_KEY: Key = crate::key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); + #[test] fn decode_encode_key() { let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; @@ -370,4 +493,11 @@ mod tests { ) ); } + + #[test] + fn const_key_macro_matches_runtime_decode() { + let runtime = + Key::decode("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=").expect("valid key"); + assert_eq!(CONST_KEY, runtime); + } } diff --git a/src/lib.rs b/src/lib.rs index 8ba36eb..72fb58f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,35 @@ pub use view::*; pub use ipnet; +#[doc(hidden)] +pub const fn __decode_wg_key_const( + encoded: &str, +) -> [u8; netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN] { + key::decode_wg_key_const(encoded) +} + +/// Creates a [`Key`] from a canonical WireGuard base64 literal at compile time. +/// +/// The literal must be exactly 44 characters of standard base64 and end with `=`. +/// +/// ``` +/// use wireguard::{key, Key}; +/// +/// const PRIVATE_KEY: Key = key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); +/// ``` +/// +/// ```compile_fail +/// use wireguard::key; +/// +/// const _BAD: wireguard::Key = key!("not-a-wireguard-key"); +/// ``` +#[macro_export] +macro_rules! key { + ($value:literal) => { + $crate::Key::new_unchecked_from($crate::__decode_wg_key_const($value)) + }; +} + pub type Result = std::result::Result; #[derive(Debug)] -- cgit