aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--embassy-net/Cargo.toml2
-rw-r--r--embassy-net/src/dns.rs120
-rw-r--r--embassy-net/src/lib.rs83
-rw-r--r--examples/std/src/bin/net_dns.rs3
4 files changed, 131 insertions, 77 deletions
diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml
index 6eea8c307..53778899b 100644
--- a/embassy-net/Cargo.toml
+++ b/embassy-net/Cargo.toml
@@ -52,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false }
52stable_deref_trait = { version = "1.2.0", default-features = false } 52stable_deref_trait = { version = "1.2.0", default-features = false }
53futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] } 53futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] }
54atomic-pool = "1.0" 54atomic-pool = "1.0"
55embedded-nal-async = { version = "0.3.0", optional = true } 55embedded-nal-async = { version = "0.4.0", optional = true }
56atomic-polyfill = { version = "1.0" } 56atomic-polyfill = { version = "1.0" }
diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs
index e98247bfd..1815d258f 100644
--- a/embassy-net/src/dns.rs
+++ b/embassy-net/src/dns.rs
@@ -1,19 +1,10 @@
1//! DNS socket with async support. 1//! DNS socket with async support.
2use core::cell::RefCell;
3use core::future::poll_fn;
4use core::mem;
5use core::task::Poll;
6
7use embassy_hal_common::drop::OnDrop;
8use embassy_net_driver::Driver;
9use heapless::Vec; 2use heapless::Vec;
10use managed::ManagedSlice; 3pub use smoltcp::socket::dns::{DnsQuery, Socket, MAX_ADDRESS_COUNT};
11use smoltcp::iface::{Interface, SocketHandle}; 4pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
12pub use smoltcp::socket::dns::DnsQuery;
13use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError, MAX_ADDRESS_COUNT};
14pub use smoltcp::wire::{DnsQueryType, IpAddress}; 5pub use smoltcp::wire::{DnsQueryType, IpAddress};
15 6
16use crate::{SocketStack, Stack}; 7use crate::{Driver, Stack};
17 8
18/// Errors returned by DnsSocket. 9/// Errors returned by DnsSocket.
19#[derive(Debug, PartialEq, Eq, Clone, Copy)] 10#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@@ -46,81 +37,64 @@ impl From<StartQueryError> for Error {
46} 37}
47 38
48/// Async socket for making DNS queries. 39/// Async socket for making DNS queries.
49pub struct DnsSocket<'a> { 40pub struct DnsSocket<'a, D>
50 stack: &'a RefCell<SocketStack>, 41where
51 handle: SocketHandle, 42 D: Driver + 'static,
43{
44 stack: &'a Stack<D>,
52} 45}
53 46
54impl<'a> DnsSocket<'a> { 47impl<'a, D> DnsSocket<'a, D>
48where
49 D: Driver + 'static,
50{
55 /// Create a new DNS socket using the provided stack and query storage. 51 /// Create a new DNS socket using the provided stack and query storage.
56 /// 52 ///
57 /// DNS servers are derived from the stack configuration. 53 /// DNS servers are derived from the stack configuration.
58 /// 54 ///
59 /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated. 55 /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
60 pub fn new<D, Q>(stack: &'a Stack<D>, queries: Q) -> Self 56 pub fn new(stack: &'a Stack<D>) -> Self {
61 where 57 Self { stack }
62 D: Driver + 'static,
63 Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
64 {
65 let servers = stack
66 .config()
67 .map(|c| {
68 let v: Vec<IpAddress, 3> = c.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
69 v
70 })
71 .unwrap_or(Vec::new());
72 let s = &mut *stack.socket.borrow_mut();
73 let queries: ManagedSlice<'static, Option<DnsQuery>> = unsafe { mem::transmute(queries.into()) };
74
75 let handle = s.sockets.add(dns::Socket::new(&servers[..], queries));
76 Self {
77 stack: &stack.socket,
78 handle,
79 }
80 }
81
82 fn with_mut<R>(&mut self, f: impl FnOnce(&mut dns::Socket, &mut Interface) -> R) -> R {
83 let s = &mut *self.stack.borrow_mut();
84 let socket = s.sockets.get_mut::<dns::Socket>(self.handle);
85 let res = f(socket, &mut s.iface);
86 s.waker.wake();
87 res
88 } 58 }
89 59
90 /// Make a query for a given name and return the corresponding IP addresses. 60 /// Make a query for a given name and return the corresponding IP addresses.
91 pub async fn query(&mut self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> { 61 pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> {
92 let query = match { self.with_mut(|s, i| s.start_query(i.context(), name, qtype)) } { 62 self.stack.dns_query(name, qtype).await
93 Ok(handle) => handle, 63 }
94 Err(e) => return Err(e.into()), 64}
95 };
96 65
97 let handle = self.handle; 66#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
98 let drop = OnDrop::new(|| { 67impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D>
99 let s = &mut *self.stack.borrow_mut(); 68where
100 let socket = s.sockets.get_mut::<dns::Socket>(handle); 69 D: Driver + 'static,
101 socket.cancel_query(query); 70{
102 s.waker.wake(); 71 type Error = Error;
103 });
104 72
105 let res = poll_fn(|cx| { 73 async fn get_host_by_name(
106 self.with_mut(|s, _| match s.get_query_result(query) { 74 &self,
107 Ok(addrs) => Poll::Ready(Ok(addrs)), 75 host: &str,
108 Err(GetQueryResultError::Pending) => { 76 addr_type: embedded_nal_async::AddrType,
109 s.register_query_waker(query, cx.waker()); 77 ) -> Result<embedded_nal_async::IpAddr, Self::Error> {
110 Poll::Pending 78 use embedded_nal_async::{AddrType, IpAddr};
111 } 79 let qtype = match addr_type {
112 Err(e) => Poll::Ready(Err(e.into())), 80 AddrType::IPv6 => DnsQueryType::Aaaa,
81 _ => DnsQueryType::A,
82 };
83 let addrs = self.query(host, qtype).await?;
84 if let Some(first) = addrs.get(0) {
85 Ok(match first {
86 IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()),
87 IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()),
113 }) 88 })
114 }) 89 } else {
115 .await; 90 Err(Error::Failed)
116 91 }
117 drop.defuse();
118 res
119 } 92 }
120}
121 93
122impl<'a> Drop for DnsSocket<'a> { 94 async fn get_host_by_address(
123 fn drop(&mut self) { 95 &self,
124 self.stack.borrow_mut().sockets.remove(self.handle); 96 _addr: embedded_nal_async::IpAddr,
97 ) -> Result<heapless::String<256>, Self::Error> {
98 todo!()
125 } 99 }
126} 100}
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index ae447d063..b63aa83df 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -48,15 +48,22 @@ use crate::device::DriverAdapter;
48 48
49const LOCAL_PORT_MIN: u16 = 1025; 49const LOCAL_PORT_MIN: u16 = 1025;
50const LOCAL_PORT_MAX: u16 = 65535; 50const LOCAL_PORT_MAX: u16 = 65535;
51const MAX_QUERIES: usize = 2;
51 52
52pub struct StackResources<const SOCK: usize> { 53pub struct StackResources<const SOCK: usize> {
53 sockets: [SocketStorage<'static>; SOCK], 54 sockets: [SocketStorage<'static>; SOCK],
55 #[cfg(feature = "dns")]
56 queries: Option<[Option<dns::DnsQuery>; MAX_QUERIES]>,
54} 57}
55 58
56impl<const SOCK: usize> StackResources<SOCK> { 59impl<const SOCK: usize> StackResources<SOCK> {
57 pub fn new() -> Self { 60 pub fn new() -> Self {
61 #[cfg(feature = "dns")]
62 const INIT: Option<dns::DnsQuery> = None;
58 Self { 63 Self {
59 sockets: [SocketStorage::EMPTY; SOCK], 64 sockets: [SocketStorage::EMPTY; SOCK],
65 #[cfg(feature = "dns")]
66 queries: Some([INIT; MAX_QUERIES]),
60 } 67 }
61 } 68 }
62} 69}
@@ -109,6 +116,8 @@ struct Inner<D: Driver> {
109 config: Option<StaticConfig>, 116 config: Option<StaticConfig>,
110 #[cfg(feature = "dhcpv4")] 117 #[cfg(feature = "dhcpv4")]
111 dhcp_socket: Option<SocketHandle>, 118 dhcp_socket: Option<SocketHandle>,
119 #[cfg(feature = "dns")]
120 dns_socket: Option<SocketHandle>,
112} 121}
113 122
114pub(crate) struct SocketStack { 123pub(crate) struct SocketStack {
@@ -153,6 +162,8 @@ impl<D: Driver + 'static> Stack<D> {
153 config: None, 162 config: None,
154 #[cfg(feature = "dhcpv4")] 163 #[cfg(feature = "dhcpv4")]
155 dhcp_socket: None, 164 dhcp_socket: None,
165 #[cfg(feature = "dns")]
166 dns_socket: None,
156 }; 167 };
157 let mut socket = SocketStack { 168 let mut socket = SocketStack {
158 sockets, 169 sockets,
@@ -161,8 +172,17 @@ impl<D: Driver + 'static> Stack<D> {
161 next_local_port, 172 next_local_port,
162 }; 173 };
163 174
175 #[cfg(feature = "dns")]
176 {
177 if let Some(queries) = resources.queries.take() {
178 inner.dns_socket = Some(socket.sockets.add(dns::Socket::new(&[], queries)));
179 }
180 }
181
164 match config { 182 match config {
165 Config::Static(config) => inner.apply_config(&mut socket, config), 183 Config::Static(config) => {
184 inner.apply_config(&mut socket, config);
185 }
166 #[cfg(feature = "dhcpv4")] 186 #[cfg(feature = "dhcpv4")]
167 Config::Dhcp(config) => { 187 Config::Dhcp(config) => {
168 let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); 188 let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
@@ -210,6 +230,59 @@ impl<D: Driver + 'static> Stack<D> {
210 .await; 230 .await;
211 unreachable!() 231 unreachable!()
212 } 232 }
233
234 #[cfg(feature = "dns")]
235 async fn dns_query(
236 &self,
237 name: &str,
238 qtype: dns::DnsQueryType,
239 ) -> Result<Vec<IpAddress, { dns::MAX_ADDRESS_COUNT }>, dns::Error> {
240 let query = self.with_mut(|s, i| {
241 if let Some(dns_handle) = i.dns_socket {
242 let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
243 match socket.start_query(s.iface.context(), name, qtype) {
244 Ok(handle) => Ok(handle),
245 Err(e) => Err(e.into()),
246 }
247 } else {
248 Err(dns::Error::Failed)
249 }
250 })?;
251
252 use embassy_hal_common::drop::OnDrop;
253 let drop = OnDrop::new(|| {
254 self.with_mut(|s, i| {
255 if let Some(dns_handle) = i.dns_socket {
256 let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
257 socket.cancel_query(query);
258 s.waker.wake();
259 }
260 })
261 });
262
263 let res = poll_fn(|cx| {
264 self.with_mut(|s, i| {
265 if let Some(dns_handle) = i.dns_socket {
266 let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
267 match socket.get_query_result(query) {
268 Ok(addrs) => Poll::Ready(Ok(addrs)),
269 Err(dns::GetQueryResultError::Pending) => {
270 socket.register_query_waker(query, cx.waker());
271 Poll::Pending
272 }
273 Err(e) => Poll::Ready(Err(e.into())),
274 }
275 } else {
276 Poll::Ready(Err(dns::Error::Failed))
277 }
278 })
279 })
280 .await;
281
282 drop.defuse();
283
284 res
285 }
213} 286}
214 287
215impl SocketStack { 288impl SocketStack {
@@ -251,6 +324,13 @@ impl<D: Driver + 'static> Inner<D> {
251 debug!(" DNS server {}: {}", i, s); 324 debug!(" DNS server {}: {}", i, s);
252 } 325 }
253 326
327 #[cfg(feature = "dns")]
328 if let Some(dns_socket) = self.dns_socket {
329 let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(dns_socket);
330 let servers: Vec<IpAddress, 3> = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
331 socket.update_servers(&servers[..]);
332 }
333
254 self.config = Some(config) 334 self.config = Some(config)
255 } 335 }
256 336
@@ -326,6 +406,7 @@ impl<D: Driver + 'static> Inner<D> {
326 //if old_link_up || self.link_up { 406 //if old_link_up || self.link_up {
327 // self.poll_configurator(timestamp) 407 // self.poll_configurator(timestamp)
328 //} 408 //}
409 //
329 410
330 if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { 411 if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
331 let t = Timer::at(instant_from_smoltcp(poll_at)); 412 let t = Timer::at(instant_from_smoltcp(poll_at));
diff --git a/examples/std/src/bin/net_dns.rs b/examples/std/src/bin/net_dns.rs
index 6203f8370..e787cb823 100644
--- a/examples/std/src/bin/net_dns.rs
+++ b/examples/std/src/bin/net_dns.rs
@@ -71,8 +71,7 @@ async fn main_task(spawner: Spawner) {
71 spawner.spawn(net_task(stack)).unwrap(); 71 spawner.spawn(net_task(stack)).unwrap();
72 72
73 // Then we can use it! 73 // Then we can use it!
74 74 let socket = DnsSocket::new(stack);
75 let mut socket = DnsSocket::new(stack, vec![]);
76 75
77 let host = "example.com"; 76 let host = "example.com";
78 info!("querying host {:?}...", host); 77 info!("querying host {:?}...", host);