use rand::{rngs::OsRng, RngCore}; const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN; const WG_KEY_B64_LEN: usize = 44; const BASE64_ALPHABET: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; // Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c type Fe = [i64; 16]; #[derive(Clone, Copy)] pub struct KeyDecodeError(&'static str); impl KeyDecodeError { pub(crate) const fn invalid_length() -> Self { Self("invalid length") } pub(crate) const fn invalid_padding() -> Self { Self("invalid padding") } pub(crate) const fn invalid_base64_character() -> Self { Self("invalid base64 character") } } impl std::fmt::Debug for KeyDecodeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "KeyDecodeError: {}", self.0) } } impl std::fmt::Display for KeyDecodeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Key decode error: {}", self.0) } } impl std::error::Error for KeyDecodeError {} #[derive(Clone, Default, Copy, PartialEq, Eq, Hash)] pub struct Key([u8; WG_KEY_LEN]); impl Key { pub const fn new_unchecked_from(bytes: [u8; WG_KEY_LEN]) -> Self { Self(bytes) } pub fn into_array(self) -> [u8; WG_KEY_LEN] { self.0 } pub fn as_array(&self) -> &[u8; WG_KEY_LEN] { &self.0 } pub fn as_slice(&self) -> &[u8] { &self.0 } pub fn encode(&self) -> String { let encoded = encode_wg_key(&self.0); String::from_utf8(encoded.to_vec()).expect("WireGuard key should encode to valid UTF-8") } pub fn decode(encoded: &str) -> Result { decode_wg_key(encoded).map(Self) } pub fn generate_pub_priv() -> (Self, Self) { let private_key = Self::generate_private(); let public_key = Self::generate_public(&private_key); (public_key, private_key) } pub fn generate_public(private: &Key) -> Self { let mut r: i32 = Default::default(); let mut public_key: [u8; WG_KEY_LEN] = Default::default(); let mut z: [u8; WG_KEY_LEN] = private.0; let mut a = fe_new_one(1); let mut b = fe_new_one(9); let mut c = fe_new_one(0); let mut d = fe_new_one(1); let mut e = fe_new_default(); let mut f = fe_new_default(); clamp_key(&mut z); for i in (0..=254i32).rev() { r = ((z[(i >> 3) as usize] >> (i & 7)) & 1) as i32; cswap(&mut a, &mut b, r); cswap(&mut c, &mut d, r); add(&mut e, &a, &c); { let a_clone = a; subtract(&mut a, &a_clone, &c); } add(&mut c, &b, &d); { let b_clone = b; subtract(&mut b, &b_clone, &d); } multmod(&mut d, &e, &e); multmod(&mut f, &a, &a); { let a_clone = a; multmod(&mut a, &c, &a_clone); } multmod(&mut c, &b, &e); add(&mut e, &a, &c); { let a_clone = a; subtract(&mut a, &a_clone, &c); } multmod(&mut b, &a, &a); subtract(&mut c, &d, &f); //multmod(&mut a, &c, (const fe){ 0xdb41, 1 }); multmod(&mut a, &c, &fe_new_two(0xdb41, 1)); { let a_clone = a; add(&mut a, &a_clone, &d); } { let c_clone = c; multmod(&mut c, &c_clone, &a); } multmod(&mut a, &d, &f); multmod(&mut d, &b, &fe_new_one(9)); multmod(&mut b, &e, &e); cswap(&mut a, &mut b, r); cswap(&mut c, &mut d, r); } { let c_clone = c; invert(&mut c, &c_clone); } { let a_clone = a; multmod(&mut a, &a_clone, &c); } pack(&mut public_key, &a); memzero_explicit(&mut r); memzero_explicit(&mut z); memzero_explicit(&mut a); memzero_explicit(&mut b); memzero_explicit(&mut c); memzero_explicit(&mut d); memzero_explicit(&mut e); memzero_explicit(&mut f); Self(public_key) } pub fn generate_private() -> Self { let mut preshared = Self::generate_preshared(); clamp_key(&mut preshared.0); preshared } pub fn generate_preshared() -> Self { let mut key = [0u8; WG_KEY_LEN]; OsRng.fill_bytes(&mut key); Self(key) } } pub(crate) const fn decode_wg_key_const(encoded: &str) -> [u8; WG_KEY_LEN] { match decode_wg_key(encoded) { Ok(out) => out, Err(_) => panic!("invalid WireGuard key literal"), } } pub(crate) const fn decode_wg_key(encoded: &str) -> Result<[u8; WG_KEY_LEN], KeyDecodeError> { let bytes = encoded.as_bytes(); if bytes.len() != WG_KEY_B64_LEN { return Err(KeyDecodeError::invalid_length()); } if bytes[WG_KEY_B64_LEN - 1] != b'=' { return Err(KeyDecodeError::invalid_padding()); } let mut out = [0u8; WG_KEY_LEN]; let mut in_idx = 0; let mut out_idx = 0; while in_idx < 40 { let a = match decode_b64_char(bytes[in_idx]) { Ok(v) => v, Err(err) => return Err(err), }; let b = match decode_b64_char(bytes[in_idx + 1]) { Ok(v) => v, Err(err) => return Err(err), }; let c = match decode_b64_char(bytes[in_idx + 2]) { Ok(v) => v, Err(err) => return Err(err), }; let d = match decode_b64_char(bytes[in_idx + 3]) { Ok(v) => v, Err(err) => return Err(err), }; out[out_idx] = (a << 2) | (b >> 4); out[out_idx + 1] = (b << 4) | (c >> 2); out[out_idx + 2] = (c << 6) | d; in_idx += 4; out_idx += 3; } let a = match decode_b64_char(bytes[40]) { Ok(v) => v, Err(err) => return Err(err), }; let b = match decode_b64_char(bytes[41]) { Ok(v) => v, Err(err) => return Err(err), }; let c = match decode_b64_char(bytes[42]) { Ok(v) => v, Err(err) => return Err(err), }; if (c & 0b0000_0011) != 0 { return Err(KeyDecodeError::invalid_padding()); } out[30] = (a << 2) | (b >> 4); out[31] = (b << 4) | (c >> 2); Ok(out) } pub(crate) const fn encode_wg_key(key: &[u8; WG_KEY_LEN]) -> [u8; WG_KEY_B64_LEN] { let mut out = [b'='; WG_KEY_B64_LEN]; let mut in_idx = 0; let mut out_idx = 0; while in_idx < 30 { let b0 = key[in_idx]; let b1 = key[in_idx + 1]; let b2 = key[in_idx + 2]; out[out_idx] = BASE64_ALPHABET[(b0 >> 2) as usize]; out[out_idx + 1] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; out[out_idx + 2] = BASE64_ALPHABET[(((b1 & 0b0000_1111) << 2) | (b2 >> 6)) as usize]; out[out_idx + 3] = BASE64_ALPHABET[(b2 & 0b0011_1111) as usize]; in_idx += 3; out_idx += 4; } let b0 = key[30]; let b1 = key[31]; out[40] = BASE64_ALPHABET[(b0 >> 2) as usize]; out[41] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; out[42] = BASE64_ALPHABET[((b1 & 0b0000_1111) << 2) as usize]; out } const fn decode_b64_char(c: u8) -> Result { match c { b'A'..=b'Z' => Ok(c - b'A'), b'a'..=b'z' => Ok(c - b'a' + 26), b'0'..=b'9' => Ok(c - b'0' + 52), b'+' => Ok(62), b'/' => Ok(63), _ => Err(KeyDecodeError::invalid_base64_character()), } } impl From<&[u8; WG_KEY_LEN]> for Key { fn from(k: &[u8; WG_KEY_LEN]) -> Self { Self(*k) } } impl From<[u8; WG_KEY_LEN]> for Key { fn from(k: [u8; WG_KEY_LEN]) -> Self { Self(k) } } impl std::str::FromStr for Key { type Err = KeyDecodeError; fn from_str(s: &str) -> Result { Key::decode(s) } } impl std::fmt::Debug for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let buf = encode_wg_key(&self.0); let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); f.debug_tuple("Key").field(&b64).finish() } } impl std::fmt::Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let buf = encode_wg_key(&self.0); let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); f.write_str(b64) } } impl serde::Serialize for Key { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&self.encode()) } } impl<'de> serde::Deserialize<'de> for Key { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; Key::decode(&s).map_err(serde::de::Error::custom) } } fn fe_new_default() -> Fe { Default::default() } fn fe_new_one(x: i64) -> Fe { let mut fe = fe_new_default(); fe[0] = x; fe } fn fe_new_two(x: i64, y: i64) -> Fe { let mut fe = fe_new_default(); fe[0] = x; fe[1] = y; fe } fn clamp_key(key: &mut [u8]) { key[31] = (key[31] & 127) | 64; key[0] &= 248; } fn carry(o: &mut Fe) { for i in 0..16 { let x = if i == 15 { 38 } else { 1 }; o[(i + 1) % 16] += x * (o[i] >> 16); o[i] &= 0xffff; } } fn cswap(p: &mut Fe, q: &mut Fe, mut b: i32) { let mut t: i64 = 0; let mut c: i64 = !i64::from(b).wrapping_sub(1); for i in 0..16 { t = c & (p[i] ^ q[i]); p[i] ^= t; q[i] ^= t; } memzero_explicit(&mut t); memzero_explicit(&mut c); memzero_explicit(&mut b); } fn pack(o: &mut [u8; WG_KEY_LEN], n: &Fe) { let mut b: i32 = 0; let mut t: Fe = fe_new_default(); let mut m: Fe = fe_new_default(); t.copy_from_slice(n); carry(&mut t); carry(&mut t); carry(&mut t); for _ in 0..2 { m[0] = t[0] - 0xffed; for i in 1..15 { m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1); m[i - 1] &= 0xffff; } m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1); b = ((m[15] >> 16) & 1) as i32; m[14] &= 0xffff; cswap(&mut t, &mut m, 1 - b); } for i in 0..16 { o[2 * i] = (t[i] & 0xff) as u8; o[2 * i + 1] = (t[i] >> 8) as u8; } memzero_explicit(&mut m); memzero_explicit(&mut t); memzero_explicit(&mut b); } fn add(o: &mut Fe, a: &Fe, b: &Fe) { for i in 0..16 { o[i] = a[i] + b[i]; } } fn subtract(o: &mut Fe, a: &Fe, b: &Fe) { for i in 0..16 { o[i] = a[i] - b[i]; } } fn multmod(o: &mut Fe, a: &Fe, b: &Fe) { let mut t: [i64; 31] = [0; 31]; for i in 0..16 { for j in 0..16 { t[i + j] += a[i] * b[j]; } } for i in 0..15 { t[i] += 38 * t[i + 16]; } o.copy_from_slice(&t[..16]); carry(o); carry(o); memzero_explicit(&mut t); } fn invert(o: &mut Fe, i: &Fe) { let mut c: Fe = fe_new_default(); c.copy_from_slice(i); for a in (0..=253).rev() { { let c_clone = c; multmod(&mut c, &c_clone, &c_clone); } if a != 2 && a != 4 { { let c_clone = c; multmod(&mut c, &c_clone, i); } } } o.copy_from_slice(&c); memzero_explicit(&mut c); } fn memzero_explicit(v: &mut T) { unsafe { let zeroed = std::mem::zeroed(); std::ptr::write_volatile(v as *mut _, zeroed); } } #[cfg(test)] mod tests { use super::Key; const CONST_KEY: Key = crate::key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); #[test] fn decode_encode_key() { let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; let key = super::Key::decode(key).unwrap(); let key = key.encode(); assert_eq!(key, "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); } #[test] fn generate_public_key() { assert_eq!( Key::decode("3D5lgnI9ztvnuyWDm7dlBDgm6xr0+WVWPoo6HIfzHRU=").unwrap(), Key::generate_public( &Key::decode("+Op7voRskU0Zm2fHFR/5tVE+PJtnwn6cbnme71jXt0E=").unwrap() ) ); assert_eq!( Key::decode("//eq/raPUE4+sOlTlozx76XEE+W8L0bUqNfyg9IpX0Q=").unwrap(), Key::generate_public( &Key::decode("8OD8QPWH/a0D5LmbWVnb7bwFq4Fghy/QUEFkIhyL/EI=").unwrap() ) ); } #[test] fn const_key_macro_matches_runtime_decode() { let runtime = Key::decode("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=").expect("valid key"); assert_eq!(CONST_KEY, runtime); } }