summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordiogo464 <[email protected]>2025-07-18 18:46:55 +0100
committerdiogo464 <[email protected]>2025-07-18 18:46:55 +0100
commit75ccbd675c22fb3275c5763518c3b97819db4c53 (patch)
tree1ff2a44abcac884875f80a44a569681e0f2d7a7d
init
-rw-r--r--.gitignore1
-rw-r--r--Cargo.toml22
-rw-r--r--src/conf.rs692
-rw-r--r--src/key.rs372
-rw-r--r--src/lib.rs401
-rw-r--r--src/setup.rs212
-rw-r--r--src/view.rs130
7 files changed, 1830 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ea8c4bf
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
/target
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..b2c8f54
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,22 @@
1[package]
2name = "wireguard"
3version = "0.0.0"
4edition = "2021"
5
6# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7
8[dependencies]
9anyhow = "1.0.79"
10base64 = "0.21.7"
11futures = "0.3.30"
12genetlink = "0.2.5"
13ipnet = { version = "2.9.0", features = ["serde"] }
14netlink-packet-core = "=0.7.0"
15netlink-packet-generic = "=0.3.3"
16netlink-packet-route = "=0.18.1"
17netlink-packet-wireguard = "=0.2.3"
18rand = "0.8.5"
19rtnetlink = "=0.14.0"
20serde = "1.0.195"
21tokio = { version = "1.35.1", features = ["full"] }
22tracing = "0.1.40"
diff --git a/src/conf.rs b/src/conf.rs
new file mode 100644
index 0000000..b6e49a0
--- /dev/null
+++ b/src/conf.rs
@@ -0,0 +1,692 @@
1use std::{fmt::Write, str::FromStr};
2
3use ipnet::IpNet;
4
5use super::Key;
6
7const FIELD_PRIVATE_KEY: &str = "PrivateKey";
8const FIELD_LISTEN_PORT: &str = "ListenPort";
9const FIELD_FWMARK: &str = "FwMark";
10const FIELD_PUBLIC_KEY: &str = "PublicKey";
11const FIELD_PRE_SHARED_KEY: &str = "PresharedKey";
12const FIELD_ALLOWED_IPS: &str = "AllowedIPs";
13const FIELD_ENDPOINT: &str = "Endpoint";
14const FIELD_PERSISTENT_KEEPALIVE: &str = "PersistentKeepalive";
15
16// wg-quick fields
17const FIELD_ADDRESS: &str = "Address";
18const FIELD_DNS: &str = "DNS";
19
20macro_rules! header {
21 ($dest:expr, $h:expr) => {
22 writeln!($dest, "[{}]", $h).unwrap();
23 };
24}
25
26macro_rules! field {
27 ($dest:expr, $n:expr, $v:expr) => {
28 writeln!($dest, "{} = {}", $n, $v).unwrap();
29 };
30}
31
32macro_rules! field_csv {
33 ($dest:expr, $n:expr, $v:expr) => {
34 if !$v.is_empty() {
35 write!($dest, "{} = ", $n).unwrap();
36 let mut comma = false;
37 for e in $v.iter() {
38 if comma {
39 write!($dest, ", ").unwrap();
40 } else {
41 comma = true;
42 }
43 write!($dest, "{}", e).unwrap();
44 }
45 writeln!($dest).unwrap();
46 }
47 };
48}
49
50macro_rules! field_opt {
51 ($dest:expr, $n:expr, $v:expr) => {
52 if let Some(ref v) = $v {
53 field!($dest, $n, v);
54 }
55 };
56}
57
58#[derive(Debug, Clone)]
59pub struct WgInterface {
60 pub private_key: Key,
61 pub address: Vec<IpNet>,
62 pub listen_port: Option<u16>,
63 pub fw_mark: Option<u32>,
64 pub dns: Option<String>,
65}
66
67#[derive(Debug, Clone)]
68pub struct WgPeer {
69 pub public_key: Key,
70 pub preshared_key: Option<Key>,
71 pub allowed_ips: Vec<IpNet>,
72 pub endpoint: Option<String>,
73 pub keep_alive: Option<u16>,
74}
75
76impl WgPeer {
77 pub fn builder(public_key: Key) -> WgPeerBuilder {
78 WgPeerBuilder::new(public_key)
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct WgConf {
84 pub interface: WgInterface,
85 pub peers: Vec<WgPeer>,
86}
87
88impl WgConf {
89 pub fn builder() -> WgConfBuilder {
90 WgConfBuilder::new()
91 }
92}
93
94#[derive(Debug)]
95pub struct WgPeerBuilder {
96 pub public_key: Key,
97 pub preshared_key: Option<Key>,
98 pub allowed_ips: Vec<IpNet>,
99 pub endpoint: Option<String>,
100 pub keep_alive: Option<u16>,
101}
102
103impl WgPeerBuilder {
104 pub fn new(public_key: Key) -> WgPeerBuilder {
105 WgPeerBuilder {
106 public_key,
107 preshared_key: None,
108 allowed_ips: Vec::new(),
109 endpoint: None,
110 keep_alive: None,
111 }
112 }
113
114 pub fn preshared_key(mut self, preshared_key: Key) -> Self {
115 self.preshared_key = Some(preshared_key);
116 self
117 }
118
119 pub fn allowed_ip(mut self, allowed_ip: IpNet) -> Self {
120 self.allowed_ips.push(allowed_ip);
121 self
122 }
123
124 pub fn allowed_ips(mut self, allowed_ips: impl IntoIterator<Item = IpNet>) -> Self {
125 self.allowed_ips.extend(allowed_ips);
126 self
127 }
128
129 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
130 self.endpoint = Some(endpoint.into());
131 self
132 }
133
134 pub fn endpoint_opt(mut self, endpoint: Option<impl Into<String>>) -> Self {
135 if let Some(endpoint) = endpoint {
136 self.endpoint = Some(endpoint.into());
137 }
138 self
139 }
140
141 pub fn keep_alive(mut self, keep_alive: u16) -> Self {
142 self.keep_alive = Some(keep_alive);
143 self
144 }
145
146 pub fn build(self) -> WgPeer {
147 WgPeer {
148 public_key: self.public_key,
149 preshared_key: self.preshared_key,
150 allowed_ips: self.allowed_ips,
151 endpoint: self.endpoint,
152 keep_alive: self.keep_alive,
153 }
154 }
155}
156
157#[derive(Debug)]
158pub struct WgConfBuilder {
159 private_key: Option<Key>,
160 address: Vec<IpNet>,
161 listen_port: Option<u16>,
162 fw_mark: Option<u32>,
163 dns: Option<String>,
164 peers: Vec<WgPeer>,
165}
166
167impl WgConfBuilder {
168 pub fn new() -> Self {
169 WgConfBuilder {
170 private_key: None,
171 address: Vec::new(),
172 listen_port: None,
173 fw_mark: None,
174 dns: None,
175 peers: Vec::new(),
176 }
177 }
178
179 pub fn private_key(mut self, private_key: Key) -> Self {
180 self.private_key = Some(private_key);
181 self
182 }
183
184 pub fn address(mut self, address: impl Into<IpNet>) -> Self {
185 self.address.push(address.into());
186 self
187 }
188
189 pub fn addresses(mut self, addresses: impl IntoIterator<Item = IpNet>) -> Self {
190 self.address.extend(addresses);
191 self
192 }
193
194 pub fn listen_port(mut self, listen_port: u16) -> Self {
195 self.listen_port = Some(listen_port);
196 self
197 }
198
199 pub fn fw_mark(mut self, fw_mark: u32) -> Self {
200 self.fw_mark = Some(fw_mark);
201 self
202 }
203
204 pub fn dns(mut self, dns: impl Into<String>) -> Self {
205 self.dns = Some(dns.into());
206 self
207 }
208
209 pub fn dns_opt(mut self, dns: Option<impl Into<String>>) -> Self {
210 if let Some(dns) = dns {
211 self.dns = Some(dns.into());
212 }
213 self
214 }
215
216 pub fn peer(mut self, peer: WgPeer) -> Self {
217 self.peers.push(peer);
218 self
219 }
220
221 pub fn peers(mut self, peers: impl IntoIterator<Item = WgPeer>) -> Self {
222 self.peers.extend(peers);
223 self
224 }
225
226 pub fn build(self) -> WgConf {
227 WgConf {
228 interface: WgInterface {
229 private_key: self.private_key.unwrap_or_else(Key::generate_private),
230 address: self.address,
231 listen_port: self.listen_port,
232 fw_mark: self.fw_mark,
233 dns: self.dns,
234 },
235 peers: self.peers,
236 }
237 }
238}
239
240#[derive(Default)]
241struct PartialConf {
242 interface: Option<WgInterface>,
243 peers: Vec<WgPeer>,
244}
245
246pub fn parse_conf(conf: &str) -> anyhow::Result<WgConf> {
247 let mut iter = conf.lines().filter_map(|l| {
248 // remove whitespace on the sides
249 let l = l.trim();
250 // remove the comment
251 let (l, _) = l.rsplit_once("#").unwrap_or((l, ""));
252 if l.is_empty() {
253 None
254 } else {
255 Some(l)
256 }
257 });
258
259 let mut partial = PartialConf::default();
260 parse_partial(&mut partial, &mut iter)?;
261
262 match partial.interface {
263 Some(interface) => Ok(WgConf {
264 interface,
265 peers: partial.peers,
266 }),
267 None => Err(anyhow::anyhow!("no interface found")),
268 }
269}
270
271pub fn serialize_conf(conf: &WgConf) -> String {
272 let mut conf_str = String::new();
273 header!(conf_str, "Interface");
274 field!(conf_str, FIELD_PRIVATE_KEY, conf.interface.private_key);
275 field_csv!(conf_str, FIELD_ADDRESS, conf.interface.address);
276 field_opt!(conf_str, FIELD_LISTEN_PORT, conf.interface.listen_port);
277 field_opt!(conf_str, FIELD_FWMARK, conf.interface.fw_mark);
278 field_opt!(conf_str, FIELD_DNS, conf.interface.dns);
279 for peer in conf.peers.iter() {
280 writeln!(conf_str).unwrap();
281 header!(conf_str, "Peer");
282 field!(conf_str, FIELD_PUBLIC_KEY, peer.public_key);
283 field_opt!(conf_str, FIELD_PRE_SHARED_KEY, peer.preshared_key);
284 field_csv!(conf_str, FIELD_ALLOWED_IPS, peer.allowed_ips);
285 field_opt!(conf_str, FIELD_ENDPOINT, peer.endpoint);
286 field_opt!(conf_str, FIELD_PERSISTENT_KEEPALIVE, peer.keep_alive);
287 }
288 conf_str
289}
290
291fn parse_partial<'s, I: Iterator<Item = &'s str>>(
292 cfg: &mut PartialConf,
293 iter: &mut I,
294) -> anyhow::Result<()> {
295 match iter.next() {
296 Some("[Interface]") => parse_interface(cfg, iter),
297 Some("[Peer]") => parse_peer(cfg, iter),
298 Some(line) => Err(anyhow::anyhow!("unexpected line: {}", line)),
299 None => Err(anyhow::anyhow!("unexpected end of file")),
300 }
301}
302
303fn parse_interface<'s, I: Iterator<Item = &'s str>>(
304 cfg: &mut PartialConf,
305 iter: &mut I,
306) -> anyhow::Result<()> {
307 let mut private_key = None;
308 let mut address = Vec::new();
309 let mut listen_port = None;
310 let mut fw_mark = None;
311 let mut dns = None;
312 let mut peer_next = false;
313
314 if cfg.interface.is_some() {
315 anyhow::bail!("cannot have more than one interface");
316 }
317
318 while let Some(line) = iter.next() {
319 if line == "[Peer]" {
320 peer_next = true;
321 break;
322 }
323
324 let (key, value) = parse_key_value(line)?;
325 match key {
326 FIELD_PRIVATE_KEY => private_key = Some(value.parse()?),
327 FIELD_LISTEN_PORT => listen_port = Some(value.parse()?),
328 FIELD_FWMARK => fw_mark = Some(value.parse()?),
329 FIELD_ADDRESS => address = parse_csv(value)?,
330 FIELD_DNS => dns = Some(value.to_string()),
331 _ => anyhow::bail!("unexpected key: {}", key),
332 }
333 }
334
335 cfg.interface = Some(WgInterface {
336 private_key: private_key.ok_or_else(|| anyhow::anyhow!("interface missing private key"))?,
337 address,
338 listen_port,
339 fw_mark,
340 dns,
341 });
342
343 if peer_next {
344 parse_peer(cfg, iter)
345 } else {
346 Ok(())
347 }
348}
349
350fn parse_peer<'s, I: Iterator<Item = &'s str>>(
351 cfg: &mut PartialConf,
352 iter: &mut I,
353) -> anyhow::Result<()> {
354 let mut public_key = None;
355 let mut preshared_key = None;
356 let mut allowed_ips = Vec::new();
357 let mut endpoint = None;
358 let mut keep_alive = None;
359 let mut interface_next = false;
360 let mut peer_next = false;
361
362 while let Some(line) = iter.next() {
363 if line == "[Interface]" {
364 interface_next = true;
365 break;
366 }
367 if line == "[Peer]" {
368 peer_next = true;
369 break;
370 }
371
372 let (key, value) = parse_key_value(line)?;
373 match key {
374 FIELD_PUBLIC_KEY => public_key = Some(value.parse()?),
375 FIELD_PRE_SHARED_KEY => preshared_key = Some(value.parse()?),
376 FIELD_ALLOWED_IPS => allowed_ips = parse_csv(value)?,
377 FIELD_ENDPOINT => endpoint = Some(value.to_string()),
378 FIELD_PERSISTENT_KEEPALIVE => keep_alive = Some(value.parse()?),
379 _ => anyhow::bail!("unexpected key: {}", key),
380 }
381 }
382
383 cfg.peers.push(WgPeer {
384 public_key: public_key.ok_or_else(|| anyhow::anyhow!("peer missing public key"))?,
385 preshared_key,
386 allowed_ips,
387 endpoint,
388 keep_alive,
389 });
390
391 if interface_next {
392 parse_interface(cfg, iter)
393 } else if peer_next {
394 parse_peer(cfg, iter)
395 } else {
396 Ok(())
397 }
398}
399
400fn parse_key_value<'s>(line: &'s str) -> anyhow::Result<(&'s str, &'s str)> {
401 line.split_once("=")
402 .map(|(k, v)| (k.trim(), v.trim()))
403 .ok_or_else(|| anyhow::anyhow!("invalid line: {}", line))
404}
405
406fn parse_csv<
407 'v,
408 T: FromStr<Err = impl std::error::Error + std::marker::Sync + std::marker::Send + 'static>,
409>(
410 value: &'v str,
411) -> anyhow::Result<Vec<T>> {
412 let mut values = Vec::new();
413 for v in value.split(',').map(str::trim) {
414 values.push(v.parse()?);
415 }
416 Ok(values)
417}
418
419#[cfg(test)]
420mod tests {
421 use std::net::Ipv4Addr;
422
423 use ipnet::{IpNet, Ipv4Net};
424
425 use crate::Key;
426
427 use super::{WgConfBuilder, WgPeerBuilder};
428
429 const TEST_CONF_1: &str = r#"
430 [Interface]
431 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
432 ListenPort = 51820
433
434 [Peer]
435 PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
436 Endpoint = 192.95.5.67:1234
437 AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
438
439 [Peer]
440 PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
441 Endpoint = [2607:5300:60:6b0::c05f:543]:2468
442 AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
443
444 [Peer]
445 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
446 Endpoint = test.wireguard.com:18981
447
448 AllowedIPs = 10.10.10.230/32
449 PersistentKeepalive = 54
450"#;
451
452 const TEST_CONF_2: &str = r#"
453 [Peer]
454 PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
455 Endpoint = 192.95.5.67:1234
456 AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
457
458 [Peer]
459 PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
460 Endpoint = [2607:5300:60:6b0::c05f:543]:2468
461 AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
462
463 [Interface]
464 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
465 ListenPort = 51820
466
467 [Peer]
468 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
469 Endpoint = test.wireguard.com:18981
470
471 AllowedIPs = 10.10.10.230/32
472 PersistentKeepalive = 54
473"#;
474
475 const TEST_CONF_3: &str = r#"
476 [Interface]
477 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
478 ListenPort = 51820
479
480 [Interface]
481 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
482 ListenPort = 51821
483
484 [Peer]
485 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
486 Endpoint = test.wireguard.com:18981
487 AllowedIPs = 10.10.10.230/32
488"#;
489
490 const TEST_CONF_4: &str = "";
491
492 const TEST_CONF_5: &str = r#"
493 PublicKey = 1
494
495 [Interface]
496 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
497 ListenPort = 51820
498
499 [Peer]
500 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
501 Endpoint = test.wireguard.com:18981
502 AllowedIPs = 10.10.10.230/32
503"#;
504
505 const TEST_CONF_6: &str = r#"
506 [Interface]
507 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
508 ListenPort = 51820
509 Unknown = 1
510
511 [Peer]
512 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
513 Endpoint = test.wireguard.com:18981
514 AllowedIPs = 10.10.10.230/32
515"#;
516
517 const TEST_CONF_7: &str = r#"
518 [Interface]
519 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
520 ListenPort = 51820
521"#;
522
523 #[test]
524 fn parse_config() {
525 parse_config_1_and_2(TEST_CONF_1);
526 }
527
528 #[test]
529 fn parse_config_out_of_order_interface() {
530 parse_config_1_and_2(TEST_CONF_2);
531 }
532
533 #[test]
534 #[should_panic]
535 fn parse_config_duplicate_interface() {
536 super::parse_conf(TEST_CONF_3).unwrap();
537 }
538
539 #[test]
540 #[should_panic]
541 fn parse_config_empty() {
542 super::parse_conf(TEST_CONF_4).unwrap();
543 }
544
545 #[test]
546 #[should_panic]
547 fn parse_config_out_of_order_field() {
548 super::parse_conf(TEST_CONF_5).unwrap();
549 }
550
551 #[test]
552 #[should_panic]
553 fn parse_config_unkown_field() {
554 super::parse_conf(TEST_CONF_6).unwrap();
555 }
556
557 #[test]
558 fn parse_config_no_peers() {
559 let cfg = super::parse_conf(TEST_CONF_7).unwrap();
560
561 assert_eq!(
562 "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=",
563 cfg.interface.private_key.to_string(),
564 );
565 assert_eq!(Some(51820), cfg.interface.listen_port);
566 assert_eq!(None, cfg.interface.fw_mark);
567
568 assert_eq!(0, cfg.peers.len());
569 }
570
571 fn parse_config_1_and_2(conf_str: &str) {
572 let cfg = super::parse_conf(conf_str).unwrap();
573
574 assert_eq!(
575 "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=",
576 cfg.interface.private_key.to_string()
577 );
578 assert_eq!(Some(51820), cfg.interface.listen_port);
579 assert_eq!(None, cfg.interface.fw_mark);
580
581 assert_eq!(3, cfg.peers.len());
582
583 let peer = &cfg.peers[0];
584 assert_eq!(
585 "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=",
586 peer.public_key.to_string()
587 );
588 assert_eq!(None, peer.preshared_key);
589 assert_eq!(2, peer.allowed_ips.len());
590 assert_eq!(Some("192.95.5.67:1234"), peer.endpoint.as_deref());
591 assert_eq!(None, peer.keep_alive);
592
593 let peer = &cfg.peers[1];
594 assert_eq!(
595 "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=",
596 peer.public_key.to_string()
597 );
598 assert_eq!(None, peer.preshared_key);
599 assert_eq!(2, peer.allowed_ips.len());
600 assert_eq!(
601 Some("[2607:5300:60:6b0::c05f:543]:2468"),
602 peer.endpoint.as_deref()
603 );
604 assert_eq!(None, peer.keep_alive);
605
606 let peer = &cfg.peers[2];
607 assert_eq!(
608 "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=",
609 peer.public_key.to_string()
610 );
611 assert_eq!(None, peer.preshared_key);
612 assert_eq!(1, peer.allowed_ips.len());
613 assert_eq!(Some("test.wireguard.com:18981"), peer.endpoint.as_deref());
614 assert_eq!(Some(54), peer.keep_alive);
615 }
616
617 #[test]
618 fn serialize_no_peers() {
619 let key = Key::decode("yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=").unwrap();
620 let conf = WgConfBuilder::new()
621 .fw_mark(10)
622 .listen_port(6000)
623 .dns("dns.example.com")
624 .address(IpNet::V4(
625 Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 5), 24).unwrap(),
626 ))
627 .private_key(key)
628 .build();
629 let serialized = super::serialize_conf(&conf);
630
631 assert_eq!(
632 r#"[Interface]
633PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
634Address = 10.0.0.5/24
635ListenPort = 6000
636FwMark = 10
637DNS = dns.example.com
638"#,
639 serialized
640 );
641 }
642
643 #[test]
644 fn serialize_with_peers() {
645 let key1 = Key::decode("xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=").unwrap();
646 let key2 = Key::decode("TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=").unwrap();
647 let key3 = Key::decode("gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=").unwrap();
648
649 let conf = WgConfBuilder::new()
650 .private_key(key1)
651 .listen_port(51820)
652 .dns("dns.example.com")
653 .peer(
654 WgPeerBuilder::new(key2)
655 .keep_alive(10)
656 .endpoint("test.wireguard.com:18981")
657 .allowed_ip(ipnet::IpNet::V4(
658 Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 2), 24).unwrap(),
659 ))
660 .build(),
661 )
662 .peer(
663 WgPeerBuilder::new(key3)
664 .allowed_ip(ipnet::IpNet::V4(
665 Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 3), 24).unwrap(),
666 ))
667 .build(),
668 )
669 .build();
670
671 let serialized = super::serialize_conf(&conf);
672
673 assert_eq!(
674 r#"[Interface]
675PrivateKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
676ListenPort = 51820
677DNS = dns.example.com
678
679[Peer]
680PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
681AllowedIPs = 10.0.0.2/24
682Endpoint = test.wireguard.com:18981
683PersistentKeepalive = 10
684
685[Peer]
686PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
687AllowedIPs = 10.0.0.3/24
688"#,
689 serialized
690 );
691 }
692}
diff --git a/src/key.rs b/src/key.rs
new file mode 100644
index 0000000..19bc127
--- /dev/null
+++ b/src/key.rs
@@ -0,0 +1,372 @@
1use base64::Engine;
2use netlink_packet_wireguard::constants::WG_KEY_LEN;
3use rand::{rngs::OsRng, RngCore};
4
5// Code from: https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library/wireguard.c
6
7type Fe = [i64; 16];
8
9#[derive(Debug)]
10pub struct KeyDecodeError;
11
12impl std::fmt::Display for KeyDecodeError {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 write!(f, "Key decode error")
15 }
16}
17
18impl std::error::Error for KeyDecodeError {}
19
20#[derive(Clone, Default, Copy, PartialEq, Eq, Hash)]
21pub struct Key([u8; WG_KEY_LEN]);
22
23impl Key {
24 pub fn into_array(self) -> [u8; WG_KEY_LEN] {
25 self.0
26 }
27
28 pub fn as_array(&self) -> &[u8; WG_KEY_LEN] {
29 &self.0
30 }
31
32 pub fn as_slice(&self) -> &[u8] {
33 &self.0
34 }
35
36 pub fn encode(&self) -> String {
37 base64::engine::general_purpose::STANDARD.encode(&self.0)
38 }
39
40 pub fn decode(encoded: &str) -> Result<Self, KeyDecodeError> {
41 let decoded = base64::engine::general_purpose::STANDARD
42 .decode(encoded)
43 .map_err(|_| KeyDecodeError)?;
44 if decoded.len() != WG_KEY_LEN {
45 return Err(KeyDecodeError);
46 }
47 let mut key = [0u8; WG_KEY_LEN];
48 key.copy_from_slice(&decoded);
49 Ok(Key(key))
50 }
51
52 pub fn generate_pub_priv() -> (Self, Self) {
53 let private_key = Self::generate_private();
54 let public_key = Self::generate_public(&private_key);
55 (public_key, private_key)
56 }
57
58 pub fn generate_public(private: &Key) -> Self {
59 let mut r: i32 = Default::default();
60 let mut public_key: [u8; WG_KEY_LEN] = Default::default();
61 let mut z: [u8; WG_KEY_LEN] = private.0;
62 let mut a = fe_new_one(1);
63 let mut b = fe_new_one(9);
64 let mut c = fe_new_one(0);
65 let mut d = fe_new_one(1);
66 let mut e = fe_new_default();
67 let mut f = fe_new_default();
68
69 clamp_key(&mut z);
70
71 for i in (0..=254i32).rev() {
72 r = ((z[(i >> 3) as usize] >> (i & 7)) & 1) as i32;
73 cswap(&mut a, &mut b, r);
74 cswap(&mut c, &mut d, r);
75 add(&mut e, &a, &c);
76 {
77 let a_clone = a;
78 subtract(&mut a, &a_clone, &c);
79 }
80 add(&mut c, &b, &d);
81 {
82 let b_clone = b;
83 subtract(&mut b, &b_clone, &d);
84 }
85 multmod(&mut d, &e, &e);
86 multmod(&mut f, &a, &a);
87 {
88 let a_clone = a;
89 multmod(&mut a, &c, &a_clone);
90 }
91 multmod(&mut c, &b, &e);
92 add(&mut e, &a, &c);
93 {
94 let a_clone = a;
95 subtract(&mut a, &a_clone, &c);
96 }
97 multmod(&mut b, &a, &a);
98 subtract(&mut c, &d, &f);
99 //multmod(&mut a, &c, (const fe){ 0xdb41, 1 });
100 multmod(&mut a, &c, &fe_new_two(0xdb41, 1));
101 {
102 let a_clone = a;
103 add(&mut a, &a_clone, &d);
104 }
105 {
106 let c_clone = c;
107 multmod(&mut c, &c_clone, &a);
108 }
109 multmod(&mut a, &d, &f);
110 multmod(&mut d, &b, &fe_new_one(9));
111 multmod(&mut b, &e, &e);
112 cswap(&mut a, &mut b, r);
113 cswap(&mut c, &mut d, r);
114 }
115 {
116 let c_clone = c;
117 invert(&mut c, &c_clone);
118 }
119 {
120 let a_clone = a;
121 multmod(&mut a, &a_clone, &c);
122 }
123 pack(&mut public_key, &a);
124
125 memzero_explicit(&mut r);
126 memzero_explicit(&mut z);
127 memzero_explicit(&mut a);
128 memzero_explicit(&mut b);
129 memzero_explicit(&mut c);
130 memzero_explicit(&mut d);
131 memzero_explicit(&mut e);
132 memzero_explicit(&mut f);
133
134 Self(public_key)
135 }
136
137 pub fn generate_private() -> Self {
138 let mut preshared = Self::generate_preshared();
139 clamp_key(&mut preshared.0);
140 preshared
141 }
142
143 pub fn generate_preshared() -> Self {
144 let mut key = [0u8; WG_KEY_LEN];
145 OsRng.fill_bytes(&mut key);
146 Self(key)
147 }
148}
149
150impl From<&[u8; WG_KEY_LEN]> for Key {
151 fn from(k: &[u8; WG_KEY_LEN]) -> Self {
152 Self(*k)
153 }
154}
155
156impl From<[u8; WG_KEY_LEN]> for Key {
157 fn from(k: [u8; WG_KEY_LEN]) -> Self {
158 Self(k)
159 }
160}
161
162impl std::str::FromStr for Key {
163 type Err = KeyDecodeError;
164
165 fn from_str(s: &str) -> Result<Self, Self::Err> {
166 Key::decode(s)
167 }
168}
169
170impl std::fmt::Debug for Key {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 let mut buf = [0; WG_KEY_LEN * 2];
173 let len = base64::engine::general_purpose::STANDARD
174 .encode_slice(&self.0, &mut buf)
175 .expect("base64 should encode");
176 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
177 f.debug_tuple("Key").field(&b64).finish()
178 }
179}
180
181impl std::fmt::Display for Key {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 let mut buf = [0; WG_KEY_LEN * 2];
184 let len = base64::engine::general_purpose::STANDARD
185 .encode_slice(&self.0, &mut buf)
186 .expect("base64 should encode");
187 let b64 = std::str::from_utf8(&buf[..len]).expect("base64 should be valid utf-8");
188 f.write_str(b64)
189 }
190}
191
192impl serde::Serialize for Key {
193 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: serde::Serializer,
196 {
197 serializer.serialize_str(&self.encode())
198 }
199}
200
201impl<'de> serde::Deserialize<'de> for Key {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: serde::Deserializer<'de>,
205 {
206 let s = String::deserialize(deserializer)?;
207 Key::decode(&s).map_err(serde::de::Error::custom)
208 }
209}
210
211fn fe_new_default() -> Fe {
212 Default::default()
213}
214
215fn fe_new_one(x: i64) -> Fe {
216 let mut fe = fe_new_default();
217 fe[0] = x;
218 fe
219}
220
221fn fe_new_two(x: i64, y: i64) -> Fe {
222 let mut fe = fe_new_default();
223 fe[0] = x;
224 fe[1] = y;
225 fe
226}
227
228fn clamp_key(key: &mut [u8]) {
229 key[31] = (key[31] & 127) | 64;
230 key[0] &= 248;
231}
232
233fn carry(o: &mut Fe) {
234 for i in 0..16 {
235 let x = if i == 15 { 38 } else { 1 };
236 o[(i + 1) % 16] += x * (o[i] >> 16);
237 o[i] &= 0xffff;
238 }
239}
240
241fn cswap(p: &mut Fe, q: &mut Fe, mut b: i32) {
242 let mut t: i64 = 0;
243 let mut c: i64 = !i64::from(b).wrapping_sub(1);
244
245 for i in 0..16 {
246 t = c & (p[i] ^ q[i]);
247 p[i] ^= t;
248 q[i] ^= t;
249 }
250
251 memzero_explicit(&mut t);
252 memzero_explicit(&mut c);
253 memzero_explicit(&mut b);
254}
255
256fn pack(o: &mut [u8; WG_KEY_LEN], n: &Fe) {
257 let mut b: i32 = 0;
258 let mut t: Fe = fe_new_default();
259 let mut m: Fe = fe_new_default();
260
261 t.copy_from_slice(n);
262 carry(&mut t);
263 carry(&mut t);
264 carry(&mut t);
265 for _ in 0..2 {
266 m[0] = t[0] - 0xffed;
267 for i in 1..15 {
268 m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
269 m[i - 1] &= 0xffff;
270 }
271 m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
272 b = ((m[15] >> 16) & 1) as i32;
273 m[14] &= 0xffff;
274 cswap(&mut t, &mut m, 1 - b);
275 }
276 for i in 0..16 {
277 o[2 * i] = (t[i] & 0xff) as u8;
278 o[2 * i + 1] = (t[i] >> 8) as u8;
279 }
280
281 memzero_explicit(&mut m);
282 memzero_explicit(&mut t);
283 memzero_explicit(&mut b);
284}
285
286fn add(o: &mut Fe, a: &Fe, b: &Fe) {
287 for i in 0..16 {
288 o[i] = a[i] + b[i];
289 }
290}
291
292fn subtract(o: &mut Fe, a: &Fe, b: &Fe) {
293 for i in 0..16 {
294 o[i] = a[i] - b[i];
295 }
296}
297
298fn multmod(o: &mut Fe, a: &Fe, b: &Fe) {
299 let mut t: [i64; 31] = [0; 31];
300
301 for i in 0..16 {
302 for j in 0..16 {
303 t[i + j] += a[i] * b[j];
304 }
305 }
306 for i in 0..15 {
307 t[i] += 38 * t[i + 16];
308 }
309 o.copy_from_slice(&t[..16]);
310 carry(o);
311 carry(o);
312
313 memzero_explicit(&mut t);
314}
315
316fn invert(o: &mut Fe, i: &Fe) {
317 let mut c: Fe = fe_new_default();
318
319 c.copy_from_slice(i);
320 for a in (0..=253).rev() {
321 {
322 let c_clone = c;
323 multmod(&mut c, &c_clone, &c_clone);
324 }
325 if a != 2 && a != 4 {
326 {
327 let c_clone = c;
328 multmod(&mut c, &c_clone, i);
329 }
330 }
331 }
332 o.copy_from_slice(&c);
333
334 memzero_explicit(&mut c);
335}
336
337fn memzero_explicit<T>(v: &mut T) {
338 unsafe {
339 let zeroed = std::mem::zeroed();
340 std::ptr::write_volatile(v as *mut _, zeroed);
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::Key;
347
348 #[test]
349 fn decode_encode_key() {
350 let key = "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=";
351 let key = super::Key::decode(key).unwrap();
352 let key = key.encode();
353 assert_eq!(key, "6F5rOtYE5A2KcXTKf9jdzWa9Y/kuV5gPS3LcKlxmOnY=");
354 }
355
356 #[test]
357 fn generate_public_key() {
358 assert_eq!(
359 Key::decode("3D5lgnI9ztvnuyWDm7dlBDgm6xr0+WVWPoo6HIfzHRU=").unwrap(),
360 Key::generate_public(
361 &Key::decode("+Op7voRskU0Zm2fHFR/5tVE+PJtnwn6cbnme71jXt0E=").unwrap()
362 )
363 );
364
365 assert_eq!(
366 Key::decode("//eq/raPUE4+sOlTlozx76XEE+W8L0bUqNfyg9IpX0Q=").unwrap(),
367 Key::generate_public(
368 &Key::decode("8OD8QPWH/a0D5LmbWVnb7bwFq4Fghy/QUEFkIhyL/EI=").unwrap()
369 )
370 );
371 }
372}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..c47d618
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,401 @@
1mod conf;
2mod key;
3mod setup;
4mod view;
5
6use std::borrow::Cow;
7
8use futures::{StreamExt, TryStreamExt};
9use genetlink::{GenetlinkError, GenetlinkHandle};
10use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST};
11use netlink_packet_generic::GenlMessage;
12use netlink_packet_route::{
13 link::{InfoKind, LinkAttribute, LinkInfo},
14 route::RouteScope,
15};
16use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd};
17use rtnetlink::Handle;
18
19pub use conf::*;
20pub use key::*;
21pub use setup::*;
22pub use view::*;
23
24pub use ipnet;
25
26pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28#[derive(Debug)]
29pub struct Error {
30 inner: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
31 message: Option<Cow<'static, str>>,
32}
33
34impl std::fmt::Display for Error {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match (self.message.as_ref(), self.inner.as_ref()) {
37 (Some(message), Some(inner)) => write!(f, "{}: {}", message, inner),
38 (Some(message), None) => write!(f, "{}", message),
39 (None, Some(inner)) => write!(f, "{}", inner),
40 (None, None) => write!(f, "Unknown error"),
41 }
42 }
43}
44
45impl std::error::Error for Error {}
46
47impl From<std::io::Error> for Error {
48 fn from(inner: std::io::Error) -> Self {
49 Self {
50 inner: Some(Box::new(inner)),
51 message: None,
52 }
53 }
54}
55
56impl From<GenetlinkError> for Error {
57 fn from(inner: GenetlinkError) -> Self {
58 Self {
59 inner: Some(Box::new(inner)),
60 message: None,
61 }
62 }
63}
64
65impl From<rtnetlink::Error> for Error {
66 fn from(inner: rtnetlink::Error) -> Self {
67 Self {
68 inner: Some(Box::new(inner)),
69 message: None,
70 }
71 }
72}
73
74impl Error {
75 pub(crate) fn with_message<E>(inner: E, message: impl Into<Cow<'static, str>>) -> Self
76 where
77 E: std::error::Error + Send + Sync + 'static,
78 {
79 Self {
80 inner: Some(Box::new(inner)),
81 message: Some(message.into()),
82 }
83 }
84
85 pub(crate) fn message(message: impl Into<Cow<'static, str>>) -> Self {
86 Self {
87 inner: None,
88 message: Some(message.into()),
89 }
90 }
91}
92
93struct Link {
94 pub name: String,
95 pub ifindex: u32,
96}
97
98pub struct WireGuard {
99 rt_handle: Handle,
100 gen_handle: GenetlinkHandle,
101}
102
103#[allow(clippy::await_holding_refcell_ref)]
104impl WireGuard {
105 pub async fn new() -> Result<Self> {
106 let (rt_connection, rt_handle, _) = rtnetlink::new_connection()?;
107 tokio::spawn(rt_connection);
108 let (gen_connection, gen_handle, _) = genetlink::new_connection()?;
109 tokio::spawn(gen_connection);
110
111 Ok(Self {
112 rt_handle,
113 gen_handle,
114 })
115 }
116
117 pub async fn create_device(
118 &mut self,
119 device_name: &str,
120 descriptor: DeviceDescriptor,
121 ) -> Result<()> {
122 tracing::trace!("Creating device {}", device_name);
123 self.link_create(device_name).await?;
124 let link = self.link_get_by_name(device_name).await?;
125 self.link_up(link.ifindex).await?;
126 self.setup_device(device_name, descriptor).await?;
127 tracing::trace!("Created device");
128 Ok(())
129 }
130
131 pub async fn reload_device(
132 &mut self,
133 device_name: &str,
134 descriptor: DeviceDescriptor,
135 ) -> Result<()> {
136 tracing::trace!("Reloading device {}", device_name);
137 self.setup_device(device_name, descriptor).await?;
138 tracing::trace!("Reloaded device");
139 Ok(())
140 }
141
142 pub async fn remove_device(&self, device_name: &str) -> Result<()> {
143 tracing::trace!("Removing device {}", device_name);
144 let link = self.link_get_by_name(device_name).await?;
145 self.link_down(link.ifindex).await?;
146 self.link_delete(link.ifindex).await?;
147 tracing::trace!("Removed device");
148 Ok(())
149 }
150
151 pub async fn view_device(&mut self, device_name: &str) -> Result<DeviceView> {
152 let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
153 cmd: WireguardCmd::GetDevice,
154 nlas: vec![WgDeviceAttrs::IfName(device_name.to_string())],
155 });
156 let mut nlmsg = NetlinkMessage::from(genlmsg);
157 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_DUMP;
158
159 let mut resp = self.gen_handle.request(nlmsg).await?;
160 while let Some(result) = resp.next().await {
161 let rx_packet = result.map_err(|e| Error::with_message(e, "Error decoding packet"))?;
162 match rx_packet.payload {
163 NetlinkPayload::InnerMessage(genlmsg) => {
164 return device_view_from_payload(genlmsg.payload);
165 }
166 NetlinkPayload::Error(e) => {
167 return Err(Error::message(format!("Error: {:?}", e)));
168 }
169 _ => (),
170 };
171 }
172 unreachable!();
173 }
174
175 pub async fn view_device_if_exists(&mut self, device_name: &str) -> Result<Option<DeviceView>> {
176 let device_names = self.list_device_names().await?;
177 if device_names.iter().any(|name| name == device_name) {
178 Ok(Some(self.view_device(device_name).await?))
179 } else {
180 Ok(None)
181 }
182 }
183
184 pub async fn view_devices(&mut self) -> Result<Vec<DeviceView>> {
185 let device_names = self.list_device_names().await?;
186 let mut devices = Vec::with_capacity(device_names.len());
187 for name in device_names {
188 let device = self.view_device(&name).await?;
189 devices.push(device);
190 }
191 Ok(devices)
192 }
193
194 pub async fn list_device_names(&self) -> Result<Vec<String>> {
195 Ok(self
196 .link_list()
197 .await?
198 .into_iter()
199 .map(|link| link.name)
200 .collect())
201 }
202
203 async fn setup_device(
204 &mut self,
205 device_name: &str,
206 descriptor: DeviceDescriptor,
207 ) -> Result<()> {
208 tracing::trace!("Setting up device {}", device_name);
209
210 let link = self.link_get_by_name(device_name).await?;
211 for addr in descriptor.addresses.iter() {
212 self.link_add_address(link.ifindex, *addr).await?;
213 }
214
215 let message = descriptor.into_wireguard(device_name.to_string());
216 let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(message);
217 let mut nlmsg = NetlinkMessage::from(genlmsg);
218 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
219
220 let mut stream = self.gen_handle.request(nlmsg).await?;
221 while (stream.next().await).is_some() {}
222 tracing::trace!("Device setup");
223
224 Ok(())
225 }
226
227 async fn link_create(&self, name: &str) -> Result<()> {
228 let mut msg = self.rt_handle.link().add().replace();
229 msg.message_mut()
230 .attributes
231 .push(LinkAttribute::LinkInfo(vec![LinkInfo::Kind(
232 InfoKind::Wireguard,
233 )]));
234 msg.message_mut()
235 .attributes
236 .push(LinkAttribute::IfName(name.to_string()));
237 msg.execute().await?;
238 Ok(())
239 }
240
241 async fn link_delete(&self, ifindex: u32) -> Result<()> {
242 self.rt_handle.link().del(ifindex).execute().await?;
243 Ok(())
244 }
245
246 async fn link_up(&self, ifindex: u32) -> Result<()> {
247 tracing::trace!("Bringing up interface {}", ifindex);
248 self.rt_handle.link().set(ifindex).up().execute().await?;
249 Ok(())
250 }
251
252 async fn link_down(&self, ifindex: u32) -> Result<()> {
253 tracing::trace!("Bringing down interface {}", ifindex);
254 self.rt_handle.link().set(ifindex).down().execute().await?;
255 Ok(())
256 }
257
258 async fn link_add_address(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> {
259 tracing::trace!("Adding address {} to {}", net, ifindex);
260 self.rt_handle
261 .address()
262 .add(ifindex, net.addr(), net.prefix_len())
263 .replace()
264 .execute()
265 .await?;
266 Ok(())
267 }
268
269 //TODO: return Result<Option<Link>>?
270 async fn link_get_by_name(&self, name: &str) -> Result<Link> {
271 let link = self
272 .link_list()
273 .await?
274 .into_iter()
275 .find(|link| link.name == name)
276 .ok_or_else(|| Error::message(format!("Link {} not found", name)))?;
277 tracing::debug!("device {} has index {}", name, link.ifindex);
278 Ok(link)
279 }
280
281 async fn link_list(&self) -> Result<Vec<Link>> {
282 let mut links = Vec::new();
283 let mut link_stream = self.rt_handle.link().get().execute();
284 while let Some(link) = link_stream.try_next().await? {
285 let mut is_wireguard = false;
286 let mut link_name = None;
287 for nla in link.attributes {
288 match nla {
289 LinkAttribute::IfName(name) => link_name = Some(name),
290 LinkAttribute::LinkInfo(infos) => {
291 for info in infos {
292 if let netlink_packet_route::link::LinkInfo::Kind(kind) = info {
293 if kind == netlink_packet_route::link::InfoKind::Wireguard {
294 is_wireguard = true;
295 break;
296 }
297 }
298 }
299 }
300 _ => {}
301 }
302 if is_wireguard && link_name.is_some() {
303 links.push(Link {
304 name: link_name.unwrap(),
305 ifindex: link.header.index,
306 });
307 break;
308 }
309 }
310 }
311 Ok(links)
312 }
313
314 #[allow(unused)]
315 async fn route_add(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> {
316 tracing::trace!("Adding route {} to {}", net, ifindex);
317 let request = self
318 .rt_handle
319 .route()
320 .add()
321 .scope(RouteScope::Link)
322 .output_interface(ifindex)
323 .replace();
324
325 match net.addr() {
326 std::net::IpAddr::V4(ip) => {
327 request
328 .v4()
329 .destination_prefix(ip, net.prefix_len())
330 .execute()
331 .await
332 }
333 std::net::IpAddr::V6(ip) => {
334 request
335 .v6()
336 .destination_prefix(ip, net.prefix_len())
337 .execute()
338 .await
339 }
340 }?;
341
342 Ok(())
343 }
344}
345
346pub async fn create_device(
347 device_name: impl AsRef<str>,
348 device_descriptor: DeviceDescriptor,
349) -> Result<()> {
350 tracing::info!("creating device {}", device_name.as_ref());
351 tracing::debug!("device descriptor: {:#?}", device_descriptor);
352 let mut wireguard = WireGuard::new().await?;
353 wireguard
354 .create_device(device_name.as_ref(), device_descriptor)
355 .await
356}
357
358pub async fn reload_device(
359 device_name: impl AsRef<str>,
360 device_descriptor: DeviceDescriptor,
361) -> Result<()> {
362 tracing::info!("reloading device {}", device_name.as_ref());
363 tracing::debug!("device descriptor: {:#?}", device_descriptor);
364 let mut wireguard = WireGuard::new().await?;
365 wireguard
366 .reload_device(device_name.as_ref(), device_descriptor)
367 .await
368}
369
370pub async fn device_exists(name: impl AsRef<str>) -> Result<bool> {
371 tracing::info!("checking if device {} exists", name.as_ref());
372 let mut wireguard = WireGuard::new().await?;
373 wireguard
374 .view_device_if_exists(name.as_ref())
375 .await
376 .map(|x| x.is_some())
377}
378
379pub async fn remove_device(name: impl AsRef<str>) -> Result<()> {
380 tracing::info!("removing device {}", name.as_ref());
381 let wireguard = WireGuard::new().await?;
382 wireguard.remove_device(name.as_ref()).await
383}
384
385pub async fn view_device(name: impl AsRef<str>) -> Result<DeviceView> {
386 tracing::info!("viewing device {}", name.as_ref());
387 let mut wireguard = WireGuard::new().await?;
388 wireguard.view_device(name.as_ref()).await
389}
390
391pub async fn view_device_if_exists(name: impl AsRef<str>) -> Result<Option<DeviceView>> {
392 tracing::info!("viewing device {}", name.as_ref());
393 let mut wireguard = WireGuard::new().await?;
394 wireguard.view_device_if_exists(name.as_ref()).await
395}
396
397pub async fn list_device_names() -> Result<Vec<String>> {
398 tracing::info!("listing device names");
399 let wireguard = WireGuard::new().await?;
400 wireguard.list_device_names().await
401}
diff --git a/src/setup.rs b/src/setup.rs
new file mode 100644
index 0000000..e7d454c
--- /dev/null
+++ b/src/setup.rs
@@ -0,0 +1,212 @@
1use std::net::{IpAddr, SocketAddr};
2
3use ipnet::IpNet;
4use netlink_packet_wireguard::{
5 constants::{AF_INET, AF_INET6, WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REPLACE_ALLOWEDIPS},
6 nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
7 Wireguard, WireguardCmd,
8};
9
10use super::Key;
11
12#[derive(Debug)]
13pub struct PeerDescriptor {
14 pub(super) public_key: Key,
15 pub(super) preshared_key: Option<Key>,
16 pub(super) endpoint: Option<SocketAddr>,
17 pub(super) keepalive: Option<u16>,
18 pub(super) allowed_ips: Option<Vec<IpNet>>,
19}
20
21impl PeerDescriptor {
22 pub fn new(public_key: Key) -> Self {
23 Self {
24 public_key,
25 preshared_key: None,
26 endpoint: None,
27 keepalive: None,
28 allowed_ips: None,
29 }
30 }
31
32 pub fn preshared_key_optional(mut self, preshared_key: Option<Key>) -> Self {
33 self.preshared_key = preshared_key;
34 self
35 }
36
37 pub fn preshared_key(mut self, preshared_key: Key) -> Self {
38 self.preshared_key = Some(preshared_key);
39 self
40 }
41
42 pub fn endpoint_optional(mut self, endpoint: Option<SocketAddr>) -> Self {
43 self.endpoint = endpoint;
44 self
45 }
46
47 pub fn endpoint(mut self, endpoint: SocketAddr) -> Self {
48 self.endpoint = Some(endpoint);
49 self
50 }
51
52 pub fn keepalive_optional(mut self, keepalive: Option<u16>) -> Self {
53 self.keepalive = keepalive;
54 self
55 }
56
57 pub fn keepalive(mut self, keepalive: u16) -> Self {
58 self.keepalive = Some(keepalive);
59 self
60 }
61
62 pub fn allowed_ip_optional(self, allowed_ip: Option<IpNet>) -> Self {
63 if let Some(allowed_ip) = allowed_ip {
64 self.allowed_ip(allowed_ip)
65 } else {
66 self
67 }
68 }
69
70 pub fn allowed_ip(mut self, allowed_ip: IpNet) -> Self {
71 let mut allowed_ips = self.allowed_ips.take().unwrap_or_default();
72 allowed_ips.push(allowed_ip);
73 self.allowed_ips = Some(allowed_ips);
74 self
75 }
76
77 pub fn allowed_ips_optional(self, allowed_ips: Option<Vec<IpNet>>) -> Self {
78 if let Some(allowed_ips) = allowed_ips {
79 self.allowed_ips(allowed_ips)
80 } else {
81 self
82 }
83 }
84
85 pub fn allowed_ips(mut self, allowed_ips: Vec<IpNet>) -> Self {
86 self.allowed_ips = Some(allowed_ips);
87 self
88 }
89
90 pub(super) fn into_wireguard(self) -> WgPeer {
91 let mut nlas = Vec::new();
92 nlas.push(WgPeerAttrs::PublicKey(self.public_key.into_array()));
93 nlas.extend(
94 self.preshared_key
95 .map(|key| WgPeerAttrs::PresharedKey(key.into_array())),
96 );
97 nlas.extend(self.endpoint.map(WgPeerAttrs::Endpoint));
98 nlas.extend(self.keepalive.map(WgPeerAttrs::PersistentKeepalive));
99 nlas.extend(self.allowed_ips.map(|allowed_ips| {
100 WgPeerAttrs::AllowedIps(allowed_ips.into_iter().map(ipnet_to_wg).collect())
101 }));
102 nlas.push(WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS));
103 WgPeer(nlas)
104 }
105}
106
107#[derive(Debug)]
108pub struct DeviceDescriptor {
109 pub(super) addresses: Vec<IpNet>,
110 pub(super) private_key: Option<Key>,
111 pub(super) listen_port: Option<u16>,
112 pub(super) fwmark: Option<u32>,
113 pub(super) peers: Option<Vec<PeerDescriptor>>,
114}
115
116impl Default for DeviceDescriptor {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl DeviceDescriptor {
123 pub fn new() -> Self {
124 Self {
125 addresses: Vec::default(),
126 private_key: None,
127 listen_port: None,
128 fwmark: None,
129 peers: None,
130 }
131 }
132
133 pub fn address(mut self, address: IpNet) -> Self {
134 self.addresses.push(address);
135 self
136 }
137
138 pub fn addresses(mut self, addresses: impl IntoIterator<Item = IpNet>) -> Self {
139 self.addresses.extend(addresses);
140 self
141 }
142
143 pub fn private_key(mut self, key: Key) -> Self {
144 self.private_key = Some(key);
145 self
146 }
147
148 pub fn listen_port(mut self, port: u16) -> Self {
149 self.listen_port = Some(port);
150 self
151 }
152
153 pub fn listen_port_optional(mut self, port: Option<u16>) -> Self {
154 self.listen_port = port;
155 self
156 }
157
158 pub fn fwmark(mut self, fwmark: u32) -> Self {
159 self.fwmark = Some(fwmark);
160 self
161 }
162
163 pub fn peer(mut self, peer: PeerDescriptor) -> Self {
164 let mut p = self.peers.take().unwrap_or_default();
165 p.push(peer);
166 self.peers = Some(p);
167 self
168 }
169
170 pub fn peers(mut self, peers: impl IntoIterator<Item = PeerDescriptor>) -> Self {
171 let mut p = self.peers.take().unwrap_or_default();
172 p.extend(peers);
173 self.peers = Some(p);
174 self
175 }
176
177 pub(super) fn into_wireguard(self, device_name: String) -> Wireguard {
178 let mut nlas = Vec::new();
179 nlas.push(WgDeviceAttrs::IfName(device_name));
180 nlas.extend(
181 self.private_key
182 .map(|key| WgDeviceAttrs::PrivateKey(key.into_array())),
183 );
184 nlas.extend(self.listen_port.map(WgDeviceAttrs::ListenPort));
185 nlas.extend(self.fwmark.map(WgDeviceAttrs::Fwmark));
186 nlas.extend(self.peers.map(|peers| {
187 WgDeviceAttrs::Peers(
188 peers
189 .into_iter()
190 .map(PeerDescriptor::into_wireguard)
191 .collect(),
192 )
193 }));
194 nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
195
196 Wireguard {
197 cmd: WireguardCmd::SetDevice,
198 nlas,
199 }
200 }
201}
202
203fn ipnet_to_wg(net: IpNet) -> WgAllowedIp {
204 let mut nlas = Vec::default();
205 nlas.push(WgAllowedIpAttrs::Cidr(net.prefix_len()));
206 nlas.push(WgAllowedIpAttrs::IpAddr(net.addr()));
207 match net.addr() {
208 IpAddr::V4(_) => nlas.push(WgAllowedIpAttrs::Family(AF_INET)),
209 IpAddr::V6(_) => nlas.push(WgAllowedIpAttrs::Family(AF_INET6)),
210 }
211 WgAllowedIp(nlas)
212}
diff --git a/src/view.rs b/src/view.rs
new file mode 100644
index 0000000..2858811
--- /dev/null
+++ b/src/view.rs
@@ -0,0 +1,130 @@
1use std::{net::SocketAddr, time::SystemTime};
2
3use ipnet::IpNet;
4use netlink_packet_wireguard::{
5 nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
6 Wireguard,
7};
8
9use super::{Error, Key, Result};
10
11#[derive(Debug, Clone)]
12pub struct DeviceView {
13 pub name: String,
14 pub ifindex: u32,
15 pub private_key: Option<Key>,
16 pub public_key: Option<Key>,
17 pub listen_port: u16,
18 pub fwmark: u32,
19 pub peers: Vec<PeerView>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PeerView {
24 pub public_key: Key,
25 pub preshared_key: Option<Key>,
26 pub endpoint: Option<SocketAddr>,
27 pub persistent_keepalive: Option<u16>,
28 pub last_handshake: SystemTime,
29 pub rx_bytes: u64,
30 pub tx_bytes: u64,
31 pub allowed_ips: Vec<IpNet>,
32}
33
34pub(super) fn device_view_from_payload(wg: Wireguard) -> Result<DeviceView> {
35 let mut if_index = None;
36 let mut if_name = None;
37 let mut private_key = None;
38 let mut public_key = None;
39 let mut listen_port = None;
40 let mut fwmark = None;
41 let mut peers = None;
42
43 for nla in wg.nlas {
44 match nla {
45 WgDeviceAttrs::IfIndex(v) => if_index = Some(v),
46 WgDeviceAttrs::IfName(v) => if_name = Some(v),
47 WgDeviceAttrs::PrivateKey(v) => private_key = Some(Key::from(v)),
48 WgDeviceAttrs::PublicKey(v) => public_key = Some(Key::from(v)),
49 WgDeviceAttrs::ListenPort(v) => listen_port = Some(v),
50 WgDeviceAttrs::Fwmark(v) => fwmark = Some(v),
51 WgDeviceAttrs::Peers(v) => peers = Some(peers_from_wg_peers(v)?),
52 _ => {}
53 }
54 }
55
56 Ok(DeviceView {
57 name: if_name.ok_or_else(|| Error::message("missing if_name"))?,
58 ifindex: if_index.ok_or_else(|| Error::message("missing if_index"))?,
59 private_key,
60 public_key,
61 listen_port: listen_port.ok_or_else(|| Error::message("missing listen_port"))?,
62 fwmark: fwmark.ok_or_else(|| Error::message("missing fwmark"))?,
63 peers: peers.unwrap_or_default(),
64 })
65}
66
67fn peers_from_wg_peers(wg_peers: Vec<WgPeer>) -> Result<Vec<PeerView>> {
68 let mut peers = Vec::with_capacity(wg_peers.len());
69 for wg_peer in wg_peers {
70 peers.push(peer_from_wg_peer(wg_peer)?);
71 }
72 Ok(peers)
73}
74
75fn peer_from_wg_peer(wg_peer: WgPeer) -> Result<PeerView> {
76 let mut public_key = None;
77 let mut preshared_key = None;
78 let mut endpoint = None;
79 let mut persistent_keepalive = None;
80 let mut last_handshake = None;
81 let mut rx_bytes = None;
82 let mut tx_bytes = None;
83 let mut allowed_ips = Vec::default();
84
85 for attr in wg_peer.iter() {
86 match attr {
87 WgPeerAttrs::PublicKey(v) => public_key = Some(Key::from(v)),
88 WgPeerAttrs::PresharedKey(v) => preshared_key = Some(Key::from(v)),
89 WgPeerAttrs::Endpoint(v) => endpoint = Some(*v),
90 WgPeerAttrs::PersistentKeepalive(v) => persistent_keepalive = Some(*v),
91 WgPeerAttrs::LastHandshake(v) => last_handshake = Some(*v),
92 WgPeerAttrs::RxBytes(v) => rx_bytes = Some(*v),
93 WgPeerAttrs::TxBytes(v) => tx_bytes = Some(*v),
94 WgPeerAttrs::AllowedIps(v) => {
95 for ip in v {
96 allowed_ips.push(ipnet_from_wg(ip)?);
97 }
98 }
99 _ => {}
100 }
101 }
102
103 Ok(PeerView {
104 public_key: public_key.ok_or_else(|| Error::message("missing public_key"))?,
105 preshared_key,
106 endpoint,
107 persistent_keepalive,
108 last_handshake: last_handshake.ok_or_else(|| Error::message("missing last_handshake"))?,
109 rx_bytes: rx_bytes.ok_or_else(|| Error::message("missing rx_bytes"))?,
110 tx_bytes: tx_bytes.ok_or_else(|| Error::message("missing tx_bytes"))?,
111 allowed_ips,
112 })
113}
114
115fn ipnet_from_wg(wg: &WgAllowedIp) -> Result<IpNet> {
116 let mut ip = None;
117 let mut prefix = None;
118 for attr in wg.iter() {
119 match attr {
120 WgAllowedIpAttrs::IpAddr(v) => ip = Some(*v),
121 WgAllowedIpAttrs::Cidr(v) => prefix = Some(*v),
122 _ => {}
123 }
124 }
125 Ok(IpNet::new(
126 ip.ok_or_else(|| Error::message("missing ip"))?,
127 prefix.ok_or_else(|| Error::message("missing prefix"))?,
128 )
129 .map_err(|e| Error::with_message(e, "invalid ipnet"))?)
130}