diff options
| author | bors[bot] <26634292+bors[bot]@users.noreply.github.com> | 2022-12-06 23:29:15 +0000 |
|---|---|---|
| committer | GitHub <[email protected]> | 2022-12-06 23:29:15 +0000 |
| commit | 94010d33620bc83b613596c5201e39bd251271e3 (patch) | |
| tree | 04cd3ccc6b628b20a76fd8892e0094b2aba06e68 /embassy-net/src | |
| parent | 40f0272dd0007616c1c92b5fb51fb723a3d47d30 (diff) | |
| parent | f7fe0c1441843b04fa17ba0fe94f8c8d4f851882 (diff) | |
Merge #1100
1100: net: remove unsafe, update smoltcp. r=Dirbaio a=Dirbaio
bors r+
Co-authored-by: Dario Nieuwenhuis <[email protected]>
Diffstat (limited to 'embassy-net/src')
| -rw-r--r-- | embassy-net/src/device.rs | 14 | ||||
| -rw-r--r-- | embassy-net/src/stack.rs | 53 | ||||
| -rw-r--r-- | embassy-net/src/tcp.rs | 62 | ||||
| -rw-r--r-- | embassy-net/src/udp.rs | 41 |
4 files changed, 72 insertions, 98 deletions
diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs index c183bd58a..4bdfd7720 100644 --- a/embassy-net/src/device.rs +++ b/embassy-net/src/device.rs | |||
| @@ -12,8 +12,6 @@ pub enum LinkState { | |||
| 12 | Up, | 12 | Up, |
| 13 | } | 13 | } |
| 14 | 14 | ||
| 15 | // 'static required due to the "fake GAT" in smoltcp::phy::Device. | ||
| 16 | // https://github.com/smoltcp-rs/smoltcp/pull/572 | ||
| 17 | pub trait Device { | 15 | pub trait Device { |
| 18 | fn is_transmit_ready(&mut self) -> bool; | 16 | fn is_transmit_ready(&mut self) -> bool; |
| 19 | fn transmit(&mut self, pkt: PacketBuf); | 17 | fn transmit(&mut self, pkt: PacketBuf); |
| @@ -25,7 +23,7 @@ pub trait Device { | |||
| 25 | fn ethernet_address(&self) -> [u8; 6]; | 23 | fn ethernet_address(&self) -> [u8; 6]; |
| 26 | } | 24 | } |
| 27 | 25 | ||
| 28 | impl<T: ?Sized + Device> Device for &'static mut T { | 26 | impl<T: ?Sized + Device> Device for &mut T { |
| 29 | fn is_transmit_ready(&mut self) -> bool { | 27 | fn is_transmit_ready(&mut self) -> bool { |
| 30 | T::is_transmit_ready(self) | 28 | T::is_transmit_ready(self) |
| 31 | } | 29 | } |
| @@ -63,11 +61,11 @@ impl<D: Device> DeviceAdapter<D> { | |||
| 63 | } | 61 | } |
| 64 | } | 62 | } |
| 65 | 63 | ||
| 66 | impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> { | 64 | impl<D: Device> SmolDevice for DeviceAdapter<D> { |
| 67 | type RxToken = RxToken; | 65 | type RxToken<'a> = RxToken where Self: 'a; |
| 68 | type TxToken = TxToken<'a, D>; | 66 | type TxToken<'a> = TxToken<'a, D> where Self: 'a; |
| 69 | 67 | ||
| 70 | fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { | 68 | fn receive(&mut self) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { |
| 71 | let tx_pkt = PacketBox::new(Packet::new())?; | 69 | let tx_pkt = PacketBox::new(Packet::new())?; |
| 72 | let rx_pkt = self.device.receive()?; | 70 | let rx_pkt = self.device.receive()?; |
| 73 | let rx_token = RxToken { pkt: rx_pkt }; | 71 | let rx_token = RxToken { pkt: rx_pkt }; |
| @@ -80,7 +78,7 @@ impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> { | |||
| 80 | } | 78 | } |
| 81 | 79 | ||
| 82 | /// Construct a transmit token. | 80 | /// Construct a transmit token. |
| 83 | fn transmit(&'a mut self) -> Option<Self::TxToken> { | 81 | fn transmit(&mut self) -> Option<Self::TxToken<'_>> { |
| 84 | if !self.device.is_transmit_ready() { | 82 | if !self.device.is_transmit_ready() { |
| 85 | return None; | 83 | return None; |
| 86 | } | 84 | } |
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs index 3a7610758..5c4fb0442 100644 --- a/embassy-net/src/stack.rs +++ b/embassy-net/src/stack.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::{poll_fn, Future}; | 2 | use core::future::{poll_fn, Future}; |
| 3 | use core::task::{Context, Poll}; | 3 | use core::task::{Context, Poll}; |
| 4 | 4 | ||
| @@ -62,8 +62,8 @@ pub enum ConfigStrategy { | |||
| 62 | } | 62 | } |
| 63 | 63 | ||
| 64 | pub struct Stack<D: Device> { | 64 | pub struct Stack<D: Device> { |
| 65 | pub(crate) socket: UnsafeCell<SocketStack>, | 65 | pub(crate) socket: RefCell<SocketStack>, |
| 66 | inner: UnsafeCell<Inner<D>>, | 66 | inner: RefCell<Inner<D>>, |
| 67 | } | 67 | } |
| 68 | 68 | ||
| 69 | struct Inner<D: Device> { | 69 | struct Inner<D: Device> { |
| @@ -81,8 +81,6 @@ pub(crate) struct SocketStack { | |||
| 81 | next_local_port: u16, | 81 | next_local_port: u16, |
| 82 | } | 82 | } |
| 83 | 83 | ||
| 84 | unsafe impl<D: Device> Send for Stack<D> {} | ||
| 85 | |||
| 86 | impl<D: Device + 'static> Stack<D> { | 84 | impl<D: Device + 'static> Stack<D> { |
| 87 | pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( | 85 | pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( |
| 88 | device: D, | 86 | device: D, |
| @@ -143,40 +141,38 @@ impl<D: Device + 'static> Stack<D> { | |||
| 143 | } | 141 | } |
| 144 | 142 | ||
| 145 | Self { | 143 | Self { |
| 146 | socket: UnsafeCell::new(socket), | 144 | socket: RefCell::new(socket), |
| 147 | inner: UnsafeCell::new(inner), | 145 | inner: RefCell::new(inner), |
| 148 | } | 146 | } |
| 149 | } | 147 | } |
| 150 | 148 | ||
| 151 | /// SAFETY: must not call reentrantly. | 149 | fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { |
| 152 | unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { | 150 | f(&*self.socket.borrow(), &*self.inner.borrow()) |
| 153 | f(&*self.socket.get(), &*self.inner.get()) | ||
| 154 | } | 151 | } |
| 155 | 152 | ||
| 156 | /// SAFETY: must not call reentrantly. | 153 | fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { |
| 157 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { | 154 | f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut()) |
| 158 | f(&mut *self.socket.get(), &mut *self.inner.get()) | ||
| 159 | } | 155 | } |
| 160 | 156 | ||
| 161 | pub fn ethernet_address(&self) -> [u8; 6] { | 157 | pub fn ethernet_address(&self) -> [u8; 6] { |
| 162 | unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } | 158 | self.with(|_s, i| i.device.device.ethernet_address()) |
| 163 | } | 159 | } |
| 164 | 160 | ||
| 165 | pub fn is_link_up(&self) -> bool { | 161 | pub fn is_link_up(&self) -> bool { |
| 166 | unsafe { self.with(|_s, i| i.link_up) } | 162 | self.with(|_s, i| i.link_up) |
| 167 | } | 163 | } |
| 168 | 164 | ||
| 169 | pub fn is_config_up(&self) -> bool { | 165 | pub fn is_config_up(&self) -> bool { |
| 170 | unsafe { self.with(|_s, i| i.config.is_some()) } | 166 | self.with(|_s, i| i.config.is_some()) |
| 171 | } | 167 | } |
| 172 | 168 | ||
| 173 | pub fn config(&self) -> Option<Config> { | 169 | pub fn config(&self) -> Option<Config> { |
| 174 | unsafe { self.with(|_s, i| i.config.clone()) } | 170 | self.with(|_s, i| i.config.clone()) |
| 175 | } | 171 | } |
| 176 | 172 | ||
| 177 | pub async fn run(&self) -> ! { | 173 | pub async fn run(&self) -> ! { |
| 178 | poll_fn(|cx| { | 174 | poll_fn(|cx| { |
| 179 | unsafe { self.with_mut(|s, i| i.poll(cx, s)) } | 175 | self.with_mut(|s, i| i.poll(cx, s)); |
| 180 | Poll::<()>::Pending | 176 | Poll::<()>::Pending |
| 181 | }) | 177 | }) |
| 182 | .await; | 178 | .await; |
| @@ -270,21 +266,12 @@ impl<D: Device + 'static> Inner<D> { | |||
| 270 | None => {} | 266 | None => {} |
| 271 | Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), | 267 | Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), |
| 272 | Some(dhcpv4::Event::Configured(config)) => { | 268 | Some(dhcpv4::Event::Configured(config)) => { |
| 273 | let mut dns_servers = Vec::new(); | 269 | let config = Config { |
| 274 | for s in &config.dns_servers { | 270 | address: config.address, |
| 275 | if let Some(addr) = s { | 271 | gateway: config.router, |
| 276 | dns_servers.push(addr.clone()).unwrap(); | 272 | dns_servers: config.dns_servers, |
| 277 | } | 273 | }; |
| 278 | } | 274 | self.apply_config(s, config) |
| 279 | |||
| 280 | self.apply_config( | ||
| 281 | s, | ||
| 282 | Config { | ||
| 283 | address: config.address, | ||
| 284 | gateway: config.router, | ||
| 285 | dns_servers, | ||
| 286 | }, | ||
| 287 | ) | ||
| 288 | } | 275 | } |
| 289 | } | 276 | } |
| 290 | } else if old_link_up { | 277 | } else if old_link_up { |
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index 85d9e5ee1..73cf2d4e4 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::poll_fn; | 2 | use core::future::poll_fn; |
| 3 | use core::mem; | 3 | use core::mem; |
| 4 | use core::task::Poll; | 4 | use core::task::Poll; |
| @@ -68,8 +68,7 @@ impl<'a> TcpWriter<'a> { | |||
| 68 | 68 | ||
| 69 | impl<'a> TcpSocket<'a> { | 69 | impl<'a> TcpSocket<'a> { |
| 70 | pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { | 70 | pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { |
| 71 | // safety: not accessed reentrantly. | 71 | let s = &mut *stack.socket.borrow_mut(); |
| 72 | let s = unsafe { &mut *stack.socket.get() }; | ||
| 73 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 72 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| 74 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; | 73 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; |
| 75 | let handle = s.sockets.add(tcp::Socket::new( | 74 | let handle = s.sockets.add(tcp::Socket::new( |
| @@ -93,17 +92,18 @@ impl<'a> TcpSocket<'a> { | |||
| 93 | where | 92 | where |
| 94 | T: Into<IpEndpoint>, | 93 | T: Into<IpEndpoint>, |
| 95 | { | 94 | { |
| 96 | // safety: not accessed reentrantly. | 95 | let local_port = self.io.stack.borrow_mut().get_local_port(); |
| 97 | let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port(); | ||
| 98 | 96 | ||
| 99 | // safety: not accessed reentrantly. | 97 | match { |
| 100 | match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { | 98 | self.io |
| 99 | .with_mut(|s, i| s.connect(i.context(), remote_endpoint, local_port)) | ||
| 100 | } { | ||
| 101 | Ok(()) => {} | 101 | Ok(()) => {} |
| 102 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), | 102 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), |
| 103 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), | 103 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), |
| 104 | } | 104 | } |
| 105 | 105 | ||
| 106 | poll_fn(|cx| unsafe { | 106 | poll_fn(|cx| { |
| 107 | self.io.with_mut(|s, _| match s.state() { | 107 | self.io.with_mut(|s, _| match s.state() { |
| 108 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), | 108 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), |
| 109 | tcp::State::Listen => unreachable!(), | 109 | tcp::State::Listen => unreachable!(), |
| @@ -121,14 +121,13 @@ impl<'a> TcpSocket<'a> { | |||
| 121 | where | 121 | where |
| 122 | T: Into<IpListenEndpoint>, | 122 | T: Into<IpListenEndpoint>, |
| 123 | { | 123 | { |
| 124 | // safety: not accessed reentrantly. | 124 | match self.io.with_mut(|s, _| s.listen(local_endpoint)) { |
| 125 | match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } { | ||
| 126 | Ok(()) => {} | 125 | Ok(()) => {} |
| 127 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), | 126 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), |
| 128 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), | 127 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), |
| 129 | } | 128 | } |
| 130 | 129 | ||
| 131 | poll_fn(|cx| unsafe { | 130 | poll_fn(|cx| { |
| 132 | self.io.with_mut(|s, _| match s.state() { | 131 | self.io.with_mut(|s, _| match s.state() { |
| 133 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { | 132 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { |
| 134 | s.register_send_waker(cx.waker()); | 133 | s.register_send_waker(cx.waker()); |
| @@ -149,51 +148,49 @@ impl<'a> TcpSocket<'a> { | |||
| 149 | } | 148 | } |
| 150 | 149 | ||
| 151 | pub fn set_timeout(&mut self, duration: Option<Duration>) { | 150 | pub fn set_timeout(&mut self, duration: Option<Duration>) { |
| 152 | unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } | 151 | self.io.with_mut(|s, _| s.set_timeout(duration)) |
| 153 | } | 152 | } |
| 154 | 153 | ||
| 155 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { | 154 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { |
| 156 | unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } | 155 | self.io.with_mut(|s, _| s.set_keep_alive(interval)) |
| 157 | } | 156 | } |
| 158 | 157 | ||
| 159 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { | 158 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { |
| 160 | unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } | 159 | self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) |
| 161 | } | 160 | } |
| 162 | 161 | ||
| 163 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { | 162 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { |
| 164 | unsafe { self.io.with(|s, _| s.local_endpoint()) } | 163 | self.io.with(|s, _| s.local_endpoint()) |
| 165 | } | 164 | } |
| 166 | 165 | ||
| 167 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { | 166 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { |
| 168 | unsafe { self.io.with(|s, _| s.remote_endpoint()) } | 167 | self.io.with(|s, _| s.remote_endpoint()) |
| 169 | } | 168 | } |
| 170 | 169 | ||
| 171 | pub fn state(&self) -> tcp::State { | 170 | pub fn state(&self) -> tcp::State { |
| 172 | unsafe { self.io.with(|s, _| s.state()) } | 171 | self.io.with(|s, _| s.state()) |
| 173 | } | 172 | } |
| 174 | 173 | ||
| 175 | pub fn close(&mut self) { | 174 | pub fn close(&mut self) { |
| 176 | unsafe { self.io.with_mut(|s, _| s.close()) } | 175 | self.io.with_mut(|s, _| s.close()) |
| 177 | } | 176 | } |
| 178 | 177 | ||
| 179 | pub fn abort(&mut self) { | 178 | pub fn abort(&mut self) { |
| 180 | unsafe { self.io.with_mut(|s, _| s.abort()) } | 179 | self.io.with_mut(|s, _| s.abort()) |
| 181 | } | 180 | } |
| 182 | 181 | ||
| 183 | pub fn may_send(&self) -> bool { | 182 | pub fn may_send(&self) -> bool { |
| 184 | unsafe { self.io.with(|s, _| s.may_send()) } | 183 | self.io.with(|s, _| s.may_send()) |
| 185 | } | 184 | } |
| 186 | 185 | ||
| 187 | pub fn may_recv(&self) -> bool { | 186 | pub fn may_recv(&self) -> bool { |
| 188 | unsafe { self.io.with(|s, _| s.may_recv()) } | 187 | self.io.with(|s, _| s.may_recv()) |
| 189 | } | 188 | } |
| 190 | } | 189 | } |
| 191 | 190 | ||
| 192 | impl<'a> Drop for TcpSocket<'a> { | 191 | impl<'a> Drop for TcpSocket<'a> { |
| 193 | fn drop(&mut self) { | 192 | fn drop(&mut self) { |
| 194 | // safety: not accessed reentrantly. | 193 | self.io.stack.borrow_mut().sockets.remove(self.io.handle); |
| 195 | let s = unsafe { &mut *self.io.stack.get() }; | ||
| 196 | s.sockets.remove(self.io.handle); | ||
| 197 | } | 194 | } |
| 198 | } | 195 | } |
| 199 | 196 | ||
| @@ -201,21 +198,19 @@ impl<'a> Drop for TcpSocket<'a> { | |||
| 201 | 198 | ||
| 202 | #[derive(Copy, Clone)] | 199 | #[derive(Copy, Clone)] |
| 203 | struct TcpIo<'a> { | 200 | struct TcpIo<'a> { |
| 204 | stack: &'a UnsafeCell<SocketStack>, | 201 | stack: &'a RefCell<SocketStack>, |
| 205 | handle: SocketHandle, | 202 | handle: SocketHandle, |
| 206 | } | 203 | } |
| 207 | 204 | ||
| 208 | impl<'d> TcpIo<'d> { | 205 | impl<'d> TcpIo<'d> { |
| 209 | /// SAFETY: must not call reentrantly. | 206 | fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { |
| 210 | unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { | 207 | let s = &*self.stack.borrow(); |
| 211 | let s = &*self.stack.get(); | ||
| 212 | let socket = s.sockets.get::<tcp::Socket>(self.handle); | 208 | let socket = s.sockets.get::<tcp::Socket>(self.handle); |
| 213 | f(socket, &s.iface) | 209 | f(socket, &s.iface) |
| 214 | } | 210 | } |
| 215 | 211 | ||
| 216 | /// SAFETY: must not call reentrantly. | 212 | fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { |
| 217 | unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { | 213 | let s = &mut *self.stack.borrow_mut(); |
| 218 | let s = &mut *self.stack.get(); | ||
| 219 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); | 214 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); |
| 220 | let res = f(socket, &mut s.iface); | 215 | let res = f(socket, &mut s.iface); |
| 221 | s.waker.wake(); | 216 | s.waker.wake(); |
| @@ -223,7 +218,7 @@ impl<'d> TcpIo<'d> { | |||
| 223 | } | 218 | } |
| 224 | 219 | ||
| 225 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { | 220 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { |
| 226 | poll_fn(move |cx| unsafe { | 221 | poll_fn(move |cx| { |
| 227 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect | 222 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect |
| 228 | // from posix-like IO, so we have to tweak things here. | 223 | // from posix-like IO, so we have to tweak things here. |
| 229 | self.with_mut(|s, _| match s.recv_slice(buf) { | 224 | self.with_mut(|s, _| match s.recv_slice(buf) { |
| @@ -244,7 +239,7 @@ impl<'d> TcpIo<'d> { | |||
| 244 | } | 239 | } |
| 245 | 240 | ||
| 246 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { | 241 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { |
| 247 | poll_fn(move |cx| unsafe { | 242 | poll_fn(move |cx| { |
| 248 | self.with_mut(|s, _| match s.send_slice(buf) { | 243 | self.with_mut(|s, _| match s.send_slice(buf) { |
| 249 | // Not ready to send (no space in the tx buffer) | 244 | // Not ready to send (no space in the tx buffer) |
| 250 | Ok(0) => { | 245 | Ok(0) => { |
| @@ -332,6 +327,7 @@ mod embedded_io_impls { | |||
| 332 | 327 | ||
| 333 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] | 328 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] |
| 334 | pub mod client { | 329 | pub mod client { |
| 330 | use core::cell::UnsafeCell; | ||
| 335 | use core::mem::MaybeUninit; | 331 | use core::mem::MaybeUninit; |
| 336 | use core::ptr::NonNull; | 332 | use core::ptr::NonNull; |
| 337 | 333 | ||
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs index f2e33493c..4ddad77d4 100644 --- a/embassy-net/src/udp.rs +++ b/embassy-net/src/udp.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::poll_fn; | 2 | use core::future::poll_fn; |
| 3 | use core::mem; | 3 | use core::mem; |
| 4 | use core::task::Poll; | 4 | use core::task::Poll; |
| @@ -27,7 +27,7 @@ pub enum Error { | |||
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | pub struct UdpSocket<'a> { | 29 | pub struct UdpSocket<'a> { |
| 30 | stack: &'a UnsafeCell<SocketStack>, | 30 | stack: &'a RefCell<SocketStack>, |
| 31 | handle: SocketHandle, | 31 | handle: SocketHandle, |
| 32 | } | 32 | } |
| 33 | 33 | ||
| @@ -39,8 +39,7 @@ impl<'a> UdpSocket<'a> { | |||
| 39 | tx_meta: &'a mut [PacketMetadata], | 39 | tx_meta: &'a mut [PacketMetadata], |
| 40 | tx_buffer: &'a mut [u8], | 40 | tx_buffer: &'a mut [u8], |
| 41 | ) -> Self { | 41 | ) -> Self { |
| 42 | // safety: not accessed reentrantly. | 42 | let s = &mut *stack.socket.borrow_mut(); |
| 43 | let s = unsafe { &mut *stack.socket.get() }; | ||
| 44 | 43 | ||
| 45 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; | 44 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; |
| 46 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 45 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| @@ -63,30 +62,26 @@ impl<'a> UdpSocket<'a> { | |||
| 63 | { | 62 | { |
| 64 | let mut endpoint = endpoint.into(); | 63 | let mut endpoint = endpoint.into(); |
| 65 | 64 | ||
| 66 | // safety: not accessed reentrantly. | ||
| 67 | if endpoint.port == 0 { | 65 | if endpoint.port == 0 { |
| 68 | // If user didn't specify port allocate a dynamic port. | 66 | // If user didn't specify port allocate a dynamic port. |
| 69 | endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); | 67 | endpoint.port = self.stack.borrow_mut().get_local_port(); |
| 70 | } | 68 | } |
| 71 | 69 | ||
| 72 | // safety: not accessed reentrantly. | 70 | match self.with_mut(|s, _| s.bind(endpoint)) { |
| 73 | match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } { | ||
| 74 | Ok(()) => Ok(()), | 71 | Ok(()) => Ok(()), |
| 75 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), | 72 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), |
| 76 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), | 73 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), |
| 77 | } | 74 | } |
| 78 | } | 75 | } |
| 79 | 76 | ||
| 80 | /// SAFETY: must not call reentrantly. | 77 | fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { |
| 81 | unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { | 78 | let s = &*self.stack.borrow(); |
| 82 | let s = &*self.stack.get(); | ||
| 83 | let socket = s.sockets.get::<udp::Socket>(self.handle); | 79 | let socket = s.sockets.get::<udp::Socket>(self.handle); |
| 84 | f(socket, &s.iface) | 80 | f(socket, &s.iface) |
| 85 | } | 81 | } |
| 86 | 82 | ||
| 87 | /// SAFETY: must not call reentrantly. | 83 | fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { |
| 88 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { | 84 | let s = &mut *self.stack.borrow_mut(); |
| 89 | let s = &mut *self.stack.get(); | ||
| 90 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); | 85 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); |
| 91 | let res = f(socket, &mut s.iface); | 86 | let res = f(socket, &mut s.iface); |
| 92 | s.waker.wake(); | 87 | s.waker.wake(); |
| @@ -94,7 +89,7 @@ impl<'a> UdpSocket<'a> { | |||
| 94 | } | 89 | } |
| 95 | 90 | ||
| 96 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { | 91 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { |
| 97 | poll_fn(move |cx| unsafe { | 92 | poll_fn(move |cx| { |
| 98 | self.with_mut(|s, _| match s.recv_slice(buf) { | 93 | self.with_mut(|s, _| match s.recv_slice(buf) { |
| 99 | Ok(x) => Poll::Ready(Ok(x)), | 94 | Ok(x) => Poll::Ready(Ok(x)), |
| 100 | // No data ready | 95 | // No data ready |
| @@ -113,7 +108,7 @@ impl<'a> UdpSocket<'a> { | |||
| 113 | T: Into<IpEndpoint>, | 108 | T: Into<IpEndpoint>, |
| 114 | { | 109 | { |
| 115 | let remote_endpoint = remote_endpoint.into(); | 110 | let remote_endpoint = remote_endpoint.into(); |
| 116 | poll_fn(move |cx| unsafe { | 111 | poll_fn(move |cx| { |
| 117 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { | 112 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { |
| 118 | // Entire datagram has been sent | 113 | // Entire datagram has been sent |
| 119 | Ok(()) => Poll::Ready(Ok(())), | 114 | Ok(()) => Poll::Ready(Ok(())), |
| @@ -128,30 +123,28 @@ impl<'a> UdpSocket<'a> { | |||
| 128 | } | 123 | } |
| 129 | 124 | ||
| 130 | pub fn endpoint(&self) -> IpListenEndpoint { | 125 | pub fn endpoint(&self) -> IpListenEndpoint { |
| 131 | unsafe { self.with(|s, _| s.endpoint()) } | 126 | self.with(|s, _| s.endpoint()) |
| 132 | } | 127 | } |
| 133 | 128 | ||
| 134 | pub fn is_open(&self) -> bool { | 129 | pub fn is_open(&self) -> bool { |
| 135 | unsafe { self.with(|s, _| s.is_open()) } | 130 | self.with(|s, _| s.is_open()) |
| 136 | } | 131 | } |
| 137 | 132 | ||
| 138 | pub fn close(&mut self) { | 133 | pub fn close(&mut self) { |
| 139 | unsafe { self.with_mut(|s, _| s.close()) } | 134 | self.with_mut(|s, _| s.close()) |
| 140 | } | 135 | } |
| 141 | 136 | ||
| 142 | pub fn may_send(&self) -> bool { | 137 | pub fn may_send(&self) -> bool { |
| 143 | unsafe { self.with(|s, _| s.can_send()) } | 138 | self.with(|s, _| s.can_send()) |
| 144 | } | 139 | } |
| 145 | 140 | ||
| 146 | pub fn may_recv(&self) -> bool { | 141 | pub fn may_recv(&self) -> bool { |
| 147 | unsafe { self.with(|s, _| s.can_recv()) } | 142 | self.with(|s, _| s.can_recv()) |
| 148 | } | 143 | } |
| 149 | } | 144 | } |
| 150 | 145 | ||
| 151 | impl Drop for UdpSocket<'_> { | 146 | impl Drop for UdpSocket<'_> { |
| 152 | fn drop(&mut self) { | 147 | fn drop(&mut self) { |
| 153 | // safety: not accessed reentrantly. | 148 | self.stack.borrow_mut().sockets.remove(self.handle); |
| 154 | let s = unsafe { &mut *self.stack.get() }; | ||
| 155 | s.sockets.remove(self.handle); | ||
| 156 | } | 149 | } |
| 157 | } | 150 | } |
