aboutsummaryrefslogtreecommitdiff
path: root/embassy-net/src
diff options
context:
space:
mode:
authorQuentin Smith <[email protected]>2023-07-17 21:31:43 -0400
committerQuentin Smith <[email protected]>2023-07-17 21:31:43 -0400
commit6f02403184eb7fb7990fb88fc9df9c4328a690a3 (patch)
tree748f510e190bb2724750507a6e69ed1a8e08cb20 /embassy-net/src
parentd896f80405aa8963877049ed999e4aba25d6e2bb (diff)
parent6b5df4523aa1c4902f02e803450ae4b418e0e3ca (diff)
Merge remote-tracking branch 'origin/main' into nrf-pdm
Diffstat (limited to 'embassy-net/src')
-rw-r--r--embassy-net/src/device.rs189
-rw-r--r--embassy-net/src/dns.rs107
-rw-r--r--embassy-net/src/lib.rs727
-rw-r--r--embassy-net/src/packet_pool.rs107
-rw-r--r--embassy-net/src/stack.rs316
-rw-r--r--embassy-net/src/tcp.rs414
-rw-r--r--embassy-net/src/time.rs20
-rw-r--r--embassy-net/src/udp.rs138
8 files changed, 1269 insertions, 749 deletions
diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs
index c183bd58a..4513c86d3 100644
--- a/embassy-net/src/device.rs
+++ b/embassy-net/src/device.rs
@@ -1,129 +1,106 @@
1use core::task::Waker; 1use core::task::Context;
2 2
3use smoltcp::phy::{Device as SmolDevice, DeviceCapabilities}; 3use embassy_net_driver::{Capabilities, Checksum, Driver, Medium, RxToken, TxToken};
4use smoltcp::time::Instant as SmolInstant; 4use smoltcp::phy;
5 5use smoltcp::time::Instant;
6use crate::packet_pool::PacketBoxExt; 6
7use crate::{Packet, PacketBox, PacketBuf}; 7pub(crate) struct DriverAdapter<'d, 'c, T>
8 8where
9#[derive(PartialEq, Eq, Clone, Copy)] 9 T: Driver,
10pub enum LinkState { 10{
11 Down, 11 // must be Some when actually using this to rx/tx
12 Up, 12 pub cx: Option<&'d mut Context<'c>>,
13} 13 pub inner: &'d mut T,
14
15// 'static required due to the "fake GAT" in smoltcp::phy::Device.
16// https://github.com/smoltcp-rs/smoltcp/pull/572
17pub trait Device {
18 fn is_transmit_ready(&mut self) -> bool;
19 fn transmit(&mut self, pkt: PacketBuf);
20 fn receive(&mut self) -> Option<PacketBuf>;
21
22 fn register_waker(&mut self, waker: &Waker);
23 fn capabilities(&self) -> DeviceCapabilities;
24 fn link_state(&mut self) -> LinkState;
25 fn ethernet_address(&self) -> [u8; 6];
26} 14}
27 15
28impl<T: ?Sized + Device> Device for &'static mut T { 16impl<'d, 'c, T> phy::Device for DriverAdapter<'d, 'c, T>
29 fn is_transmit_ready(&mut self) -> bool { 17where
30 T::is_transmit_ready(self) 18 T: Driver,
31 } 19{
32 fn transmit(&mut self, pkt: PacketBuf) { 20 type RxToken<'a> = RxTokenAdapter<T::RxToken<'a>> where Self: 'a;
33 T::transmit(self, pkt) 21 type TxToken<'a> = TxTokenAdapter<T::TxToken<'a>> where Self: 'a;
34 } 22
35 fn receive(&mut self) -> Option<PacketBuf> { 23 fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
36 T::receive(self) 24 self.inner
25 .receive(self.cx.as_deref_mut().unwrap())
26 .map(|(rx, tx)| (RxTokenAdapter(rx), TxTokenAdapter(tx)))
37 } 27 }
38 fn register_waker(&mut self, waker: &Waker) {
39 T::register_waker(self, waker)
40 }
41 fn capabilities(&self) -> DeviceCapabilities {
42 T::capabilities(self)
43 }
44 fn link_state(&mut self) -> LinkState {
45 T::link_state(self)
46 }
47 fn ethernet_address(&self) -> [u8; 6] {
48 T::ethernet_address(self)
49 }
50}
51
52pub struct DeviceAdapter<D: Device> {
53 pub device: D,
54 caps: DeviceCapabilities,
55}
56 28
57impl<D: Device> DeviceAdapter<D> { 29 /// Construct a transmit token.
58 pub(crate) fn new(device: D) -> Self { 30 fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
59 Self { 31 self.inner.transmit(self.cx.as_deref_mut().unwrap()).map(TxTokenAdapter)
60 caps: device.capabilities(),
61 device,
62 }
63 } 32 }
64}
65 33
66impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> { 34 /// Get a description of device capabilities.
67 type RxToken = RxToken; 35 fn capabilities(&self) -> phy::DeviceCapabilities {
68 type TxToken = TxToken<'a, D>; 36 fn convert(c: Checksum) -> phy::Checksum {
69 37 match c {
70 fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { 38 Checksum::Both => phy::Checksum::Both,
71 let tx_pkt = PacketBox::new(Packet::new())?; 39 Checksum::Tx => phy::Checksum::Tx,
72 let rx_pkt = self.device.receive()?; 40 Checksum::Rx => phy::Checksum::Rx,
73 let rx_token = RxToken { pkt: rx_pkt }; 41 Checksum::None => phy::Checksum::None,
74 let tx_token = TxToken { 42 }
75 device: &mut self.device, 43 }
76 pkt: tx_pkt, 44 let caps: Capabilities = self.inner.capabilities();
45 let mut smolcaps = phy::DeviceCapabilities::default();
46
47 smolcaps.max_transmission_unit = caps.max_transmission_unit;
48 smolcaps.max_burst_size = caps.max_burst_size;
49 smolcaps.medium = match caps.medium {
50 #[cfg(feature = "medium-ethernet")]
51 Medium::Ethernet => phy::Medium::Ethernet,
52 #[cfg(feature = "medium-ip")]
53 Medium::Ip => phy::Medium::Ip,
54 #[allow(unreachable_patterns)]
55 _ => panic!(
56 "Unsupported medium {:?}. Make sure to enable it in embassy-net's Cargo features.",
57 caps.medium
58 ),
77 }; 59 };
78 60 smolcaps.checksum.ipv4 = convert(caps.checksum.ipv4);
79 Some((rx_token, tx_token)) 61 smolcaps.checksum.tcp = convert(caps.checksum.tcp);
80 } 62 smolcaps.checksum.udp = convert(caps.checksum.udp);
81 63 #[cfg(feature = "proto-ipv4")]
82 /// Construct a transmit token. 64 {
83 fn transmit(&'a mut self) -> Option<Self::TxToken> { 65 smolcaps.checksum.icmpv4 = convert(caps.checksum.icmpv4);
84 if !self.device.is_transmit_ready() { 66 }
85 return None; 67 #[cfg(feature = "proto-ipv6")]
68 {
69 smolcaps.checksum.icmpv6 = convert(caps.checksum.icmpv6);
86 } 70 }
87 71
88 let tx_pkt = PacketBox::new(Packet::new())?; 72 smolcaps
89 Some(TxToken {
90 device: &mut self.device,
91 pkt: tx_pkt,
92 })
93 }
94
95 /// Get a description of device capabilities.
96 fn capabilities(&self) -> DeviceCapabilities {
97 self.caps.clone()
98 } 73 }
99} 74}
100 75
101pub struct RxToken { 76pub(crate) struct RxTokenAdapter<T>(T)
102 pkt: PacketBuf, 77where
103} 78 T: RxToken;
104 79
105impl smoltcp::phy::RxToken for RxToken { 80impl<T> phy::RxToken for RxTokenAdapter<T>
106 fn consume<R, F>(mut self, _timestamp: SmolInstant, f: F) -> smoltcp::Result<R> 81where
82 T: RxToken,
83{
84 fn consume<R, F>(self, f: F) -> R
107 where 85 where
108 F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, 86 F: FnOnce(&mut [u8]) -> R,
109 { 87 {
110 f(&mut self.pkt) 88 self.0.consume(|buf| f(buf))
111 } 89 }
112} 90}
113 91
114pub struct TxToken<'a, D: Device> { 92pub(crate) struct TxTokenAdapter<T>(T)
115 device: &'a mut D, 93where
116 pkt: PacketBox, 94 T: TxToken;
117}
118 95
119impl<'a, D: Device> smoltcp::phy::TxToken for TxToken<'a, D> { 96impl<T> phy::TxToken for TxTokenAdapter<T>
120 fn consume<R, F>(self, _timestamp: SmolInstant, len: usize, f: F) -> smoltcp::Result<R> 97where
98 T: TxToken,
99{
100 fn consume<R, F>(self, len: usize, f: F) -> R
121 where 101 where
122 F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, 102 F: FnOnce(&mut [u8]) -> R,
123 { 103 {
124 let mut buf = self.pkt.slice(0..len); 104 self.0.consume(len, |buf| f(buf))
125 let r = f(&mut buf)?;
126 self.device.transmit(buf);
127 Ok(r)
128 } 105 }
129} 106}
diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs
new file mode 100644
index 000000000..94f75f108
--- /dev/null
+++ b/embassy-net/src/dns.rs
@@ -0,0 +1,107 @@
1//! DNS client compatible with the `embedded-nal-async` traits.
2//!
3//! This exists only for compatibility with crates that use `embedded-nal-async`.
4//! Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're
5//! not using `embedded-nal-async`.
6
7use heapless::Vec;
8pub use smoltcp::socket::dns::{DnsQuery, Socket};
9pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
10pub use smoltcp::wire::{DnsQueryType, IpAddress};
11
12use crate::{Driver, Stack};
13
14/// Errors returned by DnsSocket.
15#[derive(Debug, PartialEq, Eq, Clone, Copy)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub enum Error {
18 /// Invalid name
19 InvalidName,
20 /// Name too long
21 NameTooLong,
22 /// Name lookup failed
23 Failed,
24}
25
26impl From<GetQueryResultError> for Error {
27 fn from(_: GetQueryResultError) -> Self {
28 Self::Failed
29 }
30}
31
32impl From<StartQueryError> for Error {
33 fn from(e: StartQueryError) -> Self {
34 match e {
35 StartQueryError::NoFreeSlot => Self::Failed,
36 StartQueryError::InvalidName => Self::InvalidName,
37 StartQueryError::NameTooLong => Self::NameTooLong,
38 }
39 }
40}
41
42/// DNS client compatible with the `embedded-nal-async` traits.
43///
44/// This exists only for compatibility with crates that use `embedded-nal-async`.
45/// Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're
46/// not using `embedded-nal-async`.
47pub struct DnsSocket<'a, D>
48where
49 D: Driver + 'static,
50{
51 stack: &'a Stack<D>,
52}
53
54impl<'a, D> DnsSocket<'a, D>
55where
56 D: Driver + 'static,
57{
58 /// Create a new DNS socket using the provided stack.
59 ///
60 /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
61 pub fn new(stack: &'a Stack<D>) -> Self {
62 Self { stack }
63 }
64
65 /// Make a query for a given name and return the corresponding IP addresses.
66 pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, 1>, Error> {
67 self.stack.dns_query(name, qtype).await
68 }
69}
70
71#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
72impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D>
73where
74 D: Driver + 'static,
75{
76 type Error = Error;
77
78 async fn get_host_by_name(
79 &self,
80 host: &str,
81 addr_type: embedded_nal_async::AddrType,
82 ) -> Result<embedded_nal_async::IpAddr, Self::Error> {
83 use embedded_nal_async::{AddrType, IpAddr};
84 let qtype = match addr_type {
85 AddrType::IPv6 => DnsQueryType::Aaaa,
86 _ => DnsQueryType::A,
87 };
88 let addrs = self.query(host, qtype).await?;
89 if let Some(first) = addrs.get(0) {
90 Ok(match first {
91 #[cfg(feature = "proto-ipv4")]
92 IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()),
93 #[cfg(feature = "proto-ipv6")]
94 IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()),
95 })
96 } else {
97 Err(Error::Failed)
98 }
99 }
100
101 async fn get_host_by_address(
102 &self,
103 _addr: embedded_nal_async::IpAddr,
104 ) -> Result<heapless::String<256>, Self::Error> {
105 todo!()
106 }
107}
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index 83d364715..0d0a986f6 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -1,31 +1,726 @@
1#![cfg_attr(not(feature = "std"), no_std)] 1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(clippy::new_without_default)] 2#![cfg_attr(feature = "nightly", feature(async_fn_in_trait, impl_trait_projections))]
3#![feature(generic_associated_types, type_alias_impl_trait)] 3#![warn(missing_docs)]
4#![doc = include_str!("../README.md")]
4 5
5// This mod MUST go first, so that the others see its macros. 6// This mod MUST go first, so that the others see its macros.
6pub(crate) mod fmt; 7pub(crate) mod fmt;
7 8
8mod device; 9mod device;
9mod packet_pool; 10#[cfg(feature = "dns")]
10mod stack; 11pub mod dns;
11
12pub use device::{Device, LinkState};
13pub use packet_pool::{Packet, PacketBox, PacketBoxExt, PacketBuf, MTU};
14pub use stack::{Config, ConfigStrategy, Stack, StackResources};
15
16#[cfg(feature = "tcp")] 12#[cfg(feature = "tcp")]
17pub mod tcp; 13pub mod tcp;
18 14mod time;
19#[cfg(feature = "udp")] 15#[cfg(feature = "udp")]
20pub mod udp; 16pub mod udp;
21 17
22// smoltcp reexports 18use core::cell::RefCell;
23pub use smoltcp::phy::{DeviceCapabilities, Medium}; 19use core::future::{poll_fn, Future};
24pub use smoltcp::time::{Duration as SmolDuration, Instant as SmolInstant}; 20use core::task::{Context, Poll};
21
22pub use embassy_net_driver as driver;
23use embassy_net_driver::{Driver, LinkState, Medium};
24use embassy_sync::waitqueue::WakerRegistration;
25use embassy_time::{Instant, Timer};
26use futures::pin_mut;
27use heapless::Vec;
28#[cfg(feature = "igmp")]
29pub use smoltcp::iface::MulticastError;
30use smoltcp::iface::{Interface, SocketHandle, SocketSet, SocketStorage};
31#[cfg(feature = "dhcpv4")]
32use smoltcp::socket::dhcpv4::{self, RetryConfig};
33#[cfg(feature = "udp")]
34pub use smoltcp::wire::IpListenEndpoint;
25#[cfg(feature = "medium-ethernet")] 35#[cfg(feature = "medium-ethernet")]
26pub use smoltcp::wire::{EthernetAddress, HardwareAddress}; 36pub use smoltcp::wire::{EthernetAddress, HardwareAddress};
27pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Cidr}; 37pub use smoltcp::wire::{IpAddress, IpCidr, IpEndpoint};
38#[cfg(feature = "proto-ipv4")]
39pub use smoltcp::wire::{Ipv4Address, Ipv4Cidr};
28#[cfg(feature = "proto-ipv6")] 40#[cfg(feature = "proto-ipv6")]
29pub use smoltcp::wire::{Ipv6Address, Ipv6Cidr}; 41pub use smoltcp::wire::{Ipv6Address, Ipv6Cidr};
30#[cfg(feature = "udp")] 42
31pub use smoltcp::{socket::udp::PacketMetadata, wire::IpListenEndpoint}; 43use crate::device::DriverAdapter;
44use crate::time::{instant_from_smoltcp, instant_to_smoltcp};
45
46const LOCAL_PORT_MIN: u16 = 1025;
47const LOCAL_PORT_MAX: u16 = 65535;
48#[cfg(feature = "dns")]
49const MAX_QUERIES: usize = 4;
50
51/// Memory resources needed for a network stack.
52pub struct StackResources<const SOCK: usize> {
53 sockets: [SocketStorage<'static>; SOCK],
54 #[cfg(feature = "dns")]
55 queries: [Option<dns::DnsQuery>; MAX_QUERIES],
56}
57
58impl<const SOCK: usize> StackResources<SOCK> {
59 /// Create a new set of stack resources.
60 pub const fn new() -> Self {
61 #[cfg(feature = "dns")]
62 const INIT: Option<dns::DnsQuery> = None;
63 Self {
64 sockets: [SocketStorage::EMPTY; SOCK],
65 #[cfg(feature = "dns")]
66 queries: [INIT; MAX_QUERIES],
67 }
68 }
69}
70
71/// Static IP address configuration.
72#[cfg(feature = "proto-ipv4")]
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct StaticConfigV4 {
75 /// IP address and subnet mask.
76 pub address: Ipv4Cidr,
77 /// Default gateway.
78 pub gateway: Option<Ipv4Address>,
79 /// DNS servers.
80 pub dns_servers: Vec<Ipv4Address, 3>,
81}
82
83/// Static IPv6 address configuration
84#[cfg(feature = "proto-ipv6")]
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct StaticConfigV6 {
87 /// IP address and subnet mask.
88 pub address: Ipv6Cidr,
89 /// Default gateway.
90 pub gateway: Option<Ipv6Address>,
91 /// DNS servers.
92 pub dns_servers: Vec<Ipv6Address, 3>,
93}
94
95/// DHCP configuration.
96#[cfg(feature = "dhcpv4")]
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct DhcpConfig {
99 /// Maximum lease duration.
100 ///
101 /// If not set, the lease duration specified by the server will be used.
102 /// If set, the lease duration will be capped at this value.
103 pub max_lease_duration: Option<embassy_time::Duration>,
104 /// Retry configuration.
105 pub retry_config: RetryConfig,
106 /// Ignore NAKs from DHCP servers.
107 ///
108 /// This is not compliant with the DHCP RFCs, since theoretically we must stop using the assigned IP when receiving a NAK. This can increase reliability on broken networks with buggy routers or rogue DHCP servers, however.
109 pub ignore_naks: bool,
110 /// Server port. This is almost always 67. Do not change unless you know what you're doing.
111 pub server_port: u16,
112 /// Client port. This is almost always 68. Do not change unless you know what you're doing.
113 pub client_port: u16,
114}
115
116#[cfg(feature = "dhcpv4")]
117impl Default for DhcpConfig {
118 fn default() -> Self {
119 Self {
120 max_lease_duration: Default::default(),
121 retry_config: Default::default(),
122 ignore_naks: Default::default(),
123 server_port: smoltcp::wire::DHCP_SERVER_PORT,
124 client_port: smoltcp::wire::DHCP_CLIENT_PORT,
125 }
126 }
127}
128
129/// Network stack configuration.
130pub struct Config {
131 /// IPv4 configuration
132 #[cfg(feature = "proto-ipv4")]
133 pub ipv4: ConfigV4,
134 /// IPv6 configuration
135 #[cfg(feature = "proto-ipv6")]
136 pub ipv6: ConfigV6,
137}
138
139impl Config {
140 /// IPv4 configuration with static addressing.
141 #[cfg(feature = "proto-ipv4")]
142 pub fn ipv4_static(config: StaticConfigV4) -> Self {
143 Self {
144 ipv4: ConfigV4::Static(config),
145 #[cfg(feature = "proto-ipv6")]
146 ipv6: ConfigV6::None,
147 }
148 }
149
150 /// IPv6 configuration with static addressing.
151 #[cfg(feature = "proto-ipv6")]
152 pub fn ipv6_static(config: StaticConfigV6) -> Self {
153 Self {
154 #[cfg(feature = "proto-ipv4")]
155 ipv4: ConfigV4::None,
156 ipv6: ConfigV6::Static(config),
157 }
158 }
159
160 /// IPv6 configuration with dynamic addressing.
161 ///
162 /// # Example
163 /// ```rust
164 /// let _cfg = Config::dhcpv4(Default::default());
165 /// ```
166 #[cfg(feature = "dhcpv4")]
167 pub fn dhcpv4(config: DhcpConfig) -> Self {
168 Self {
169 ipv4: ConfigV4::Dhcp(config),
170 #[cfg(feature = "proto-ipv6")]
171 ipv6: ConfigV6::None,
172 }
173 }
174}
175
176/// Network stack IPv4 configuration.
177#[cfg(feature = "proto-ipv4")]
178pub enum ConfigV4 {
179 /// Use a static IPv4 address configuration.
180 Static(StaticConfigV4),
181 /// Use DHCP to obtain an IP address configuration.
182 #[cfg(feature = "dhcpv4")]
183 Dhcp(DhcpConfig),
184 /// Do not configure IPv6.
185 None,
186}
187
188/// Network stack IPv6 configuration.
189#[cfg(feature = "proto-ipv6")]
190pub enum ConfigV6 {
191 /// Use a static IPv6 address configuration.
192 Static(StaticConfigV6),
193 /// Do not configure IPv6.
194 None,
195}
196
197/// A network stack.
198///
199/// This is the main entry point for the network stack.
200pub struct Stack<D: Driver> {
201 pub(crate) socket: RefCell<SocketStack>,
202 inner: RefCell<Inner<D>>,
203}
204
205struct Inner<D: Driver> {
206 device: D,
207 link_up: bool,
208 #[cfg(feature = "proto-ipv4")]
209 static_v4: Option<StaticConfigV4>,
210 #[cfg(feature = "proto-ipv6")]
211 static_v6: Option<StaticConfigV6>,
212 #[cfg(feature = "dhcpv4")]
213 dhcp_socket: Option<SocketHandle>,
214 #[cfg(feature = "dns")]
215 dns_socket: SocketHandle,
216 #[cfg(feature = "dns")]
217 dns_waker: WakerRegistration,
218}
219
220pub(crate) struct SocketStack {
221 pub(crate) sockets: SocketSet<'static>,
222 pub(crate) iface: Interface,
223 pub(crate) waker: WakerRegistration,
224 next_local_port: u16,
225}
226
227impl<D: Driver + 'static> Stack<D> {
228 /// Create a new network stack.
229 pub fn new<const SOCK: usize>(
230 mut device: D,
231 config: Config,
232 resources: &'static mut StackResources<SOCK>,
233 random_seed: u64,
234 ) -> Self {
235 #[cfg(feature = "medium-ethernet")]
236 let medium = device.capabilities().medium;
237
238 let hardware_addr = match medium {
239 #[cfg(feature = "medium-ethernet")]
240 Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress(device.ethernet_address())),
241 #[cfg(feature = "medium-ip")]
242 Medium::Ip => HardwareAddress::Ip,
243 #[allow(unreachable_patterns)]
244 _ => panic!(
245 "Unsupported medium {:?}. Make sure to enable it in embassy-net's Cargo features.",
246 medium
247 ),
248 };
249 let mut iface_cfg = smoltcp::iface::Config::new(hardware_addr);
250 iface_cfg.random_seed = random_seed;
251
252 let iface = Interface::new(
253 iface_cfg,
254 &mut DriverAdapter {
255 inner: &mut device,
256 cx: None,
257 },
258 instant_to_smoltcp(Instant::now()),
259 );
260
261 let sockets = SocketSet::new(&mut resources.sockets[..]);
262
263 let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
264
265 let mut socket = SocketStack {
266 sockets,
267 iface,
268 waker: WakerRegistration::new(),
269 next_local_port,
270 };
271
272 let mut inner = Inner {
273 device,
274 link_up: false,
275 #[cfg(feature = "proto-ipv4")]
276 static_v4: None,
277 #[cfg(feature = "proto-ipv6")]
278 static_v6: None,
279 #[cfg(feature = "dhcpv4")]
280 dhcp_socket: None,
281 #[cfg(feature = "dns")]
282 dns_socket: socket.sockets.add(dns::Socket::new(
283 &[],
284 managed::ManagedSlice::Borrowed(&mut resources.queries),
285 )),
286 #[cfg(feature = "dns")]
287 dns_waker: WakerRegistration::new(),
288 };
289
290 #[cfg(feature = "proto-ipv4")]
291 match config.ipv4 {
292 ConfigV4::Static(config) => {
293 inner.apply_config_v4(&mut socket, config);
294 }
295 #[cfg(feature = "dhcpv4")]
296 ConfigV4::Dhcp(config) => {
297 let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
298 inner.apply_dhcp_config(&mut dhcp_socket, config);
299 let handle = socket.sockets.add(dhcp_socket);
300 inner.dhcp_socket = Some(handle);
301 }
302 ConfigV4::None => {}
303 }
304 #[cfg(feature = "proto-ipv6")]
305 match config.ipv6 {
306 ConfigV6::Static(config) => {
307 inner.apply_config_v6(&mut socket, config);
308 }
309 ConfigV6::None => {}
310 }
311
312 Self {
313 socket: RefCell::new(socket),
314 inner: RefCell::new(inner),
315 }
316 }
317
318 fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R {
319 f(&*self.socket.borrow(), &*self.inner.borrow())
320 }
321
322 fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R {
323 f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut())
324 }
325
326 /// Get the MAC address of the network interface.
327 pub fn ethernet_address(&self) -> [u8; 6] {
328 self.with(|_s, i| i.device.ethernet_address())
329 }
330
331 /// Get whether the link is up.
332 pub fn is_link_up(&self) -> bool {
333 self.with(|_s, i| i.link_up)
334 }
335
336 /// Get whether the network stack has a valid IP configuration.
337 /// This is true if the network stack has a static IP configuration or if DHCP has completed
338 pub fn is_config_up(&self) -> bool {
339 let v4_up;
340 let v6_up;
341
342 #[cfg(feature = "proto-ipv4")]
343 {
344 v4_up = self.config_v4().is_some();
345 }
346 #[cfg(not(feature = "proto-ipv4"))]
347 {
348 v4_up = false;
349 }
350
351 #[cfg(feature = "proto-ipv6")]
352 {
353 v6_up = self.config_v6().is_some();
354 }
355 #[cfg(not(feature = "proto-ipv6"))]
356 {
357 v6_up = false;
358 }
359
360 v4_up || v6_up
361 }
362
363 /// Get the current IPv4 configuration.
364 #[cfg(feature = "proto-ipv4")]
365 pub fn config_v4(&self) -> Option<StaticConfigV4> {
366 self.with(|_s, i| i.static_v4.clone())
367 }
368
369 /// Get the current IPv6 configuration.
370 #[cfg(feature = "proto-ipv6")]
371 pub fn config_v6(&self) -> Option<StaticConfigV6> {
372 self.with(|_s, i| i.static_v6.clone())
373 }
374
375 /// Run the network stack.
376 ///
377 /// You must call this in a background task, to process network events.
378 pub async fn run(&self) -> ! {
379 poll_fn(|cx| {
380 self.with_mut(|s, i| i.poll(cx, s));
381 Poll::<()>::Pending
382 })
383 .await;
384 unreachable!()
385 }
386
387 /// Make a query for a given name and return the corresponding IP addresses.
388 #[cfg(feature = "dns")]
389 pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> {
390 // For A and AAAA queries we try detect whether `name` is just an IP address
391 match qtype {
392 #[cfg(feature = "proto-ipv4")]
393 dns::DnsQueryType::A => {
394 if let Ok(ip) = name.parse().map(IpAddress::Ipv4) {
395 return Ok([ip].into_iter().collect());
396 }
397 }
398 #[cfg(feature = "proto-ipv6")]
399 dns::DnsQueryType::Aaaa => {
400 if let Ok(ip) = name.parse().map(IpAddress::Ipv6) {
401 return Ok([ip].into_iter().collect());
402 }
403 }
404 _ => {}
405 }
406
407 let query = poll_fn(|cx| {
408 self.with_mut(|s, i| {
409 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
410 match socket.start_query(s.iface.context(), name, qtype) {
411 Ok(handle) => Poll::Ready(Ok(handle)),
412 Err(dns::StartQueryError::NoFreeSlot) => {
413 i.dns_waker.register(cx.waker());
414 Poll::Pending
415 }
416 Err(e) => Poll::Ready(Err(e)),
417 }
418 })
419 })
420 .await?;
421
422 #[must_use = "to delay the drop handler invocation to the end of the scope"]
423 struct OnDrop<F: FnOnce()> {
424 f: core::mem::MaybeUninit<F>,
425 }
426
427 impl<F: FnOnce()> OnDrop<F> {
428 fn new(f: F) -> Self {
429 Self {
430 f: core::mem::MaybeUninit::new(f),
431 }
432 }
433
434 fn defuse(self) {
435 core::mem::forget(self)
436 }
437 }
438
439 impl<F: FnOnce()> Drop for OnDrop<F> {
440 fn drop(&mut self) {
441 unsafe { self.f.as_ptr().read()() }
442 }
443 }
444
445 let drop = OnDrop::new(|| {
446 self.with_mut(|s, i| {
447 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
448 socket.cancel_query(query);
449 s.waker.wake();
450 i.dns_waker.wake();
451 })
452 });
453
454 let res = poll_fn(|cx| {
455 self.with_mut(|s, i| {
456 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
457 match socket.get_query_result(query) {
458 Ok(addrs) => {
459 i.dns_waker.wake();
460 Poll::Ready(Ok(addrs))
461 }
462 Err(dns::GetQueryResultError::Pending) => {
463 socket.register_query_waker(query, cx.waker());
464 Poll::Pending
465 }
466 Err(e) => {
467 i.dns_waker.wake();
468 Poll::Ready(Err(e.into()))
469 }
470 }
471 })
472 })
473 .await;
474
475 drop.defuse();
476
477 res
478 }
479}
480
481#[cfg(feature = "igmp")]
482impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
483 /// Join a multicast group.
484 pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, MulticastError>
485 where
486 T: Into<IpAddress>,
487 {
488 let addr = addr.into();
489
490 self.with_mut(|s, i| {
491 s.iface
492 .join_multicast_group(&mut i.device, addr, instant_to_smoltcp(Instant::now()))
493 })
494 }
495
496 /// Leave a multicast group.
497 pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, MulticastError>
498 where
499 T: Into<IpAddress>,
500 {
501 let addr = addr.into();
502
503 self.with_mut(|s, i| {
504 s.iface
505 .leave_multicast_group(&mut i.device, addr, instant_to_smoltcp(Instant::now()))
506 })
507 }
508
509 /// Get whether the network stack has joined the given multicast group.
510 pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
511 self.socket.borrow().iface.has_multicast_group(addr)
512 }
513}
514
515impl SocketStack {
516 #[allow(clippy::absurd_extreme_comparisons, dead_code)]
517 pub fn get_local_port(&mut self) -> u16 {
518 let res = self.next_local_port;
519 self.next_local_port = if res >= LOCAL_PORT_MAX { LOCAL_PORT_MIN } else { res + 1 };
520 res
521 }
522}
523
524impl<D: Driver + 'static> Inner<D> {
525 #[cfg(feature = "proto-ipv4")]
526 fn apply_config_v4(&mut self, s: &mut SocketStack, config: StaticConfigV4) {
527 #[cfg(feature = "medium-ethernet")]
528 let medium = self.device.capabilities().medium;
529
530 debug!("Acquired IP configuration:");
531
532 debug!(" IP address: {}", config.address);
533 s.iface.update_ip_addrs(|addrs| {
534 if addrs.is_empty() {
535 addrs.push(IpCidr::Ipv4(config.address)).unwrap();
536 } else {
537 addrs[0] = IpCidr::Ipv4(config.address);
538 }
539 });
540
541 #[cfg(feature = "medium-ethernet")]
542 if medium == Medium::Ethernet {
543 if let Some(gateway) = config.gateway {
544 debug!(" Default gateway: {}", gateway);
545 s.iface.routes_mut().add_default_ipv4_route(gateway).unwrap();
546 } else {
547 debug!(" Default gateway: None");
548 s.iface.routes_mut().remove_default_ipv4_route();
549 }
550 }
551 for (i, s) in config.dns_servers.iter().enumerate() {
552 debug!(" DNS server {}: {}", i, s);
553 }
554
555 self.static_v4 = Some(config);
556
557 #[cfg(feature = "dns")]
558 {
559 self.update_dns_servers(s)
560 }
561 }
562
563 /// Replaces the current IPv6 static configuration with a newly supplied config.
564 #[cfg(feature = "proto-ipv6")]
565 fn apply_config_v6(&mut self, s: &mut SocketStack, config: StaticConfigV6) {
566 #[cfg(feature = "medium-ethernet")]
567 let medium = self.device.capabilities().medium;
568
569 debug!("Acquired IPv6 configuration:");
570
571 debug!(" IP address: {}", config.address);
572 s.iface.update_ip_addrs(|addrs| {
573 if addrs.is_empty() {
574 addrs.push(IpCidr::Ipv6(config.address)).unwrap();
575 } else {
576 addrs[0] = IpCidr::Ipv6(config.address);
577 }
578 });
579
580 #[cfg(feature = "medium-ethernet")]
581 if Medium::Ethernet == medium {
582 if let Some(gateway) = config.gateway {
583 debug!(" Default gateway: {}", gateway);
584 s.iface.routes_mut().add_default_ipv6_route(gateway).unwrap();
585 } else {
586 debug!(" Default gateway: None");
587 s.iface.routes_mut().remove_default_ipv6_route();
588 }
589 }
590 for (i, s) in config.dns_servers.iter().enumerate() {
591 debug!(" DNS server {}: {}", i, s);
592 }
593
594 self.static_v6 = Some(config);
595
596 #[cfg(feature = "dns")]
597 {
598 self.update_dns_servers(s)
599 }
600 }
601
602 #[cfg(feature = "dns")]
603 fn update_dns_servers(&mut self, s: &mut SocketStack) {
604 let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(self.dns_socket);
605
606 let servers_v4;
607 #[cfg(feature = "proto-ipv4")]
608 {
609 servers_v4 = self
610 .static_v4
611 .iter()
612 .flat_map(|cfg| cfg.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)));
613 };
614 #[cfg(not(feature = "proto-ipv4"))]
615 {
616 servers_v4 = core::iter::empty();
617 }
618
619 let servers_v6;
620 #[cfg(feature = "proto-ipv6")]
621 {
622 servers_v6 = self
623 .static_v6
624 .iter()
625 .flat_map(|cfg| cfg.dns_servers.iter().map(|c| IpAddress::Ipv6(*c)));
626 }
627 #[cfg(not(feature = "proto-ipv6"))]
628 {
629 servers_v6 = core::iter::empty();
630 }
631
632 // Prefer the v6 DNS servers over the v4 servers
633 let servers: Vec<IpAddress, 6> = servers_v6.chain(servers_v4).collect();
634 socket.update_servers(&servers[..]);
635 }
636
637 #[cfg(feature = "dhcpv4")]
638 fn apply_dhcp_config(&self, socket: &mut smoltcp::socket::dhcpv4::Socket, config: DhcpConfig) {
639 socket.set_ignore_naks(config.ignore_naks);
640 socket.set_max_lease_duration(config.max_lease_duration.map(crate::time::duration_to_smoltcp));
641 socket.set_ports(config.server_port, config.client_port);
642 socket.set_retry_config(config.retry_config);
643 }
644
645 #[allow(unused)] // used only with dhcp
646 fn unapply_config(&mut self, s: &mut SocketStack) {
647 #[cfg(feature = "medium-ethernet")]
648 let medium = self.device.capabilities().medium;
649
650 debug!("Lost IP configuration");
651 s.iface.update_ip_addrs(|ip_addrs| ip_addrs.clear());
652 #[cfg(feature = "medium-ethernet")]
653 if medium == Medium::Ethernet {
654 #[cfg(feature = "proto-ipv4")]
655 {
656 s.iface.routes_mut().remove_default_ipv4_route();
657 }
658 }
659 #[cfg(feature = "proto-ipv4")]
660 {
661 self.static_v4 = None
662 }
663 }
664
665 fn poll(&mut self, cx: &mut Context<'_>, s: &mut SocketStack) {
666 s.waker.register(cx.waker());
667
668 #[cfg(feature = "medium-ethernet")]
669 if self.device.capabilities().medium == Medium::Ethernet {
670 s.iface.set_hardware_addr(HardwareAddress::Ethernet(EthernetAddress(
671 self.device.ethernet_address(),
672 )));
673 }
674
675 let timestamp = instant_to_smoltcp(Instant::now());
676 let mut smoldev = DriverAdapter {
677 cx: Some(cx),
678 inner: &mut self.device,
679 };
680 s.iface.poll(timestamp, &mut smoldev, &mut s.sockets);
681
682 // Update link up
683 let old_link_up = self.link_up;
684 self.link_up = self.device.link_state(cx) == LinkState::Up;
685
686 // Print when changed
687 if old_link_up != self.link_up {
688 info!("link_up = {:?}", self.link_up);
689 }
690
691 #[cfg(feature = "dhcpv4")]
692 if let Some(dhcp_handle) = self.dhcp_socket {
693 let socket = s.sockets.get_mut::<dhcpv4::Socket>(dhcp_handle);
694
695 if self.link_up {
696 match socket.poll() {
697 None => {}
698 Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s),
699 Some(dhcpv4::Event::Configured(config)) => {
700 let config = StaticConfigV4 {
701 address: config.address,
702 gateway: config.router,
703 dns_servers: config.dns_servers,
704 };
705 self.apply_config_v4(s, config)
706 }
707 }
708 } else if old_link_up {
709 socket.reset();
710 self.unapply_config(s);
711 }
712 }
713 //if old_link_up || self.link_up {
714 // self.poll_configurator(timestamp)
715 //}
716 //
717
718 if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
719 let t = Timer::at(instant_from_smoltcp(poll_at));
720 pin_mut!(t);
721 if t.poll(cx).is_ready() {
722 cx.waker().wake_by_ref();
723 }
724 }
725 }
726}
diff --git a/embassy-net/src/packet_pool.rs b/embassy-net/src/packet_pool.rs
deleted file mode 100644
index cb8a1316c..000000000
--- a/embassy-net/src/packet_pool.rs
+++ /dev/null
@@ -1,107 +0,0 @@
1use core::ops::{Deref, DerefMut, Range};
2
3use as_slice::{AsMutSlice, AsSlice};
4use atomic_pool::{pool, Box};
5
6pub const MTU: usize = 1516;
7
8#[cfg(feature = "pool-4")]
9pub const PACKET_POOL_SIZE: usize = 4;
10
11#[cfg(feature = "pool-8")]
12pub const PACKET_POOL_SIZE: usize = 8;
13
14#[cfg(feature = "pool-16")]
15pub const PACKET_POOL_SIZE: usize = 16;
16
17#[cfg(feature = "pool-32")]
18pub const PACKET_POOL_SIZE: usize = 32;
19
20#[cfg(feature = "pool-64")]
21pub const PACKET_POOL_SIZE: usize = 64;
22
23#[cfg(feature = "pool-128")]
24pub const PACKET_POOL_SIZE: usize = 128;
25
26pool!(pub PacketPool: [Packet; PACKET_POOL_SIZE]);
27pub type PacketBox = Box<PacketPool>;
28
29#[repr(align(4))]
30pub struct Packet(pub [u8; MTU]);
31
32impl Packet {
33 pub const fn new() -> Self {
34 Self([0; MTU])
35 }
36}
37
38pub trait PacketBoxExt {
39 fn slice(self, range: Range<usize>) -> PacketBuf;
40}
41
42impl PacketBoxExt for PacketBox {
43 fn slice(self, range: Range<usize>) -> PacketBuf {
44 PacketBuf { packet: self, range }
45 }
46}
47
48impl AsSlice for Packet {
49 type Element = u8;
50
51 fn as_slice(&self) -> &[Self::Element] {
52 &self.deref()[..]
53 }
54}
55
56impl AsMutSlice for Packet {
57 fn as_mut_slice(&mut self) -> &mut [Self::Element] {
58 &mut self.deref_mut()[..]
59 }
60}
61
62impl Deref for Packet {
63 type Target = [u8; MTU];
64
65 fn deref(&self) -> &[u8; MTU] {
66 &self.0
67 }
68}
69
70impl DerefMut for Packet {
71 fn deref_mut(&mut self) -> &mut [u8; MTU] {
72 &mut self.0
73 }
74}
75
76pub struct PacketBuf {
77 packet: PacketBox,
78 range: Range<usize>,
79}
80
81impl AsSlice for PacketBuf {
82 type Element = u8;
83
84 fn as_slice(&self) -> &[Self::Element] {
85 &self.packet[self.range.clone()]
86 }
87}
88
89impl AsMutSlice for PacketBuf {
90 fn as_mut_slice(&mut self) -> &mut [Self::Element] {
91 &mut self.packet[self.range.clone()]
92 }
93}
94
95impl Deref for PacketBuf {
96 type Target = [u8];
97
98 fn deref(&self) -> &[u8] {
99 &self.packet[self.range.clone()]
100 }
101}
102
103impl DerefMut for PacketBuf {
104 fn deref_mut(&mut self) -> &mut [u8] {
105 &mut self.packet[self.range.clone()]
106 }
107}
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs
deleted file mode 100644
index 8d2dd4bca..000000000
--- a/embassy-net/src/stack.rs
+++ /dev/null
@@ -1,316 +0,0 @@
1use core::cell::UnsafeCell;
2use core::future::Future;
3use core::task::{Context, Poll};
4
5use embassy_sync::waitqueue::WakerRegistration;
6use embassy_time::{Instant, Timer};
7use futures::future::poll_fn;
8use futures::pin_mut;
9use heapless::Vec;
10#[cfg(feature = "dhcpv4")]
11use smoltcp::iface::SocketHandle;
12use smoltcp::iface::{Interface, InterfaceBuilder, SocketSet, SocketStorage};
13#[cfg(feature = "medium-ethernet")]
14use smoltcp::iface::{Neighbor, NeighborCache, Route, Routes};
15#[cfg(feature = "medium-ethernet")]
16use smoltcp::phy::{Device as _, Medium};
17#[cfg(feature = "dhcpv4")]
18use smoltcp::socket::dhcpv4;
19use smoltcp::time::Instant as SmolInstant;
20#[cfg(feature = "medium-ethernet")]
21use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress};
22use smoltcp::wire::{IpCidr, Ipv4Address, Ipv4Cidr};
23
24use crate::device::{Device, DeviceAdapter, LinkState};
25
26const LOCAL_PORT_MIN: u16 = 1025;
27const LOCAL_PORT_MAX: u16 = 65535;
28
29pub struct StackResources<const ADDR: usize, const SOCK: usize, const NEIGHBOR: usize> {
30 addresses: [IpCidr; ADDR],
31 sockets: [SocketStorage<'static>; SOCK],
32
33 #[cfg(feature = "medium-ethernet")]
34 routes: [Option<(IpCidr, Route)>; 1],
35 #[cfg(feature = "medium-ethernet")]
36 neighbor_cache: [Option<(IpAddress, Neighbor)>; NEIGHBOR],
37}
38
39impl<const ADDR: usize, const SOCK: usize, const NEIGHBOR: usize> StackResources<ADDR, SOCK, NEIGHBOR> {
40 pub fn new() -> Self {
41 Self {
42 addresses: [IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 32); ADDR],
43 sockets: [SocketStorage::EMPTY; SOCK],
44 #[cfg(feature = "medium-ethernet")]
45 routes: [None; 1],
46 #[cfg(feature = "medium-ethernet")]
47 neighbor_cache: [None; NEIGHBOR],
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Config {
54 pub address: Ipv4Cidr,
55 pub gateway: Option<Ipv4Address>,
56 pub dns_servers: Vec<Ipv4Address, 3>,
57}
58
59pub enum ConfigStrategy {
60 Static(Config),
61 #[cfg(feature = "dhcpv4")]
62 Dhcp,
63}
64
65pub struct Stack<D: Device> {
66 pub(crate) socket: UnsafeCell<SocketStack>,
67 inner: UnsafeCell<Inner<D>>,
68}
69
70struct Inner<D: Device> {
71 device: DeviceAdapter<D>,
72 link_up: bool,
73 config: Option<Config>,
74 #[cfg(feature = "dhcpv4")]
75 dhcp_socket: Option<SocketHandle>,
76}
77
78pub(crate) struct SocketStack {
79 pub(crate) sockets: SocketSet<'static>,
80 pub(crate) iface: Interface<'static>,
81 pub(crate) waker: WakerRegistration,
82 next_local_port: u16,
83}
84
85unsafe impl<D: Device> Send for Stack<D> {}
86
87impl<D: Device + 'static> Stack<D> {
88 pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>(
89 device: D,
90 config: ConfigStrategy,
91 resources: &'static mut StackResources<ADDR, SOCK, NEIGH>,
92 random_seed: u64,
93 ) -> Self {
94 #[cfg(feature = "medium-ethernet")]
95 let medium = device.capabilities().medium;
96
97 #[cfg(feature = "medium-ethernet")]
98 let ethernet_addr = if medium == Medium::Ethernet {
99 device.ethernet_address()
100 } else {
101 [0, 0, 0, 0, 0, 0]
102 };
103
104 let mut device = DeviceAdapter::new(device);
105
106 let mut b = InterfaceBuilder::new();
107 b = b.ip_addrs(&mut resources.addresses[..]);
108 b = b.random_seed(random_seed);
109
110 #[cfg(feature = "medium-ethernet")]
111 if medium == Medium::Ethernet {
112 b = b.hardware_addr(HardwareAddress::Ethernet(EthernetAddress(ethernet_addr)));
113 b = b.neighbor_cache(NeighborCache::new(&mut resources.neighbor_cache[..]));
114 b = b.routes(Routes::new(&mut resources.routes[..]));
115 }
116
117 let iface = b.finalize(&mut device);
118
119 let sockets = SocketSet::new(&mut resources.sockets[..]);
120
121 let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
122
123 let mut inner = Inner {
124 device,
125 link_up: false,
126 config: None,
127 #[cfg(feature = "dhcpv4")]
128 dhcp_socket: None,
129 };
130 let mut socket = SocketStack {
131 sockets,
132 iface,
133 waker: WakerRegistration::new(),
134 next_local_port,
135 };
136
137 match config {
138 ConfigStrategy::Static(config) => inner.apply_config(&mut socket, config),
139 #[cfg(feature = "dhcpv4")]
140 ConfigStrategy::Dhcp => {
141 let handle = socket.sockets.add(smoltcp::socket::dhcpv4::Socket::new());
142 inner.dhcp_socket = Some(handle);
143 }
144 }
145
146 Self {
147 socket: UnsafeCell::new(socket),
148 inner: UnsafeCell::new(inner),
149 }
150 }
151
152 /// SAFETY: must not call reentrantly.
153 unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R {
154 f(&*self.socket.get(), &*self.inner.get())
155 }
156
157 /// SAFETY: must not call reentrantly.
158 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R {
159 f(&mut *self.socket.get(), &mut *self.inner.get())
160 }
161
162 pub fn ethernet_address(&self) -> [u8; 6] {
163 unsafe { self.with(|_s, i| i.device.device.ethernet_address()) }
164 }
165
166 pub fn is_link_up(&self) -> bool {
167 unsafe { self.with(|_s, i| i.link_up) }
168 }
169
170 pub fn is_config_up(&self) -> bool {
171 unsafe { self.with(|_s, i| i.config.is_some()) }
172 }
173
174 pub fn config(&self) -> Option<Config> {
175 unsafe { self.with(|_s, i| i.config.clone()) }
176 }
177
178 pub async fn run(&self) -> ! {
179 poll_fn(|cx| {
180 unsafe { self.with_mut(|s, i| i.poll(cx, s)) }
181 Poll::<()>::Pending
182 })
183 .await;
184 unreachable!()
185 }
186}
187
188impl SocketStack {
189 #[allow(clippy::absurd_extreme_comparisons)]
190 pub fn get_local_port(&mut self) -> u16 {
191 let res = self.next_local_port;
192 self.next_local_port = if res >= LOCAL_PORT_MAX { LOCAL_PORT_MIN } else { res + 1 };
193 res
194 }
195}
196
197impl<D: Device + 'static> Inner<D> {
198 fn apply_config(&mut self, s: &mut SocketStack, config: Config) {
199 #[cfg(feature = "medium-ethernet")]
200 let medium = self.device.capabilities().medium;
201
202 debug!("Acquired IP configuration:");
203
204 debug!(" IP address: {}", config.address);
205 self.set_ipv4_addr(s, config.address);
206
207 #[cfg(feature = "medium-ethernet")]
208 if medium == Medium::Ethernet {
209 if let Some(gateway) = config.gateway {
210 debug!(" Default gateway: {}", gateway);
211 s.iface.routes_mut().add_default_ipv4_route(gateway).unwrap();
212 } else {
213 debug!(" Default gateway: None");
214 s.iface.routes_mut().remove_default_ipv4_route();
215 }
216 }
217 for (i, s) in config.dns_servers.iter().enumerate() {
218 debug!(" DNS server {}: {}", i, s);
219 }
220
221 self.config = Some(config)
222 }
223
224 #[allow(unused)] // used only with dhcp
225 fn unapply_config(&mut self, s: &mut SocketStack) {
226 #[cfg(feature = "medium-ethernet")]
227 let medium = self.device.capabilities().medium;
228
229 debug!("Lost IP configuration");
230 self.set_ipv4_addr(s, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0));
231 #[cfg(feature = "medium-ethernet")]
232 if medium == Medium::Ethernet {
233 s.iface.routes_mut().remove_default_ipv4_route();
234 }
235 self.config = None
236 }
237
238 fn set_ipv4_addr(&mut self, s: &mut SocketStack, cidr: Ipv4Cidr) {
239 s.iface.update_ip_addrs(|addrs| {
240 let dest = addrs.iter_mut().next().unwrap();
241 *dest = IpCidr::Ipv4(cidr);
242 });
243 }
244
245 fn poll(&mut self, cx: &mut Context<'_>, s: &mut SocketStack) {
246 self.device.device.register_waker(cx.waker());
247 s.waker.register(cx.waker());
248
249 let timestamp = instant_to_smoltcp(Instant::now());
250 if s.iface.poll(timestamp, &mut self.device, &mut s.sockets).is_err() {
251 // If poll() returns error, it may not be done yet, so poll again later.
252 cx.waker().wake_by_ref();
253 return;
254 }
255
256 // Update link up
257 let old_link_up = self.link_up;
258 self.link_up = self.device.device.link_state() == LinkState::Up;
259
260 // Print when changed
261 if old_link_up != self.link_up {
262 info!("link_up = {:?}", self.link_up);
263 }
264
265 #[cfg(feature = "dhcpv4")]
266 if let Some(dhcp_handle) = self.dhcp_socket {
267 let socket = s.sockets.get_mut::<dhcpv4::Socket>(dhcp_handle);
268
269 if self.link_up {
270 match socket.poll() {
271 None => {}
272 Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s),
273 Some(dhcpv4::Event::Configured(config)) => {
274 let mut dns_servers = Vec::new();
275 for s in &config.dns_servers {
276 if let Some(addr) = s {
277 dns_servers.push(addr.clone()).unwrap();
278 }
279 }
280
281 self.apply_config(
282 s,
283 Config {
284 address: config.address,
285 gateway: config.router,
286 dns_servers,
287 },
288 )
289 }
290 }
291 } else if old_link_up {
292 socket.reset();
293 self.unapply_config(s);
294 }
295 }
296 //if old_link_up || self.link_up {
297 // self.poll_configurator(timestamp)
298 //}
299
300 if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
301 let t = Timer::at(instant_from_smoltcp(poll_at));
302 pin_mut!(t);
303 if t.poll(cx).is_ready() {
304 cx.waker().wake_by_ref();
305 }
306 }
307 }
308}
309
310fn instant_to_smoltcp(instant: Instant) -> SmolInstant {
311 SmolInstant::from_millis(instant.as_millis() as i64)
312}
313
314fn instant_from_smoltcp(instant: SmolInstant) -> Instant {
315 Instant::from_millis(instant.total_millis() as u64)
316}
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs
index 910772c7d..367675b13 100644
--- a/embassy-net/src/tcp.rs
+++ b/embassy-net/src/tcp.rs
@@ -1,24 +1,39 @@
1use core::cell::UnsafeCell; 1//! TCP sockets.
2use core::future::Future; 2//!
3//! # Listening
4//!
5//! `embassy-net` does not have a `TcpListener`. Instead, individual `TcpSocket`s can be put into
6//! listening mode by calling [`TcpSocket::accept`].
7//!
8//! Incoming connections when no socket is listening are rejected. To accept many incoming
9//! connections, create many sockets and put them all into listening mode.
10
11use core::cell::RefCell;
12use core::future::poll_fn;
3use core::mem; 13use core::mem;
4use core::task::Poll; 14use core::task::Poll;
5 15
6use futures::future::poll_fn; 16use embassy_net_driver::Driver;
17use embassy_time::Duration;
7use smoltcp::iface::{Interface, SocketHandle}; 18use smoltcp::iface::{Interface, SocketHandle};
8use smoltcp::socket::tcp; 19use smoltcp::socket::tcp;
9use smoltcp::time::Duration; 20pub use smoltcp::socket::tcp::State;
10use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; 21use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
11 22
12use super::stack::Stack; 23use crate::time::duration_to_smoltcp;
13use crate::stack::SocketStack; 24use crate::{SocketStack, Stack};
14use crate::Device;
15 25
26/// Error returned by TcpSocket read/write functions.
16#[derive(PartialEq, Eq, Clone, Copy, Debug)] 27#[derive(PartialEq, Eq, Clone, Copy, Debug)]
17#[cfg_attr(feature = "defmt", derive(defmt::Format))] 28#[cfg_attr(feature = "defmt", derive(defmt::Format))]
18pub enum Error { 29pub enum Error {
30 /// The connection was reset.
31 ///
32 /// This can happen on receiving a RST packet, or on timeout.
19 ConnectionReset, 33 ConnectionReset,
20} 34}
21 35
36/// Error returned by [`TcpSocket::connect`].
22#[derive(PartialEq, Eq, Clone, Copy, Debug)] 37#[derive(PartialEq, Eq, Clone, Copy, Debug)]
23#[cfg_attr(feature = "defmt", derive(defmt::Format))] 38#[cfg_attr(feature = "defmt", derive(defmt::Format))]
24pub enum ConnectError { 39pub enum ConnectError {
@@ -32,6 +47,7 @@ pub enum ConnectError {
32 NoRoute, 47 NoRoute,
33} 48}
34 49
50/// Error returned by [`TcpSocket::accept`].
35#[derive(PartialEq, Eq, Clone, Copy, Debug)] 51#[derive(PartialEq, Eq, Clone, Copy, Debug)]
36#[cfg_attr(feature = "defmt", derive(defmt::Format))] 52#[cfg_attr(feature = "defmt", derive(defmt::Format))]
37pub enum AcceptError { 53pub enum AcceptError {
@@ -43,22 +59,53 @@ pub enum AcceptError {
43 ConnectionReset, 59 ConnectionReset,
44} 60}
45 61
62/// A TCP socket.
46pub struct TcpSocket<'a> { 63pub struct TcpSocket<'a> {
47 io: TcpIo<'a>, 64 io: TcpIo<'a>,
48} 65}
49 66
67/// The reader half of a TCP socket.
50pub struct TcpReader<'a> { 68pub struct TcpReader<'a> {
51 io: TcpIo<'a>, 69 io: TcpIo<'a>,
52} 70}
53 71
72/// The writer half of a TCP socket.
54pub struct TcpWriter<'a> { 73pub struct TcpWriter<'a> {
55 io: TcpIo<'a>, 74 io: TcpIo<'a>,
56} 75}
57 76
77impl<'a> TcpReader<'a> {
78 /// Read data from the socket.
79 ///
80 /// Returns how many bytes were read, or an error. If no data is available, it waits
81 /// until there is at least one byte available.
82 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
83 self.io.read(buf).await
84 }
85}
86
87impl<'a> TcpWriter<'a> {
88 /// Write data to the socket.
89 ///
90 /// Returns how many bytes were written, or an error. If the socket is not ready to
91 /// accept data, it waits until it is.
92 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
93 self.io.write(buf).await
94 }
95
96 /// Flushes the written data to the socket.
97 ///
98 /// This waits until all data has been sent, and ACKed by the remote host. For a connection
99 /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent.
100 pub async fn flush(&mut self) -> Result<(), Error> {
101 self.io.flush().await
102 }
103}
104
58impl<'a> TcpSocket<'a> { 105impl<'a> TcpSocket<'a> {
59 pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { 106 /// Create a new TCP socket on the given stack, with the given buffers.
60 // safety: not accessed reentrantly. 107 pub fn new<D: Driver>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
61 let s = unsafe { &mut *stack.socket.get() }; 108 let s = &mut *stack.socket.borrow_mut();
62 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 109 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
63 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; 110 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
64 let handle = s.sockets.add(tcp::Socket::new( 111 let handle = s.sockets.add(tcp::Socket::new(
@@ -74,25 +121,28 @@ impl<'a> TcpSocket<'a> {
74 } 121 }
75 } 122 }
76 123
124 /// Split the socket into reader and a writer halves.
77 pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { 125 pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) {
78 (TcpReader { io: self.io }, TcpWriter { io: self.io }) 126 (TcpReader { io: self.io }, TcpWriter { io: self.io })
79 } 127 }
80 128
129 /// Connect to a remote host.
81 pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> 130 pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError>
82 where 131 where
83 T: Into<IpEndpoint>, 132 T: Into<IpEndpoint>,
84 { 133 {
85 // safety: not accessed reentrantly. 134 let local_port = self.io.stack.borrow_mut().get_local_port();
86 let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port();
87 135
88 // safety: not accessed reentrantly. 136 match {
89 match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { 137 self.io
138 .with_mut(|s, i| s.connect(i.context(), remote_endpoint, local_port))
139 } {
90 Ok(()) => {} 140 Ok(()) => {}
91 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), 141 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState),
92 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), 142 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute),
93 } 143 }
94 144
95 futures::future::poll_fn(|cx| unsafe { 145 poll_fn(|cx| {
96 self.io.with_mut(|s, _| match s.state() { 146 self.io.with_mut(|s, _| match s.state() {
97 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), 147 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)),
98 tcp::State::Listen => unreachable!(), 148 tcp::State::Listen => unreachable!(),
@@ -106,18 +156,20 @@ impl<'a> TcpSocket<'a> {
106 .await 156 .await
107 } 157 }
108 158
159 /// Accept a connection from a remote host.
160 ///
161 /// This function puts the socket in listening mode, and waits until a connection is received.
109 pub async fn accept<T>(&mut self, local_endpoint: T) -> Result<(), AcceptError> 162 pub async fn accept<T>(&mut self, local_endpoint: T) -> Result<(), AcceptError>
110 where 163 where
111 T: Into<IpListenEndpoint>, 164 T: Into<IpListenEndpoint>,
112 { 165 {
113 // safety: not accessed reentrantly. 166 match self.io.with_mut(|s, _| s.listen(local_endpoint)) {
114 match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } {
115 Ok(()) => {} 167 Ok(()) => {}
116 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), 168 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState),
117 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), 169 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort),
118 } 170 }
119 171
120 futures::future::poll_fn(|cx| unsafe { 172 poll_fn(|cx| {
121 self.io.with_mut(|s, _| match s.state() { 173 self.io.with_mut(|s, _| match s.state() {
122 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { 174 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
123 s.register_send_waker(cx.waker()); 175 s.register_send_waker(cx.waker());
@@ -129,52 +181,120 @@ impl<'a> TcpSocket<'a> {
129 .await 181 .await
130 } 182 }
131 183
184 /// Read data from the socket.
185 ///
186 /// Returns how many bytes were read, or an error. If no data is available, it waits
187 /// until there is at least one byte available.
188 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
189 self.io.read(buf).await
190 }
191
192 /// Write data to the socket.
193 ///
194 /// Returns how many bytes were written, or an error. If the socket is not ready to
195 /// accept data, it waits until it is.
196 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
197 self.io.write(buf).await
198 }
199
200 /// Flushes the written data to the socket.
201 ///
202 /// This waits until all data has been sent, and ACKed by the remote host. For a connection
203 /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent.
204 pub async fn flush(&mut self) -> Result<(), Error> {
205 self.io.flush().await
206 }
207
208 /// Set the timeout for the socket.
209 ///
210 /// If the timeout is set, the socket will be closed if no data is received for the
211 /// specified duration.
132 pub fn set_timeout(&mut self, duration: Option<Duration>) { 212 pub fn set_timeout(&mut self, duration: Option<Duration>) {
133 unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } 213 self.io
214 .with_mut(|s, _| s.set_timeout(duration.map(duration_to_smoltcp)))
134 } 215 }
135 216
217 /// Set the keep-alive interval for the socket.
218 ///
219 /// If the keep-alive interval is set, the socket will send keep-alive packets after
220 /// the specified duration of inactivity.
221 ///
222 /// If not set, the socket will not send keep-alive packets.
136 pub fn set_keep_alive(&mut self, interval: Option<Duration>) { 223 pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
137 unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } 224 self.io
225 .with_mut(|s, _| s.set_keep_alive(interval.map(duration_to_smoltcp)))
138 } 226 }
139 227
228 /// Set the hop limit field in the IP header of sent packets.
140 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { 229 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
141 unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } 230 self.io.with_mut(|s, _| s.set_hop_limit(hop_limit))
142 } 231 }
143 232
233 /// Get the local endpoint of the socket.
234 ///
235 /// Returns `None` if the socket is not bound (listening) or not connected.
144 pub fn local_endpoint(&self) -> Option<IpEndpoint> { 236 pub fn local_endpoint(&self) -> Option<IpEndpoint> {
145 unsafe { self.io.with(|s, _| s.local_endpoint()) } 237 self.io.with(|s, _| s.local_endpoint())
146 } 238 }
147 239
240 /// Get the remote endpoint of the socket.
241 ///
242 /// Returns `None` if the socket is not connected.
148 pub fn remote_endpoint(&self) -> Option<IpEndpoint> { 243 pub fn remote_endpoint(&self) -> Option<IpEndpoint> {
149 unsafe { self.io.with(|s, _| s.remote_endpoint()) } 244 self.io.with(|s, _| s.remote_endpoint())
150 } 245 }
151 246
152 pub fn state(&self) -> tcp::State { 247 /// Get the state of the socket.
153 unsafe { self.io.with(|s, _| s.state()) } 248 pub fn state(&self) -> State {
249 self.io.with(|s, _| s.state())
154 } 250 }
155 251
252 /// Close the write half of the socket.
253 ///
254 /// This closes only the write half of the socket. The read half side remains open, the
255 /// socket can still receive data.
256 ///
257 /// Data that has been written to the socket and not yet sent (or not yet ACKed) will still
258 /// still sent. The last segment of the pending to send data is sent with the FIN flag set.
156 pub fn close(&mut self) { 259 pub fn close(&mut self) {
157 unsafe { self.io.with_mut(|s, _| s.close()) } 260 self.io.with_mut(|s, _| s.close())
158 } 261 }
159 262
263 /// Forcibly close the socket.
264 ///
265 /// This instantly closes both the read and write halves of the socket. Any pending data
266 /// that has not been sent will be lost.
267 ///
268 /// Note that the TCP RST packet is not sent immediately - if the `TcpSocket` is dropped too soon
269 /// the remote host may not know the connection has been closed.
270 /// `abort()` callers should wait for a [`flush()`](TcpSocket::flush) call to complete before
271 /// dropping or reusing the socket.
160 pub fn abort(&mut self) { 272 pub fn abort(&mut self) {
161 unsafe { self.io.with_mut(|s, _| s.abort()) } 273 self.io.with_mut(|s, _| s.abort())
162 } 274 }
163 275
276 /// Get whether the socket is ready to send data, i.e. whether there is space in the send buffer.
164 pub fn may_send(&self) -> bool { 277 pub fn may_send(&self) -> bool {
165 unsafe { self.io.with(|s, _| s.may_send()) } 278 self.io.with(|s, _| s.may_send())
166 } 279 }
167 280
281 /// return whether the recieve half of the full-duplex connection is open.
282 /// This function returns true if it’s possible to receive data from the remote endpoint.
283 /// It will return true while there is data in the receive buffer, and if there isn’t,
284 /// as long as the remote endpoint has not closed the connection.
168 pub fn may_recv(&self) -> bool { 285 pub fn may_recv(&self) -> bool {
169 unsafe { self.io.with(|s, _| s.may_recv()) } 286 self.io.with(|s, _| s.may_recv())
287 }
288
289 /// Get whether the socket is ready to receive data, i.e. whether there is some pending data in the receive buffer.
290 pub fn can_recv(&self) -> bool {
291 self.io.with(|s, _| s.can_recv())
170 } 292 }
171} 293}
172 294
173impl<'a> Drop for TcpSocket<'a> { 295impl<'a> Drop for TcpSocket<'a> {
174 fn drop(&mut self) { 296 fn drop(&mut self) {
175 // safety: not accessed reentrantly. 297 self.io.stack.borrow_mut().sockets.remove(self.io.handle);
176 let s = unsafe { &mut *self.io.stack.get() };
177 s.sockets.remove(self.io.handle);
178 } 298 }
179} 299}
180 300
@@ -182,21 +302,19 @@ impl<'a> Drop for TcpSocket<'a> {
182 302
183#[derive(Copy, Clone)] 303#[derive(Copy, Clone)]
184struct TcpIo<'a> { 304struct TcpIo<'a> {
185 stack: &'a UnsafeCell<SocketStack>, 305 stack: &'a RefCell<SocketStack>,
186 handle: SocketHandle, 306 handle: SocketHandle,
187} 307}
188 308
189impl<'d> TcpIo<'d> { 309impl<'d> TcpIo<'d> {
190 /// SAFETY: must not call reentrantly. 310 fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R {
191 unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { 311 let s = &*self.stack.borrow();
192 let s = &*self.stack.get();
193 let socket = s.sockets.get::<tcp::Socket>(self.handle); 312 let socket = s.sockets.get::<tcp::Socket>(self.handle);
194 f(socket, &s.iface) 313 f(socket, &s.iface)
195 } 314 }
196 315
197 /// SAFETY: must not call reentrantly. 316 fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R {
198 unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { 317 let s = &mut *self.stack.borrow_mut();
199 let s = &mut *self.stack.get();
200 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); 318 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle);
201 let res = f(socket, &mut s.iface); 319 let res = f(socket, &mut s.iface);
202 s.waker.wake(); 320 s.waker.wake();
@@ -204,7 +322,7 @@ impl<'d> TcpIo<'d> {
204 } 322 }
205 323
206 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { 324 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
207 poll_fn(move |cx| unsafe { 325 poll_fn(move |cx| {
208 // CAUTION: smoltcp semantics around EOF are different to what you'd expect 326 // CAUTION: smoltcp semantics around EOF are different to what you'd expect
209 // from posix-like IO, so we have to tweak things here. 327 // from posix-like IO, so we have to tweak things here.
210 self.with_mut(|s, _| match s.recv_slice(buf) { 328 self.with_mut(|s, _| match s.recv_slice(buf) {
@@ -225,7 +343,7 @@ impl<'d> TcpIo<'d> {
225 } 343 }
226 344
227 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { 345 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
228 poll_fn(move |cx| unsafe { 346 poll_fn(move |cx| {
229 self.with_mut(|s, _| match s.send_slice(buf) { 347 self.with_mut(|s, _| match s.send_slice(buf) {
230 // Not ready to send (no space in the tx buffer) 348 // Not ready to send (no space in the tx buffer)
231 Ok(0) => { 349 Ok(0) => {
@@ -242,95 +360,89 @@ impl<'d> TcpIo<'d> {
242 } 360 }
243 361
244 async fn flush(&mut self) -> Result<(), Error> { 362 async fn flush(&mut self) -> Result<(), Error> {
245 poll_fn(move |_| { 363 poll_fn(move |cx| {
246 Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? 364 self.with_mut(|s, _| {
365 let waiting_close = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
366 // If there are outstanding send operations, register for wake up and wait
367 // smoltcp issues wake-ups when octets are dequeued from the send buffer
368 if s.send_queue() > 0 || waiting_close {
369 s.register_send_waker(cx.waker());
370 Poll::Pending
371 // No outstanding sends, socket is flushed
372 } else {
373 Poll::Ready(Ok(()))
374 }
375 })
247 }) 376 })
248 .await 377 .await
249 } 378 }
250} 379}
251 380
252impl embedded_io::Error for ConnectError { 381#[cfg(feature = "nightly")]
253 fn kind(&self) -> embedded_io::ErrorKind { 382mod embedded_io_impls {
254 embedded_io::ErrorKind::Other 383 use super::*;
255 }
256}
257 384
258impl embedded_io::Error for Error { 385 impl embedded_io::Error for ConnectError {
259 fn kind(&self) -> embedded_io::ErrorKind { 386 fn kind(&self) -> embedded_io::ErrorKind {
260 embedded_io::ErrorKind::Other 387 embedded_io::ErrorKind::Other
388 }
261 } 389 }
262}
263
264impl<'d> embedded_io::Io for TcpSocket<'d> {
265 type Error = Error;
266}
267 390
268impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { 391 impl embedded_io::Error for Error {
269 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 392 fn kind(&self) -> embedded_io::ErrorKind {
270 where 393 embedded_io::ErrorKind::Other
271 Self: 'a; 394 }
272
273 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
274 self.io.read(buf)
275 } 395 }
276}
277 396
278impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { 397 impl<'d> embedded_io::Io for TcpSocket<'d> {
279 type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 398 type Error = Error;
280 where
281 Self: 'a;
282
283 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> {
284 self.io.write(buf)
285 } 399 }
286 400
287 type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> 401 impl<'d> embedded_io::asynch::Read for TcpSocket<'d> {
288 where 402 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
289 Self: 'a; 403 self.io.read(buf).await
290 404 }
291 fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> {
292 self.io.flush()
293 } 405 }
294}
295
296impl<'d> embedded_io::Io for TcpReader<'d> {
297 type Error = Error;
298}
299 406
300impl<'d> embedded_io::asynch::Read for TcpReader<'d> { 407 impl<'d> embedded_io::asynch::Write for TcpSocket<'d> {
301 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 408 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
302 where 409 self.io.write(buf).await
303 Self: 'a; 410 }
304 411
305 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { 412 async fn flush(&mut self) -> Result<(), Self::Error> {
306 self.io.read(buf) 413 self.io.flush().await
414 }
307 } 415 }
308}
309 416
310impl<'d> embedded_io::Io for TcpWriter<'d> { 417 impl<'d> embedded_io::Io for TcpReader<'d> {
311 type Error = Error; 418 type Error = Error;
312} 419 }
313 420
314impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { 421 impl<'d> embedded_io::asynch::Read for TcpReader<'d> {
315 type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 422 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
316 where 423 self.io.read(buf).await
317 Self: 'a; 424 }
425 }
318 426
319 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { 427 impl<'d> embedded_io::Io for TcpWriter<'d> {
320 self.io.write(buf) 428 type Error = Error;
321 } 429 }
322 430
323 type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> 431 impl<'d> embedded_io::asynch::Write for TcpWriter<'d> {
324 where 432 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
325 Self: 'a; 433 self.io.write(buf).await
434 }
326 435
327 fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { 436 async fn flush(&mut self) -> Result<(), Self::Error> {
328 self.io.flush() 437 self.io.flush().await
438 }
329 } 439 }
330} 440}
331 441
332#[cfg(feature = "unstable-traits")] 442/// TCP client compatible with `embedded-nal-async` traits.
443#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
333pub mod client { 444pub mod client {
445 use core::cell::UnsafeCell;
334 use core::mem::MaybeUninit; 446 use core::mem::MaybeUninit;
335 use core::ptr::NonNull; 447 use core::ptr::NonNull;
336 448
@@ -339,49 +451,56 @@ pub mod client {
339 451
340 use super::*; 452 use super::*;
341 453
342 /// TCP client capable of creating up to N multiple connections with tx and rx buffers according to TX_SZ and RX_SZ. 454 /// TCP client connection pool compatible with `embedded-nal-async` traits.
343 pub struct TcpClient<'d, D: Device, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { 455 ///
456 /// The pool is capable of managing up to N concurrent connections with tx and rx buffers according to TX_SZ and RX_SZ.
457 pub struct TcpClient<'d, D: Driver, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> {
344 stack: &'d Stack<D>, 458 stack: &'d Stack<D>,
345 state: &'d TcpClientState<N, TX_SZ, RX_SZ>, 459 state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
346 } 460 }
347 461
348 impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { 462 impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> {
349 /// Create a new TcpClient 463 /// Create a new `TcpClient`.
350 pub fn new(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self { 464 pub fn new(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self {
351 Self { stack, state } 465 Self { stack, state }
352 } 466 }
353 } 467 }
354 468
355 impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_nal_async::TcpConnect 469 impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_nal_async::TcpConnect
356 for TcpClient<'d, D, N, TX_SZ, RX_SZ> 470 for TcpClient<'d, D, N, TX_SZ, RX_SZ>
357 { 471 {
358 type Error = Error; 472 type Error = Error;
359 type Connection<'m> = TcpConnection<'m, N, TX_SZ, RX_SZ> where Self: 'm; 473 type Connection<'m> = TcpConnection<'m, N, TX_SZ, RX_SZ> where Self: 'm;
360 type ConnectFuture<'m> = impl Future<Output = Result<Self::Connection<'m>, Self::Error>> + 'm 474
361 where 475 async fn connect<'a>(
362 Self: 'm; 476 &'a self,
363 477 remote: embedded_nal_async::SocketAddr,
364 fn connect<'m>(&'m self, remote: embedded_nal_async::SocketAddr) -> Self::ConnectFuture<'m> { 478 ) -> Result<Self::Connection<'a>, Self::Error>
365 async move { 479 where
366 let addr: crate::IpAddress = match remote.ip() { 480 Self: 'a,
367 IpAddr::V4(addr) => crate::IpAddress::Ipv4(crate::Ipv4Address::from_bytes(&addr.octets())), 481 {
368 #[cfg(feature = "proto-ipv6")] 482 let addr: crate::IpAddress = match remote.ip() {
369 IpAddr::V6(addr) => crate::IpAddress::Ipv6(crate::Ipv6Address::from_bytes(&addr.octets())), 483 #[cfg(feature = "proto-ipv4")]
370 #[cfg(not(feature = "proto-ipv6"))] 484 IpAddr::V4(addr) => crate::IpAddress::Ipv4(crate::Ipv4Address::from_bytes(&addr.octets())),
371 IpAddr::V6(_) => panic!("ipv6 support not enabled"), 485 #[cfg(not(feature = "proto-ipv4"))]
372 }; 486 IpAddr::V4(_) => panic!("ipv4 support not enabled"),
373 let remote_endpoint = (addr, remote.port()); 487 #[cfg(feature = "proto-ipv6")]
374 let mut socket = TcpConnection::new(&self.stack, self.state)?; 488 IpAddr::V6(addr) => crate::IpAddress::Ipv6(crate::Ipv6Address::from_bytes(&addr.octets())),
375 socket 489 #[cfg(not(feature = "proto-ipv6"))]
376 .socket 490 IpAddr::V6(_) => panic!("ipv6 support not enabled"),
377 .connect(remote_endpoint) 491 };
378 .await 492 let remote_endpoint = (addr, remote.port());
379 .map_err(|_| Error::ConnectionReset)?; 493 let mut socket = TcpConnection::new(&self.stack, self.state)?;
380 Ok(socket) 494 socket
381 } 495 .socket
496 .connect(remote_endpoint)
497 .await
498 .map_err(|_| Error::ConnectionReset)?;
499 Ok(socket)
382 } 500 }
383 } 501 }
384 502
503 /// Opened TCP connection in a [`TcpClient`].
385 pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> { 504 pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> {
386 socket: TcpSocket<'d>, 505 socket: TcpSocket<'d>,
387 state: &'d TcpClientState<N, TX_SZ, RX_SZ>, 506 state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
@@ -389,10 +508,10 @@ pub mod client {
389 } 508 }
390 509
391 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> { 510 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> {
392 fn new<D: Device>(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Result<Self, Error> { 511 fn new<D: Driver>(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Result<Self, Error> {
393 let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?; 512 let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?;
394 Ok(Self { 513 Ok(Self {
395 socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().0, &mut bufs.as_mut().1) }, 514 socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().1, &mut bufs.as_mut().0) },
396 state, 515 state,
397 bufs, 516 bufs,
398 }) 517 })
@@ -417,32 +536,20 @@ pub mod client {
417 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Read 536 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Read
418 for TcpConnection<'d, N, TX_SZ, RX_SZ> 537 for TcpConnection<'d, N, TX_SZ, RX_SZ>
419 { 538 {
420 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 539 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
421 where 540 self.socket.read(buf).await
422 Self: 'a;
423
424 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
425 self.socket.read(buf)
426 } 541 }
427 } 542 }
428 543
429 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Write 544 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Write
430 for TcpConnection<'d, N, TX_SZ, RX_SZ> 545 for TcpConnection<'d, N, TX_SZ, RX_SZ>
431 { 546 {
432 type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 547 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
433 where 548 self.socket.write(buf).await
434 Self: 'a;
435
436 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> {
437 self.socket.write(buf)
438 } 549 }
439 550
440 type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> 551 async fn flush(&mut self) -> Result<(), Self::Error> {
441 where 552 self.socket.flush().await
442 Self: 'a;
443
444 fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> {
445 self.socket.flush()
446 } 553 }
447 } 554 }
448 555
@@ -452,6 +559,7 @@ pub mod client {
452 } 559 }
453 560
454 impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> { 561 impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> {
562 /// Create a new `TcpClientState`.
455 pub const fn new() -> Self { 563 pub const fn new() -> Self {
456 Self { pool: Pool::new() } 564 Self { pool: Pool::new() }
457 } 565 }
diff --git a/embassy-net/src/time.rs b/embassy-net/src/time.rs
new file mode 100644
index 000000000..b98d40fdc
--- /dev/null
+++ b/embassy-net/src/time.rs
@@ -0,0 +1,20 @@
1#![allow(unused)]
2
3use embassy_time::{Duration, Instant};
4use smoltcp::time::{Duration as SmolDuration, Instant as SmolInstant};
5
6pub(crate) fn instant_to_smoltcp(instant: Instant) -> SmolInstant {
7 SmolInstant::from_micros(instant.as_micros() as i64)
8}
9
10pub(crate) fn instant_from_smoltcp(instant: SmolInstant) -> Instant {
11 Instant::from_micros(instant.total_micros() as u64)
12}
13
14pub(crate) fn duration_to_smoltcp(duration: Duration) -> SmolDuration {
15 SmolDuration::from_micros(duration.as_micros())
16}
17
18pub(crate) fn duration_from_smoltcp(duration: SmolDuration) -> Duration {
19 Duration::from_micros(duration.total_micros())
20}
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
index 78b09a492..0d97b6db1 100644
--- a/embassy-net/src/udp.rs
+++ b/embassy-net/src/udp.rs
@@ -1,15 +1,19 @@
1use core::cell::UnsafeCell; 1//! UDP sockets.
2
3use core::cell::RefCell;
4use core::future::poll_fn;
2use core::mem; 5use core::mem;
3use core::task::Poll; 6use core::task::{Context, Poll};
4 7
5use futures::future::poll_fn; 8use embassy_net_driver::Driver;
6use smoltcp::iface::{Interface, SocketHandle}; 9use smoltcp::iface::{Interface, SocketHandle};
7use smoltcp::socket::udp::{self, PacketMetadata}; 10use smoltcp::socket::udp;
11pub use smoltcp::socket::udp::PacketMetadata;
8use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; 12use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
9 13
10use super::stack::SocketStack; 14use crate::{SocketStack, Stack};
11use crate::{Device, Stack};
12 15
16/// Error returned by [`UdpSocket::bind`].
13#[derive(PartialEq, Eq, Clone, Copy, Debug)] 17#[derive(PartialEq, Eq, Clone, Copy, Debug)]
14#[cfg_attr(feature = "defmt", derive(defmt::Format))] 18#[cfg_attr(feature = "defmt", derive(defmt::Format))]
15pub enum BindError { 19pub enum BindError {
@@ -19,6 +23,7 @@ pub enum BindError {
19 NoRoute, 23 NoRoute,
20} 24}
21 25
26/// Error returned by [`UdpSocket::recv_from`] and [`UdpSocket::send_to`].
22#[derive(PartialEq, Eq, Clone, Copy, Debug)] 27#[derive(PartialEq, Eq, Clone, Copy, Debug)]
23#[cfg_attr(feature = "defmt", derive(defmt::Format))] 28#[cfg_attr(feature = "defmt", derive(defmt::Format))]
24pub enum Error { 29pub enum Error {
@@ -26,21 +31,22 @@ pub enum Error {
26 NoRoute, 31 NoRoute,
27} 32}
28 33
34/// An UDP socket.
29pub struct UdpSocket<'a> { 35pub struct UdpSocket<'a> {
30 stack: &'a UnsafeCell<SocketStack>, 36 stack: &'a RefCell<SocketStack>,
31 handle: SocketHandle, 37 handle: SocketHandle,
32} 38}
33 39
34impl<'a> UdpSocket<'a> { 40impl<'a> UdpSocket<'a> {
35 pub fn new<D: Device>( 41 /// Create a new UDP socket using the provided stack and buffers.
42 pub fn new<D: Driver>(
36 stack: &'a Stack<D>, 43 stack: &'a Stack<D>,
37 rx_meta: &'a mut [PacketMetadata], 44 rx_meta: &'a mut [PacketMetadata],
38 rx_buffer: &'a mut [u8], 45 rx_buffer: &'a mut [u8],
39 tx_meta: &'a mut [PacketMetadata], 46 tx_meta: &'a mut [PacketMetadata],
40 tx_buffer: &'a mut [u8], 47 tx_buffer: &'a mut [u8],
41 ) -> Self { 48 ) -> Self {
42 // safety: not accessed reentrantly. 49 let s = &mut *stack.socket.borrow_mut();
43 let s = unsafe { &mut *stack.socket.get() };
44 50
45 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; 51 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
46 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 52 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
@@ -57,101 +63,131 @@ impl<'a> UdpSocket<'a> {
57 } 63 }
58 } 64 }
59 65
66 /// Bind the socket to a local endpoint.
60 pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError> 67 pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
61 where 68 where
62 T: Into<IpListenEndpoint>, 69 T: Into<IpListenEndpoint>,
63 { 70 {
64 let mut endpoint = endpoint.into(); 71 let mut endpoint = endpoint.into();
65 72
66 // safety: not accessed reentrantly.
67 if endpoint.port == 0 { 73 if endpoint.port == 0 {
68 // If user didn't specify port allocate a dynamic port. 74 // If user didn't specify port allocate a dynamic port.
69 endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); 75 endpoint.port = self.stack.borrow_mut().get_local_port();
70 } 76 }
71 77
72 // safety: not accessed reentrantly. 78 match self.with_mut(|s, _| s.bind(endpoint)) {
73 match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } {
74 Ok(()) => Ok(()), 79 Ok(()) => Ok(()),
75 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), 80 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
76 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), 81 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
77 } 82 }
78 } 83 }
79 84
80 /// SAFETY: must not call reentrantly. 85 fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
81 unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { 86 let s = &*self.stack.borrow();
82 let s = &*self.stack.get();
83 let socket = s.sockets.get::<udp::Socket>(self.handle); 87 let socket = s.sockets.get::<udp::Socket>(self.handle);
84 f(socket, &s.iface) 88 f(socket, &s.iface)
85 } 89 }
86 90
87 /// SAFETY: must not call reentrantly. 91 fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
88 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { 92 let s = &mut *self.stack.borrow_mut();
89 let s = &mut *self.stack.get();
90 let socket = s.sockets.get_mut::<udp::Socket>(self.handle); 93 let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
91 let res = f(socket, &mut s.iface); 94 let res = f(socket, &mut s.iface);
92 s.waker.wake(); 95 s.waker.wake();
93 res 96 res
94 } 97 }
95 98
99 /// Receive a datagram.
100 ///
101 /// This method will wait until a datagram is received.
102 ///
103 /// Returns the number of bytes received and the remote endpoint.
96 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { 104 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
97 poll_fn(move |cx| unsafe { 105 poll_fn(move |cx| self.poll_recv_from(buf, cx)).await
98 self.with_mut(|s, _| match s.recv_slice(buf) { 106 }
99 Ok(x) => Poll::Ready(Ok(x)), 107
100 // No data ready 108 /// Receive a datagram.
101 Err(udp::RecvError::Exhausted) => { 109 ///
102 //s.register_recv_waker(cx.waker()); 110 /// When no datagram is available, this method will return `Poll::Pending` and
103 cx.waker().wake_by_ref(); 111 /// register the current task to be notified when a datagram is received.
104 Poll::Pending 112 ///
105 } 113 /// When a datagram is received, this method will return `Poll::Ready` with the
106 }) 114 /// number of bytes received and the remote endpoint.
115 pub fn poll_recv_from(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<(usize, IpEndpoint), Error>> {
116 self.with_mut(|s, _| match s.recv_slice(buf) {
117 Ok((n, meta)) => Poll::Ready(Ok((n, meta.endpoint))),
118 // No data ready
119 Err(udp::RecvError::Exhausted) => {
120 s.register_recv_waker(cx.waker());
121 Poll::Pending
122 }
107 }) 123 })
108 .await
109 } 124 }
110 125
126 /// Send a datagram to the specified remote endpoint.
127 ///
128 /// This method will wait until the datagram has been sent.
129 ///
130 /// When the remote endpoint is not reachable, this method will return `Err(Error::NoRoute)`
111 pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error> 131 pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error>
112 where 132 where
113 T: Into<IpEndpoint>, 133 T: Into<IpEndpoint>,
114 { 134 {
115 let remote_endpoint = remote_endpoint.into(); 135 let remote_endpoint: IpEndpoint = remote_endpoint.into();
116 poll_fn(move |cx| unsafe { 136 poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await
117 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { 137 }
118 // Entire datagram has been sent 138
119 Ok(()) => Poll::Ready(Ok(())), 139 /// Send a datagram to the specified remote endpoint.
120 Err(udp::SendError::BufferFull) => { 140 ///
121 s.register_send_waker(cx.waker()); 141 /// When the datagram has been sent, this method will return `Poll::Ready(Ok())`.
122 Poll::Pending 142 ///
123 } 143 /// When the socket's send buffer is full, this method will return `Poll::Pending`
124 Err(udp::SendError::Unaddressable) => Poll::Ready(Err(Error::NoRoute)), 144 /// and register the current task to be notified when the buffer has space available.
125 }) 145 ///
146 /// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`.
147 pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), Error>>
148 where
149 T: Into<IpEndpoint>,
150 {
151 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
152 // Entire datagram has been sent
153 Ok(()) => Poll::Ready(Ok(())),
154 Err(udp::SendError::BufferFull) => {
155 s.register_send_waker(cx.waker());
156 Poll::Pending
157 }
158 Err(udp::SendError::Unaddressable) => Poll::Ready(Err(Error::NoRoute)),
126 }) 159 })
127 .await
128 } 160 }
129 161
162 /// Returns the local endpoint of the socket.
130 pub fn endpoint(&self) -> IpListenEndpoint { 163 pub fn endpoint(&self) -> IpListenEndpoint {
131 unsafe { self.with(|s, _| s.endpoint()) } 164 self.with(|s, _| s.endpoint())
132 } 165 }
133 166
167 /// Returns whether the socket is open.
168
134 pub fn is_open(&self) -> bool { 169 pub fn is_open(&self) -> bool {
135 unsafe { self.with(|s, _| s.is_open()) } 170 self.with(|s, _| s.is_open())
136 } 171 }
137 172
173 /// Close the socket.
138 pub fn close(&mut self) { 174 pub fn close(&mut self) {
139 unsafe { self.with_mut(|s, _| s.close()) } 175 self.with_mut(|s, _| s.close())
140 } 176 }
141 177
178 /// Returns whether the socket is ready to send data, i.e. it has enough buffer space to hold a packet.
142 pub fn may_send(&self) -> bool { 179 pub fn may_send(&self) -> bool {
143 unsafe { self.with(|s, _| s.can_send()) } 180 self.with(|s, _| s.can_send())
144 } 181 }
145 182
183 /// Returns whether the socket is ready to receive data, i.e. it has received a packet that's now in the buffer.
146 pub fn may_recv(&self) -> bool { 184 pub fn may_recv(&self) -> bool {
147 unsafe { self.with(|s, _| s.can_recv()) } 185 self.with(|s, _| s.can_recv())
148 } 186 }
149} 187}
150 188
151impl Drop for UdpSocket<'_> { 189impl Drop for UdpSocket<'_> {
152 fn drop(&mut self) { 190 fn drop(&mut self) {
153 // safety: not accessed reentrantly. 191 self.stack.borrow_mut().sockets.remove(self.handle);
154 let s = unsafe { &mut *self.stack.get() };
155 s.sockets.remove(self.handle);
156 } 192 }
157} 193}