diff options
| author | Quentin Smith <[email protected]> | 2023-07-17 21:31:43 -0400 |
|---|---|---|
| committer | Quentin Smith <[email protected]> | 2023-07-17 21:31:43 -0400 |
| commit | 6f02403184eb7fb7990fb88fc9df9c4328a690a3 (patch) | |
| tree | 748f510e190bb2724750507a6e69ed1a8e08cb20 /embassy-net/src | |
| parent | d896f80405aa8963877049ed999e4aba25d6e2bb (diff) | |
| parent | 6b5df4523aa1c4902f02e803450ae4b418e0e3ca (diff) | |
Merge remote-tracking branch 'origin/main' into nrf-pdm
Diffstat (limited to 'embassy-net/src')
| -rw-r--r-- | embassy-net/src/device.rs | 189 | ||||
| -rw-r--r-- | embassy-net/src/dns.rs | 107 | ||||
| -rw-r--r-- | embassy-net/src/lib.rs | 727 | ||||
| -rw-r--r-- | embassy-net/src/packet_pool.rs | 107 | ||||
| -rw-r--r-- | embassy-net/src/stack.rs | 316 | ||||
| -rw-r--r-- | embassy-net/src/tcp.rs | 414 | ||||
| -rw-r--r-- | embassy-net/src/time.rs | 20 | ||||
| -rw-r--r-- | embassy-net/src/udp.rs | 138 |
8 files changed, 1269 insertions, 749 deletions
diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs index c183bd58a..4513c86d3 100644 --- a/embassy-net/src/device.rs +++ b/embassy-net/src/device.rs | |||
| @@ -1,129 +1,106 @@ | |||
| 1 | use core::task::Waker; | 1 | use core::task::Context; |
| 2 | 2 | ||
| 3 | use smoltcp::phy::{Device as SmolDevice, DeviceCapabilities}; | 3 | use embassy_net_driver::{Capabilities, Checksum, Driver, Medium, RxToken, TxToken}; |
| 4 | use smoltcp::time::Instant as SmolInstant; | 4 | use smoltcp::phy; |
| 5 | 5 | use smoltcp::time::Instant; | |
| 6 | use crate::packet_pool::PacketBoxExt; | 6 | |
| 7 | use crate::{Packet, PacketBox, PacketBuf}; | 7 | pub(crate) struct DriverAdapter<'d, 'c, T> |
| 8 | 8 | where | |
| 9 | #[derive(PartialEq, Eq, Clone, Copy)] | 9 | T: Driver, |
| 10 | pub enum LinkState { | 10 | { |
| 11 | Down, | 11 | // must be Some when actually using this to rx/tx |
| 12 | Up, | 12 | pub cx: Option<&'d mut Context<'c>>, |
| 13 | } | 13 | pub inner: &'d mut T, |
| 14 | |||
| 15 | // 'static required due to the "fake GAT" in smoltcp::phy::Device. | ||
| 16 | // https://github.com/smoltcp-rs/smoltcp/pull/572 | ||
| 17 | pub trait Device { | ||
| 18 | fn is_transmit_ready(&mut self) -> bool; | ||
| 19 | fn transmit(&mut self, pkt: PacketBuf); | ||
| 20 | fn receive(&mut self) -> Option<PacketBuf>; | ||
| 21 | |||
| 22 | fn register_waker(&mut self, waker: &Waker); | ||
| 23 | fn capabilities(&self) -> DeviceCapabilities; | ||
| 24 | fn link_state(&mut self) -> LinkState; | ||
| 25 | fn ethernet_address(&self) -> [u8; 6]; | ||
| 26 | } | 14 | } |
| 27 | 15 | ||
| 28 | impl<T: ?Sized + Device> Device for &'static mut T { | 16 | impl<'d, 'c, T> phy::Device for DriverAdapter<'d, 'c, T> |
| 29 | fn is_transmit_ready(&mut self) -> bool { | 17 | where |
| 30 | T::is_transmit_ready(self) | 18 | T: Driver, |
| 31 | } | 19 | { |
| 32 | fn transmit(&mut self, pkt: PacketBuf) { | 20 | type RxToken<'a> = RxTokenAdapter<T::RxToken<'a>> where Self: 'a; |
| 33 | T::transmit(self, pkt) | 21 | type TxToken<'a> = TxTokenAdapter<T::TxToken<'a>> where Self: 'a; |
| 34 | } | 22 | |
| 35 | fn receive(&mut self) -> Option<PacketBuf> { | 23 | fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { |
| 36 | T::receive(self) | 24 | self.inner |
| 25 | .receive(self.cx.as_deref_mut().unwrap()) | ||
| 26 | .map(|(rx, tx)| (RxTokenAdapter(rx), TxTokenAdapter(tx))) | ||
| 37 | } | 27 | } |
| 38 | fn register_waker(&mut self, waker: &Waker) { | ||
| 39 | T::register_waker(self, waker) | ||
| 40 | } | ||
| 41 | fn capabilities(&self) -> DeviceCapabilities { | ||
| 42 | T::capabilities(self) | ||
| 43 | } | ||
| 44 | fn link_state(&mut self) -> LinkState { | ||
| 45 | T::link_state(self) | ||
| 46 | } | ||
| 47 | fn ethernet_address(&self) -> [u8; 6] { | ||
| 48 | T::ethernet_address(self) | ||
| 49 | } | ||
| 50 | } | ||
| 51 | |||
| 52 | pub struct DeviceAdapter<D: Device> { | ||
| 53 | pub device: D, | ||
| 54 | caps: DeviceCapabilities, | ||
| 55 | } | ||
| 56 | 28 | ||
| 57 | impl<D: Device> DeviceAdapter<D> { | 29 | /// Construct a transmit token. |
| 58 | pub(crate) fn new(device: D) -> Self { | 30 | fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { |
| 59 | Self { | 31 | self.inner.transmit(self.cx.as_deref_mut().unwrap()).map(TxTokenAdapter) |
| 60 | caps: device.capabilities(), | ||
| 61 | device, | ||
| 62 | } | ||
| 63 | } | 32 | } |
| 64 | } | ||
| 65 | 33 | ||
| 66 | impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> { | 34 | /// Get a description of device capabilities. |
| 67 | type RxToken = RxToken; | 35 | fn capabilities(&self) -> phy::DeviceCapabilities { |
| 68 | type TxToken = TxToken<'a, D>; | 36 | fn convert(c: Checksum) -> phy::Checksum { |
| 69 | 37 | match c { | |
| 70 | fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { | 38 | Checksum::Both => phy::Checksum::Both, |
| 71 | let tx_pkt = PacketBox::new(Packet::new())?; | 39 | Checksum::Tx => phy::Checksum::Tx, |
| 72 | let rx_pkt = self.device.receive()?; | 40 | Checksum::Rx => phy::Checksum::Rx, |
| 73 | let rx_token = RxToken { pkt: rx_pkt }; | 41 | Checksum::None => phy::Checksum::None, |
| 74 | let tx_token = TxToken { | 42 | } |
| 75 | device: &mut self.device, | 43 | } |
| 76 | pkt: tx_pkt, | 44 | let caps: Capabilities = self.inner.capabilities(); |
| 45 | let mut smolcaps = phy::DeviceCapabilities::default(); | ||
| 46 | |||
| 47 | smolcaps.max_transmission_unit = caps.max_transmission_unit; | ||
| 48 | smolcaps.max_burst_size = caps.max_burst_size; | ||
| 49 | smolcaps.medium = match caps.medium { | ||
| 50 | #[cfg(feature = "medium-ethernet")] | ||
| 51 | Medium::Ethernet => phy::Medium::Ethernet, | ||
| 52 | #[cfg(feature = "medium-ip")] | ||
| 53 | Medium::Ip => phy::Medium::Ip, | ||
| 54 | #[allow(unreachable_patterns)] | ||
| 55 | _ => panic!( | ||
| 56 | "Unsupported medium {:?}. Make sure to enable it in embassy-net's Cargo features.", | ||
| 57 | caps.medium | ||
| 58 | ), | ||
| 77 | }; | 59 | }; |
| 78 | 60 | smolcaps.checksum.ipv4 = convert(caps.checksum.ipv4); | |
| 79 | Some((rx_token, tx_token)) | 61 | smolcaps.checksum.tcp = convert(caps.checksum.tcp); |
| 80 | } | 62 | smolcaps.checksum.udp = convert(caps.checksum.udp); |
| 81 | 63 | #[cfg(feature = "proto-ipv4")] | |
| 82 | /// Construct a transmit token. | 64 | { |
| 83 | fn transmit(&'a mut self) -> Option<Self::TxToken> { | 65 | smolcaps.checksum.icmpv4 = convert(caps.checksum.icmpv4); |
| 84 | if !self.device.is_transmit_ready() { | 66 | } |
| 85 | return None; | 67 | #[cfg(feature = "proto-ipv6")] |
| 68 | { | ||
| 69 | smolcaps.checksum.icmpv6 = convert(caps.checksum.icmpv6); | ||
| 86 | } | 70 | } |
| 87 | 71 | ||
| 88 | let tx_pkt = PacketBox::new(Packet::new())?; | 72 | smolcaps |
| 89 | Some(TxToken { | ||
| 90 | device: &mut self.device, | ||
| 91 | pkt: tx_pkt, | ||
| 92 | }) | ||
| 93 | } | ||
| 94 | |||
| 95 | /// Get a description of device capabilities. | ||
| 96 | fn capabilities(&self) -> DeviceCapabilities { | ||
| 97 | self.caps.clone() | ||
| 98 | } | 73 | } |
| 99 | } | 74 | } |
| 100 | 75 | ||
| 101 | pub struct RxToken { | 76 | pub(crate) struct RxTokenAdapter<T>(T) |
| 102 | pkt: PacketBuf, | 77 | where |
| 103 | } | 78 | T: RxToken; |
| 104 | 79 | ||
| 105 | impl smoltcp::phy::RxToken for RxToken { | 80 | impl<T> phy::RxToken for RxTokenAdapter<T> |
| 106 | fn consume<R, F>(mut self, _timestamp: SmolInstant, f: F) -> smoltcp::Result<R> | 81 | where |
| 82 | T: RxToken, | ||
| 83 | { | ||
| 84 | fn consume<R, F>(self, f: F) -> R | ||
| 107 | where | 85 | where |
| 108 | F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, | 86 | F: FnOnce(&mut [u8]) -> R, |
| 109 | { | 87 | { |
| 110 | f(&mut self.pkt) | 88 | self.0.consume(|buf| f(buf)) |
| 111 | } | 89 | } |
| 112 | } | 90 | } |
| 113 | 91 | ||
| 114 | pub struct TxToken<'a, D: Device> { | 92 | pub(crate) struct TxTokenAdapter<T>(T) |
| 115 | device: &'a mut D, | 93 | where |
| 116 | pkt: PacketBox, | 94 | T: TxToken; |
| 117 | } | ||
| 118 | 95 | ||
| 119 | impl<'a, D: Device> smoltcp::phy::TxToken for TxToken<'a, D> { | 96 | impl<T> phy::TxToken for TxTokenAdapter<T> |
| 120 | fn consume<R, F>(self, _timestamp: SmolInstant, len: usize, f: F) -> smoltcp::Result<R> | 97 | where |
| 98 | T: TxToken, | ||
| 99 | { | ||
| 100 | fn consume<R, F>(self, len: usize, f: F) -> R | ||
| 121 | where | 101 | where |
| 122 | F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, | 102 | F: FnOnce(&mut [u8]) -> R, |
| 123 | { | 103 | { |
| 124 | let mut buf = self.pkt.slice(0..len); | 104 | self.0.consume(len, |buf| f(buf)) |
| 125 | let r = f(&mut buf)?; | ||
| 126 | self.device.transmit(buf); | ||
| 127 | Ok(r) | ||
| 128 | } | 105 | } |
| 129 | } | 106 | } |
diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs new file mode 100644 index 000000000..94f75f108 --- /dev/null +++ b/embassy-net/src/dns.rs | |||
| @@ -0,0 +1,107 @@ | |||
| 1 | //! DNS client compatible with the `embedded-nal-async` traits. | ||
| 2 | //! | ||
| 3 | //! This exists only for compatibility with crates that use `embedded-nal-async`. | ||
| 4 | //! Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're | ||
| 5 | //! not using `embedded-nal-async`. | ||
| 6 | |||
| 7 | use heapless::Vec; | ||
| 8 | pub use smoltcp::socket::dns::{DnsQuery, Socket}; | ||
| 9 | pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError}; | ||
| 10 | pub use smoltcp::wire::{DnsQueryType, IpAddress}; | ||
| 11 | |||
| 12 | use crate::{Driver, Stack}; | ||
| 13 | |||
| 14 | /// Errors returned by DnsSocket. | ||
| 15 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||
| 16 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | ||
| 17 | pub enum Error { | ||
| 18 | /// Invalid name | ||
| 19 | InvalidName, | ||
| 20 | /// Name too long | ||
| 21 | NameTooLong, | ||
| 22 | /// Name lookup failed | ||
| 23 | Failed, | ||
| 24 | } | ||
| 25 | |||
| 26 | impl From<GetQueryResultError> for Error { | ||
| 27 | fn from(_: GetQueryResultError) -> Self { | ||
| 28 | Self::Failed | ||
| 29 | } | ||
| 30 | } | ||
| 31 | |||
| 32 | impl From<StartQueryError> for Error { | ||
| 33 | fn from(e: StartQueryError) -> Self { | ||
| 34 | match e { | ||
| 35 | StartQueryError::NoFreeSlot => Self::Failed, | ||
| 36 | StartQueryError::InvalidName => Self::InvalidName, | ||
| 37 | StartQueryError::NameTooLong => Self::NameTooLong, | ||
| 38 | } | ||
| 39 | } | ||
| 40 | } | ||
| 41 | |||
| 42 | /// DNS client compatible with the `embedded-nal-async` traits. | ||
| 43 | /// | ||
| 44 | /// This exists only for compatibility with crates that use `embedded-nal-async`. | ||
| 45 | /// Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're | ||
| 46 | /// not using `embedded-nal-async`. | ||
| 47 | pub struct DnsSocket<'a, D> | ||
| 48 | where | ||
| 49 | D: Driver + 'static, | ||
| 50 | { | ||
| 51 | stack: &'a Stack<D>, | ||
| 52 | } | ||
| 53 | |||
| 54 | impl<'a, D> DnsSocket<'a, D> | ||
| 55 | where | ||
| 56 | D: Driver + 'static, | ||
| 57 | { | ||
| 58 | /// Create a new DNS socket using the provided stack. | ||
| 59 | /// | ||
| 60 | /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated. | ||
| 61 | pub fn new(stack: &'a Stack<D>) -> Self { | ||
| 62 | Self { stack } | ||
| 63 | } | ||
| 64 | |||
| 65 | /// Make a query for a given name and return the corresponding IP addresses. | ||
| 66 | pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, 1>, Error> { | ||
| 67 | self.stack.dns_query(name, qtype).await | ||
| 68 | } | ||
| 69 | } | ||
| 70 | |||
| 71 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] | ||
| 72 | impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D> | ||
| 73 | where | ||
| 74 | D: Driver + 'static, | ||
| 75 | { | ||
| 76 | type Error = Error; | ||
| 77 | |||
| 78 | async fn get_host_by_name( | ||
| 79 | &self, | ||
| 80 | host: &str, | ||
| 81 | addr_type: embedded_nal_async::AddrType, | ||
| 82 | ) -> Result<embedded_nal_async::IpAddr, Self::Error> { | ||
| 83 | use embedded_nal_async::{AddrType, IpAddr}; | ||
| 84 | let qtype = match addr_type { | ||
| 85 | AddrType::IPv6 => DnsQueryType::Aaaa, | ||
| 86 | _ => DnsQueryType::A, | ||
| 87 | }; | ||
| 88 | let addrs = self.query(host, qtype).await?; | ||
| 89 | if let Some(first) = addrs.get(0) { | ||
| 90 | Ok(match first { | ||
| 91 | #[cfg(feature = "proto-ipv4")] | ||
| 92 | IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()), | ||
| 93 | #[cfg(feature = "proto-ipv6")] | ||
| 94 | IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()), | ||
| 95 | }) | ||
| 96 | } else { | ||
| 97 | Err(Error::Failed) | ||
| 98 | } | ||
| 99 | } | ||
| 100 | |||
| 101 | async fn get_host_by_address( | ||
| 102 | &self, | ||
| 103 | _addr: embedded_nal_async::IpAddr, | ||
| 104 | ) -> Result<heapless::String<256>, Self::Error> { | ||
| 105 | todo!() | ||
| 106 | } | ||
| 107 | } | ||
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index 83d364715..0d0a986f6 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs | |||
| @@ -1,31 +1,726 @@ | |||
| 1 | #![cfg_attr(not(feature = "std"), no_std)] | 1 | #![cfg_attr(not(feature = "std"), no_std)] |
| 2 | #![allow(clippy::new_without_default)] | 2 | #![cfg_attr(feature = "nightly", feature(async_fn_in_trait, impl_trait_projections))] |
| 3 | #![feature(generic_associated_types, type_alias_impl_trait)] | 3 | #![warn(missing_docs)] |
| 4 | #![doc = include_str!("../README.md")] | ||
| 4 | 5 | ||
| 5 | // This mod MUST go first, so that the others see its macros. | 6 | // This mod MUST go first, so that the others see its macros. |
| 6 | pub(crate) mod fmt; | 7 | pub(crate) mod fmt; |
| 7 | 8 | ||
| 8 | mod device; | 9 | mod device; |
| 9 | mod packet_pool; | 10 | #[cfg(feature = "dns")] |
| 10 | mod stack; | 11 | pub mod dns; |
| 11 | |||
| 12 | pub use device::{Device, LinkState}; | ||
| 13 | pub use packet_pool::{Packet, PacketBox, PacketBoxExt, PacketBuf, MTU}; | ||
| 14 | pub use stack::{Config, ConfigStrategy, Stack, StackResources}; | ||
| 15 | |||
| 16 | #[cfg(feature = "tcp")] | 12 | #[cfg(feature = "tcp")] |
| 17 | pub mod tcp; | 13 | pub mod tcp; |
| 18 | 14 | mod time; | |
| 19 | #[cfg(feature = "udp")] | 15 | #[cfg(feature = "udp")] |
| 20 | pub mod udp; | 16 | pub mod udp; |
| 21 | 17 | ||
| 22 | // smoltcp reexports | 18 | use core::cell::RefCell; |
| 23 | pub use smoltcp::phy::{DeviceCapabilities, Medium}; | 19 | use core::future::{poll_fn, Future}; |
| 24 | pub use smoltcp::time::{Duration as SmolDuration, Instant as SmolInstant}; | 20 | use core::task::{Context, Poll}; |
| 21 | |||
| 22 | pub use embassy_net_driver as driver; | ||
| 23 | use embassy_net_driver::{Driver, LinkState, Medium}; | ||
| 24 | use embassy_sync::waitqueue::WakerRegistration; | ||
| 25 | use embassy_time::{Instant, Timer}; | ||
| 26 | use futures::pin_mut; | ||
| 27 | use heapless::Vec; | ||
| 28 | #[cfg(feature = "igmp")] | ||
| 29 | pub use smoltcp::iface::MulticastError; | ||
| 30 | use smoltcp::iface::{Interface, SocketHandle, SocketSet, SocketStorage}; | ||
| 31 | #[cfg(feature = "dhcpv4")] | ||
| 32 | use smoltcp::socket::dhcpv4::{self, RetryConfig}; | ||
| 33 | #[cfg(feature = "udp")] | ||
| 34 | pub use smoltcp::wire::IpListenEndpoint; | ||
| 25 | #[cfg(feature = "medium-ethernet")] | 35 | #[cfg(feature = "medium-ethernet")] |
| 26 | pub use smoltcp::wire::{EthernetAddress, HardwareAddress}; | 36 | pub use smoltcp::wire::{EthernetAddress, HardwareAddress}; |
| 27 | pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Cidr}; | 37 | pub use smoltcp::wire::{IpAddress, IpCidr, IpEndpoint}; |
| 38 | #[cfg(feature = "proto-ipv4")] | ||
| 39 | pub use smoltcp::wire::{Ipv4Address, Ipv4Cidr}; | ||
| 28 | #[cfg(feature = "proto-ipv6")] | 40 | #[cfg(feature = "proto-ipv6")] |
| 29 | pub use smoltcp::wire::{Ipv6Address, Ipv6Cidr}; | 41 | pub use smoltcp::wire::{Ipv6Address, Ipv6Cidr}; |
| 30 | #[cfg(feature = "udp")] | 42 | |
| 31 | pub use smoltcp::{socket::udp::PacketMetadata, wire::IpListenEndpoint}; | 43 | use crate::device::DriverAdapter; |
| 44 | use crate::time::{instant_from_smoltcp, instant_to_smoltcp}; | ||
| 45 | |||
| 46 | const LOCAL_PORT_MIN: u16 = 1025; | ||
| 47 | const LOCAL_PORT_MAX: u16 = 65535; | ||
| 48 | #[cfg(feature = "dns")] | ||
| 49 | const MAX_QUERIES: usize = 4; | ||
| 50 | |||
| 51 | /// Memory resources needed for a network stack. | ||
| 52 | pub struct StackResources<const SOCK: usize> { | ||
| 53 | sockets: [SocketStorage<'static>; SOCK], | ||
| 54 | #[cfg(feature = "dns")] | ||
| 55 | queries: [Option<dns::DnsQuery>; MAX_QUERIES], | ||
| 56 | } | ||
| 57 | |||
| 58 | impl<const SOCK: usize> StackResources<SOCK> { | ||
| 59 | /// Create a new set of stack resources. | ||
| 60 | pub const fn new() -> Self { | ||
| 61 | #[cfg(feature = "dns")] | ||
| 62 | const INIT: Option<dns::DnsQuery> = None; | ||
| 63 | Self { | ||
| 64 | sockets: [SocketStorage::EMPTY; SOCK], | ||
| 65 | #[cfg(feature = "dns")] | ||
| 66 | queries: [INIT; MAX_QUERIES], | ||
| 67 | } | ||
| 68 | } | ||
| 69 | } | ||
| 70 | |||
| 71 | /// Static IP address configuration. | ||
| 72 | #[cfg(feature = "proto-ipv4")] | ||
| 73 | #[derive(Debug, Clone, PartialEq, Eq)] | ||
| 74 | pub struct StaticConfigV4 { | ||
| 75 | /// IP address and subnet mask. | ||
| 76 | pub address: Ipv4Cidr, | ||
| 77 | /// Default gateway. | ||
| 78 | pub gateway: Option<Ipv4Address>, | ||
| 79 | /// DNS servers. | ||
| 80 | pub dns_servers: Vec<Ipv4Address, 3>, | ||
| 81 | } | ||
| 82 | |||
| 83 | /// Static IPv6 address configuration | ||
| 84 | #[cfg(feature = "proto-ipv6")] | ||
| 85 | #[derive(Debug, Clone, PartialEq, Eq)] | ||
| 86 | pub struct StaticConfigV6 { | ||
| 87 | /// IP address and subnet mask. | ||
| 88 | pub address: Ipv6Cidr, | ||
| 89 | /// Default gateway. | ||
| 90 | pub gateway: Option<Ipv6Address>, | ||
| 91 | /// DNS servers. | ||
| 92 | pub dns_servers: Vec<Ipv6Address, 3>, | ||
| 93 | } | ||
| 94 | |||
| 95 | /// DHCP configuration. | ||
| 96 | #[cfg(feature = "dhcpv4")] | ||
| 97 | #[derive(Debug, Clone, PartialEq, Eq)] | ||
| 98 | pub struct DhcpConfig { | ||
| 99 | /// Maximum lease duration. | ||
| 100 | /// | ||
| 101 | /// If not set, the lease duration specified by the server will be used. | ||
| 102 | /// If set, the lease duration will be capped at this value. | ||
| 103 | pub max_lease_duration: Option<embassy_time::Duration>, | ||
| 104 | /// Retry configuration. | ||
| 105 | pub retry_config: RetryConfig, | ||
| 106 | /// Ignore NAKs from DHCP servers. | ||
| 107 | /// | ||
| 108 | /// This is not compliant with the DHCP RFCs, since theoretically we must stop using the assigned IP when receiving a NAK. This can increase reliability on broken networks with buggy routers or rogue DHCP servers, however. | ||
| 109 | pub ignore_naks: bool, | ||
| 110 | /// Server port. This is almost always 67. Do not change unless you know what you're doing. | ||
| 111 | pub server_port: u16, | ||
| 112 | /// Client port. This is almost always 68. Do not change unless you know what you're doing. | ||
| 113 | pub client_port: u16, | ||
| 114 | } | ||
| 115 | |||
| 116 | #[cfg(feature = "dhcpv4")] | ||
| 117 | impl Default for DhcpConfig { | ||
| 118 | fn default() -> Self { | ||
| 119 | Self { | ||
| 120 | max_lease_duration: Default::default(), | ||
| 121 | retry_config: Default::default(), | ||
| 122 | ignore_naks: Default::default(), | ||
| 123 | server_port: smoltcp::wire::DHCP_SERVER_PORT, | ||
| 124 | client_port: smoltcp::wire::DHCP_CLIENT_PORT, | ||
| 125 | } | ||
| 126 | } | ||
| 127 | } | ||
| 128 | |||
| 129 | /// Network stack configuration. | ||
| 130 | pub struct Config { | ||
| 131 | /// IPv4 configuration | ||
| 132 | #[cfg(feature = "proto-ipv4")] | ||
| 133 | pub ipv4: ConfigV4, | ||
| 134 | /// IPv6 configuration | ||
| 135 | #[cfg(feature = "proto-ipv6")] | ||
| 136 | pub ipv6: ConfigV6, | ||
| 137 | } | ||
| 138 | |||
| 139 | impl Config { | ||
| 140 | /// IPv4 configuration with static addressing. | ||
| 141 | #[cfg(feature = "proto-ipv4")] | ||
| 142 | pub fn ipv4_static(config: StaticConfigV4) -> Self { | ||
| 143 | Self { | ||
| 144 | ipv4: ConfigV4::Static(config), | ||
| 145 | #[cfg(feature = "proto-ipv6")] | ||
| 146 | ipv6: ConfigV6::None, | ||
| 147 | } | ||
| 148 | } | ||
| 149 | |||
| 150 | /// IPv6 configuration with static addressing. | ||
| 151 | #[cfg(feature = "proto-ipv6")] | ||
| 152 | pub fn ipv6_static(config: StaticConfigV6) -> Self { | ||
| 153 | Self { | ||
| 154 | #[cfg(feature = "proto-ipv4")] | ||
| 155 | ipv4: ConfigV4::None, | ||
| 156 | ipv6: ConfigV6::Static(config), | ||
| 157 | } | ||
| 158 | } | ||
| 159 | |||
| 160 | /// IPv6 configuration with dynamic addressing. | ||
| 161 | /// | ||
| 162 | /// # Example | ||
| 163 | /// ```rust | ||
| 164 | /// let _cfg = Config::dhcpv4(Default::default()); | ||
| 165 | /// ``` | ||
| 166 | #[cfg(feature = "dhcpv4")] | ||
| 167 | pub fn dhcpv4(config: DhcpConfig) -> Self { | ||
| 168 | Self { | ||
| 169 | ipv4: ConfigV4::Dhcp(config), | ||
| 170 | #[cfg(feature = "proto-ipv6")] | ||
| 171 | ipv6: ConfigV6::None, | ||
| 172 | } | ||
| 173 | } | ||
| 174 | } | ||
| 175 | |||
| 176 | /// Network stack IPv4 configuration. | ||
| 177 | #[cfg(feature = "proto-ipv4")] | ||
| 178 | pub enum ConfigV4 { | ||
| 179 | /// Use a static IPv4 address configuration. | ||
| 180 | Static(StaticConfigV4), | ||
| 181 | /// Use DHCP to obtain an IP address configuration. | ||
| 182 | #[cfg(feature = "dhcpv4")] | ||
| 183 | Dhcp(DhcpConfig), | ||
| 184 | /// Do not configure IPv6. | ||
| 185 | None, | ||
| 186 | } | ||
| 187 | |||
| 188 | /// Network stack IPv6 configuration. | ||
| 189 | #[cfg(feature = "proto-ipv6")] | ||
| 190 | pub enum ConfigV6 { | ||
| 191 | /// Use a static IPv6 address configuration. | ||
| 192 | Static(StaticConfigV6), | ||
| 193 | /// Do not configure IPv6. | ||
| 194 | None, | ||
| 195 | } | ||
| 196 | |||
| 197 | /// A network stack. | ||
| 198 | /// | ||
| 199 | /// This is the main entry point for the network stack. | ||
| 200 | pub struct Stack<D: Driver> { | ||
| 201 | pub(crate) socket: RefCell<SocketStack>, | ||
| 202 | inner: RefCell<Inner<D>>, | ||
| 203 | } | ||
| 204 | |||
| 205 | struct Inner<D: Driver> { | ||
| 206 | device: D, | ||
| 207 | link_up: bool, | ||
| 208 | #[cfg(feature = "proto-ipv4")] | ||
| 209 | static_v4: Option<StaticConfigV4>, | ||
| 210 | #[cfg(feature = "proto-ipv6")] | ||
| 211 | static_v6: Option<StaticConfigV6>, | ||
| 212 | #[cfg(feature = "dhcpv4")] | ||
| 213 | dhcp_socket: Option<SocketHandle>, | ||
| 214 | #[cfg(feature = "dns")] | ||
| 215 | dns_socket: SocketHandle, | ||
| 216 | #[cfg(feature = "dns")] | ||
| 217 | dns_waker: WakerRegistration, | ||
| 218 | } | ||
| 219 | |||
| 220 | pub(crate) struct SocketStack { | ||
| 221 | pub(crate) sockets: SocketSet<'static>, | ||
| 222 | pub(crate) iface: Interface, | ||
| 223 | pub(crate) waker: WakerRegistration, | ||
| 224 | next_local_port: u16, | ||
| 225 | } | ||
| 226 | |||
| 227 | impl<D: Driver + 'static> Stack<D> { | ||
| 228 | /// Create a new network stack. | ||
| 229 | pub fn new<const SOCK: usize>( | ||
| 230 | mut device: D, | ||
| 231 | config: Config, | ||
| 232 | resources: &'static mut StackResources<SOCK>, | ||
| 233 | random_seed: u64, | ||
| 234 | ) -> Self { | ||
| 235 | #[cfg(feature = "medium-ethernet")] | ||
| 236 | let medium = device.capabilities().medium; | ||
| 237 | |||
| 238 | let hardware_addr = match medium { | ||
| 239 | #[cfg(feature = "medium-ethernet")] | ||
| 240 | Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress(device.ethernet_address())), | ||
| 241 | #[cfg(feature = "medium-ip")] | ||
| 242 | Medium::Ip => HardwareAddress::Ip, | ||
| 243 | #[allow(unreachable_patterns)] | ||
| 244 | _ => panic!( | ||
| 245 | "Unsupported medium {:?}. Make sure to enable it in embassy-net's Cargo features.", | ||
| 246 | medium | ||
| 247 | ), | ||
| 248 | }; | ||
| 249 | let mut iface_cfg = smoltcp::iface::Config::new(hardware_addr); | ||
| 250 | iface_cfg.random_seed = random_seed; | ||
| 251 | |||
| 252 | let iface = Interface::new( | ||
| 253 | iface_cfg, | ||
| 254 | &mut DriverAdapter { | ||
| 255 | inner: &mut device, | ||
| 256 | cx: None, | ||
| 257 | }, | ||
| 258 | instant_to_smoltcp(Instant::now()), | ||
| 259 | ); | ||
| 260 | |||
| 261 | let sockets = SocketSet::new(&mut resources.sockets[..]); | ||
| 262 | |||
| 263 | let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; | ||
| 264 | |||
| 265 | let mut socket = SocketStack { | ||
| 266 | sockets, | ||
| 267 | iface, | ||
| 268 | waker: WakerRegistration::new(), | ||
| 269 | next_local_port, | ||
| 270 | }; | ||
| 271 | |||
| 272 | let mut inner = Inner { | ||
| 273 | device, | ||
| 274 | link_up: false, | ||
| 275 | #[cfg(feature = "proto-ipv4")] | ||
| 276 | static_v4: None, | ||
| 277 | #[cfg(feature = "proto-ipv6")] | ||
| 278 | static_v6: None, | ||
| 279 | #[cfg(feature = "dhcpv4")] | ||
| 280 | dhcp_socket: None, | ||
| 281 | #[cfg(feature = "dns")] | ||
| 282 | dns_socket: socket.sockets.add(dns::Socket::new( | ||
| 283 | &[], | ||
| 284 | managed::ManagedSlice::Borrowed(&mut resources.queries), | ||
| 285 | )), | ||
| 286 | #[cfg(feature = "dns")] | ||
| 287 | dns_waker: WakerRegistration::new(), | ||
| 288 | }; | ||
| 289 | |||
| 290 | #[cfg(feature = "proto-ipv4")] | ||
| 291 | match config.ipv4 { | ||
| 292 | ConfigV4::Static(config) => { | ||
| 293 | inner.apply_config_v4(&mut socket, config); | ||
| 294 | } | ||
| 295 | #[cfg(feature = "dhcpv4")] | ||
| 296 | ConfigV4::Dhcp(config) => { | ||
| 297 | let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); | ||
| 298 | inner.apply_dhcp_config(&mut dhcp_socket, config); | ||
| 299 | let handle = socket.sockets.add(dhcp_socket); | ||
| 300 | inner.dhcp_socket = Some(handle); | ||
| 301 | } | ||
| 302 | ConfigV4::None => {} | ||
| 303 | } | ||
| 304 | #[cfg(feature = "proto-ipv6")] | ||
| 305 | match config.ipv6 { | ||
| 306 | ConfigV6::Static(config) => { | ||
| 307 | inner.apply_config_v6(&mut socket, config); | ||
| 308 | } | ||
| 309 | ConfigV6::None => {} | ||
| 310 | } | ||
| 311 | |||
| 312 | Self { | ||
| 313 | socket: RefCell::new(socket), | ||
| 314 | inner: RefCell::new(inner), | ||
| 315 | } | ||
| 316 | } | ||
| 317 | |||
| 318 | fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { | ||
| 319 | f(&*self.socket.borrow(), &*self.inner.borrow()) | ||
| 320 | } | ||
| 321 | |||
| 322 | fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { | ||
| 323 | f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut()) | ||
| 324 | } | ||
| 325 | |||
| 326 | /// Get the MAC address of the network interface. | ||
| 327 | pub fn ethernet_address(&self) -> [u8; 6] { | ||
| 328 | self.with(|_s, i| i.device.ethernet_address()) | ||
| 329 | } | ||
| 330 | |||
| 331 | /// Get whether the link is up. | ||
| 332 | pub fn is_link_up(&self) -> bool { | ||
| 333 | self.with(|_s, i| i.link_up) | ||
| 334 | } | ||
| 335 | |||
| 336 | /// Get whether the network stack has a valid IP configuration. | ||
| 337 | /// This is true if the network stack has a static IP configuration or if DHCP has completed | ||
| 338 | pub fn is_config_up(&self) -> bool { | ||
| 339 | let v4_up; | ||
| 340 | let v6_up; | ||
| 341 | |||
| 342 | #[cfg(feature = "proto-ipv4")] | ||
| 343 | { | ||
| 344 | v4_up = self.config_v4().is_some(); | ||
| 345 | } | ||
| 346 | #[cfg(not(feature = "proto-ipv4"))] | ||
| 347 | { | ||
| 348 | v4_up = false; | ||
| 349 | } | ||
| 350 | |||
| 351 | #[cfg(feature = "proto-ipv6")] | ||
| 352 | { | ||
| 353 | v6_up = self.config_v6().is_some(); | ||
| 354 | } | ||
| 355 | #[cfg(not(feature = "proto-ipv6"))] | ||
| 356 | { | ||
| 357 | v6_up = false; | ||
| 358 | } | ||
| 359 | |||
| 360 | v4_up || v6_up | ||
| 361 | } | ||
| 362 | |||
| 363 | /// Get the current IPv4 configuration. | ||
| 364 | #[cfg(feature = "proto-ipv4")] | ||
| 365 | pub fn config_v4(&self) -> Option<StaticConfigV4> { | ||
| 366 | self.with(|_s, i| i.static_v4.clone()) | ||
| 367 | } | ||
| 368 | |||
| 369 | /// Get the current IPv6 configuration. | ||
| 370 | #[cfg(feature = "proto-ipv6")] | ||
| 371 | pub fn config_v6(&self) -> Option<StaticConfigV6> { | ||
| 372 | self.with(|_s, i| i.static_v6.clone()) | ||
| 373 | } | ||
| 374 | |||
| 375 | /// Run the network stack. | ||
| 376 | /// | ||
| 377 | /// You must call this in a background task, to process network events. | ||
| 378 | pub async fn run(&self) -> ! { | ||
| 379 | poll_fn(|cx| { | ||
| 380 | self.with_mut(|s, i| i.poll(cx, s)); | ||
| 381 | Poll::<()>::Pending | ||
| 382 | }) | ||
| 383 | .await; | ||
| 384 | unreachable!() | ||
| 385 | } | ||
| 386 | |||
| 387 | /// Make a query for a given name and return the corresponding IP addresses. | ||
| 388 | #[cfg(feature = "dns")] | ||
| 389 | pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> { | ||
| 390 | // For A and AAAA queries we try detect whether `name` is just an IP address | ||
| 391 | match qtype { | ||
| 392 | #[cfg(feature = "proto-ipv4")] | ||
| 393 | dns::DnsQueryType::A => { | ||
| 394 | if let Ok(ip) = name.parse().map(IpAddress::Ipv4) { | ||
| 395 | return Ok([ip].into_iter().collect()); | ||
| 396 | } | ||
| 397 | } | ||
| 398 | #[cfg(feature = "proto-ipv6")] | ||
| 399 | dns::DnsQueryType::Aaaa => { | ||
| 400 | if let Ok(ip) = name.parse().map(IpAddress::Ipv6) { | ||
| 401 | return Ok([ip].into_iter().collect()); | ||
| 402 | } | ||
| 403 | } | ||
| 404 | _ => {} | ||
| 405 | } | ||
| 406 | |||
| 407 | let query = poll_fn(|cx| { | ||
| 408 | self.with_mut(|s, i| { | ||
| 409 | let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||
| 410 | match socket.start_query(s.iface.context(), name, qtype) { | ||
| 411 | Ok(handle) => Poll::Ready(Ok(handle)), | ||
| 412 | Err(dns::StartQueryError::NoFreeSlot) => { | ||
| 413 | i.dns_waker.register(cx.waker()); | ||
| 414 | Poll::Pending | ||
| 415 | } | ||
| 416 | Err(e) => Poll::Ready(Err(e)), | ||
| 417 | } | ||
| 418 | }) | ||
| 419 | }) | ||
| 420 | .await?; | ||
| 421 | |||
| 422 | #[must_use = "to delay the drop handler invocation to the end of the scope"] | ||
| 423 | struct OnDrop<F: FnOnce()> { | ||
| 424 | f: core::mem::MaybeUninit<F>, | ||
| 425 | } | ||
| 426 | |||
| 427 | impl<F: FnOnce()> OnDrop<F> { | ||
| 428 | fn new(f: F) -> Self { | ||
| 429 | Self { | ||
| 430 | f: core::mem::MaybeUninit::new(f), | ||
| 431 | } | ||
| 432 | } | ||
| 433 | |||
| 434 | fn defuse(self) { | ||
| 435 | core::mem::forget(self) | ||
| 436 | } | ||
| 437 | } | ||
| 438 | |||
| 439 | impl<F: FnOnce()> Drop for OnDrop<F> { | ||
| 440 | fn drop(&mut self) { | ||
| 441 | unsafe { self.f.as_ptr().read()() } | ||
| 442 | } | ||
| 443 | } | ||
| 444 | |||
| 445 | let drop = OnDrop::new(|| { | ||
| 446 | self.with_mut(|s, i| { | ||
| 447 | let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||
| 448 | socket.cancel_query(query); | ||
| 449 | s.waker.wake(); | ||
| 450 | i.dns_waker.wake(); | ||
| 451 | }) | ||
| 452 | }); | ||
| 453 | |||
| 454 | let res = poll_fn(|cx| { | ||
| 455 | self.with_mut(|s, i| { | ||
| 456 | let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||
| 457 | match socket.get_query_result(query) { | ||
| 458 | Ok(addrs) => { | ||
| 459 | i.dns_waker.wake(); | ||
| 460 | Poll::Ready(Ok(addrs)) | ||
| 461 | } | ||
| 462 | Err(dns::GetQueryResultError::Pending) => { | ||
| 463 | socket.register_query_waker(query, cx.waker()); | ||
| 464 | Poll::Pending | ||
| 465 | } | ||
| 466 | Err(e) => { | ||
| 467 | i.dns_waker.wake(); | ||
| 468 | Poll::Ready(Err(e.into())) | ||
| 469 | } | ||
| 470 | } | ||
| 471 | }) | ||
| 472 | }) | ||
| 473 | .await; | ||
| 474 | |||
| 475 | drop.defuse(); | ||
| 476 | |||
| 477 | res | ||
| 478 | } | ||
| 479 | } | ||
| 480 | |||
| 481 | #[cfg(feature = "igmp")] | ||
| 482 | impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> { | ||
| 483 | /// Join a multicast group. | ||
| 484 | pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, MulticastError> | ||
| 485 | where | ||
| 486 | T: Into<IpAddress>, | ||
| 487 | { | ||
| 488 | let addr = addr.into(); | ||
| 489 | |||
| 490 | self.with_mut(|s, i| { | ||
| 491 | s.iface | ||
| 492 | .join_multicast_group(&mut i.device, addr, instant_to_smoltcp(Instant::now())) | ||
| 493 | }) | ||
| 494 | } | ||
| 495 | |||
| 496 | /// Leave a multicast group. | ||
| 497 | pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, MulticastError> | ||
| 498 | where | ||
| 499 | T: Into<IpAddress>, | ||
| 500 | { | ||
| 501 | let addr = addr.into(); | ||
| 502 | |||
| 503 | self.with_mut(|s, i| { | ||
| 504 | s.iface | ||
| 505 | .leave_multicast_group(&mut i.device, addr, instant_to_smoltcp(Instant::now())) | ||
| 506 | }) | ||
| 507 | } | ||
| 508 | |||
| 509 | /// Get whether the network stack has joined the given multicast group. | ||
| 510 | pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { | ||
| 511 | self.socket.borrow().iface.has_multicast_group(addr) | ||
| 512 | } | ||
| 513 | } | ||
| 514 | |||
| 515 | impl SocketStack { | ||
| 516 | #[allow(clippy::absurd_extreme_comparisons, dead_code)] | ||
| 517 | pub fn get_local_port(&mut self) -> u16 { | ||
| 518 | let res = self.next_local_port; | ||
| 519 | self.next_local_port = if res >= LOCAL_PORT_MAX { LOCAL_PORT_MIN } else { res + 1 }; | ||
| 520 | res | ||
| 521 | } | ||
| 522 | } | ||
| 523 | |||
| 524 | impl<D: Driver + 'static> Inner<D> { | ||
| 525 | #[cfg(feature = "proto-ipv4")] | ||
| 526 | fn apply_config_v4(&mut self, s: &mut SocketStack, config: StaticConfigV4) { | ||
| 527 | #[cfg(feature = "medium-ethernet")] | ||
| 528 | let medium = self.device.capabilities().medium; | ||
| 529 | |||
| 530 | debug!("Acquired IP configuration:"); | ||
| 531 | |||
| 532 | debug!(" IP address: {}", config.address); | ||
| 533 | s.iface.update_ip_addrs(|addrs| { | ||
| 534 | if addrs.is_empty() { | ||
| 535 | addrs.push(IpCidr::Ipv4(config.address)).unwrap(); | ||
| 536 | } else { | ||
| 537 | addrs[0] = IpCidr::Ipv4(config.address); | ||
| 538 | } | ||
| 539 | }); | ||
| 540 | |||
| 541 | #[cfg(feature = "medium-ethernet")] | ||
| 542 | if medium == Medium::Ethernet { | ||
| 543 | if let Some(gateway) = config.gateway { | ||
| 544 | debug!(" Default gateway: {}", gateway); | ||
| 545 | s.iface.routes_mut().add_default_ipv4_route(gateway).unwrap(); | ||
| 546 | } else { | ||
| 547 | debug!(" Default gateway: None"); | ||
| 548 | s.iface.routes_mut().remove_default_ipv4_route(); | ||
| 549 | } | ||
| 550 | } | ||
| 551 | for (i, s) in config.dns_servers.iter().enumerate() { | ||
| 552 | debug!(" DNS server {}: {}", i, s); | ||
| 553 | } | ||
| 554 | |||
| 555 | self.static_v4 = Some(config); | ||
| 556 | |||
| 557 | #[cfg(feature = "dns")] | ||
| 558 | { | ||
| 559 | self.update_dns_servers(s) | ||
| 560 | } | ||
| 561 | } | ||
| 562 | |||
| 563 | /// Replaces the current IPv6 static configuration with a newly supplied config. | ||
| 564 | #[cfg(feature = "proto-ipv6")] | ||
| 565 | fn apply_config_v6(&mut self, s: &mut SocketStack, config: StaticConfigV6) { | ||
| 566 | #[cfg(feature = "medium-ethernet")] | ||
| 567 | let medium = self.device.capabilities().medium; | ||
| 568 | |||
| 569 | debug!("Acquired IPv6 configuration:"); | ||
| 570 | |||
| 571 | debug!(" IP address: {}", config.address); | ||
| 572 | s.iface.update_ip_addrs(|addrs| { | ||
| 573 | if addrs.is_empty() { | ||
| 574 | addrs.push(IpCidr::Ipv6(config.address)).unwrap(); | ||
| 575 | } else { | ||
| 576 | addrs[0] = IpCidr::Ipv6(config.address); | ||
| 577 | } | ||
| 578 | }); | ||
| 579 | |||
| 580 | #[cfg(feature = "medium-ethernet")] | ||
| 581 | if Medium::Ethernet == medium { | ||
| 582 | if let Some(gateway) = config.gateway { | ||
| 583 | debug!(" Default gateway: {}", gateway); | ||
| 584 | s.iface.routes_mut().add_default_ipv6_route(gateway).unwrap(); | ||
| 585 | } else { | ||
| 586 | debug!(" Default gateway: None"); | ||
| 587 | s.iface.routes_mut().remove_default_ipv6_route(); | ||
| 588 | } | ||
| 589 | } | ||
| 590 | for (i, s) in config.dns_servers.iter().enumerate() { | ||
| 591 | debug!(" DNS server {}: {}", i, s); | ||
| 592 | } | ||
| 593 | |||
| 594 | self.static_v6 = Some(config); | ||
| 595 | |||
| 596 | #[cfg(feature = "dns")] | ||
| 597 | { | ||
| 598 | self.update_dns_servers(s) | ||
| 599 | } | ||
| 600 | } | ||
| 601 | |||
| 602 | #[cfg(feature = "dns")] | ||
| 603 | fn update_dns_servers(&mut self, s: &mut SocketStack) { | ||
| 604 | let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(self.dns_socket); | ||
| 605 | |||
| 606 | let servers_v4; | ||
| 607 | #[cfg(feature = "proto-ipv4")] | ||
| 608 | { | ||
| 609 | servers_v4 = self | ||
| 610 | .static_v4 | ||
| 611 | .iter() | ||
| 612 | .flat_map(|cfg| cfg.dns_servers.iter().map(|c| IpAddress::Ipv4(*c))); | ||
| 613 | }; | ||
| 614 | #[cfg(not(feature = "proto-ipv4"))] | ||
| 615 | { | ||
| 616 | servers_v4 = core::iter::empty(); | ||
| 617 | } | ||
| 618 | |||
| 619 | let servers_v6; | ||
| 620 | #[cfg(feature = "proto-ipv6")] | ||
| 621 | { | ||
| 622 | servers_v6 = self | ||
| 623 | .static_v6 | ||
| 624 | .iter() | ||
| 625 | .flat_map(|cfg| cfg.dns_servers.iter().map(|c| IpAddress::Ipv6(*c))); | ||
| 626 | } | ||
| 627 | #[cfg(not(feature = "proto-ipv6"))] | ||
| 628 | { | ||
| 629 | servers_v6 = core::iter::empty(); | ||
| 630 | } | ||
| 631 | |||
| 632 | // Prefer the v6 DNS servers over the v4 servers | ||
| 633 | let servers: Vec<IpAddress, 6> = servers_v6.chain(servers_v4).collect(); | ||
| 634 | socket.update_servers(&servers[..]); | ||
| 635 | } | ||
| 636 | |||
| 637 | #[cfg(feature = "dhcpv4")] | ||
| 638 | fn apply_dhcp_config(&self, socket: &mut smoltcp::socket::dhcpv4::Socket, config: DhcpConfig) { | ||
| 639 | socket.set_ignore_naks(config.ignore_naks); | ||
| 640 | socket.set_max_lease_duration(config.max_lease_duration.map(crate::time::duration_to_smoltcp)); | ||
| 641 | socket.set_ports(config.server_port, config.client_port); | ||
| 642 | socket.set_retry_config(config.retry_config); | ||
| 643 | } | ||
| 644 | |||
| 645 | #[allow(unused)] // used only with dhcp | ||
| 646 | fn unapply_config(&mut self, s: &mut SocketStack) { | ||
| 647 | #[cfg(feature = "medium-ethernet")] | ||
| 648 | let medium = self.device.capabilities().medium; | ||
| 649 | |||
| 650 | debug!("Lost IP configuration"); | ||
| 651 | s.iface.update_ip_addrs(|ip_addrs| ip_addrs.clear()); | ||
| 652 | #[cfg(feature = "medium-ethernet")] | ||
| 653 | if medium == Medium::Ethernet { | ||
| 654 | #[cfg(feature = "proto-ipv4")] | ||
| 655 | { | ||
| 656 | s.iface.routes_mut().remove_default_ipv4_route(); | ||
| 657 | } | ||
| 658 | } | ||
| 659 | #[cfg(feature = "proto-ipv4")] | ||
| 660 | { | ||
| 661 | self.static_v4 = None | ||
| 662 | } | ||
| 663 | } | ||
| 664 | |||
| 665 | fn poll(&mut self, cx: &mut Context<'_>, s: &mut SocketStack) { | ||
| 666 | s.waker.register(cx.waker()); | ||
| 667 | |||
| 668 | #[cfg(feature = "medium-ethernet")] | ||
| 669 | if self.device.capabilities().medium == Medium::Ethernet { | ||
| 670 | s.iface.set_hardware_addr(HardwareAddress::Ethernet(EthernetAddress( | ||
| 671 | self.device.ethernet_address(), | ||
| 672 | ))); | ||
| 673 | } | ||
| 674 | |||
| 675 | let timestamp = instant_to_smoltcp(Instant::now()); | ||
| 676 | let mut smoldev = DriverAdapter { | ||
| 677 | cx: Some(cx), | ||
| 678 | inner: &mut self.device, | ||
| 679 | }; | ||
| 680 | s.iface.poll(timestamp, &mut smoldev, &mut s.sockets); | ||
| 681 | |||
| 682 | // Update link up | ||
| 683 | let old_link_up = self.link_up; | ||
| 684 | self.link_up = self.device.link_state(cx) == LinkState::Up; | ||
| 685 | |||
| 686 | // Print when changed | ||
| 687 | if old_link_up != self.link_up { | ||
| 688 | info!("link_up = {:?}", self.link_up); | ||
| 689 | } | ||
| 690 | |||
| 691 | #[cfg(feature = "dhcpv4")] | ||
| 692 | if let Some(dhcp_handle) = self.dhcp_socket { | ||
| 693 | let socket = s.sockets.get_mut::<dhcpv4::Socket>(dhcp_handle); | ||
| 694 | |||
| 695 | if self.link_up { | ||
| 696 | match socket.poll() { | ||
| 697 | None => {} | ||
| 698 | Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), | ||
| 699 | Some(dhcpv4::Event::Configured(config)) => { | ||
| 700 | let config = StaticConfigV4 { | ||
| 701 | address: config.address, | ||
| 702 | gateway: config.router, | ||
| 703 | dns_servers: config.dns_servers, | ||
| 704 | }; | ||
| 705 | self.apply_config_v4(s, config) | ||
| 706 | } | ||
| 707 | } | ||
| 708 | } else if old_link_up { | ||
| 709 | socket.reset(); | ||
| 710 | self.unapply_config(s); | ||
| 711 | } | ||
| 712 | } | ||
| 713 | //if old_link_up || self.link_up { | ||
| 714 | // self.poll_configurator(timestamp) | ||
| 715 | //} | ||
| 716 | // | ||
| 717 | |||
| 718 | if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { | ||
| 719 | let t = Timer::at(instant_from_smoltcp(poll_at)); | ||
| 720 | pin_mut!(t); | ||
| 721 | if t.poll(cx).is_ready() { | ||
| 722 | cx.waker().wake_by_ref(); | ||
| 723 | } | ||
| 724 | } | ||
| 725 | } | ||
| 726 | } | ||
diff --git a/embassy-net/src/packet_pool.rs b/embassy-net/src/packet_pool.rs deleted file mode 100644 index cb8a1316c..000000000 --- a/embassy-net/src/packet_pool.rs +++ /dev/null | |||
| @@ -1,107 +0,0 @@ | |||
| 1 | use core::ops::{Deref, DerefMut, Range}; | ||
| 2 | |||
| 3 | use as_slice::{AsMutSlice, AsSlice}; | ||
| 4 | use atomic_pool::{pool, Box}; | ||
| 5 | |||
| 6 | pub const MTU: usize = 1516; | ||
| 7 | |||
| 8 | #[cfg(feature = "pool-4")] | ||
| 9 | pub const PACKET_POOL_SIZE: usize = 4; | ||
| 10 | |||
| 11 | #[cfg(feature = "pool-8")] | ||
| 12 | pub const PACKET_POOL_SIZE: usize = 8; | ||
| 13 | |||
| 14 | #[cfg(feature = "pool-16")] | ||
| 15 | pub const PACKET_POOL_SIZE: usize = 16; | ||
| 16 | |||
| 17 | #[cfg(feature = "pool-32")] | ||
| 18 | pub const PACKET_POOL_SIZE: usize = 32; | ||
| 19 | |||
| 20 | #[cfg(feature = "pool-64")] | ||
| 21 | pub const PACKET_POOL_SIZE: usize = 64; | ||
| 22 | |||
| 23 | #[cfg(feature = "pool-128")] | ||
| 24 | pub const PACKET_POOL_SIZE: usize = 128; | ||
| 25 | |||
| 26 | pool!(pub PacketPool: [Packet; PACKET_POOL_SIZE]); | ||
| 27 | pub type PacketBox = Box<PacketPool>; | ||
| 28 | |||
| 29 | #[repr(align(4))] | ||
| 30 | pub struct Packet(pub [u8; MTU]); | ||
| 31 | |||
| 32 | impl Packet { | ||
| 33 | pub const fn new() -> Self { | ||
| 34 | Self([0; MTU]) | ||
| 35 | } | ||
| 36 | } | ||
| 37 | |||
| 38 | pub trait PacketBoxExt { | ||
| 39 | fn slice(self, range: Range<usize>) -> PacketBuf; | ||
| 40 | } | ||
| 41 | |||
| 42 | impl PacketBoxExt for PacketBox { | ||
| 43 | fn slice(self, range: Range<usize>) -> PacketBuf { | ||
| 44 | PacketBuf { packet: self, range } | ||
| 45 | } | ||
| 46 | } | ||
| 47 | |||
| 48 | impl AsSlice for Packet { | ||
| 49 | type Element = u8; | ||
| 50 | |||
| 51 | fn as_slice(&self) -> &[Self::Element] { | ||
| 52 | &self.deref()[..] | ||
| 53 | } | ||
| 54 | } | ||
| 55 | |||
| 56 | impl AsMutSlice for Packet { | ||
| 57 | fn as_mut_slice(&mut self) -> &mut [Self::Element] { | ||
| 58 | &mut self.deref_mut()[..] | ||
| 59 | } | ||
| 60 | } | ||
| 61 | |||
| 62 | impl Deref for Packet { | ||
| 63 | type Target = [u8; MTU]; | ||
| 64 | |||
| 65 | fn deref(&self) -> &[u8; MTU] { | ||
| 66 | &self.0 | ||
| 67 | } | ||
| 68 | } | ||
| 69 | |||
| 70 | impl DerefMut for Packet { | ||
| 71 | fn deref_mut(&mut self) -> &mut [u8; MTU] { | ||
| 72 | &mut self.0 | ||
| 73 | } | ||
| 74 | } | ||
| 75 | |||
| 76 | pub struct PacketBuf { | ||
| 77 | packet: PacketBox, | ||
| 78 | range: Range<usize>, | ||
| 79 | } | ||
| 80 | |||
| 81 | impl AsSlice for PacketBuf { | ||
| 82 | type Element = u8; | ||
| 83 | |||
| 84 | fn as_slice(&self) -> &[Self::Element] { | ||
| 85 | &self.packet[self.range.clone()] | ||
| 86 | } | ||
| 87 | } | ||
| 88 | |||
| 89 | impl AsMutSlice for PacketBuf { | ||
| 90 | fn as_mut_slice(&mut self) -> &mut [Self::Element] { | ||
| 91 | &mut self.packet[self.range.clone()] | ||
| 92 | } | ||
| 93 | } | ||
| 94 | |||
| 95 | impl Deref for PacketBuf { | ||
| 96 | type Target = [u8]; | ||
| 97 | |||
| 98 | fn deref(&self) -> &[u8] { | ||
| 99 | &self.packet[self.range.clone()] | ||
| 100 | } | ||
| 101 | } | ||
| 102 | |||
| 103 | impl DerefMut for PacketBuf { | ||
| 104 | fn deref_mut(&mut self) -> &mut [u8] { | ||
| 105 | &mut self.packet[self.range.clone()] | ||
| 106 | } | ||
| 107 | } | ||
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs deleted file mode 100644 index 8d2dd4bca..000000000 --- a/embassy-net/src/stack.rs +++ /dev/null | |||
| @@ -1,316 +0,0 @@ | |||
| 1 | use core::cell::UnsafeCell; | ||
| 2 | use core::future::Future; | ||
| 3 | use core::task::{Context, Poll}; | ||
| 4 | |||
| 5 | use embassy_sync::waitqueue::WakerRegistration; | ||
| 6 | use embassy_time::{Instant, Timer}; | ||
| 7 | use futures::future::poll_fn; | ||
| 8 | use futures::pin_mut; | ||
| 9 | use heapless::Vec; | ||
| 10 | #[cfg(feature = "dhcpv4")] | ||
| 11 | use smoltcp::iface::SocketHandle; | ||
| 12 | use smoltcp::iface::{Interface, InterfaceBuilder, SocketSet, SocketStorage}; | ||
| 13 | #[cfg(feature = "medium-ethernet")] | ||
| 14 | use smoltcp::iface::{Neighbor, NeighborCache, Route, Routes}; | ||
| 15 | #[cfg(feature = "medium-ethernet")] | ||
| 16 | use smoltcp::phy::{Device as _, Medium}; | ||
| 17 | #[cfg(feature = "dhcpv4")] | ||
| 18 | use smoltcp::socket::dhcpv4; | ||
| 19 | use smoltcp::time::Instant as SmolInstant; | ||
| 20 | #[cfg(feature = "medium-ethernet")] | ||
| 21 | use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress}; | ||
| 22 | use smoltcp::wire::{IpCidr, Ipv4Address, Ipv4Cidr}; | ||
| 23 | |||
| 24 | use crate::device::{Device, DeviceAdapter, LinkState}; | ||
| 25 | |||
| 26 | const LOCAL_PORT_MIN: u16 = 1025; | ||
| 27 | const LOCAL_PORT_MAX: u16 = 65535; | ||
| 28 | |||
| 29 | pub struct StackResources<const ADDR: usize, const SOCK: usize, const NEIGHBOR: usize> { | ||
| 30 | addresses: [IpCidr; ADDR], | ||
| 31 | sockets: [SocketStorage<'static>; SOCK], | ||
| 32 | |||
| 33 | #[cfg(feature = "medium-ethernet")] | ||
| 34 | routes: [Option<(IpCidr, Route)>; 1], | ||
| 35 | #[cfg(feature = "medium-ethernet")] | ||
| 36 | neighbor_cache: [Option<(IpAddress, Neighbor)>; NEIGHBOR], | ||
| 37 | } | ||
| 38 | |||
| 39 | impl<const ADDR: usize, const SOCK: usize, const NEIGHBOR: usize> StackResources<ADDR, SOCK, NEIGHBOR> { | ||
| 40 | pub fn new() -> Self { | ||
| 41 | Self { | ||
| 42 | addresses: [IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 32); ADDR], | ||
| 43 | sockets: [SocketStorage::EMPTY; SOCK], | ||
| 44 | #[cfg(feature = "medium-ethernet")] | ||
| 45 | routes: [None; 1], | ||
| 46 | #[cfg(feature = "medium-ethernet")] | ||
| 47 | neighbor_cache: [None; NEIGHBOR], | ||
| 48 | } | ||
| 49 | } | ||
| 50 | } | ||
| 51 | |||
| 52 | #[derive(Debug, Clone, PartialEq, Eq)] | ||
| 53 | pub struct Config { | ||
| 54 | pub address: Ipv4Cidr, | ||
| 55 | pub gateway: Option<Ipv4Address>, | ||
| 56 | pub dns_servers: Vec<Ipv4Address, 3>, | ||
| 57 | } | ||
| 58 | |||
| 59 | pub enum ConfigStrategy { | ||
| 60 | Static(Config), | ||
| 61 | #[cfg(feature = "dhcpv4")] | ||
| 62 | Dhcp, | ||
| 63 | } | ||
| 64 | |||
| 65 | pub struct Stack<D: Device> { | ||
| 66 | pub(crate) socket: UnsafeCell<SocketStack>, | ||
| 67 | inner: UnsafeCell<Inner<D>>, | ||
| 68 | } | ||
| 69 | |||
| 70 | struct Inner<D: Device> { | ||
| 71 | device: DeviceAdapter<D>, | ||
| 72 | link_up: bool, | ||
| 73 | config: Option<Config>, | ||
| 74 | #[cfg(feature = "dhcpv4")] | ||
| 75 | dhcp_socket: Option<SocketHandle>, | ||
| 76 | } | ||
| 77 | |||
| 78 | pub(crate) struct SocketStack { | ||
| 79 | pub(crate) sockets: SocketSet<'static>, | ||
| 80 | pub(crate) iface: Interface<'static>, | ||
| 81 | pub(crate) waker: WakerRegistration, | ||
| 82 | next_local_port: u16, | ||
| 83 | } | ||
| 84 | |||
| 85 | unsafe impl<D: Device> Send for Stack<D> {} | ||
| 86 | |||
| 87 | impl<D: Device + 'static> Stack<D> { | ||
| 88 | pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( | ||
| 89 | device: D, | ||
| 90 | config: ConfigStrategy, | ||
| 91 | resources: &'static mut StackResources<ADDR, SOCK, NEIGH>, | ||
| 92 | random_seed: u64, | ||
| 93 | ) -> Self { | ||
| 94 | #[cfg(feature = "medium-ethernet")] | ||
| 95 | let medium = device.capabilities().medium; | ||
| 96 | |||
| 97 | #[cfg(feature = "medium-ethernet")] | ||
| 98 | let ethernet_addr = if medium == Medium::Ethernet { | ||
| 99 | device.ethernet_address() | ||
| 100 | } else { | ||
| 101 | [0, 0, 0, 0, 0, 0] | ||
| 102 | }; | ||
| 103 | |||
| 104 | let mut device = DeviceAdapter::new(device); | ||
| 105 | |||
| 106 | let mut b = InterfaceBuilder::new(); | ||
| 107 | b = b.ip_addrs(&mut resources.addresses[..]); | ||
| 108 | b = b.random_seed(random_seed); | ||
| 109 | |||
| 110 | #[cfg(feature = "medium-ethernet")] | ||
| 111 | if medium == Medium::Ethernet { | ||
| 112 | b = b.hardware_addr(HardwareAddress::Ethernet(EthernetAddress(ethernet_addr))); | ||
| 113 | b = b.neighbor_cache(NeighborCache::new(&mut resources.neighbor_cache[..])); | ||
| 114 | b = b.routes(Routes::new(&mut resources.routes[..])); | ||
| 115 | } | ||
| 116 | |||
| 117 | let iface = b.finalize(&mut device); | ||
| 118 | |||
| 119 | let sockets = SocketSet::new(&mut resources.sockets[..]); | ||
| 120 | |||
| 121 | let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; | ||
| 122 | |||
| 123 | let mut inner = Inner { | ||
| 124 | device, | ||
| 125 | link_up: false, | ||
| 126 | config: None, | ||
| 127 | #[cfg(feature = "dhcpv4")] | ||
| 128 | dhcp_socket: None, | ||
| 129 | }; | ||
| 130 | let mut socket = SocketStack { | ||
| 131 | sockets, | ||
| 132 | iface, | ||
| 133 | waker: WakerRegistration::new(), | ||
| 134 | next_local_port, | ||
| 135 | }; | ||
| 136 | |||
| 137 | match config { | ||
| 138 | ConfigStrategy::Static(config) => inner.apply_config(&mut socket, config), | ||
| 139 | #[cfg(feature = "dhcpv4")] | ||
| 140 | ConfigStrategy::Dhcp => { | ||
| 141 | let handle = socket.sockets.add(smoltcp::socket::dhcpv4::Socket::new()); | ||
| 142 | inner.dhcp_socket = Some(handle); | ||
| 143 | } | ||
| 144 | } | ||
| 145 | |||
| 146 | Self { | ||
| 147 | socket: UnsafeCell::new(socket), | ||
| 148 | inner: UnsafeCell::new(inner), | ||
| 149 | } | ||
| 150 | } | ||
| 151 | |||
| 152 | /// SAFETY: must not call reentrantly. | ||
| 153 | unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { | ||
| 154 | f(&*self.socket.get(), &*self.inner.get()) | ||
| 155 | } | ||
| 156 | |||
| 157 | /// SAFETY: must not call reentrantly. | ||
| 158 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { | ||
| 159 | f(&mut *self.socket.get(), &mut *self.inner.get()) | ||
| 160 | } | ||
| 161 | |||
| 162 | pub fn ethernet_address(&self) -> [u8; 6] { | ||
| 163 | unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } | ||
| 164 | } | ||
| 165 | |||
| 166 | pub fn is_link_up(&self) -> bool { | ||
| 167 | unsafe { self.with(|_s, i| i.link_up) } | ||
| 168 | } | ||
| 169 | |||
| 170 | pub fn is_config_up(&self) -> bool { | ||
| 171 | unsafe { self.with(|_s, i| i.config.is_some()) } | ||
| 172 | } | ||
| 173 | |||
| 174 | pub fn config(&self) -> Option<Config> { | ||
| 175 | unsafe { self.with(|_s, i| i.config.clone()) } | ||
| 176 | } | ||
| 177 | |||
| 178 | pub async fn run(&self) -> ! { | ||
| 179 | poll_fn(|cx| { | ||
| 180 | unsafe { self.with_mut(|s, i| i.poll(cx, s)) } | ||
| 181 | Poll::<()>::Pending | ||
| 182 | }) | ||
| 183 | .await; | ||
| 184 | unreachable!() | ||
| 185 | } | ||
| 186 | } | ||
| 187 | |||
| 188 | impl SocketStack { | ||
| 189 | #[allow(clippy::absurd_extreme_comparisons)] | ||
| 190 | pub fn get_local_port(&mut self) -> u16 { | ||
| 191 | let res = self.next_local_port; | ||
| 192 | self.next_local_port = if res >= LOCAL_PORT_MAX { LOCAL_PORT_MIN } else { res + 1 }; | ||
| 193 | res | ||
| 194 | } | ||
| 195 | } | ||
| 196 | |||
| 197 | impl<D: Device + 'static> Inner<D> { | ||
| 198 | fn apply_config(&mut self, s: &mut SocketStack, config: Config) { | ||
| 199 | #[cfg(feature = "medium-ethernet")] | ||
| 200 | let medium = self.device.capabilities().medium; | ||
| 201 | |||
| 202 | debug!("Acquired IP configuration:"); | ||
| 203 | |||
| 204 | debug!(" IP address: {}", config.address); | ||
| 205 | self.set_ipv4_addr(s, config.address); | ||
| 206 | |||
| 207 | #[cfg(feature = "medium-ethernet")] | ||
| 208 | if medium == Medium::Ethernet { | ||
| 209 | if let Some(gateway) = config.gateway { | ||
| 210 | debug!(" Default gateway: {}", gateway); | ||
| 211 | s.iface.routes_mut().add_default_ipv4_route(gateway).unwrap(); | ||
| 212 | } else { | ||
| 213 | debug!(" Default gateway: None"); | ||
| 214 | s.iface.routes_mut().remove_default_ipv4_route(); | ||
| 215 | } | ||
| 216 | } | ||
| 217 | for (i, s) in config.dns_servers.iter().enumerate() { | ||
| 218 | debug!(" DNS server {}: {}", i, s); | ||
| 219 | } | ||
| 220 | |||
| 221 | self.config = Some(config) | ||
| 222 | } | ||
| 223 | |||
| 224 | #[allow(unused)] // used only with dhcp | ||
| 225 | fn unapply_config(&mut self, s: &mut SocketStack) { | ||
| 226 | #[cfg(feature = "medium-ethernet")] | ||
| 227 | let medium = self.device.capabilities().medium; | ||
| 228 | |||
| 229 | debug!("Lost IP configuration"); | ||
| 230 | self.set_ipv4_addr(s, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)); | ||
| 231 | #[cfg(feature = "medium-ethernet")] | ||
| 232 | if medium == Medium::Ethernet { | ||
| 233 | s.iface.routes_mut().remove_default_ipv4_route(); | ||
| 234 | } | ||
| 235 | self.config = None | ||
| 236 | } | ||
| 237 | |||
| 238 | fn set_ipv4_addr(&mut self, s: &mut SocketStack, cidr: Ipv4Cidr) { | ||
| 239 | s.iface.update_ip_addrs(|addrs| { | ||
| 240 | let dest = addrs.iter_mut().next().unwrap(); | ||
| 241 | *dest = IpCidr::Ipv4(cidr); | ||
| 242 | }); | ||
| 243 | } | ||
| 244 | |||
| 245 | fn poll(&mut self, cx: &mut Context<'_>, s: &mut SocketStack) { | ||
| 246 | self.device.device.register_waker(cx.waker()); | ||
| 247 | s.waker.register(cx.waker()); | ||
| 248 | |||
| 249 | let timestamp = instant_to_smoltcp(Instant::now()); | ||
| 250 | if s.iface.poll(timestamp, &mut self.device, &mut s.sockets).is_err() { | ||
| 251 | // If poll() returns error, it may not be done yet, so poll again later. | ||
| 252 | cx.waker().wake_by_ref(); | ||
| 253 | return; | ||
| 254 | } | ||
| 255 | |||
| 256 | // Update link up | ||
| 257 | let old_link_up = self.link_up; | ||
| 258 | self.link_up = self.device.device.link_state() == LinkState::Up; | ||
| 259 | |||
| 260 | // Print when changed | ||
| 261 | if old_link_up != self.link_up { | ||
| 262 | info!("link_up = {:?}", self.link_up); | ||
| 263 | } | ||
| 264 | |||
| 265 | #[cfg(feature = "dhcpv4")] | ||
| 266 | if let Some(dhcp_handle) = self.dhcp_socket { | ||
| 267 | let socket = s.sockets.get_mut::<dhcpv4::Socket>(dhcp_handle); | ||
| 268 | |||
| 269 | if self.link_up { | ||
| 270 | match socket.poll() { | ||
| 271 | None => {} | ||
| 272 | Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), | ||
| 273 | Some(dhcpv4::Event::Configured(config)) => { | ||
| 274 | let mut dns_servers = Vec::new(); | ||
| 275 | for s in &config.dns_servers { | ||
| 276 | if let Some(addr) = s { | ||
| 277 | dns_servers.push(addr.clone()).unwrap(); | ||
| 278 | } | ||
| 279 | } | ||
| 280 | |||
| 281 | self.apply_config( | ||
| 282 | s, | ||
| 283 | Config { | ||
| 284 | address: config.address, | ||
| 285 | gateway: config.router, | ||
| 286 | dns_servers, | ||
| 287 | }, | ||
| 288 | ) | ||
| 289 | } | ||
| 290 | } | ||
| 291 | } else if old_link_up { | ||
| 292 | socket.reset(); | ||
| 293 | self.unapply_config(s); | ||
| 294 | } | ||
| 295 | } | ||
| 296 | //if old_link_up || self.link_up { | ||
| 297 | // self.poll_configurator(timestamp) | ||
| 298 | //} | ||
| 299 | |||
| 300 | if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { | ||
| 301 | let t = Timer::at(instant_from_smoltcp(poll_at)); | ||
| 302 | pin_mut!(t); | ||
| 303 | if t.poll(cx).is_ready() { | ||
| 304 | cx.waker().wake_by_ref(); | ||
| 305 | } | ||
| 306 | } | ||
| 307 | } | ||
| 308 | } | ||
| 309 | |||
| 310 | fn instant_to_smoltcp(instant: Instant) -> SmolInstant { | ||
| 311 | SmolInstant::from_millis(instant.as_millis() as i64) | ||
| 312 | } | ||
| 313 | |||
| 314 | fn instant_from_smoltcp(instant: SmolInstant) -> Instant { | ||
| 315 | Instant::from_millis(instant.total_millis() as u64) | ||
| 316 | } | ||
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index 910772c7d..367675b13 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs | |||
| @@ -1,24 +1,39 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | //! TCP sockets. |
| 2 | use core::future::Future; | 2 | //! |
| 3 | //! # Listening | ||
| 4 | //! | ||
| 5 | //! `embassy-net` does not have a `TcpListener`. Instead, individual `TcpSocket`s can be put into | ||
| 6 | //! listening mode by calling [`TcpSocket::accept`]. | ||
| 7 | //! | ||
| 8 | //! Incoming connections when no socket is listening are rejected. To accept many incoming | ||
| 9 | //! connections, create many sockets and put them all into listening mode. | ||
| 10 | |||
| 11 | use core::cell::RefCell; | ||
| 12 | use core::future::poll_fn; | ||
| 3 | use core::mem; | 13 | use core::mem; |
| 4 | use core::task::Poll; | 14 | use core::task::Poll; |
| 5 | 15 | ||
| 6 | use futures::future::poll_fn; | 16 | use embassy_net_driver::Driver; |
| 17 | use embassy_time::Duration; | ||
| 7 | use smoltcp::iface::{Interface, SocketHandle}; | 18 | use smoltcp::iface::{Interface, SocketHandle}; |
| 8 | use smoltcp::socket::tcp; | 19 | use smoltcp::socket::tcp; |
| 9 | use smoltcp::time::Duration; | 20 | pub use smoltcp::socket::tcp::State; |
| 10 | use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; | 21 | use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; |
| 11 | 22 | ||
| 12 | use super::stack::Stack; | 23 | use crate::time::duration_to_smoltcp; |
| 13 | use crate::stack::SocketStack; | 24 | use crate::{SocketStack, Stack}; |
| 14 | use crate::Device; | ||
| 15 | 25 | ||
| 26 | /// Error returned by TcpSocket read/write functions. | ||
| 16 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] | 27 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] |
| 17 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | 28 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] |
| 18 | pub enum Error { | 29 | pub enum Error { |
| 30 | /// The connection was reset. | ||
| 31 | /// | ||
| 32 | /// This can happen on receiving a RST packet, or on timeout. | ||
| 19 | ConnectionReset, | 33 | ConnectionReset, |
| 20 | } | 34 | } |
| 21 | 35 | ||
| 36 | /// Error returned by [`TcpSocket::connect`]. | ||
| 22 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] | 37 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] |
| 23 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | 38 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] |
| 24 | pub enum ConnectError { | 39 | pub enum ConnectError { |
| @@ -32,6 +47,7 @@ pub enum ConnectError { | |||
| 32 | NoRoute, | 47 | NoRoute, |
| 33 | } | 48 | } |
| 34 | 49 | ||
| 50 | /// Error returned by [`TcpSocket::accept`]. | ||
| 35 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] | 51 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] |
| 36 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | 52 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] |
| 37 | pub enum AcceptError { | 53 | pub enum AcceptError { |
| @@ -43,22 +59,53 @@ pub enum AcceptError { | |||
| 43 | ConnectionReset, | 59 | ConnectionReset, |
| 44 | } | 60 | } |
| 45 | 61 | ||
| 62 | /// A TCP socket. | ||
| 46 | pub struct TcpSocket<'a> { | 63 | pub struct TcpSocket<'a> { |
| 47 | io: TcpIo<'a>, | 64 | io: TcpIo<'a>, |
| 48 | } | 65 | } |
| 49 | 66 | ||
| 67 | /// The reader half of a TCP socket. | ||
| 50 | pub struct TcpReader<'a> { | 68 | pub struct TcpReader<'a> { |
| 51 | io: TcpIo<'a>, | 69 | io: TcpIo<'a>, |
| 52 | } | 70 | } |
| 53 | 71 | ||
| 72 | /// The writer half of a TCP socket. | ||
| 54 | pub struct TcpWriter<'a> { | 73 | pub struct TcpWriter<'a> { |
| 55 | io: TcpIo<'a>, | 74 | io: TcpIo<'a>, |
| 56 | } | 75 | } |
| 57 | 76 | ||
| 77 | impl<'a> TcpReader<'a> { | ||
| 78 | /// Read data from the socket. | ||
| 79 | /// | ||
| 80 | /// Returns how many bytes were read, or an error. If no data is available, it waits | ||
| 81 | /// until there is at least one byte available. | ||
| 82 | pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { | ||
| 83 | self.io.read(buf).await | ||
| 84 | } | ||
| 85 | } | ||
| 86 | |||
| 87 | impl<'a> TcpWriter<'a> { | ||
| 88 | /// Write data to the socket. | ||
| 89 | /// | ||
| 90 | /// Returns how many bytes were written, or an error. If the socket is not ready to | ||
| 91 | /// accept data, it waits until it is. | ||
| 92 | pub async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { | ||
| 93 | self.io.write(buf).await | ||
| 94 | } | ||
| 95 | |||
| 96 | /// Flushes the written data to the socket. | ||
| 97 | /// | ||
| 98 | /// This waits until all data has been sent, and ACKed by the remote host. For a connection | ||
| 99 | /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent. | ||
| 100 | pub async fn flush(&mut self) -> Result<(), Error> { | ||
| 101 | self.io.flush().await | ||
| 102 | } | ||
| 103 | } | ||
| 104 | |||
| 58 | impl<'a> TcpSocket<'a> { | 105 | impl<'a> TcpSocket<'a> { |
| 59 | pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { | 106 | /// Create a new TCP socket on the given stack, with the given buffers. |
| 60 | // safety: not accessed reentrantly. | 107 | pub fn new<D: Driver>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { |
| 61 | let s = unsafe { &mut *stack.socket.get() }; | 108 | let s = &mut *stack.socket.borrow_mut(); |
| 62 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 109 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| 63 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; | 110 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; |
| 64 | let handle = s.sockets.add(tcp::Socket::new( | 111 | let handle = s.sockets.add(tcp::Socket::new( |
| @@ -74,25 +121,28 @@ impl<'a> TcpSocket<'a> { | |||
| 74 | } | 121 | } |
| 75 | } | 122 | } |
| 76 | 123 | ||
| 124 | /// Split the socket into reader and a writer halves. | ||
| 77 | pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { | 125 | pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { |
| 78 | (TcpReader { io: self.io }, TcpWriter { io: self.io }) | 126 | (TcpReader { io: self.io }, TcpWriter { io: self.io }) |
| 79 | } | 127 | } |
| 80 | 128 | ||
| 129 | /// Connect to a remote host. | ||
| 81 | pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> | 130 | pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> |
| 82 | where | 131 | where |
| 83 | T: Into<IpEndpoint>, | 132 | T: Into<IpEndpoint>, |
| 84 | { | 133 | { |
| 85 | // safety: not accessed reentrantly. | 134 | let local_port = self.io.stack.borrow_mut().get_local_port(); |
| 86 | let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port(); | ||
| 87 | 135 | ||
| 88 | // safety: not accessed reentrantly. | 136 | match { |
| 89 | match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { | 137 | self.io |
| 138 | .with_mut(|s, i| s.connect(i.context(), remote_endpoint, local_port)) | ||
| 139 | } { | ||
| 90 | Ok(()) => {} | 140 | Ok(()) => {} |
| 91 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), | 141 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), |
| 92 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), | 142 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), |
| 93 | } | 143 | } |
| 94 | 144 | ||
| 95 | futures::future::poll_fn(|cx| unsafe { | 145 | poll_fn(|cx| { |
| 96 | self.io.with_mut(|s, _| match s.state() { | 146 | self.io.with_mut(|s, _| match s.state() { |
| 97 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), | 147 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), |
| 98 | tcp::State::Listen => unreachable!(), | 148 | tcp::State::Listen => unreachable!(), |
| @@ -106,18 +156,20 @@ impl<'a> TcpSocket<'a> { | |||
| 106 | .await | 156 | .await |
| 107 | } | 157 | } |
| 108 | 158 | ||
| 159 | /// Accept a connection from a remote host. | ||
| 160 | /// | ||
| 161 | /// This function puts the socket in listening mode, and waits until a connection is received. | ||
| 109 | pub async fn accept<T>(&mut self, local_endpoint: T) -> Result<(), AcceptError> | 162 | pub async fn accept<T>(&mut self, local_endpoint: T) -> Result<(), AcceptError> |
| 110 | where | 163 | where |
| 111 | T: Into<IpListenEndpoint>, | 164 | T: Into<IpListenEndpoint>, |
| 112 | { | 165 | { |
| 113 | // safety: not accessed reentrantly. | 166 | match self.io.with_mut(|s, _| s.listen(local_endpoint)) { |
| 114 | match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } { | ||
| 115 | Ok(()) => {} | 167 | Ok(()) => {} |
| 116 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), | 168 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), |
| 117 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), | 169 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), |
| 118 | } | 170 | } |
| 119 | 171 | ||
| 120 | futures::future::poll_fn(|cx| unsafe { | 172 | poll_fn(|cx| { |
| 121 | self.io.with_mut(|s, _| match s.state() { | 173 | self.io.with_mut(|s, _| match s.state() { |
| 122 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { | 174 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { |
| 123 | s.register_send_waker(cx.waker()); | 175 | s.register_send_waker(cx.waker()); |
| @@ -129,52 +181,120 @@ impl<'a> TcpSocket<'a> { | |||
| 129 | .await | 181 | .await |
| 130 | } | 182 | } |
| 131 | 183 | ||
| 184 | /// Read data from the socket. | ||
| 185 | /// | ||
| 186 | /// Returns how many bytes were read, or an error. If no data is available, it waits | ||
| 187 | /// until there is at least one byte available. | ||
| 188 | pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { | ||
| 189 | self.io.read(buf).await | ||
| 190 | } | ||
| 191 | |||
| 192 | /// Write data to the socket. | ||
| 193 | /// | ||
| 194 | /// Returns how many bytes were written, or an error. If the socket is not ready to | ||
| 195 | /// accept data, it waits until it is. | ||
| 196 | pub async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { | ||
| 197 | self.io.write(buf).await | ||
| 198 | } | ||
| 199 | |||
| 200 | /// Flushes the written data to the socket. | ||
| 201 | /// | ||
| 202 | /// This waits until all data has been sent, and ACKed by the remote host. For a connection | ||
| 203 | /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent. | ||
| 204 | pub async fn flush(&mut self) -> Result<(), Error> { | ||
| 205 | self.io.flush().await | ||
| 206 | } | ||
| 207 | |||
| 208 | /// Set the timeout for the socket. | ||
| 209 | /// | ||
| 210 | /// If the timeout is set, the socket will be closed if no data is received for the | ||
| 211 | /// specified duration. | ||
| 132 | pub fn set_timeout(&mut self, duration: Option<Duration>) { | 212 | pub fn set_timeout(&mut self, duration: Option<Duration>) { |
| 133 | unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } | 213 | self.io |
| 214 | .with_mut(|s, _| s.set_timeout(duration.map(duration_to_smoltcp))) | ||
| 134 | } | 215 | } |
| 135 | 216 | ||
| 217 | /// Set the keep-alive interval for the socket. | ||
| 218 | /// | ||
| 219 | /// If the keep-alive interval is set, the socket will send keep-alive packets after | ||
| 220 | /// the specified duration of inactivity. | ||
| 221 | /// | ||
| 222 | /// If not set, the socket will not send keep-alive packets. | ||
| 136 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { | 223 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { |
| 137 | unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } | 224 | self.io |
| 225 | .with_mut(|s, _| s.set_keep_alive(interval.map(duration_to_smoltcp))) | ||
| 138 | } | 226 | } |
| 139 | 227 | ||
| 228 | /// Set the hop limit field in the IP header of sent packets. | ||
| 140 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { | 229 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { |
| 141 | unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } | 230 | self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) |
| 142 | } | 231 | } |
| 143 | 232 | ||
| 233 | /// Get the local endpoint of the socket. | ||
| 234 | /// | ||
| 235 | /// Returns `None` if the socket is not bound (listening) or not connected. | ||
| 144 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { | 236 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { |
| 145 | unsafe { self.io.with(|s, _| s.local_endpoint()) } | 237 | self.io.with(|s, _| s.local_endpoint()) |
| 146 | } | 238 | } |
| 147 | 239 | ||
| 240 | /// Get the remote endpoint of the socket. | ||
| 241 | /// | ||
| 242 | /// Returns `None` if the socket is not connected. | ||
| 148 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { | 243 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { |
| 149 | unsafe { self.io.with(|s, _| s.remote_endpoint()) } | 244 | self.io.with(|s, _| s.remote_endpoint()) |
| 150 | } | 245 | } |
| 151 | 246 | ||
| 152 | pub fn state(&self) -> tcp::State { | 247 | /// Get the state of the socket. |
| 153 | unsafe { self.io.with(|s, _| s.state()) } | 248 | pub fn state(&self) -> State { |
| 249 | self.io.with(|s, _| s.state()) | ||
| 154 | } | 250 | } |
| 155 | 251 | ||
| 252 | /// Close the write half of the socket. | ||
| 253 | /// | ||
| 254 | /// This closes only the write half of the socket. The read half side remains open, the | ||
| 255 | /// socket can still receive data. | ||
| 256 | /// | ||
| 257 | /// Data that has been written to the socket and not yet sent (or not yet ACKed) will still | ||
| 258 | /// still sent. The last segment of the pending to send data is sent with the FIN flag set. | ||
| 156 | pub fn close(&mut self) { | 259 | pub fn close(&mut self) { |
| 157 | unsafe { self.io.with_mut(|s, _| s.close()) } | 260 | self.io.with_mut(|s, _| s.close()) |
| 158 | } | 261 | } |
| 159 | 262 | ||
| 263 | /// Forcibly close the socket. | ||
| 264 | /// | ||
| 265 | /// This instantly closes both the read and write halves of the socket. Any pending data | ||
| 266 | /// that has not been sent will be lost. | ||
| 267 | /// | ||
| 268 | /// Note that the TCP RST packet is not sent immediately - if the `TcpSocket` is dropped too soon | ||
| 269 | /// the remote host may not know the connection has been closed. | ||
| 270 | /// `abort()` callers should wait for a [`flush()`](TcpSocket::flush) call to complete before | ||
| 271 | /// dropping or reusing the socket. | ||
| 160 | pub fn abort(&mut self) { | 272 | pub fn abort(&mut self) { |
| 161 | unsafe { self.io.with_mut(|s, _| s.abort()) } | 273 | self.io.with_mut(|s, _| s.abort()) |
| 162 | } | 274 | } |
| 163 | 275 | ||
| 276 | /// Get whether the socket is ready to send data, i.e. whether there is space in the send buffer. | ||
| 164 | pub fn may_send(&self) -> bool { | 277 | pub fn may_send(&self) -> bool { |
| 165 | unsafe { self.io.with(|s, _| s.may_send()) } | 278 | self.io.with(|s, _| s.may_send()) |
| 166 | } | 279 | } |
| 167 | 280 | ||
| 281 | /// return whether the recieve half of the full-duplex connection is open. | ||
| 282 | /// This function returns true if it’s possible to receive data from the remote endpoint. | ||
| 283 | /// It will return true while there is data in the receive buffer, and if there isn’t, | ||
| 284 | /// as long as the remote endpoint has not closed the connection. | ||
| 168 | pub fn may_recv(&self) -> bool { | 285 | pub fn may_recv(&self) -> bool { |
| 169 | unsafe { self.io.with(|s, _| s.may_recv()) } | 286 | self.io.with(|s, _| s.may_recv()) |
| 287 | } | ||
| 288 | |||
| 289 | /// Get whether the socket is ready to receive data, i.e. whether there is some pending data in the receive buffer. | ||
| 290 | pub fn can_recv(&self) -> bool { | ||
| 291 | self.io.with(|s, _| s.can_recv()) | ||
| 170 | } | 292 | } |
| 171 | } | 293 | } |
| 172 | 294 | ||
| 173 | impl<'a> Drop for TcpSocket<'a> { | 295 | impl<'a> Drop for TcpSocket<'a> { |
| 174 | fn drop(&mut self) { | 296 | fn drop(&mut self) { |
| 175 | // safety: not accessed reentrantly. | 297 | self.io.stack.borrow_mut().sockets.remove(self.io.handle); |
| 176 | let s = unsafe { &mut *self.io.stack.get() }; | ||
| 177 | s.sockets.remove(self.io.handle); | ||
| 178 | } | 298 | } |
| 179 | } | 299 | } |
| 180 | 300 | ||
| @@ -182,21 +302,19 @@ impl<'a> Drop for TcpSocket<'a> { | |||
| 182 | 302 | ||
| 183 | #[derive(Copy, Clone)] | 303 | #[derive(Copy, Clone)] |
| 184 | struct TcpIo<'a> { | 304 | struct TcpIo<'a> { |
| 185 | stack: &'a UnsafeCell<SocketStack>, | 305 | stack: &'a RefCell<SocketStack>, |
| 186 | handle: SocketHandle, | 306 | handle: SocketHandle, |
| 187 | } | 307 | } |
| 188 | 308 | ||
| 189 | impl<'d> TcpIo<'d> { | 309 | impl<'d> TcpIo<'d> { |
| 190 | /// SAFETY: must not call reentrantly. | 310 | fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { |
| 191 | unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { | 311 | let s = &*self.stack.borrow(); |
| 192 | let s = &*self.stack.get(); | ||
| 193 | let socket = s.sockets.get::<tcp::Socket>(self.handle); | 312 | let socket = s.sockets.get::<tcp::Socket>(self.handle); |
| 194 | f(socket, &s.iface) | 313 | f(socket, &s.iface) |
| 195 | } | 314 | } |
| 196 | 315 | ||
| 197 | /// SAFETY: must not call reentrantly. | 316 | fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { |
| 198 | unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { | 317 | let s = &mut *self.stack.borrow_mut(); |
| 199 | let s = &mut *self.stack.get(); | ||
| 200 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); | 318 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); |
| 201 | let res = f(socket, &mut s.iface); | 319 | let res = f(socket, &mut s.iface); |
| 202 | s.waker.wake(); | 320 | s.waker.wake(); |
| @@ -204,7 +322,7 @@ impl<'d> TcpIo<'d> { | |||
| 204 | } | 322 | } |
| 205 | 323 | ||
| 206 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { | 324 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { |
| 207 | poll_fn(move |cx| unsafe { | 325 | poll_fn(move |cx| { |
| 208 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect | 326 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect |
| 209 | // from posix-like IO, so we have to tweak things here. | 327 | // from posix-like IO, so we have to tweak things here. |
| 210 | self.with_mut(|s, _| match s.recv_slice(buf) { | 328 | self.with_mut(|s, _| match s.recv_slice(buf) { |
| @@ -225,7 +343,7 @@ impl<'d> TcpIo<'d> { | |||
| 225 | } | 343 | } |
| 226 | 344 | ||
| 227 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { | 345 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { |
| 228 | poll_fn(move |cx| unsafe { | 346 | poll_fn(move |cx| { |
| 229 | self.with_mut(|s, _| match s.send_slice(buf) { | 347 | self.with_mut(|s, _| match s.send_slice(buf) { |
| 230 | // Not ready to send (no space in the tx buffer) | 348 | // Not ready to send (no space in the tx buffer) |
| 231 | Ok(0) => { | 349 | Ok(0) => { |
| @@ -242,95 +360,89 @@ impl<'d> TcpIo<'d> { | |||
| 242 | } | 360 | } |
| 243 | 361 | ||
| 244 | async fn flush(&mut self) -> Result<(), Error> { | 362 | async fn flush(&mut self) -> Result<(), Error> { |
| 245 | poll_fn(move |_| { | 363 | poll_fn(move |cx| { |
| 246 | Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? | 364 | self.with_mut(|s, _| { |
| 365 | let waiting_close = s.state() == tcp::State::Closed && s.remote_endpoint().is_some(); | ||
| 366 | // If there are outstanding send operations, register for wake up and wait | ||
| 367 | // smoltcp issues wake-ups when octets are dequeued from the send buffer | ||
| 368 | if s.send_queue() > 0 || waiting_close { | ||
| 369 | s.register_send_waker(cx.waker()); | ||
| 370 | Poll::Pending | ||
| 371 | // No outstanding sends, socket is flushed | ||
| 372 | } else { | ||
| 373 | Poll::Ready(Ok(())) | ||
| 374 | } | ||
| 375 | }) | ||
| 247 | }) | 376 | }) |
| 248 | .await | 377 | .await |
| 249 | } | 378 | } |
| 250 | } | 379 | } |
| 251 | 380 | ||
| 252 | impl embedded_io::Error for ConnectError { | 381 | #[cfg(feature = "nightly")] |
| 253 | fn kind(&self) -> embedded_io::ErrorKind { | 382 | mod embedded_io_impls { |
| 254 | embedded_io::ErrorKind::Other | 383 | use super::*; |
| 255 | } | ||
| 256 | } | ||
| 257 | 384 | ||
| 258 | impl embedded_io::Error for Error { | 385 | impl embedded_io::Error for ConnectError { |
| 259 | fn kind(&self) -> embedded_io::ErrorKind { | 386 | fn kind(&self) -> embedded_io::ErrorKind { |
| 260 | embedded_io::ErrorKind::Other | 387 | embedded_io::ErrorKind::Other |
| 388 | } | ||
| 261 | } | 389 | } |
| 262 | } | ||
| 263 | |||
| 264 | impl<'d> embedded_io::Io for TcpSocket<'d> { | ||
| 265 | type Error = Error; | ||
| 266 | } | ||
| 267 | 390 | ||
| 268 | impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { | 391 | impl embedded_io::Error for Error { |
| 269 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 392 | fn kind(&self) -> embedded_io::ErrorKind { |
| 270 | where | 393 | embedded_io::ErrorKind::Other |
| 271 | Self: 'a; | 394 | } |
| 272 | |||
| 273 | fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { | ||
| 274 | self.io.read(buf) | ||
| 275 | } | 395 | } |
| 276 | } | ||
| 277 | 396 | ||
| 278 | impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { | 397 | impl<'d> embedded_io::Io for TcpSocket<'d> { |
| 279 | type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 398 | type Error = Error; |
| 280 | where | ||
| 281 | Self: 'a; | ||
| 282 | |||
| 283 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { | ||
| 284 | self.io.write(buf) | ||
| 285 | } | 399 | } |
| 286 | 400 | ||
| 287 | type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> | 401 | impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { |
| 288 | where | 402 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> { |
| 289 | Self: 'a; | 403 | self.io.read(buf).await |
| 290 | 404 | } | |
| 291 | fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { | ||
| 292 | self.io.flush() | ||
| 293 | } | 405 | } |
| 294 | } | ||
| 295 | |||
| 296 | impl<'d> embedded_io::Io for TcpReader<'d> { | ||
| 297 | type Error = Error; | ||
| 298 | } | ||
| 299 | 406 | ||
| 300 | impl<'d> embedded_io::asynch::Read for TcpReader<'d> { | 407 | impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { |
| 301 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 408 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> { |
| 302 | where | 409 | self.io.write(buf).await |
| 303 | Self: 'a; | 410 | } |
| 304 | 411 | ||
| 305 | fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { | 412 | async fn flush(&mut self) -> Result<(), Self::Error> { |
| 306 | self.io.read(buf) | 413 | self.io.flush().await |
| 414 | } | ||
| 307 | } | 415 | } |
| 308 | } | ||
| 309 | 416 | ||
| 310 | impl<'d> embedded_io::Io for TcpWriter<'d> { | 417 | impl<'d> embedded_io::Io for TcpReader<'d> { |
| 311 | type Error = Error; | 418 | type Error = Error; |
| 312 | } | 419 | } |
| 313 | 420 | ||
| 314 | impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { | 421 | impl<'d> embedded_io::asynch::Read for TcpReader<'d> { |
| 315 | type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 422 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> { |
| 316 | where | 423 | self.io.read(buf).await |
| 317 | Self: 'a; | 424 | } |
| 425 | } | ||
| 318 | 426 | ||
| 319 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { | 427 | impl<'d> embedded_io::Io for TcpWriter<'d> { |
| 320 | self.io.write(buf) | 428 | type Error = Error; |
| 321 | } | 429 | } |
| 322 | 430 | ||
| 323 | type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> | 431 | impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { |
| 324 | where | 432 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> { |
| 325 | Self: 'a; | 433 | self.io.write(buf).await |
| 434 | } | ||
| 326 | 435 | ||
| 327 | fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { | 436 | async fn flush(&mut self) -> Result<(), Self::Error> { |
| 328 | self.io.flush() | 437 | self.io.flush().await |
| 438 | } | ||
| 329 | } | 439 | } |
| 330 | } | 440 | } |
| 331 | 441 | ||
| 332 | #[cfg(feature = "unstable-traits")] | 442 | /// TCP client compatible with `embedded-nal-async` traits. |
| 443 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] | ||
| 333 | pub mod client { | 444 | pub mod client { |
| 445 | use core::cell::UnsafeCell; | ||
| 334 | use core::mem::MaybeUninit; | 446 | use core::mem::MaybeUninit; |
| 335 | use core::ptr::NonNull; | 447 | use core::ptr::NonNull; |
| 336 | 448 | ||
| @@ -339,49 +451,56 @@ pub mod client { | |||
| 339 | 451 | ||
| 340 | use super::*; | 452 | use super::*; |
| 341 | 453 | ||
| 342 | /// TCP client capable of creating up to N multiple connections with tx and rx buffers according to TX_SZ and RX_SZ. | 454 | /// TCP client connection pool compatible with `embedded-nal-async` traits. |
| 343 | pub struct TcpClient<'d, D: Device, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { | 455 | /// |
| 456 | /// The pool is capable of managing up to N concurrent connections with tx and rx buffers according to TX_SZ and RX_SZ. | ||
| 457 | pub struct TcpClient<'d, D: Driver, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { | ||
| 344 | stack: &'d Stack<D>, | 458 | stack: &'d Stack<D>, |
| 345 | state: &'d TcpClientState<N, TX_SZ, RX_SZ>, | 459 | state: &'d TcpClientState<N, TX_SZ, RX_SZ>, |
| 346 | } | 460 | } |
| 347 | 461 | ||
| 348 | impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { | 462 | impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { |
| 349 | /// Create a new TcpClient | 463 | /// Create a new `TcpClient`. |
| 350 | pub fn new(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self { | 464 | pub fn new(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self { |
| 351 | Self { stack, state } | 465 | Self { stack, state } |
| 352 | } | 466 | } |
| 353 | } | 467 | } |
| 354 | 468 | ||
| 355 | impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_nal_async::TcpConnect | 469 | impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_nal_async::TcpConnect |
| 356 | for TcpClient<'d, D, N, TX_SZ, RX_SZ> | 470 | for TcpClient<'d, D, N, TX_SZ, RX_SZ> |
| 357 | { | 471 | { |
| 358 | type Error = Error; | 472 | type Error = Error; |
| 359 | type Connection<'m> = TcpConnection<'m, N, TX_SZ, RX_SZ> where Self: 'm; | 473 | type Connection<'m> = TcpConnection<'m, N, TX_SZ, RX_SZ> where Self: 'm; |
| 360 | type ConnectFuture<'m> = impl Future<Output = Result<Self::Connection<'m>, Self::Error>> + 'm | 474 | |
| 361 | where | 475 | async fn connect<'a>( |
| 362 | Self: 'm; | 476 | &'a self, |
| 363 | 477 | remote: embedded_nal_async::SocketAddr, | |
| 364 | fn connect<'m>(&'m self, remote: embedded_nal_async::SocketAddr) -> Self::ConnectFuture<'m> { | 478 | ) -> Result<Self::Connection<'a>, Self::Error> |
| 365 | async move { | 479 | where |
| 366 | let addr: crate::IpAddress = match remote.ip() { | 480 | Self: 'a, |
| 367 | IpAddr::V4(addr) => crate::IpAddress::Ipv4(crate::Ipv4Address::from_bytes(&addr.octets())), | 481 | { |
| 368 | #[cfg(feature = "proto-ipv6")] | 482 | let addr: crate::IpAddress = match remote.ip() { |
| 369 | IpAddr::V6(addr) => crate::IpAddress::Ipv6(crate::Ipv6Address::from_bytes(&addr.octets())), | 483 | #[cfg(feature = "proto-ipv4")] |
| 370 | #[cfg(not(feature = "proto-ipv6"))] | 484 | IpAddr::V4(addr) => crate::IpAddress::Ipv4(crate::Ipv4Address::from_bytes(&addr.octets())), |
| 371 | IpAddr::V6(_) => panic!("ipv6 support not enabled"), | 485 | #[cfg(not(feature = "proto-ipv4"))] |
| 372 | }; | 486 | IpAddr::V4(_) => panic!("ipv4 support not enabled"), |
| 373 | let remote_endpoint = (addr, remote.port()); | 487 | #[cfg(feature = "proto-ipv6")] |
| 374 | let mut socket = TcpConnection::new(&self.stack, self.state)?; | 488 | IpAddr::V6(addr) => crate::IpAddress::Ipv6(crate::Ipv6Address::from_bytes(&addr.octets())), |
| 375 | socket | 489 | #[cfg(not(feature = "proto-ipv6"))] |
| 376 | .socket | 490 | IpAddr::V6(_) => panic!("ipv6 support not enabled"), |
| 377 | .connect(remote_endpoint) | 491 | }; |
| 378 | .await | 492 | let remote_endpoint = (addr, remote.port()); |
| 379 | .map_err(|_| Error::ConnectionReset)?; | 493 | let mut socket = TcpConnection::new(&self.stack, self.state)?; |
| 380 | Ok(socket) | 494 | socket |
| 381 | } | 495 | .socket |
| 496 | .connect(remote_endpoint) | ||
| 497 | .await | ||
| 498 | .map_err(|_| Error::ConnectionReset)?; | ||
| 499 | Ok(socket) | ||
| 382 | } | 500 | } |
| 383 | } | 501 | } |
| 384 | 502 | ||
| 503 | /// Opened TCP connection in a [`TcpClient`]. | ||
| 385 | pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> { | 504 | pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> { |
| 386 | socket: TcpSocket<'d>, | 505 | socket: TcpSocket<'d>, |
| 387 | state: &'d TcpClientState<N, TX_SZ, RX_SZ>, | 506 | state: &'d TcpClientState<N, TX_SZ, RX_SZ>, |
| @@ -389,10 +508,10 @@ pub mod client { | |||
| 389 | } | 508 | } |
| 390 | 509 | ||
| 391 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> { | 510 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> { |
| 392 | fn new<D: Device>(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Result<Self, Error> { | 511 | fn new<D: Driver>(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Result<Self, Error> { |
| 393 | let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?; | 512 | let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?; |
| 394 | Ok(Self { | 513 | Ok(Self { |
| 395 | socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().0, &mut bufs.as_mut().1) }, | 514 | socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().1, &mut bufs.as_mut().0) }, |
| 396 | state, | 515 | state, |
| 397 | bufs, | 516 | bufs, |
| 398 | }) | 517 | }) |
| @@ -417,32 +536,20 @@ pub mod client { | |||
| 417 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Read | 536 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Read |
| 418 | for TcpConnection<'d, N, TX_SZ, RX_SZ> | 537 | for TcpConnection<'d, N, TX_SZ, RX_SZ> |
| 419 | { | 538 | { |
| 420 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 539 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> { |
| 421 | where | 540 | self.socket.read(buf).await |
| 422 | Self: 'a; | ||
| 423 | |||
| 424 | fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { | ||
| 425 | self.socket.read(buf) | ||
| 426 | } | 541 | } |
| 427 | } | 542 | } |
| 428 | 543 | ||
| 429 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Write | 544 | impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io::asynch::Write |
| 430 | for TcpConnection<'d, N, TX_SZ, RX_SZ> | 545 | for TcpConnection<'d, N, TX_SZ, RX_SZ> |
| 431 | { | 546 | { |
| 432 | type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 547 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> { |
| 433 | where | 548 | self.socket.write(buf).await |
| 434 | Self: 'a; | ||
| 435 | |||
| 436 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { | ||
| 437 | self.socket.write(buf) | ||
| 438 | } | 549 | } |
| 439 | 550 | ||
| 440 | type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> | 551 | async fn flush(&mut self) -> Result<(), Self::Error> { |
| 441 | where | 552 | self.socket.flush().await |
| 442 | Self: 'a; | ||
| 443 | |||
| 444 | fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { | ||
| 445 | self.socket.flush() | ||
| 446 | } | 553 | } |
| 447 | } | 554 | } |
| 448 | 555 | ||
| @@ -452,6 +559,7 @@ pub mod client { | |||
| 452 | } | 559 | } |
| 453 | 560 | ||
| 454 | impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> { | 561 | impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> { |
| 562 | /// Create a new `TcpClientState`. | ||
| 455 | pub const fn new() -> Self { | 563 | pub const fn new() -> Self { |
| 456 | Self { pool: Pool::new() } | 564 | Self { pool: Pool::new() } |
| 457 | } | 565 | } |
diff --git a/embassy-net/src/time.rs b/embassy-net/src/time.rs new file mode 100644 index 000000000..b98d40fdc --- /dev/null +++ b/embassy-net/src/time.rs | |||
| @@ -0,0 +1,20 @@ | |||
| 1 | #![allow(unused)] | ||
| 2 | |||
| 3 | use embassy_time::{Duration, Instant}; | ||
| 4 | use smoltcp::time::{Duration as SmolDuration, Instant as SmolInstant}; | ||
| 5 | |||
| 6 | pub(crate) fn instant_to_smoltcp(instant: Instant) -> SmolInstant { | ||
| 7 | SmolInstant::from_micros(instant.as_micros() as i64) | ||
| 8 | } | ||
| 9 | |||
| 10 | pub(crate) fn instant_from_smoltcp(instant: SmolInstant) -> Instant { | ||
| 11 | Instant::from_micros(instant.total_micros() as u64) | ||
| 12 | } | ||
| 13 | |||
| 14 | pub(crate) fn duration_to_smoltcp(duration: Duration) -> SmolDuration { | ||
| 15 | SmolDuration::from_micros(duration.as_micros()) | ||
| 16 | } | ||
| 17 | |||
| 18 | pub(crate) fn duration_from_smoltcp(duration: SmolDuration) -> Duration { | ||
| 19 | Duration::from_micros(duration.total_micros()) | ||
| 20 | } | ||
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs index 78b09a492..0d97b6db1 100644 --- a/embassy-net/src/udp.rs +++ b/embassy-net/src/udp.rs | |||
| @@ -1,15 +1,19 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | //! UDP sockets. |
| 2 | |||
| 3 | use core::cell::RefCell; | ||
| 4 | use core::future::poll_fn; | ||
| 2 | use core::mem; | 5 | use core::mem; |
| 3 | use core::task::Poll; | 6 | use core::task::{Context, Poll}; |
| 4 | 7 | ||
| 5 | use futures::future::poll_fn; | 8 | use embassy_net_driver::Driver; |
| 6 | use smoltcp::iface::{Interface, SocketHandle}; | 9 | use smoltcp::iface::{Interface, SocketHandle}; |
| 7 | use smoltcp::socket::udp::{self, PacketMetadata}; | 10 | use smoltcp::socket::udp; |
| 11 | pub use smoltcp::socket::udp::PacketMetadata; | ||
| 8 | use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; | 12 | use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; |
| 9 | 13 | ||
| 10 | use super::stack::SocketStack; | 14 | use crate::{SocketStack, Stack}; |
| 11 | use crate::{Device, Stack}; | ||
| 12 | 15 | ||
| 16 | /// Error returned by [`UdpSocket::bind`]. | ||
| 13 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] | 17 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] |
| 14 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | 18 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] |
| 15 | pub enum BindError { | 19 | pub enum BindError { |
| @@ -19,6 +23,7 @@ pub enum BindError { | |||
| 19 | NoRoute, | 23 | NoRoute, |
| 20 | } | 24 | } |
| 21 | 25 | ||
| 26 | /// Error returned by [`UdpSocket::recv_from`] and [`UdpSocket::send_to`]. | ||
| 22 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] | 27 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] |
| 23 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | 28 | #[cfg_attr(feature = "defmt", derive(defmt::Format))] |
| 24 | pub enum Error { | 29 | pub enum Error { |
| @@ -26,21 +31,22 @@ pub enum Error { | |||
| 26 | NoRoute, | 31 | NoRoute, |
| 27 | } | 32 | } |
| 28 | 33 | ||
| 34 | /// An UDP socket. | ||
| 29 | pub struct UdpSocket<'a> { | 35 | pub struct UdpSocket<'a> { |
| 30 | stack: &'a UnsafeCell<SocketStack>, | 36 | stack: &'a RefCell<SocketStack>, |
| 31 | handle: SocketHandle, | 37 | handle: SocketHandle, |
| 32 | } | 38 | } |
| 33 | 39 | ||
| 34 | impl<'a> UdpSocket<'a> { | 40 | impl<'a> UdpSocket<'a> { |
| 35 | pub fn new<D: Device>( | 41 | /// Create a new UDP socket using the provided stack and buffers. |
| 42 | pub fn new<D: Driver>( | ||
| 36 | stack: &'a Stack<D>, | 43 | stack: &'a Stack<D>, |
| 37 | rx_meta: &'a mut [PacketMetadata], | 44 | rx_meta: &'a mut [PacketMetadata], |
| 38 | rx_buffer: &'a mut [u8], | 45 | rx_buffer: &'a mut [u8], |
| 39 | tx_meta: &'a mut [PacketMetadata], | 46 | tx_meta: &'a mut [PacketMetadata], |
| 40 | tx_buffer: &'a mut [u8], | 47 | tx_buffer: &'a mut [u8], |
| 41 | ) -> Self { | 48 | ) -> Self { |
| 42 | // safety: not accessed reentrantly. | 49 | let s = &mut *stack.socket.borrow_mut(); |
| 43 | let s = unsafe { &mut *stack.socket.get() }; | ||
| 44 | 50 | ||
| 45 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; | 51 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; |
| 46 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 52 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| @@ -57,101 +63,131 @@ impl<'a> UdpSocket<'a> { | |||
| 57 | } | 63 | } |
| 58 | } | 64 | } |
| 59 | 65 | ||
| 66 | /// Bind the socket to a local endpoint. | ||
| 60 | pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError> | 67 | pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError> |
| 61 | where | 68 | where |
| 62 | T: Into<IpListenEndpoint>, | 69 | T: Into<IpListenEndpoint>, |
| 63 | { | 70 | { |
| 64 | let mut endpoint = endpoint.into(); | 71 | let mut endpoint = endpoint.into(); |
| 65 | 72 | ||
| 66 | // safety: not accessed reentrantly. | ||
| 67 | if endpoint.port == 0 { | 73 | if endpoint.port == 0 { |
| 68 | // If user didn't specify port allocate a dynamic port. | 74 | // If user didn't specify port allocate a dynamic port. |
| 69 | endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); | 75 | endpoint.port = self.stack.borrow_mut().get_local_port(); |
| 70 | } | 76 | } |
| 71 | 77 | ||
| 72 | // safety: not accessed reentrantly. | 78 | match self.with_mut(|s, _| s.bind(endpoint)) { |
| 73 | match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } { | ||
| 74 | Ok(()) => Ok(()), | 79 | Ok(()) => Ok(()), |
| 75 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), | 80 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), |
| 76 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), | 81 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), |
| 77 | } | 82 | } |
| 78 | } | 83 | } |
| 79 | 84 | ||
| 80 | /// SAFETY: must not call reentrantly. | 85 | fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { |
| 81 | unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { | 86 | let s = &*self.stack.borrow(); |
| 82 | let s = &*self.stack.get(); | ||
| 83 | let socket = s.sockets.get::<udp::Socket>(self.handle); | 87 | let socket = s.sockets.get::<udp::Socket>(self.handle); |
| 84 | f(socket, &s.iface) | 88 | f(socket, &s.iface) |
| 85 | } | 89 | } |
| 86 | 90 | ||
| 87 | /// SAFETY: must not call reentrantly. | 91 | fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { |
| 88 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { | 92 | let s = &mut *self.stack.borrow_mut(); |
| 89 | let s = &mut *self.stack.get(); | ||
| 90 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); | 93 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); |
| 91 | let res = f(socket, &mut s.iface); | 94 | let res = f(socket, &mut s.iface); |
| 92 | s.waker.wake(); | 95 | s.waker.wake(); |
| 93 | res | 96 | res |
| 94 | } | 97 | } |
| 95 | 98 | ||
| 99 | /// Receive a datagram. | ||
| 100 | /// | ||
| 101 | /// This method will wait until a datagram is received. | ||
| 102 | /// | ||
| 103 | /// Returns the number of bytes received and the remote endpoint. | ||
| 96 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { | 104 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { |
| 97 | poll_fn(move |cx| unsafe { | 105 | poll_fn(move |cx| self.poll_recv_from(buf, cx)).await |
| 98 | self.with_mut(|s, _| match s.recv_slice(buf) { | 106 | } |
| 99 | Ok(x) => Poll::Ready(Ok(x)), | 107 | |
| 100 | // No data ready | 108 | /// Receive a datagram. |
| 101 | Err(udp::RecvError::Exhausted) => { | 109 | /// |
| 102 | //s.register_recv_waker(cx.waker()); | 110 | /// When no datagram is available, this method will return `Poll::Pending` and |
| 103 | cx.waker().wake_by_ref(); | 111 | /// register the current task to be notified when a datagram is received. |
| 104 | Poll::Pending | 112 | /// |
| 105 | } | 113 | /// When a datagram is received, this method will return `Poll::Ready` with the |
| 106 | }) | 114 | /// number of bytes received and the remote endpoint. |
| 115 | pub fn poll_recv_from(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<(usize, IpEndpoint), Error>> { | ||
| 116 | self.with_mut(|s, _| match s.recv_slice(buf) { | ||
| 117 | Ok((n, meta)) => Poll::Ready(Ok((n, meta.endpoint))), | ||
| 118 | // No data ready | ||
| 119 | Err(udp::RecvError::Exhausted) => { | ||
| 120 | s.register_recv_waker(cx.waker()); | ||
| 121 | Poll::Pending | ||
| 122 | } | ||
| 107 | }) | 123 | }) |
| 108 | .await | ||
| 109 | } | 124 | } |
| 110 | 125 | ||
| 126 | /// Send a datagram to the specified remote endpoint. | ||
| 127 | /// | ||
| 128 | /// This method will wait until the datagram has been sent. | ||
| 129 | /// | ||
| 130 | /// When the remote endpoint is not reachable, this method will return `Err(Error::NoRoute)` | ||
| 111 | pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error> | 131 | pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error> |
| 112 | where | 132 | where |
| 113 | T: Into<IpEndpoint>, | 133 | T: Into<IpEndpoint>, |
| 114 | { | 134 | { |
| 115 | let remote_endpoint = remote_endpoint.into(); | 135 | let remote_endpoint: IpEndpoint = remote_endpoint.into(); |
| 116 | poll_fn(move |cx| unsafe { | 136 | poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await |
| 117 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { | 137 | } |
| 118 | // Entire datagram has been sent | 138 | |
| 119 | Ok(()) => Poll::Ready(Ok(())), | 139 | /// Send a datagram to the specified remote endpoint. |
| 120 | Err(udp::SendError::BufferFull) => { | 140 | /// |
| 121 | s.register_send_waker(cx.waker()); | 141 | /// When the datagram has been sent, this method will return `Poll::Ready(Ok())`. |
| 122 | Poll::Pending | 142 | /// |
| 123 | } | 143 | /// When the socket's send buffer is full, this method will return `Poll::Pending` |
| 124 | Err(udp::SendError::Unaddressable) => Poll::Ready(Err(Error::NoRoute)), | 144 | /// and register the current task to be notified when the buffer has space available. |
| 125 | }) | 145 | /// |
| 146 | /// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`. | ||
| 147 | pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), Error>> | ||
| 148 | where | ||
| 149 | T: Into<IpEndpoint>, | ||
| 150 | { | ||
| 151 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { | ||
| 152 | // Entire datagram has been sent | ||
| 153 | Ok(()) => Poll::Ready(Ok(())), | ||
| 154 | Err(udp::SendError::BufferFull) => { | ||
| 155 | s.register_send_waker(cx.waker()); | ||
| 156 | Poll::Pending | ||
| 157 | } | ||
| 158 | Err(udp::SendError::Unaddressable) => Poll::Ready(Err(Error::NoRoute)), | ||
| 126 | }) | 159 | }) |
| 127 | .await | ||
| 128 | } | 160 | } |
| 129 | 161 | ||
| 162 | /// Returns the local endpoint of the socket. | ||
| 130 | pub fn endpoint(&self) -> IpListenEndpoint { | 163 | pub fn endpoint(&self) -> IpListenEndpoint { |
| 131 | unsafe { self.with(|s, _| s.endpoint()) } | 164 | self.with(|s, _| s.endpoint()) |
| 132 | } | 165 | } |
| 133 | 166 | ||
| 167 | /// Returns whether the socket is open. | ||
| 168 | |||
| 134 | pub fn is_open(&self) -> bool { | 169 | pub fn is_open(&self) -> bool { |
| 135 | unsafe { self.with(|s, _| s.is_open()) } | 170 | self.with(|s, _| s.is_open()) |
| 136 | } | 171 | } |
| 137 | 172 | ||
| 173 | /// Close the socket. | ||
| 138 | pub fn close(&mut self) { | 174 | pub fn close(&mut self) { |
| 139 | unsafe { self.with_mut(|s, _| s.close()) } | 175 | self.with_mut(|s, _| s.close()) |
| 140 | } | 176 | } |
| 141 | 177 | ||
| 178 | /// Returns whether the socket is ready to send data, i.e. it has enough buffer space to hold a packet. | ||
| 142 | pub fn may_send(&self) -> bool { | 179 | pub fn may_send(&self) -> bool { |
| 143 | unsafe { self.with(|s, _| s.can_send()) } | 180 | self.with(|s, _| s.can_send()) |
| 144 | } | 181 | } |
| 145 | 182 | ||
| 183 | /// Returns whether the socket is ready to receive data, i.e. it has received a packet that's now in the buffer. | ||
| 146 | pub fn may_recv(&self) -> bool { | 184 | pub fn may_recv(&self) -> bool { |
| 147 | unsafe { self.with(|s, _| s.can_recv()) } | 185 | self.with(|s, _| s.can_recv()) |
| 148 | } | 186 | } |
| 149 | } | 187 | } |
| 150 | 188 | ||
| 151 | impl Drop for UdpSocket<'_> { | 189 | impl Drop for UdpSocket<'_> { |
| 152 | fn drop(&mut self) { | 190 | fn drop(&mut self) { |
| 153 | // safety: not accessed reentrantly. | 191 | self.stack.borrow_mut().sockets.remove(self.handle); |
| 154 | let s = unsafe { &mut *self.stack.get() }; | ||
| 155 | s.sockets.remove(self.handle); | ||
| 156 | } | 192 | } |
| 157 | } | 193 | } |
