From a5aea995a802fea8fc1b3e4b5fe47bd6d1fca2a4 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Mon, 23 May 2022 03:50:43 +0200 Subject: WIP embassy-net v2 --- embassy-net/src/config/dhcp.rs | 55 ------ embassy-net/src/config/mod.rs | 35 ---- embassy-net/src/config/statik.rs | 29 --- embassy-net/src/device.rs | 50 +++-- embassy-net/src/lib.rs | 10 +- embassy-net/src/stack.rs | 393 ++++++++++++++++++++++----------------- embassy-net/src/tcp.rs | 299 ++++++++++++++--------------- 7 files changed, 405 insertions(+), 466 deletions(-) delete mode 100644 embassy-net/src/config/dhcp.rs delete mode 100644 embassy-net/src/config/mod.rs delete mode 100644 embassy-net/src/config/statik.rs (limited to 'embassy-net/src') diff --git a/embassy-net/src/config/dhcp.rs b/embassy-net/src/config/dhcp.rs deleted file mode 100644 index 298657ed6..000000000 --- a/embassy-net/src/config/dhcp.rs +++ /dev/null @@ -1,55 +0,0 @@ -use heapless::Vec; -use smoltcp::iface::SocketHandle; -use smoltcp::socket::{Dhcpv4Event, Dhcpv4Socket}; -use smoltcp::time::Instant; - -use super::*; -use crate::device::LinkState; -use crate::Interface; - -pub struct DhcpConfigurator { - handle: Option, -} - -impl DhcpConfigurator { - pub fn new() -> Self { - Self { handle: None } - } -} - -impl Configurator for DhcpConfigurator { - fn poll(&mut self, iface: &mut Interface, _timestamp: Instant) -> Event { - if self.handle.is_none() { - let handle = iface.add_socket(Dhcpv4Socket::new()); - self.handle = Some(handle) - } - - let link_up = iface.device_mut().device.link_state() == LinkState::Up; - - let socket = iface.get_socket::(self.handle.unwrap()); - - if !link_up { - socket.reset(); - return Event::Deconfigured; - } - - match socket.poll() { - None => Event::NoChange, - Some(Dhcpv4Event::Deconfigured) => Event::Deconfigured, - Some(Dhcpv4Event::Configured(config)) => { - let mut dns_servers = Vec::new(); - for s in &config.dns_servers { - if let Some(addr) = s { - dns_servers.push(addr.clone()).unwrap(); - } - } - - Event::Configured(Config { - address: config.address, - gateway: config.router, - dns_servers, - }) - } - } - } -} diff --git a/embassy-net/src/config/mod.rs b/embassy-net/src/config/mod.rs deleted file mode 100644 index eb1b6636a..000000000 --- a/embassy-net/src/config/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -use heapless::Vec; -use smoltcp::time::Instant; -use smoltcp::wire::{Ipv4Address, Ipv4Cidr}; - -use crate::Interface; - -mod statik; -pub use statik::StaticConfigurator; - -#[cfg(feature = "dhcpv4")] -mod dhcp; -#[cfg(feature = "dhcpv4")] -pub use dhcp::DhcpConfigurator; - -/// Return value for the `Configurator::poll` function -#[derive(Debug, Clone)] -pub enum Event { - /// No change has occured to the configuration. - NoChange, - /// Configuration has been lost (for example, DHCP lease has expired) - Deconfigured, - /// Configuration has been newly acquired, or modified. - Configured(Config), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Config { - pub address: Ipv4Cidr, - pub gateway: Option, - pub dns_servers: Vec, -} - -pub trait Configurator { - fn poll(&mut self, iface: &mut Interface, timestamp: Instant) -> Event; -} diff --git a/embassy-net/src/config/statik.rs b/embassy-net/src/config/statik.rs deleted file mode 100644 index e614db73b..000000000 --- a/embassy-net/src/config/statik.rs +++ /dev/null @@ -1,29 +0,0 @@ -use smoltcp::time::Instant; - -use super::*; -use crate::Interface; - -pub struct StaticConfigurator { - config: Config, - returned: bool, -} - -impl StaticConfigurator { - pub fn new(config: Config) -> Self { - Self { - config, - returned: false, - } - } -} - -impl Configurator for StaticConfigurator { - fn poll(&mut self, _iface: &mut Interface, _timestamp: Instant) -> Event { - if self.returned { - Event::NoChange - } else { - self.returned = true; - Event::Configured(self.config.clone()) - } - } -} diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs index 1f4fa5208..99c6a2212 100644 --- a/embassy-net/src/device.rs +++ b/embassy-net/src/device.rs @@ -12,24 +12,50 @@ pub enum LinkState { Up, } +// 'static required due to the "fake GAT" in smoltcp::phy::Device. +// https://github.com/smoltcp-rs/smoltcp/pull/572 pub trait Device { fn is_transmit_ready(&mut self) -> bool; fn transmit(&mut self, pkt: PacketBuf); fn receive(&mut self) -> Option; fn register_waker(&mut self, waker: &Waker); - fn capabilities(&mut self) -> DeviceCapabilities; + fn capabilities(&self) -> DeviceCapabilities; fn link_state(&mut self) -> LinkState; fn ethernet_address(&self) -> [u8; 6]; } -pub struct DeviceAdapter { - pub device: &'static mut dyn Device, +impl Device for &'static mut T { + fn is_transmit_ready(&mut self) -> bool { + T::is_transmit_ready(self) + } + fn transmit(&mut self, pkt: PacketBuf) { + T::transmit(self, pkt) + } + fn receive(&mut self) -> Option { + T::receive(self) + } + fn register_waker(&mut self, waker: &Waker) { + T::register_waker(self, waker) + } + fn capabilities(&self) -> DeviceCapabilities { + T::capabilities(self) + } + fn link_state(&mut self) -> LinkState { + T::link_state(self) + } + fn ethernet_address(&self) -> [u8; 6] { + T::ethernet_address(self) + } +} + +pub struct DeviceAdapter { + pub device: D, caps: DeviceCapabilities, } -impl DeviceAdapter { - pub(crate) fn new(device: &'static mut dyn Device) -> Self { +impl DeviceAdapter { + pub(crate) fn new(device: D) -> Self { Self { caps: device.capabilities(), device, @@ -37,16 +63,16 @@ impl DeviceAdapter { } } -impl<'a> SmolDevice<'a> for DeviceAdapter { +impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter { type RxToken = RxToken; - type TxToken = TxToken<'a>; + type TxToken = TxToken<'a, D>; fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { let tx_pkt = PacketBox::new(Packet::new())?; let rx_pkt = self.device.receive()?; let rx_token = RxToken { pkt: rx_pkt }; let tx_token = TxToken { - device: self.device, + device: &mut self.device, pkt: tx_pkt, }; @@ -61,7 +87,7 @@ impl<'a> SmolDevice<'a> for DeviceAdapter { let tx_pkt = PacketBox::new(Packet::new())?; Some(TxToken { - device: self.device, + device: &mut self.device, pkt: tx_pkt, }) } @@ -85,12 +111,12 @@ impl smoltcp::phy::RxToken for RxToken { } } -pub struct TxToken<'a> { - device: &'a mut dyn Device, +pub struct TxToken<'a, D: Device> { + device: &'a mut D, pkt: PacketBox, } -impl<'a> smoltcp::phy::TxToken for TxToken<'a> { +impl<'a, D: Device> smoltcp::phy::TxToken for TxToken<'a, D> { fn consume(self, _timestamp: SmolInstant, len: usize, f: F) -> smoltcp::Result where F: FnOnce(&mut [u8]) -> smoltcp::Result, diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index 18dc1ef61..7b5f29f16 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -5,20 +5,13 @@ // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; -mod config; mod device; mod packet_pool; mod stack; -#[cfg(feature = "dhcpv4")] -pub use config::DhcpConfigurator; -pub use config::{Config, Configurator, Event as ConfigEvent, StaticConfigurator}; - pub use device::{Device, LinkState}; pub use packet_pool::{Packet, PacketBox, PacketBoxExt, PacketBuf, MTU}; -pub use stack::{ - config, ethernet_address, init, is_config_up, is_init, is_link_up, run, StackResources, -}; +pub use stack::{Config, ConfigStrategy, Stack, StackResources}; #[cfg(feature = "tcp")] pub mod tcp; @@ -30,4 +23,3 @@ pub use smoltcp::time::Instant as SmolInstant; #[cfg(feature = "medium-ethernet")] pub use smoltcp::wire::{EthernetAddress, HardwareAddress}; pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Cidr}; -pub type Interface = smoltcp::iface::Interface<'static, device::DeviceAdapter>; diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs index 9461f832f..e28370df8 100644 --- a/embassy-net/src/stack.rs +++ b/embassy-net/src/stack.rs @@ -1,13 +1,18 @@ -use core::cell::RefCell; +use core::cell::UnsafeCell; use core::future::Future; use core::task::Context; use core::task::Poll; -use embassy::blocking_mutex::ThreadModeMutex; use embassy::time::{Instant, Timer}; use embassy::waitqueue::WakerRegistration; +use futures::future::poll_fn; use futures::pin_mut; -use smoltcp::iface::InterfaceBuilder; -use smoltcp::iface::SocketStorage; +use heapless::Vec; +#[cfg(feature = "dhcpv4")] +use smoltcp::iface::SocketHandle; +use smoltcp::iface::{Interface, InterfaceBuilder}; +use smoltcp::iface::{SocketSet, SocketStorage}; +#[cfg(feature = "dhcpv4")] +use smoltcp::socket::dhcpv4; use smoltcp::time::Instant as SmolInstant; use smoltcp::wire::{IpCidr, Ipv4Address, Ipv4Cidr}; @@ -18,10 +23,7 @@ use smoltcp::phy::{Device as _, Medium}; #[cfg(feature = "medium-ethernet")] use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress}; -use crate::config::Configurator; -use crate::config::Event; use crate::device::{Device, DeviceAdapter, LinkState}; -use crate::{Config, Interface}; const LOCAL_PORT_MIN: u16 = 1025; const LOCAL_PORT_MAX: u16 = 65535; @@ -51,24 +53,144 @@ impl } } -static STACK: ThreadModeMutex>> = ThreadModeMutex::new(RefCell::new(None)); +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Config { + pub address: Ipv4Cidr, + pub gateway: Option, + pub dns_servers: Vec, +} + +pub enum ConfigStrategy { + Static(Config), + #[cfg(feature = "dhcpv4")] + Dhcp, +} -pub(crate) struct Stack { - pub iface: Interface, +pub struct Stack { + pub(crate) socket: UnsafeCell, + inner: UnsafeCell>, +} + +struct Inner { + device: DeviceAdapter, link_up: bool, config: Option, + #[cfg(feature = "dhcpv4")] + dhcp_socket: Option, +} + +pub(crate) struct SocketStack { + pub(crate) sockets: SocketSet<'static>, + pub(crate) iface: Interface<'static>, + pub(crate) waker: WakerRegistration, next_local_port: u16, - configurator: &'static mut dyn Configurator, - waker: WakerRegistration, } -impl Stack { - pub(crate) fn with(f: impl FnOnce(&mut Stack) -> R) -> R { - let mut stack = STACK.borrow().borrow_mut(); - let stack = stack.as_mut().unwrap(); - f(stack) +unsafe impl Send for Stack {} + +impl Stack { + pub fn new( + device: D, + config: ConfigStrategy, + resources: &'static mut StackResources, + random_seed: u64, + ) -> Self { + #[cfg(feature = "medium-ethernet")] + let medium = device.capabilities().medium; + + #[cfg(feature = "medium-ethernet")] + let ethernet_addr = if medium == Medium::Ethernet { + device.ethernet_address() + } else { + [0, 0, 0, 0, 0, 0] + }; + + let mut device = DeviceAdapter::new(device); + + let mut b = InterfaceBuilder::new(); + b = b.ip_addrs(&mut resources.addresses[..]); + b = b.random_seed(random_seed); + + #[cfg(feature = "medium-ethernet")] + if medium == Medium::Ethernet { + b = b.hardware_addr(HardwareAddress::Ethernet(EthernetAddress(ethernet_addr))); + b = b.neighbor_cache(NeighborCache::new(&mut resources.neighbor_cache[..])); + b = b.routes(Routes::new(&mut resources.routes[..])); + } + + let iface = b.finalize(&mut device); + + let sockets = SocketSet::new(&mut resources.sockets[..]); + + let next_local_port = + (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; + + let mut inner = Inner { + device, + link_up: false, + config: None, + #[cfg(feature = "dhcpv4")] + dhcp_socket: None, + }; + let mut socket = SocketStack { + sockets, + iface, + waker: WakerRegistration::new(), + next_local_port, + }; + + match config { + ConfigStrategy::Static(config) => inner.apply_config(&mut socket, config), + #[cfg(feature = "dhcpv4")] + ConfigStrategy::Dhcp => { + let handle = socket.sockets.add(smoltcp::socket::dhcpv4::Socket::new()); + inner.dhcp_socket = Some(handle); + } + } + + Self { + socket: UnsafeCell::new(socket), + inner: UnsafeCell::new(inner), + } + } + + /// SAFETY: must not call reentrantly. + unsafe fn with(&self, f: impl FnOnce(&SocketStack, &Inner) -> R) -> R { + f(&*self.socket.get(), &*self.inner.get()) + } + + /// SAFETY: must not call reentrantly. + unsafe fn with_mut(&self, f: impl FnOnce(&mut SocketStack, &mut Inner) -> R) -> R { + f(&mut *self.socket.get(), &mut *self.inner.get()) + } + + pub fn ethernet_address(&self) -> [u8; 6] { + unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } + } + + pub fn is_link_up(&self) -> bool { + unsafe { self.with(|_s, i| i.link_up) } + } + + pub fn is_config_up(&self) -> bool { + unsafe { self.with(|_s, i| i.config.is_some()) } + } + + pub fn config(&self) -> Option { + unsafe { self.with(|_s, i| i.config.clone()) } + } + + pub async fn run(&self) -> ! { + poll_fn(|cx| { + unsafe { self.with_mut(|s, i| i.poll(cx, s)) } + Poll::<()>::Pending + }) + .await; + unreachable!() } +} +impl SocketStack { #[allow(clippy::absurd_extreme_comparisons)] pub fn get_local_port(&mut self) -> u16 { let res = self.next_local_port; @@ -79,60 +201,68 @@ impl Stack { }; res } +} + +impl Inner { + fn apply_config(&mut self, s: &mut SocketStack, config: Config) { + #[cfg(feature = "medium-ethernet")] + let medium = self.device.capabilities().medium; + + debug!("Acquired IP configuration:"); - pub(crate) fn wake(&mut self) { - self.waker.wake() + debug!(" IP address: {}", config.address); + self.set_ipv4_addr(s, config.address); + + #[cfg(feature = "medium-ethernet")] + if medium == Medium::Ethernet { + if let Some(gateway) = config.gateway { + debug!(" Default gateway: {}", gateway); + s.iface + .routes_mut() + .add_default_ipv4_route(gateway) + .unwrap(); + } else { + debug!(" Default gateway: None"); + s.iface.routes_mut().remove_default_ipv4_route(); + } + } + for (i, s) in config.dns_servers.iter().enumerate() { + debug!(" DNS server {}: {}", i, s); + } + + self.config = Some(config) } - fn poll_configurator(&mut self, timestamp: SmolInstant) { + #[allow(unused)] // used only with dhcp + fn unapply_config(&mut self, s: &mut SocketStack) { #[cfg(feature = "medium-ethernet")] - let medium = self.iface.device().capabilities().medium; - - match self.configurator.poll(&mut self.iface, timestamp) { - Event::NoChange => {} - Event::Configured(config) => { - debug!("Acquired IP configuration:"); - - debug!(" IP address: {}", config.address); - set_ipv4_addr(&mut self.iface, config.address); - - #[cfg(feature = "medium-ethernet")] - if medium == Medium::Ethernet { - if let Some(gateway) = config.gateway { - debug!(" Default gateway: {}", gateway); - self.iface - .routes_mut() - .add_default_ipv4_route(gateway) - .unwrap(); - } else { - debug!(" Default gateway: None"); - self.iface.routes_mut().remove_default_ipv4_route(); - } - } - for (i, s) in config.dns_servers.iter().enumerate() { - debug!(" DNS server {}: {}", i, s); - } + let medium = self.device.capabilities().medium; - self.config = Some(config) - } - Event::Deconfigured => { - debug!("Lost IP configuration"); - set_ipv4_addr(&mut self.iface, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)); - #[cfg(feature = "medium-ethernet")] - if medium == Medium::Ethernet { - self.iface.routes_mut().remove_default_ipv4_route(); - } - self.config = None - } + debug!("Lost IP configuration"); + self.set_ipv4_addr(s, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)); + #[cfg(feature = "medium-ethernet")] + if medium == Medium::Ethernet { + s.iface.routes_mut().remove_default_ipv4_route(); } + self.config = None + } + + fn set_ipv4_addr(&mut self, s: &mut SocketStack, cidr: Ipv4Cidr) { + s.iface.update_ip_addrs(|addrs| { + let dest = addrs.iter_mut().next().unwrap(); + *dest = IpCidr::Ipv4(cidr); + }); } - fn poll(&mut self, cx: &mut Context<'_>) { - self.iface.device_mut().device.register_waker(cx.waker()); - self.waker.register(cx.waker()); + fn poll(&mut self, cx: &mut Context<'_>, s: &mut SocketStack) { + self.device.device.register_waker(cx.waker()); + s.waker.register(cx.waker()); let timestamp = instant_to_smoltcp(Instant::now()); - if self.iface.poll(timestamp).is_err() { + if s.iface + .poll(timestamp, &mut self.device, &mut s.sockets) + .is_err() + { // If poll() returns error, it may not be done yet, so poll again later. cx.waker().wake_by_ref(); return; @@ -140,18 +270,49 @@ impl Stack { // Update link up let old_link_up = self.link_up; - self.link_up = self.iface.device_mut().device.link_state() == LinkState::Up; + self.link_up = self.device.device.link_state() == LinkState::Up; // Print when changed if old_link_up != self.link_up { info!("link_up = {:?}", self.link_up); } - if old_link_up || self.link_up { - self.poll_configurator(timestamp) + #[cfg(feature = "dhcpv4")] + if let Some(dhcp_handle) = self.dhcp_socket { + let socket = s.sockets.get_mut::(dhcp_handle); + + if self.link_up { + match socket.poll() { + None => {} + Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), + Some(dhcpv4::Event::Configured(config)) => { + let mut dns_servers = Vec::new(); + for s in &config.dns_servers { + if let Some(addr) = s { + dns_servers.push(addr.clone()).unwrap(); + } + } + + self.apply_config( + s, + Config { + address: config.address, + gateway: config.router, + dns_servers, + }, + ) + } + } + } else if old_link_up { + socket.reset(); + self.unapply_config(s); + } } + //if old_link_up || self.link_up { + // self.poll_configurator(timestamp) + //} - if let Some(poll_at) = self.iface.poll_at(timestamp) { + if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { let t = Timer::at(instant_from_smoltcp(poll_at)); pin_mut!(t); if t.poll(cx).is_ready() { @@ -161,100 +322,6 @@ impl Stack { } } -fn set_ipv4_addr(iface: &mut Interface, cidr: Ipv4Cidr) { - iface.update_ip_addrs(|addrs| { - let dest = addrs.iter_mut().next().unwrap(); - *dest = IpCidr::Ipv4(cidr); - }); -} - -/// Initialize embassy_net. -/// This function must be called from thread mode. -pub fn init( - device: &'static mut dyn Device, - configurator: &'static mut dyn Configurator, - resources: &'static mut StackResources, -) { - #[cfg(feature = "medium-ethernet")] - let medium = device.capabilities().medium; - - #[cfg(feature = "medium-ethernet")] - let ethernet_addr = if medium == Medium::Ethernet { - device.ethernet_address() - } else { - [0, 0, 0, 0, 0, 0] - }; - - let mut b = InterfaceBuilder::new(DeviceAdapter::new(device), &mut resources.sockets[..]); - b = b.ip_addrs(&mut resources.addresses[..]); - - #[cfg(feature = "medium-ethernet")] - if medium == Medium::Ethernet { - b = b.hardware_addr(HardwareAddress::Ethernet(EthernetAddress(ethernet_addr))); - b = b.neighbor_cache(NeighborCache::new(&mut resources.neighbor_cache[..])); - b = b.routes(Routes::new(&mut resources.routes[..])); - } - - let iface = b.finalize(); - - let local_port = loop { - let mut res = [0u8; 2]; - rand(&mut res); - let port = u16::from_le_bytes(res); - if (LOCAL_PORT_MIN..=LOCAL_PORT_MAX).contains(&port) { - break port; - } - }; - - let stack = Stack { - iface, - link_up: false, - config: None, - configurator, - next_local_port: local_port, - waker: WakerRegistration::new(), - }; - - *STACK.borrow().borrow_mut() = Some(stack); -} - -pub fn ethernet_address() -> [u8; 6] { - STACK - .borrow() - .borrow() - .as_ref() - .unwrap() - .iface - .device() - .device - .ethernet_address() -} - -pub fn is_init() -> bool { - STACK.borrow().borrow().is_some() -} - -pub fn is_link_up() -> bool { - STACK.borrow().borrow().as_ref().unwrap().link_up -} - -pub fn is_config_up() -> bool { - STACK.borrow().borrow().as_ref().unwrap().config.is_some() -} - -pub fn config() -> Option { - STACK.borrow().borrow().as_ref().unwrap().config.clone() -} - -pub async fn run() -> ! { - futures::future::poll_fn(|cx| { - Stack::with(|stack| stack.poll(cx)); - Poll::<()>::Pending - }) - .await; - unreachable!() -} - fn instant_to_smoltcp(instant: Instant) -> SmolInstant { SmolInstant::from_millis(instant.as_millis() as i64) } @@ -262,11 +329,3 @@ fn instant_to_smoltcp(instant: Instant) -> SmolInstant { fn instant_from_smoltcp(instant: SmolInstant) -> Instant { Instant::from_millis(instant.total_millis() as u64) } - -extern "Rust" { - fn _embassy_rand(buf: &mut [u8]); -} - -fn rand(buf: &mut [u8]) { - unsafe { _embassy_rand(buf) } -} diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index c18651b93..2d81e66bd 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs @@ -1,13 +1,16 @@ +use core::cell::UnsafeCell; use core::future::Future; -use core::marker::PhantomData; use core::mem; use core::task::Poll; use futures::future::poll_fn; -use smoltcp::iface::{Context as SmolContext, SocketHandle}; -use smoltcp::socket::TcpSocket as SyncTcpSocket; -use smoltcp::socket::{TcpSocketBuffer, TcpState}; +use smoltcp::iface::{Interface, SocketHandle}; +use smoltcp::socket::tcp; use smoltcp::time::Duration; use smoltcp::wire::IpEndpoint; +use smoltcp::wire::IpListenEndpoint; + +use crate::stack::SocketStack; +use crate::Device; use super::stack::Stack; @@ -42,78 +45,68 @@ pub enum AcceptError { } pub struct TcpSocket<'a> { - handle: SocketHandle, - ghost: PhantomData<&'a mut [u8]>, + io: TcpIo<'a>, } -impl<'a> Unpin for TcpSocket<'a> {} - pub struct TcpReader<'a> { - handle: SocketHandle, - ghost: PhantomData<&'a mut [u8]>, + io: TcpIo<'a>, } -impl<'a> Unpin for TcpReader<'a> {} - pub struct TcpWriter<'a> { - handle: SocketHandle, - ghost: PhantomData<&'a mut [u8]>, + io: TcpIo<'a>, } -impl<'a> Unpin for TcpWriter<'a> {} - impl<'a> TcpSocket<'a> { - pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { - let handle = Stack::with(|stack| { - let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; - let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; - stack.iface.add_socket(SyncTcpSocket::new( - TcpSocketBuffer::new(rx_buffer), - TcpSocketBuffer::new(tx_buffer), - )) - }); + pub fn new( + stack: &'a Stack, + rx_buffer: &'a mut [u8], + tx_buffer: &'a mut [u8], + ) -> Self { + // safety: not accessed reentrantly. + let s = unsafe { &mut *stack.socket.get() }; + let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; + let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; + let handle = s.sockets.add(tcp::Socket::new( + tcp::SocketBuffer::new(rx_buffer), + tcp::SocketBuffer::new(tx_buffer), + )); Self { - handle, - ghost: PhantomData, + io: TcpIo { + stack: &stack.socket, + handle, + }, } } pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { - ( - TcpReader { - handle: self.handle, - ghost: PhantomData, - }, - TcpWriter { - handle: self.handle, - ghost: PhantomData, - }, - ) + (TcpReader { io: self.io }, TcpWriter { io: self.io }) } pub async fn connect(&mut self, remote_endpoint: T) -> Result<(), ConnectError> where T: Into, { - let local_port = Stack::with(|stack| stack.get_local_port()); - match with_socket(self.handle, |s, cx| { - s.connect(cx, remote_endpoint, local_port) - }) { + // safety: not accessed reentrantly. + let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port(); + + // safety: not accessed reentrantly. + match unsafe { + self.io + .with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) + } { Ok(()) => {} - Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), - Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), + Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), + Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), } - futures::future::poll_fn(|cx| { - with_socket(self.handle, |s, _| match s.state() { - TcpState::Closed | TcpState::TimeWait => { + futures::future::poll_fn(|cx| unsafe { + self.io.with_mut(|s, _| match s.state() { + tcp::State::Closed | tcp::State::TimeWait => { Poll::Ready(Err(ConnectError::ConnectionReset)) } - TcpState::Listen => unreachable!(), - TcpState::SynSent | TcpState::SynReceived => { + tcp::State::Listen => unreachable!(), + tcp::State::SynSent | tcp::State::SynReceived => { s.register_send_waker(cx.waker()); Poll::Pending } @@ -125,19 +118,18 @@ impl<'a> TcpSocket<'a> { pub async fn accept(&mut self, local_endpoint: T) -> Result<(), AcceptError> where - T: Into, + T: Into, { - match with_socket(self.handle, |s, _| s.listen(local_endpoint)) { + // safety: not accessed reentrantly. + match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } { Ok(()) => {} - Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), - Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), + Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), + Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), } - futures::future::poll_fn(|cx| { - with_socket(self.handle, |s, _| match s.state() { - TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { + futures::future::poll_fn(|cx| unsafe { + self.io.with_mut(|s, _| match s.state() { + tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { s.register_send_waker(cx.waker()); Poll::Pending } @@ -148,88 +140,84 @@ impl<'a> TcpSocket<'a> { } pub fn set_timeout(&mut self, duration: Option) { - with_socket(self.handle, |s, _| s.set_timeout(duration)) + unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } } pub fn set_keep_alive(&mut self, interval: Option) { - with_socket(self.handle, |s, _| s.set_keep_alive(interval)) + unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } } pub fn set_hop_limit(&mut self, hop_limit: Option) { - with_socket(self.handle, |s, _| s.set_hop_limit(hop_limit)) + unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } } - pub fn local_endpoint(&self) -> IpEndpoint { - with_socket(self.handle, |s, _| s.local_endpoint()) + pub fn local_endpoint(&self) -> Option { + unsafe { self.io.with(|s, _| s.local_endpoint()) } } - pub fn remote_endpoint(&self) -> IpEndpoint { - with_socket(self.handle, |s, _| s.remote_endpoint()) + pub fn remote_endpoint(&self) -> Option { + unsafe { self.io.with(|s, _| s.remote_endpoint()) } } - pub fn state(&self) -> TcpState { - with_socket(self.handle, |s, _| s.state()) + pub fn state(&self) -> tcp::State { + unsafe { self.io.with(|s, _| s.state()) } } pub fn close(&mut self) { - with_socket(self.handle, |s, _| s.close()) + unsafe { self.io.with_mut(|s, _| s.close()) } } pub fn abort(&mut self) { - with_socket(self.handle, |s, _| s.abort()) + unsafe { self.io.with_mut(|s, _| s.abort()) } } pub fn may_send(&self) -> bool { - with_socket(self.handle, |s, _| s.may_send()) + unsafe { self.io.with(|s, _| s.may_send()) } } pub fn may_recv(&self) -> bool { - with_socket(self.handle, |s, _| s.may_recv()) + unsafe { self.io.with(|s, _| s.may_recv()) } } } -fn with_socket( - handle: SocketHandle, - f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R, -) -> R { - Stack::with(|stack| { - let res = { - let (s, cx) = stack.iface.get_socket_and_context::(handle); - f(s, cx) - }; - stack.wake(); - res - }) -} - impl<'a> Drop for TcpSocket<'a> { fn drop(&mut self) { - Stack::with(|stack| { - stack.iface.remove_socket(self.handle); - }) + // safety: not accessed reentrantly. + let s = unsafe { &mut *self.io.stack.get() }; + s.sockets.remove(self.io.handle); } } -impl embedded_io::Error for Error { - fn kind(&self) -> embedded_io::ErrorKind { - embedded_io::ErrorKind::Other - } -} +// ======================= -impl<'d> embedded_io::Io for TcpSocket<'d> { - type Error = Error; +#[derive(Copy, Clone)] +pub struct TcpIo<'a> { + stack: &'a UnsafeCell, + handle: SocketHandle, } -impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { - type ReadFuture<'a> = impl Future> - where - Self: 'a; +impl<'d> TcpIo<'d> { + /// SAFETY: must not call reentrantly. + unsafe fn with(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { + let s = &*self.stack.get(); + let socket = s.sockets.get::(self.handle); + f(socket, &s.iface) + } - fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { - poll_fn(move |cx| { + /// SAFETY: must not call reentrantly. + unsafe fn with_mut(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { + let s = &mut *self.stack.get(); + let socket = s.sockets.get_mut::(self.handle); + let res = f(socket, &mut s.iface); + s.waker.wake(); + res + } + + async fn read(&mut self, buf: &mut [u8]) -> Result { + poll_fn(move |cx| unsafe { // CAUTION: smoltcp semantics around EOF are different to what you'd expect // from posix-like IO, so we have to tweak things here. - with_socket(self.handle, |s, _| match s.recv_slice(buf) { + self.with_mut(|s, _| match s.recv_slice(buf) { // No data ready Ok(0) => { s.register_recv_waker(cx.waker()); @@ -238,24 +226,17 @@ impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { // Data ready! Ok(n) => Poll::Ready(Ok(n)), // EOF - Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), + Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)), // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), + Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)), }) }) + .await } -} - -impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { - type WriteFuture<'a> = impl Future> - where - Self: 'a; - fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { - poll_fn(move |cx| { - with_socket(self.handle, |s, _| match s.send_slice(buf) { + async fn write(&mut self, buf: &[u8]) -> Result { + poll_fn(move |cx| unsafe { + self.with_mut(|s, _| match s.send_slice(buf) { // Not ready to send (no space in the tx buffer) Ok(0) => { s.register_send_waker(cx.waker()); @@ -264,11 +245,47 @@ impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { // Some data sent Ok(n) => Poll::Ready(Ok(n)), // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), + Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)), }) }) + .await + } + + async fn flush(&mut self) -> Result<(), Error> { + poll_fn(move |_| { + Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? + }) + .await + } +} + +impl embedded_io::Error for Error { + fn kind(&self) -> embedded_io::ErrorKind { + embedded_io::ErrorKind::Other + } +} + +impl<'d> embedded_io::Io for TcpSocket<'d> { + type Error = Error; +} + +impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { + type ReadFuture<'a> = impl Future> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + self.io.read(buf) + } +} + +impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { + type WriteFuture<'a> = impl Future> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + self.io.write(buf) } type FlushFuture<'a> = impl Future> @@ -276,9 +293,7 @@ impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { Self: 'a; fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { - poll_fn(move |_| { - Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? - }) + self.io.flush() } } @@ -292,25 +307,7 @@ impl<'d> embedded_io::asynch::Read for TcpReader<'d> { Self: 'a; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { - poll_fn(move |cx| { - // CAUTION: smoltcp semantics around EOF are different to what you'd expect - // from posix-like IO, so we have to tweak things here. - with_socket(self.handle, |s, _| match s.recv_slice(buf) { - // No data ready - Ok(0) => { - s.register_recv_waker(cx.waker()); - Poll::Pending - } - // Data ready! - Ok(n) => Poll::Ready(Ok(n)), - // EOF - Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), - // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - }) - }) + self.io.read(buf) } } @@ -324,21 +321,7 @@ impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { Self: 'a; fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { - poll_fn(move |cx| { - with_socket(self.handle, |s, _| match s.send_slice(buf) { - // Not ready to send (no space in the tx buffer) - Ok(0) => { - s.register_send_waker(cx.waker()); - Poll::Pending - } - // Some data sent - Ok(n) => Poll::Ready(Ok(n)), - // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - }) - }) + self.io.write(buf) } type FlushFuture<'a> = impl Future> @@ -346,8 +329,6 @@ impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { Self: 'a; fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { - poll_fn(move |_| { - Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? - }) + self.io.flush() } } -- cgit