summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordiogo464 <[email protected]>2026-02-17 11:41:02 +0000
committerdiogo464 <[email protected]>2026-02-17 11:41:02 +0000
commited0f17dea764fea93ea7c2b897fa4a9e31c9476e (patch)
treedd757fe08efbc863f7cac3c2e3288417ed550b93
parent56ac8740b79e291eabe6427d722921533b3a9837 (diff)
add const key macro and internal base64 codec
-rw-r--r--Cargo.lock7
-rw-r--r--Cargo.toml1
-rw-r--r--src/key.rs178
-rw-r--r--src/lib.rs29
4 files changed, 183 insertions, 32 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 6203091..66088b5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -9,12 +9,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
9checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" 9checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
10 10
11[[package]] 11[[package]]
12name = "base64"
13version = "0.21.7"
14source = "registry+https://github.com/rust-lang/crates.io-index"
15checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
16
17[[package]]
18name = "bitflags" 12name = "bitflags"
19version = "2.11.0" 13version = "2.11.0"
20source = "registry+https://github.com/rust-lang/crates.io-index" 14source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -699,7 +693,6 @@ name = "wireguard"
699version = "0.0.0" 693version = "0.0.0"
700dependencies = [ 694dependencies = [
701 "anyhow", 695 "anyhow",
702 "base64",
703 "futures", 696 "futures",
704 "genetlink", 697 "genetlink",
705 "ipnet", 698 "ipnet",
diff --git a/Cargo.toml b/Cargo.toml
index e7ffd1b..aea7700 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -7,7 +7,6 @@ edition = "2021"
7 7
8[dependencies] 8[dependencies]
9anyhow = "1.0.79" 9anyhow = "1.0.79"
10base64 = "0.21.7"
11futures = "0.3.30" 10futures = "0.3.30"
12genetlink = "=0.2.6" 11genetlink = "=0.2.6"
13ipnet = { version = "2.9.0", features = ["serde"] } 12ipnet = { version = "2.9.0", features = ["serde"] }
diff --git a/src/key.rs b/src/key.rs
index 49f9284..3fa7021 100644
--- a/src/key.rs
+++ b/src/key.rs
@@ -1,18 +1,40 @@
1use base64::Engine;
2use rand::{rngs::OsRng, RngCore}; 1use rand::{rngs::OsRng, RngCore};
3 2
4const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN; 3const WG_KEY_LEN: usize = netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN;
4const WG_KEY_B64_LEN: usize = 44;
5const BASE64_ALPHABET: [u8; 64] =
6 *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
5 7
6// Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c 8// Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c
7 9
8type Fe = [i64; 16]; 10type Fe = [i64; 16];
9 11
10#[derive(Debug)] 12#[derive(Clone, Copy)]
11pub struct KeyDecodeError; 13pub struct KeyDecodeError(&'static str);
14
15impl KeyDecodeError {
16 pub(crate) const fn invalid_length() -> Self {
17 Self("invalid length")
18 }
19
20 pub(crate) const fn invalid_padding() -> Self {
21 Self("invalid padding")
22 }
23
24 pub(crate) const fn invalid_base64_character() -> Self {
25 Self("invalid base64 character")
26 }
27}
28
29impl std::fmt::Debug for KeyDecodeError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "KeyDecodeError: {}", self.0)
32 }
33}
12 34
13impl std::fmt::Display for KeyDecodeError { 35impl std::fmt::Display for KeyDecodeError {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(f, "Key decode error") 37 write!(f, "Key decode error: {}", self.0)
16 } 38 }
17} 39}
18 40
@@ -22,6 +44,10 @@ impl std::error::Error for KeyDecodeError {}
22pub struct Key([u8; WG_KEY_LEN]); 44pub struct Key([u8; WG_KEY_LEN]);
23 45
24impl Key { 46impl Key {
47 pub const fn new_unchecked_from(bytes: [u8; WG_KEY_LEN]) -> Self {
48 Self(bytes)
49 }
50
25 pub fn into_array(self) -> [u8; WG_KEY_LEN] { 51 pub fn into_array(self) -> [u8; WG_KEY_LEN] {
26 self.0 52 self.0
27 } 53 }
@@ -35,19 +61,12 @@ impl Key {
35 } 61 }
36 62
37 pub fn encode(&self) -> String { 63 pub fn encode(&self) -> String {
38 base64::engine::general_purpose::STANDARD.encode(&self.0) 64 let encoded = encode_wg_key(&self.0);
65 String::from_utf8(encoded.to_vec()).expect("WireGuard key should encode to valid UTF-8")
39 } 66 }
40 67
41 pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> { 68 pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> {
42 let decoded = base64::engine::general_purpose::STANDARD 69 decode_wg_key(encoded).map(Self)
43 .decode(encoded)
44 .map_err(|_| KeyDecodeError)?;
45 if decoded.len() != WG_KEY_LEN {
46 return Err(KeyDecodeError);
47 }
48 let mut key = [0u8; WG_KEY_LEN];
49 key.copy_from_slice(&decoded);
50 Ok(Key(key))
51 } 70 }
52 71
53 pub fn generate_pub_priv() -> (Self, Self) { 72 pub fn generate_pub_priv() -> (Self, Self) {
@@ -148,6 +167,114 @@ impl Key {
148 } 167 }
149} 168}
150 169
170pub(crate) const fn decode_wg_key_const(encoded: &str) -> [u8; WG_KEY_LEN] {
171 match decode_wg_key(encoded) {
172 Ok(out) => out,
173 Err(_) => panic!("invalid WireGuard key literal"),
174 }
175}
176
177pub(crate) const fn decode_wg_key(encoded: &str) -> Result<[u8; WG_KEY_LEN], KeyDecodeError> {
178 let bytes = encoded.as_bytes();
179 if bytes.len() != WG_KEY_B64_LEN {
180 return Err(KeyDecodeError::invalid_length());
181 }
182
183 if bytes[WG_KEY_B64_LEN - 1] != b'=' {
184 return Err(KeyDecodeError::invalid_padding());
185 }
186
187 let mut out = [0u8; WG_KEY_LEN];
188 let mut in_idx = 0;
189 let mut out_idx = 0;
190
191 while in_idx < 40 {
192 let a = match decode_b64_char(bytes[in_idx]) {
193 Ok(v) => v,
194 Err(err) => return Err(err),
195 };
196 let b = match decode_b64_char(bytes[in_idx + 1]) {
197 Ok(v) => v,
198 Err(err) => return Err(err),
199 };
200 let c = match decode_b64_char(bytes[in_idx + 2]) {
201 Ok(v) => v,
202 Err(err) => return Err(err),
203 };
204 let d = match decode_b64_char(bytes[in_idx + 3]) {
205 Ok(v) => v,
206 Err(err) => return Err(err),
207 };
208
209 out[out_idx] = (a << 2) | (b >> 4);
210 out[out_idx + 1] = (b << 4) | (c >> 2);
211 out[out_idx + 2] = (c << 6) | d;
212
213 in_idx += 4;
214 out_idx += 3;
215 }
216
217 let a = match decode_b64_char(bytes[40]) {
218 Ok(v) => v,
219 Err(err) => return Err(err),
220 };
221 let b = match decode_b64_char(bytes[41]) {
222 Ok(v) => v,
223 Err(err) => return Err(err),
224 };
225 let c = match decode_b64_char(bytes[42]) {
226 Ok(v) => v,
227 Err(err) => return Err(err),
228 };
229 if (c & 0b0000_0011) != 0 {
230 return Err(KeyDecodeError::invalid_padding());
231 }
232
233 out[30] = (a << 2) | (b >> 4);
234 out[31] = (b << 4) | (c >> 2);
235
236 Ok(out)
237}
238
239pub(crate) const fn encode_wg_key(key: &[u8; WG_KEY_LEN]) -> [u8; WG_KEY_B64_LEN] {
240 let mut out = [b'='; WG_KEY_B64_LEN];
241 let mut in_idx = 0;
242 let mut out_idx = 0;
243
244 while in_idx < 30 {
245 let b0 = key[in_idx];
246 let b1 = key[in_idx + 1];
247 let b2 = key[in_idx + 2];
248
249 out[out_idx] = BASE64_ALPHABET[(b0 >> 2) as usize];
250 out[out_idx + 1] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize];
251 out[out_idx + 2] = BASE64_ALPHABET[(((b1 & 0b0000_1111) << 2) | (b2 >> 6)) as usize];
252 out[out_idx + 3] = BASE64_ALPHABET[(b2 & 0b0011_1111) as usize];
253
254 in_idx += 3;
255 out_idx += 4;
256 }
257
258 let b0 = key[30];
259 let b1 = key[31];
260 out[40] = BASE64_ALPHABET[(b0 >> 2) as usize];
261 out[41] = BASE64_ALPHABET[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize];
262 out[42] = BASE64_ALPHABET[((b1 & 0b0000_1111) << 2) as usize];
263
264 out
265}
266
267const fn decode_b64_char(c: u8) -> Result<u8, KeyDecodeError> {
268 match c {
269 b'A'..=b'Z' => Ok(c - b'A'),
270 b'a'..=b'z' => Ok(c - b'a' + 26),
271 b'0'..=b'9' => Ok(c - b'0' + 52),
272 b'+' => Ok(62),
273 b'/' => Ok(63),
274 _ => Err(KeyDecodeError::invalid_base64_character()),
275 }
276}
277
151impl From<&[u8; WG_KEY_LEN]> for Key { 278impl From<&[u8; WG_KEY_LEN]> for Key {
152 fn from(k: &[u8; WG_KEY_LEN]) -> Self { 279 fn from(k: &[u8; WG_KEY_LEN]) -> Self {
153 Self(*k) 280 Self(*k)
@@ -170,22 +297,16 @@ impl std::str::FromStr for Key {
170 297
171impl std::fmt::Debug for Key { 298impl std::fmt::Debug for Key {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 let mut buf = [0; WG_KEY_LEN * 2]; 300 let buf = encode_wg_key(&self.0);
174 let len = base64::engine::general_purpose::STANDARD 301 let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8");
175 .encode_slice(&self.0, &mut buf)
176 .expect("base64 should encode");
177 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
178 f.debug_tuple("Key").field(&b64).finish() 302 f.debug_tuple("Key").field(&b64).finish()
179 } 303 }
180} 304}
181 305
182impl std::fmt::Display for Key { 306impl std::fmt::Display for Key {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 let mut buf = [0; WG_KEY_LEN * 2]; 308 let buf = encode_wg_key(&self.0);
185 let len = base64::engine::general_purpose::STANDARD 309 let b64 = std::str::from_utf8(&buf).expect("WireGuard key should encode to valid UTF-8");
186 .encode_slice(&self.0, &mut buf)
187 .expect("base64 should encode");
188 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
189 f.write_str(b64) 310 f.write_str(b64)
190 } 311 }
191} 312}
@@ -346,6 +467,8 @@ fn memzero_explicit<T>(v: &mut T) {
346mod tests { 467mod tests {
347 use super::Key; 468 use super::Key;
348 469
470 const CONST_KEY: Key = crate::key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=");
471
349 #[test] 472 #[test]
350 fn decode_encode_key() { 473 fn decode_encode_key() {
351 let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY="; 474 let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=";
@@ -370,4 +493,11 @@ mod tests {
370 ) 493 )
371 ); 494 );
372 } 495 }
496
497 #[test]
498 fn const_key_macro_matches_runtime_decode() {
499 let runtime =
500 Key::decode("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=").expect("valid key");
501 assert_eq!(CONST_KEY, runtime);
502 }
373} 503}
diff --git a/src/lib.rs b/src/lib.rs
index 8ba36eb..72fb58f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -23,6 +23,35 @@ pub use view::*;
23 23
24pub use ipnet; 24pub use ipnet;
25 25
26#[doc(hidden)]
27pub const fn __decode_wg_key_const(
28 encoded: &str,
29) -> [u8; netlink_packet_wireguard::WireguardAttribute::WG_KEY_LEN] {
30 key::decode_wg_key_const(encoded)
31}
32
33/// Creates a [`Key`] from a canonical WireGuard base64 literal at compile time.
34///
35/// The literal must be exactly 44 characters of standard base64 and end with `=`.
36///
37/// ```
38/// use wireguard::{key, Key};
39///
40/// const PRIVATE_KEY: Key = key!("6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=");
41/// ```
42///
43/// ```compile_fail
44/// use wireguard::key;
45///
46/// const _BAD: wireguard::Key = key!("not-a-wireguard-key");
47/// ```
48#[macro_export]
49macro_rules! key {
50 ($value:literal) => {
51 $crate::Key::new_unchecked_from($crate::__decode_wg_key_const($value))
52 };
53}
54
26pub type Result<T, E = Error> = std::result::Result<T, E>; 55pub type Result<T, E = Error> = std::result::Result<T, E>;
27 56
28#[derive(Debug)] 57#[derive(Debug)]