summaryrefslogtreecommitdiff
path: root/src/key.rs
diff options
context:
space:
mode:
authordiogo464 <[email protected]>2025-07-18 18:46:55 +0100
committerdiogo464 <[email protected]>2025-07-18 18:46:55 +0100
commit75ccbd675c22fb3275c5763518c3b97819db4c53 (patch)
tree1ff2a44abcac884875f80a44a569681e0f2d7a7d /src/key.rs
init
Diffstat (limited to 'src/key.rs')
-rw-r--r--src/key.rs372
1 files changed, 372 insertions, 0 deletions
diff --git a/src/key.rs b/src/key.rs
new file mode 100644
index 0000000..19bc127
--- /dev/null
+++ b/src/key.rs
@@ -0,0 +1,372 @@
1use base64::Engine;
2use netlink_packet_wireguard::constants::WG_KEY_LEN;
3use rand::{rngs::OsRng, RngCore};
4
5// Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c
6
7type Fe = [i64; 16];
8
9#[derive(Debug)]
10pub struct KeyDecodeError;
11
12impl std::fmt::Display for KeyDecodeError {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 write!(f, "Key decode error")
15 }
16}
17
18impl std::error::Error for KeyDecodeError {}
19
20#[derive(Clone, Default, Copy, PartialEq, Eq, Hash)]
21pub struct Key([u8; WG_KEY_LEN]);
22
23impl Key {
24 pub fn into_array(self) -> [u8; WG_KEY_LEN] {
25 self.0
26 }
27
28 pub fn as_array(&self) -> &[u8; WG_KEY_LEN] {
29 &self.0
30 }
31
32 pub fn as_slice(&self) -> &[u8] {
33 &self.0
34 }
35
36 pub fn encode(&self) -> String {
37 base64::engine::general_purpose::STANDARD.encode(&self.0)
38 }
39
40 pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> {
41 let decoded = base64::engine::general_purpose::STANDARD
42 .decode(encoded)
43 .map_err(|_| KeyDecodeError)?;
44 if decoded.len() != WG_KEY_LEN {
45 return Err(KeyDecodeError);
46 }
47 let mut key = [0u8; WG_KEY_LEN];
48 key.copy_from_slice(&decoded);
49 Ok(Key(key))
50 }
51
52 pub fn generate_pub_priv() -> (Self, Self) {
53 let private_key = Self::generate_private();
54 let public_key = Self::generate_public(&private_key);
55 (public_key, private_key)
56 }
57
58 pub fn generate_public(private: &Key) -> Self {
59 let mut r: i32 = Default::default();
60 let mut public_key: [u8; WG_KEY_LEN] = Default::default();
61 let mut z: [u8; WG_KEY_LEN] = private.0;
62 let mut a = fe_new_one(1);
63 let mut b = fe_new_one(9);
64 let mut c = fe_new_one(0);
65 let mut d = fe_new_one(1);
66 let mut e = fe_new_default();
67 let mut f = fe_new_default();
68
69 clamp_key(&mut z);
70
71 for i in (0..=254i32).rev() {
72 r = ((z[(i >> 3) as usize] >> (i & 7)) & 1) as i32;
73 cswap(&mut a, &mut b, r);
74 cswap(&mut c, &mut d, r);
75 add(&mut e, &a, &c);
76 {
77 let a_clone = a;
78 subtract(&mut a, &a_clone, &c);
79 }
80 add(&mut c, &b, &d);
81 {
82 let b_clone = b;
83 subtract(&mut b, &b_clone, &d);
84 }
85 multmod(&mut d, &e, &e);
86 multmod(&mut f, &a, &a);
87 {
88 let a_clone = a;
89 multmod(&mut a, &c, &a_clone);
90 }
91 multmod(&mut c, &b, &e);
92 add(&mut e, &a, &c);
93 {
94 let a_clone = a;
95 subtract(&mut a, &a_clone, &c);
96 }
97 multmod(&mut b, &a, &a);
98 subtract(&mut c, &d, &f);
99 //multmod(&mut a, &c, (const fe){ 0xdb41, 1 });
100 multmod(&mut a, &c, &fe_new_two(0xdb41, 1));
101 {
102 let a_clone = a;
103 add(&mut a, &a_clone, &d);
104 }
105 {
106 let c_clone = c;
107 multmod(&mut c, &c_clone, &a);
108 }
109 multmod(&mut a, &d, &f);
110 multmod(&mut d, &b, &fe_new_one(9));
111 multmod(&mut b, &e, &e);
112 cswap(&mut a, &mut b, r);
113 cswap(&mut c, &mut d, r);
114 }
115 {
116 let c_clone = c;
117 invert(&mut c, &c_clone);
118 }
119 {
120 let a_clone = a;
121 multmod(&mut a, &a_clone, &c);
122 }
123 pack(&mut public_key, &a);
124
125 memzero_explicit(&mut r);
126 memzero_explicit(&mut z);
127 memzero_explicit(&mut a);
128 memzero_explicit(&mut b);
129 memzero_explicit(&mut c);
130 memzero_explicit(&mut d);
131 memzero_explicit(&mut e);
132 memzero_explicit(&mut f);
133
134 Self(public_key)
135 }
136
137 pub fn generate_private() -> Self {
138 let mut preshared = Self::generate_preshared();
139 clamp_key(&mut preshared.0);
140 preshared
141 }
142
143 pub fn generate_preshared() -> Self {
144 let mut key = [0u8; WG_KEY_LEN];
145 OsRng.fill_bytes(&mut key);
146 Self(key)
147 }
148}
149
150impl From<&[u8; WG_KEY_LEN]> for Key {
151 fn from(k: &[u8; WG_KEY_LEN]) -> Self {
152 Self(*k)
153 }
154}
155
156impl From<[u8; WG_KEY_LEN]> for Key {
157 fn from(k: [u8; WG_KEY_LEN]) -> Self {
158 Self(k)
159 }
160}
161
162impl std::str::FromStr for Key {
163 type Err = KeyDecodeError;
164
165 fn from_str(s: &str) -> Result<Self, Self::Err> {
166 Key::decode(s)
167 }
168}
169
170impl std::fmt::Debug for Key {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 let mut buf = [0; WG_KEY_LEN * 2];
173 let len = base64::engine::general_purpose::STANDARD
174 .encode_slice(&self.0, &mut buf)
175 .expect("base64 should encode");
176 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
177 f.debug_tuple("Key").field(&b64).finish()
178 }
179}
180
181impl std::fmt::Display for Key {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 let mut buf = [0; WG_KEY_LEN * 2];
184 let len = base64::engine::general_purpose::STANDARD
185 .encode_slice(&self.0, &mut buf)
186 .expect("base64 should encode");
187 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
188 f.write_str(b64)
189 }
190}
191
192impl serde::Serialize for Key {
193 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: serde::Serializer,
196 {
197 serializer.serialize_str(&self.encode())
198 }
199}
200
201impl<'de> serde::Deserialize<'de> for Key {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: serde::Deserializer<'de>,
205 {
206 let s = String::deserialize(deserializer)?;
207 Key::decode(&s).map_err(serde::de::Error::custom)
208 }
209}
210
211fn fe_new_default() -> Fe {
212 Default::default()
213}
214
215fn fe_new_one(x: i64) -> Fe {
216 let mut fe = fe_new_default();
217 fe[0] = x;
218 fe
219}
220
221fn fe_new_two(x: i64, y: i64) -> Fe {
222 let mut fe = fe_new_default();
223 fe[0] = x;
224 fe[1] = y;
225 fe
226}
227
228fn clamp_key(key: &mut [u8]) {
229 key[31] = (key[31] & 127) | 64;
230 key[0] &= 248;
231}
232
233fn carry(o: &mut Fe) {
234 for i in 0..16 {
235 let x = if i == 15 { 38 } else { 1 };
236 o[(i + 1) % 16] += x * (o[i] >> 16);
237 o[i] &= 0xffff;
238 }
239}
240
241fn cswap(p: &mut Fe, q: &mut Fe, mut b: i32) {
242 let mut t: i64 = 0;
243 let mut c: i64 = !i64::from(b).wrapping_sub(1);
244
245 for i in 0..16 {
246 t = c & (p[i] ^ q[i]);
247 p[i] ^= t;
248 q[i] ^= t;
249 }
250
251 memzero_explicit(&mut t);
252 memzero_explicit(&mut c);
253 memzero_explicit(&mut b);
254}
255
256fn pack(o: &mut [u8; WG_KEY_LEN], n: &Fe) {
257 let mut b: i32 = 0;
258 let mut t: Fe = fe_new_default();
259 let mut m: Fe = fe_new_default();
260
261 t.copy_from_slice(n);
262 carry(&mut t);
263 carry(&mut t);
264 carry(&mut t);
265 for _ in 0..2 {
266 m[0] = t[0] - 0xffed;
267 for i in 1..15 {
268 m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
269 m[i - 1] &= 0xffff;
270 }
271 m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
272 b = ((m[15] >> 16) & 1) as i32;
273 m[14] &= 0xffff;
274 cswap(&mut t, &mut m, 1 - b);
275 }
276 for i in 0..16 {
277 o[2 * i] = (t[i] & 0xff) as u8;
278 o[2 * i + 1] = (t[i] >> 8) as u8;
279 }
280
281 memzero_explicit(&mut m);
282 memzero_explicit(&mut t);
283 memzero_explicit(&mut b);
284}
285
286fn add(o: &mut Fe, a: &Fe, b: &Fe) {
287 for i in 0..16 {
288 o[i] = a[i] + b[i];
289 }
290}
291
292fn subtract(o: &mut Fe, a: &Fe, b: &Fe) {
293 for i in 0..16 {
294 o[i] = a[i] - b[i];
295 }
296}
297
298fn multmod(o: &mut Fe, a: &Fe, b: &Fe) {
299 let mut t: [i64; 31] = [0; 31];
300
301 for i in 0..16 {
302 for j in 0..16 {
303 t[i + j] += a[i] * b[j];
304 }
305 }
306 for i in 0..15 {
307 t[i] += 38 * t[i + 16];
308 }
309 o.copy_from_slice(&t[..16]);
310 carry(o);
311 carry(o);
312
313 memzero_explicit(&mut t);
314}
315
316fn invert(o: &mut Fe, i: &Fe) {
317 let mut c: Fe = fe_new_default();
318
319 c.copy_from_slice(i);
320 for a in (0..=253).rev() {
321 {
322 let c_clone = c;
323 multmod(&mut c, &c_clone, &c_clone);
324 }
325 if a != 2 && a != 4 {
326 {
327 let c_clone = c;
328 multmod(&mut c, &c_clone, i);
329 }
330 }
331 }
332 o.copy_from_slice(&c);
333
334 memzero_explicit(&mut c);
335}
336
337fn memzero_explicit<T>(v: &mut T) {
338 unsafe {
339 let zeroed = std::mem::zeroed();
340 std::ptr::write_volatile(v as *mut _, zeroed);
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::Key;
347
348 #[test]
349 fn decode_encode_key() {
350 let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=";
351 let key = super::Key::decode(key).unwrap();
352 let key = key.encode();
353 assert_eq!(key, "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=");
354 }
355
356 #[test]
357 fn generate_public_key() {
358 assert_eq!(
359 Key::decode("3D5lgnI9ztvnuyWDm7dlBDgm6xr0+WVWPoo6HIfzHRU=").unwrap(),
360 Key::generate_public(
361 &Key::decode("+Op7voRskU0Zm2fHFR/5tVE+PJtnwn6cbnme71jXt0E=").unwrap()
362 )
363 );
364
365 assert_eq!(
366 Key::decode("//eq/raPUE4+sOlTlozx76XEE+W8L0bUqNfyg9IpX0Q=").unwrap(),
367 Key::generate_public(
368 &Key::decode("8OD8QPWH/a0D5LmbWVnb7bwFq4Fghy/QUEFkIhyL/EI=").unwrap()
369 )
370 );
371 }
372}