diff options
| author | diogo464 <[email protected]> | 2026-02-17 11:41:02 +0000 |
|---|---|---|
| committer | diogo464 <[email protected]> | 2026-02-17 11:41:02 +0000 |
| commit | ed0f17dea764fea93ea7c2b897fa4a9e31c9476e (patch) | |
| tree | dd757fe08efbc863f7cac3c2e3288417ed550b93 /src/key.rs | |
| parent | 56ac8740b79e291eabe6427d722921533b3a9837 (diff) | |
add const key macro and internal base64 codec
Diffstat (limited to 'src/key.rs')
| -rw-r--r-- | src/key.rs | 178 |
1 files changed, 154 insertions, 24 deletions
| @@ -1,18 +1,40 @@ | |||
| 1 | use base64::Engine; | ||
| 2 | use rand::{rngs::OsRng, RngCore}; | 1 | use rand::{rngs::OsRng, RngCore}; |
| 3 | 2 | ||
| 4 | const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN; | 3 | const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN; |
| 4 | const WG_KEY_B64_LEN: usize = 44; | ||
| 5 | const BASE64_ALPHABET: [u8; 64] = | ||
| 6 | *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | ||
| 5 | 7 | ||
| 6 | // Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c | 8 | // Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c |
| 7 | 9 | ||
| 8 | type Fe = [i64; 16]; | 10 | type Fe = [i64; 16]; |
| 9 | 11 | ||
| 10 | #[derive(Debug)] | 12 | #[derive(Clone, Copy)] |
| 11 | pub struct KeyDecodeError; | 13 | pub struct KeyDecodeError(&'static str); |
| 14 | |||
| 15 | impl KeyDecodeError { | ||
| 16 | pub(crate) const fn invalid_length() -> Self { | ||
| 17 | Self("invalid length") | ||
| 18 | } | ||
| 19 | |||
| 20 | pub(crate) const fn invalid_padding() -> Self { | ||
| 21 | Self("invalid padding") | ||
| 22 | } | ||
| 23 | |||
| 24 | pub(crate) const fn invalid_base64_character() -> Self { | ||
| 25 | Self("invalid base64 character") | ||
| 26 | } | ||
| 27 | } | ||
| 28 | |||
| 29 | impl std::fmt::Debug for KeyDecodeError { | ||
| 30 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| 31 | write!(f, "KeyDecodeError: {}", self.0) | ||
| 32 | } | ||
| 33 | } | ||
| 12 | 34 | ||
| 13 | impl std::fmt::Display for KeyDecodeError { | 35 | impl std::fmt::Display for KeyDecodeError { |
| 14 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | 36 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 15 | write!(f, "Key decode error") | 37 | write!(f, "Key decode error: {}", self.0) |
| 16 | } | 38 | } |
| 17 | } | 39 | } |
| 18 | 40 | ||
| @@ -22,6 +44,10 @@ impl std::error::Error for KeyDecodeError {} | |||
| 22 | pub struct Key([u8; WG_KEY_LEN]); | 44 | pub struct Key([u8; WG_KEY_LEN]); |
| 23 | 45 | ||
| 24 | impl Key { | 46 | impl Key { |
| 47 | pub const fn new_unchecked_from(bytes: [u8; WG_KEY_LEN]) -> Self { | ||
| 48 | Self(bytes) | ||
| 49 | } | ||
| 50 | |||
| 25 | pub fn into_array(self) -> [u8; WG_KEY_LEN] { | 51 | pub fn into_array(self) -> [u8; WG_KEY_LEN] { |
| 26 | self.0 | 52 | self.0 |
| 27 | } | 53 | } |
| @@ -35,19 +61,12 @@ impl Key { | |||
| 35 | } | 61 | } |
| 36 | 62 | ||
| 37 | pub fn encode(&self) -> String { | 63 | pub fn encode(&self) -> String { |
| 38 | base64::engine::general_purpose::STANDARD.encode(&self.0) | 64 | let encoded = encode_wg_key(&self.0); |
| 65 | String::from_utf8(encoded.to_vec()).expect("WireGuard key should encode to valid UTF-8") | ||
| 39 | } | 66 | } |
| 40 | 67 | ||
| 41 | pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> { | 68 | pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> { |
| 42 | let decoded = base64::engine::general_purpose::STANDARD | 69 | decode_wg_key(encoded).map(Self) |
| 43 | .decode(encoded) | ||
| 44 | .map_err(|_| KeyDecodeError)?; | ||
| 45 | if decoded.len() != WG_KEY_LEN { | ||
| 46 | return Err(KeyDecodeError); | ||
| 47 | } | ||
| 48 | let mut key = [0u8; WG_KEY_LEN]; | ||
| 49 | key.copy_from_slice(&decoded); | ||
| 50 | Ok(Key(key)) | ||
| 51 | } | 70 | } |
| 52 | 71 | ||
| 53 | pub fn generate_pub_priv() -> (Self, Self) { | 72 | pub fn generate_pub_priv() -> (Self, Self) { |
| @@ -148,6 +167,114 @@ impl Key { | |||
| 148 | } | 167 | } |
| 149 | } | 168 | } |
| 150 | 169 | ||
| 170 | pub(crate) const fn decode_wg_key_const(encoded: &str) -> [u8; WG_KEY_LEN] { | ||
| 171 | match decode_wg_key(encoded) { | ||
| 172 | Ok(out) => out, | ||
| 173 | Err(_) => panic!("invalid WireGuard key literal"), | ||
| 174 | } | ||
| 175 | } | ||
| 176 | |||
| 177 | pub(crate) const fn decode_wg_key(encoded: &str) -> Result<[u8; WG_KEY_LEN], KeyDecodeError> { | ||
| 178 | let bytes = encoded.as_bytes(); | ||
| 179 | if bytes.len() != WG_KEY_B64_LEN { | ||
| 180 | return Err(KeyDecodeError::invalid_length()); | ||
| 181 | } | ||
| 182 | |||
| 183 | if bytes[WG_KEY_B64_LEN - 1] != b'=' { | ||
| 184 | return Err(KeyDecodeError::invalid_padding()); | ||
| 185 | } | ||
| 186 | |||
| 187 | let mut out = [0u8; WG_KEY_LEN]; | ||
| 188 | let mut in_idx = 0; | ||
| 189 | let mut out_idx = 0; | ||
| 190 | |||
| 191 | while in_idx < 40 { | ||
| 192 | let a = match decode_b64_char(bytes[in_idx]) { | ||
| 193 | Ok(v) => v, | ||
| 194 | Err(err) => return Err(err), | ||
| 195 | }; | ||
| 196 | let b = match decode_b64_char(bytes[in_idx + 1]) { | ||
| 197 | Ok(v) => v, | ||
| 198 | Err(err) => return Err(err), | ||
| 199 | }; | ||
| 200 | let c = match decode_b64_char(bytes[in_idx + 2]) { | ||
| 201 | Ok(v) => v, | ||
| 202 | Err(err) => return Err(err), | ||
| 203 | }; | ||
| 204 | let d = match decode_b64_char(bytes[in_idx + 3]) { | ||
| 205 | Ok(v) => v, | ||
| 206 | Err(err) => return Err(err), | ||
| 207 | }; | ||
| 208 | |||
| 209 | out[out_idx] = (a << 2) | (b >> 4); | ||
| 210 | out[out_idx + 1] = (b << 4) | (c >> 2); | ||
| 211 | out[out_idx + 2] = (c << 6) | d; | ||
| 212 | |||
| 213 | in_idx += 4; | ||
| 214 | out_idx += 3; | ||
| 215 | } | ||
| 216 | |||
| 217 | let a = match decode_b64_char(bytes[40]) { | ||
| 218 | Ok(v) => v, | ||
| 219 | Err(err) => return Err(err), | ||
| 220 | }; | ||
| 221 | let b = match decode_b64_char(bytes[41]) { | ||
| 222 | Ok(v) => v, | ||
| 223 | Err(err) => return Err(err), | ||
| 224 | }; | ||
| 225 | let c = match decode_b64_char(bytes[42]) { | ||
| 226 | Ok(v) => v, | ||
| 227 | Err(err) => return Err(err), | ||
| 228 | }; | ||
| 229 | if (c & 0b0000_0011) != 0 { | ||
| 230 | return Err(KeyDecodeError::invalid_padding()); | ||
| 231 | } | ||
| 232 | |||
| 233 | out[30] = (a << 2) | (b >> 4); | ||
| 234 | out[31] = (b << 4) | (c >> 2); | ||
| 235 | |||
| 236 | Ok(out) | ||
| 237 | } | ||
| 238 | |||
| 239 | pub(crate) const fn encode_wg_key(key: &[u8; WG_KEY_LEN]) -> [u8; WG_KEY_B64_LEN] { | ||
| 240 | let mut out = [b'='; WG_KEY_B64_LEN]; | ||
| 241 | let mut in_idx = 0; | ||
| 242 | let mut out_idx = 0; | ||
| 243 | |||
| 244 | while in_idx < 30 { | ||
| 245 | let b0 = key[in_idx]; | ||
| 246 | let b1 = key[in_idx + 1]; | ||
| 247 | let b2 = key[in_idx + 2]; | ||
| 248 | |||
| 249 | out[out_idx] = BASE64_ALPHABET[(b0 >> 2) as usize]; | ||
| 250 | out[out_idx + 1] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; | ||
| 251 | out[out_idx + 2] = BASE64_ALPHABET[(((b1 & 0b0000_1111) << 2) | (b2 >> 6)) as usize]; | ||
| 252 | out[out_idx + 3] = BASE64_ALPHABET[(b2 & 0b0011_1111) as usize]; | ||
| 253 | |||
| 254 | in_idx += 3; | ||
| 255 | out_idx += 4; | ||
| 256 | } | ||
| 257 | |||
| 258 | let b0 = key[30]; | ||
| 259 | let b1 = key[31]; | ||
| 260 | out[40] = BASE64_ALPHABET[(b0 >> 2) as usize]; | ||
| 261 | out[41] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize]; | ||
| 262 | out[42] = BASE64_ALPHABET[((b1 & 0b0000_1111) << 2) as usize]; | ||
| 263 | |||
| 264 | out | ||
| 265 | } | ||
| 266 | |||
| 267 | const fn decode_b64_char(c: u8) -> Result<u8, KeyDecodeError> { | ||
| 268 | match c { | ||
| 269 | b'A'..=b'Z' => Ok(c - b'A'), | ||
| 270 | b'a'..=b'z' => Ok(c - b'a' + 26), | ||
| 271 | b'0'..=b'9' => Ok(c - b'0' + 52), | ||
| 272 | b'+' => Ok(62), | ||
| 273 | b'/' => Ok(63), | ||
| 274 | _ => Err(KeyDecodeError::invalid_base64_character()), | ||
| 275 | } | ||
| 276 | } | ||
| 277 | |||
| 151 | impl From<&[u8; WG_KEY_LEN]> for Key { | 278 | impl From<&[u8; WG_KEY_LEN]> for Key { |
| 152 | fn from(k: &[u8; WG_KEY_LEN]) -> Self { | 279 | fn from(k: &[u8; WG_KEY_LEN]) -> Self { |
| 153 | Self(*k) | 280 | Self(*k) |
| @@ -170,22 +297,16 @@ impl std::str::FromStr for Key { | |||
| 170 | 297 | ||
| 171 | impl std::fmt::Debug for Key { | 298 | impl std::fmt::Debug for Key { |
| 172 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | 299 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 173 | let mut buf = [0; WG_KEY_LEN * 2]; | 300 | let buf = encode_wg_key(&self.0); |
| 174 | let len = base64::engine::general_purpose::STANDARD | 301 | let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); |
| 175 | .encode_slice(&self.0, &mut buf) | ||
| 176 | .expect("base64 should encode"); | ||
| 177 | let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); | ||
| 178 | f.debug_tuple("Key").field(&b64).finish() | 302 | f.debug_tuple("Key").field(&b64).finish() |
| 179 | } | 303 | } |
| 180 | } | 304 | } |
| 181 | 305 | ||
| 182 | impl std::fmt::Display for Key { | 306 | impl std::fmt::Display for Key { |
| 183 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | 307 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 184 | let mut buf = [0; WG_KEY_LEN * 2]; | 308 | let buf = encode_wg_key(&self.0); |
| 185 | let len = base64::engine::general_purpose::STANDARD | 309 | let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8"); |
| 186 | .encode_slice(&self.0, &mut buf) | ||
| 187 | .expect("base64 should encode"); | ||
| 188 | let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8"); | ||
| 189 | f.write_str(b64) | 310 | f.write_str(b64) |
| 190 | } | 311 | } |
| 191 | } | 312 | } |
| @@ -346,6 +467,8 @@ fn memzero_explicit<T>(v: &mut T) { | |||
| 346 | mod tests { | 467 | mod tests { |
| 347 | use super::Key; | 468 | use super::Key; |
| 348 | 469 | ||
| 470 | const CONST_KEY: Key = crate::key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="); | ||
| 471 | |||
| 349 | #[test] | 472 | #[test] |
| 350 | fn decode_encode_key() { | 473 | fn decode_encode_key() { |
| 351 | let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; | 474 | let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; |
| @@ -370,4 +493,11 @@ mod tests { | |||
| 370 | ) | 493 | ) |
| 371 | ); | 494 | ); |
| 372 | } | 495 | } |
| 496 | |||
| 497 | #[test] | ||
| 498 | fn const_key_macro_matches_runtime_decode() { | ||
| 499 | let runtime = | ||
| 500 | Key::decode("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=").expect("valid key"); | ||
| 501 | assert_eq!(CONST_KEY, runtime); | ||
| 502 | } | ||
| 373 | } | 503 | } |
