summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs401
1 files changed, 401 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..c47d618
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,401 @@
1mod conf;
2mod key;
3mod setup;
4mod view;
5
6use std::borrow::Cow;
7
8use futures::{StreamExt, TryStreamExt};
9use genetlink::{GenetlinkError, GenetlinkHandle};
10use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST};
11use netlink_packet_generic::GenlMessage;
12use netlink_packet_route::{
13 link::{InfoKind, LinkAttribute, LinkInfo},
14 route::RouteScope,
15};
16use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd};
17use rtnetlink::Handle;
18
19pub use conf::*;
20pub use key::*;
21pub use setup::*;
22pub use view::*;
23
24pub use ipnet;
25
26pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28#[derive(Debug)]
29pub struct Error {
30 inner: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
31 message: Option<Cow<'static, str>>,
32}
33
34impl std::fmt::Display for Error {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match (self.message.as_ref(), self.inner.as_ref()) {
37 (Some(message), Some(inner)) => write!(f, "{}: {}", message, inner),
38 (Some(message), None) => write!(f, "{}", message),
39 (None, Some(inner)) => write!(f, "{}", inner),
40 (None, None) => write!(f, "Unknown error"),
41 }
42 }
43}
44
45impl std::error::Error for Error {}
46
47impl From<std::io::Error> for Error {
48 fn from(inner: std::io::Error) -> Self {
49 Self {
50 inner: Some(Box::new(inner)),
51 message: None,
52 }
53 }
54}
55
56impl From<GenetlinkError> for Error {
57 fn from(inner: GenetlinkError) -> Self {
58 Self {
59 inner: Some(Box::new(inner)),
60 message: None,
61 }
62 }
63}
64
65impl From<rtnetlink::Error> for Error {
66 fn from(inner: rtnetlink::Error) -> Self {
67 Self {
68 inner: Some(Box::new(inner)),
69 message: None,
70 }
71 }
72}
73
74impl Error {
75 pub(crate) fn with_message<E>(inner: E, message: impl Into<Cow<'static, str>>) -> Self
76 where
77 E: std::error::Error + Send + Sync + 'static,
78 {
79 Self {
80 inner: Some(Box::new(inner)),
81 message: Some(message.into()),
82 }
83 }
84
85 pub(crate) fn message(message: impl Into<Cow<'static, str>>) -> Self {
86 Self {
87 inner: None,
88 message: Some(message.into()),
89 }
90 }
91}
92
93struct Link {
94 pub name: String,
95 pub ifindex: u32,
96}
97
98pub struct WireGuard {
99 rt_handle: Handle,
100 gen_handle: GenetlinkHandle,
101}
102
103#[allow(clippy::await_holding_refcell_ref)]
104impl WireGuard {
105 pub async fn new() -> Result<Self> {
106 let (rt_connection, rt_handle, _) = rtnetlink::new_connection()?;
107 tokio::spawn(rt_connection);
108 let (gen_connection, gen_handle, _) = genetlink::new_connection()?;
109 tokio::spawn(gen_connection);
110
111 Ok(Self {
112 rt_handle,
113 gen_handle,
114 })
115 }
116
117 pub async fn create_device(
118 &mut self,
119 device_name: &str,
120 descriptor: DeviceDescriptor,
121 ) -> Result<()> {
122 tracing::trace!("Creating device {}", device_name);
123 self.link_create(device_name).await?;
124 let link = self.link_get_by_name(device_name).await?;
125 self.link_up(link.ifindex).await?;
126 self.setup_device(device_name, descriptor).await?;
127 tracing::trace!("Created device");
128 Ok(())
129 }
130
131 pub async fn reload_device(
132 &mut self,
133 device_name: &str,
134 descriptor: DeviceDescriptor,
135 ) -> Result<()> {
136 tracing::trace!("Reloading device {}", device_name);
137 self.setup_device(device_name, descriptor).await?;
138 tracing::trace!("Reloaded device");
139 Ok(())
140 }
141
142 pub async fn remove_device(&self, device_name: &str) -> Result<()> {
143 tracing::trace!("Removing device {}", device_name);
144 let link = self.link_get_by_name(device_name).await?;
145 self.link_down(link.ifindex).await?;
146 self.link_delete(link.ifindex).await?;
147 tracing::trace!("Removed device");
148 Ok(())
149 }
150
151 pub async fn view_device(&mut self, device_name: &str) -> Result<DeviceView> {
152 let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
153 cmd: WireguardCmd::GetDevice,
154 nlas: vec![WgDeviceAttrs::IfName(device_name.to_string())],
155 });
156 let mut nlmsg = NetlinkMessage::from(genlmsg);
157 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_DUMP;
158
159 let mut resp = self.gen_handle.request(nlmsg).await?;
160 while let Some(result) = resp.next().await {
161 let rx_packet = result.map_err(|e| Error::with_message(e, "Error decoding packet"))?;
162 match rx_packet.payload {
163 NetlinkPayload::InnerMessage(genlmsg) => {
164 return device_view_from_payload(genlmsg.payload);
165 }
166 NetlinkPayload::Error(e) => {
167 return Err(Error::message(format!("Error: {:?}", e)));
168 }
169 _ => (),
170 };
171 }
172 unreachable!();
173 }
174
175 pub async fn view_device_if_exists(&mut self, device_name: &str) -> Result<Option<DeviceView>> {
176 let device_names = self.list_device_names().await?;
177 if device_names.iter().any(|name| name == device_name) {
178 Ok(Some(self.view_device(device_name).await?))
179 } else {
180 Ok(None)
181 }
182 }
183
184 pub async fn view_devices(&mut self) -> Result<Vec<DeviceView>> {
185 let device_names = self.list_device_names().await?;
186 let mut devices = Vec::with_capacity(device_names.len());
187 for name in device_names {
188 let device = self.view_device(&name).await?;
189 devices.push(device);
190 }
191 Ok(devices)
192 }
193
194 pub async fn list_device_names(&self) -> Result<Vec<String>> {
195 Ok(self
196 .link_list()
197 .await?
198 .into_iter()
199 .map(|link| link.name)
200 .collect())
201 }
202
203 async fn setup_device(
204 &mut self,
205 device_name: &str,
206 descriptor: DeviceDescriptor,
207 ) -> Result<()> {
208 tracing::trace!("Setting up device {}", device_name);
209
210 let link = self.link_get_by_name(device_name).await?;
211 for addr in descriptor.addresses.iter() {
212 self.link_add_address(link.ifindex, *addr).await?;
213 }
214
215 let message = descriptor.into_wireguard(device_name.to_string());
216 let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(message);
217 let mut nlmsg = NetlinkMessage::from(genlmsg);
218 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
219
220 let mut stream = self.gen_handle.request(nlmsg).await?;
221 while (stream.next().await).is_some() {}
222 tracing::trace!("Device setup");
223
224 Ok(())
225 }
226
227 async fn link_create(&self, name: &str) -> Result<()> {
228 let mut msg = self.rt_handle.link().add().replace();
229 msg.message_mut()
230 .attributes
231 .push(LinkAttribute::LinkInfo(vec![LinkInfo::Kind(
232 InfoKind::Wireguard,
233 )]));
234 msg.message_mut()
235 .attributes
236 .push(LinkAttribute::IfName(name.to_string()));
237 msg.execute().await?;
238 Ok(())
239 }
240
241 async fn link_delete(&self, ifindex: u32) -> Result<()> {
242 self.rt_handle.link().del(ifindex).execute().await?;
243 Ok(())
244 }
245
246 async fn link_up(&self, ifindex: u32) -> Result<()> {
247 tracing::trace!("Bringing up interface {}", ifindex);
248 self.rt_handle.link().set(ifindex).up().execute().await?;
249 Ok(())
250 }
251
252 async fn link_down(&self, ifindex: u32) -> Result<()> {
253 tracing::trace!("Bringing down interface {}", ifindex);
254 self.rt_handle.link().set(ifindex).down().execute().await?;
255 Ok(())
256 }
257
258 async fn link_add_address(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> {
259 tracing::trace!("Adding address {} to {}", net, ifindex);
260 self.rt_handle
261 .address()
262 .add(ifindex, net.addr(), net.prefix_len())
263 .replace()
264 .execute()
265 .await?;
266 Ok(())
267 }
268
269 //TODO: return Result<Option<Link>>?
270 async fn link_get_by_name(&self, name: &str) -> Result<Link> {
271 let link = self
272 .link_list()
273 .await?
274 .into_iter()
275 .find(|link| link.name == name)
276 .ok_or_else(|| Error::message(format!("Link {} not found", name)))?;
277 tracing::debug!("device {} has index {}", name, link.ifindex);
278 Ok(link)
279 }
280
281 async fn link_list(&self) -> Result<Vec<Link>> {
282 let mut links = Vec::new();
283 let mut link_stream = self.rt_handle.link().get().execute();
284 while let Some(link) = link_stream.try_next().await? {
285 let mut is_wireguard = false;
286 let mut link_name = None;
287 for nla in link.attributes {
288 match nla {
289 LinkAttribute::IfName(name) => link_name = Some(name),
290 LinkAttribute::LinkInfo(infos) => {
291 for info in infos {
292 if let netlink_packet_route::link::LinkInfo::Kind(kind) = info {
293 if kind == netlink_packet_route::link::InfoKind::Wireguard {
294 is_wireguard = true;
295 break;
296 }
297 }
298 }
299 }
300 _ => {}
301 }
302 if is_wireguard && link_name.is_some() {
303 links.push(Link {
304 name: link_name.unwrap(),
305 ifindex: link.header.index,
306 });
307 break;
308 }
309 }
310 }
311 Ok(links)
312 }
313
314 #[allow(unused)]
315 async fn route_add(&self, ifindex: u32, net: ipnet::IpNet) -> Result<()> {
316 tracing::trace!("Adding route {} to {}", net, ifindex);
317 let request = self
318 .rt_handle
319 .route()
320 .add()
321 .scope(RouteScope::Link)
322 .output_interface(ifindex)
323 .replace();
324
325 match net.addr() {
326 std::net::IpAddr::V4(ip) => {
327 request
328 .v4()
329 .destination_prefix(ip, net.prefix_len())
330 .execute()
331 .await
332 }
333 std::net::IpAddr::V6(ip) => {
334 request
335 .v6()
336 .destination_prefix(ip, net.prefix_len())
337 .execute()
338 .await
339 }
340 }?;
341
342 Ok(())
343 }
344}
345
346pub async fn create_device(
347 device_name: impl AsRef<str>,
348 device_descriptor: DeviceDescriptor,
349) -> Result<()> {
350 tracing::info!("creating device {}", device_name.as_ref());
351 tracing::debug!("device descriptor: {:#?}", device_descriptor);
352 let mut wireguard = WireGuard::new().await?;
353 wireguard
354 .create_device(device_name.as_ref(), device_descriptor)
355 .await
356}
357
358pub async fn reload_device(
359 device_name: impl AsRef<str>,
360 device_descriptor: DeviceDescriptor,
361) -> Result<()> {
362 tracing::info!("reloading device {}", device_name.as_ref());
363 tracing::debug!("device descriptor: {:#?}", device_descriptor);
364 let mut wireguard = WireGuard::new().await?;
365 wireguard
366 .reload_device(device_name.as_ref(), device_descriptor)
367 .await
368}
369
370pub async fn device_exists(name: impl AsRef<str>) -> Result<bool> {
371 tracing::info!("checking if device {} exists", name.as_ref());
372 let mut wireguard = WireGuard::new().await?;
373 wireguard
374 .view_device_if_exists(name.as_ref())
375 .await
376 .map(|x| x.is_some())
377}
378
379pub async fn remove_device(name: impl AsRef<str>) -> Result<()> {
380 tracing::info!("removing device {}", name.as_ref());
381 let wireguard = WireGuard::new().await?;
382 wireguard.remove_device(name.as_ref()).await
383}
384
385pub async fn view_device(name: impl AsRef<str>) -> Result<DeviceView> {
386 tracing::info!("viewing device {}", name.as_ref());
387 let mut wireguard = WireGuard::new().await?;
388 wireguard.view_device(name.as_ref()).await
389}
390
391pub async fn view_device_if_exists(name: impl AsRef<str>) -> Result<Option<DeviceView>> {
392 tracing::info!("viewing device {}", name.as_ref());
393 let mut wireguard = WireGuard::new().await?;
394 wireguard.view_device_if_exists(name.as_ref()).await
395}
396
397pub async fn list_device_names() -> Result<Vec<String>> {
398 tracing::info!("listing device names");
399 let wireguard = WireGuard::new().await?;
400 wireguard.list_device_names().await
401}