diff options
| author | Dario Nieuwenhuis <[email protected]> | 2022-12-03 00:56:16 +0100 |
|---|---|---|
| committer | Dario Nieuwenhuis <[email protected]> | 2022-12-03 00:56:16 +0100 |
| commit | 02abe00439ba873945bd6b60546a200b3da751f1 (patch) | |
| tree | 62724bbe40f58380ce7ce67125e10348c5867adb /embassy-net | |
| parent | f109e73c6d7ef2ad93102b7c8223f5cef30ef36f (diff) | |
net: don't use UnsafeCell.
The "must not be called reentrantly" invariant is too "global" to
maintain comfortably, and the cost of the RefCell is negligible,
so this was a case of premature optimization.
Diffstat (limited to 'embassy-net')
| -rw-r--r-- | embassy-net/src/stack.rs | 32 | ||||
| -rw-r--r-- | embassy-net/src/tcp.rs | 59 | ||||
| -rw-r--r-- | embassy-net/src/udp.rs | 41 |
3 files changed, 57 insertions, 75 deletions
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs index 3a7610758..631087405 100644 --- a/embassy-net/src/stack.rs +++ b/embassy-net/src/stack.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::{poll_fn, Future}; | 2 | use core::future::{poll_fn, Future}; |
| 3 | use core::task::{Context, Poll}; | 3 | use core::task::{Context, Poll}; |
| 4 | 4 | ||
| @@ -62,8 +62,8 @@ pub enum ConfigStrategy { | |||
| 62 | } | 62 | } |
| 63 | 63 | ||
| 64 | pub struct Stack<D: Device> { | 64 | pub struct Stack<D: Device> { |
| 65 | pub(crate) socket: UnsafeCell<SocketStack>, | 65 | pub(crate) socket: RefCell<SocketStack>, |
| 66 | inner: UnsafeCell<Inner<D>>, | 66 | inner: RefCell<Inner<D>>, |
| 67 | } | 67 | } |
| 68 | 68 | ||
| 69 | struct Inner<D: Device> { | 69 | struct Inner<D: Device> { |
| @@ -81,8 +81,6 @@ pub(crate) struct SocketStack { | |||
| 81 | next_local_port: u16, | 81 | next_local_port: u16, |
| 82 | } | 82 | } |
| 83 | 83 | ||
| 84 | unsafe impl<D: Device> Send for Stack<D> {} | ||
| 85 | |||
| 86 | impl<D: Device + 'static> Stack<D> { | 84 | impl<D: Device + 'static> Stack<D> { |
| 87 | pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( | 85 | pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( |
| 88 | device: D, | 86 | device: D, |
| @@ -143,40 +141,38 @@ impl<D: Device + 'static> Stack<D> { | |||
| 143 | } | 141 | } |
| 144 | 142 | ||
| 145 | Self { | 143 | Self { |
| 146 | socket: UnsafeCell::new(socket), | 144 | socket: RefCell::new(socket), |
| 147 | inner: UnsafeCell::new(inner), | 145 | inner: RefCell::new(inner), |
| 148 | } | 146 | } |
| 149 | } | 147 | } |
| 150 | 148 | ||
| 151 | /// SAFETY: must not call reentrantly. | 149 | fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { |
| 152 | unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { | 150 | f(&*self.socket.borrow(), &*self.inner.borrow()) |
| 153 | f(&*self.socket.get(), &*self.inner.get()) | ||
| 154 | } | 151 | } |
| 155 | 152 | ||
| 156 | /// SAFETY: must not call reentrantly. | 153 | fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { |
| 157 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { | 154 | f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut()) |
| 158 | f(&mut *self.socket.get(), &mut *self.inner.get()) | ||
| 159 | } | 155 | } |
| 160 | 156 | ||
| 161 | pub fn ethernet_address(&self) -> [u8; 6] { | 157 | pub fn ethernet_address(&self) -> [u8; 6] { |
| 162 | unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } | 158 | self.with(|_s, i| i.device.device.ethernet_address()) |
| 163 | } | 159 | } |
| 164 | 160 | ||
| 165 | pub fn is_link_up(&self) -> bool { | 161 | pub fn is_link_up(&self) -> bool { |
| 166 | unsafe { self.with(|_s, i| i.link_up) } | 162 | self.with(|_s, i| i.link_up) |
| 167 | } | 163 | } |
| 168 | 164 | ||
| 169 | pub fn is_config_up(&self) -> bool { | 165 | pub fn is_config_up(&self) -> bool { |
| 170 | unsafe { self.with(|_s, i| i.config.is_some()) } | 166 | self.with(|_s, i| i.config.is_some()) |
| 171 | } | 167 | } |
| 172 | 168 | ||
| 173 | pub fn config(&self) -> Option<Config> { | 169 | pub fn config(&self) -> Option<Config> { |
| 174 | unsafe { self.with(|_s, i| i.config.clone()) } | 170 | self.with(|_s, i| i.config.clone()) |
| 175 | } | 171 | } |
| 176 | 172 | ||
| 177 | pub async fn run(&self) -> ! { | 173 | pub async fn run(&self) -> ! { |
| 178 | poll_fn(|cx| { | 174 | poll_fn(|cx| { |
| 179 | unsafe { self.with_mut(|s, i| i.poll(cx, s)) } | 175 | self.with_mut(|s, i| i.poll(cx, s)); |
| 180 | Poll::<()>::Pending | 176 | Poll::<()>::Pending |
| 181 | }) | 177 | }) |
| 182 | .await; | 178 | .await; |
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index 85d9e5ee1..60386535a 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::poll_fn; | 2 | use core::future::poll_fn; |
| 3 | use core::mem; | 3 | use core::mem; |
| 4 | use core::task::Poll; | 4 | use core::task::Poll; |
| @@ -68,8 +68,7 @@ impl<'a> TcpWriter<'a> { | |||
| 68 | 68 | ||
| 69 | impl<'a> TcpSocket<'a> { | 69 | impl<'a> TcpSocket<'a> { |
| 70 | pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { | 70 | pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { |
| 71 | // safety: not accessed reentrantly. | 71 | let s = &mut *stack.socket.borrow_mut(); |
| 72 | let s = unsafe { &mut *stack.socket.get() }; | ||
| 73 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 72 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| 74 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; | 73 | let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; |
| 75 | let handle = s.sockets.add(tcp::Socket::new( | 74 | let handle = s.sockets.add(tcp::Socket::new( |
| @@ -93,17 +92,15 @@ impl<'a> TcpSocket<'a> { | |||
| 93 | where | 92 | where |
| 94 | T: Into<IpEndpoint>, | 93 | T: Into<IpEndpoint>, |
| 95 | { | 94 | { |
| 96 | // safety: not accessed reentrantly. | 95 | let local_port = self.io.stack.borrow_mut().get_local_port(); |
| 97 | let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port(); | ||
| 98 | 96 | ||
| 99 | // safety: not accessed reentrantly. | 97 | match { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { |
| 100 | match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { | ||
| 101 | Ok(()) => {} | 98 | Ok(()) => {} |
| 102 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), | 99 | Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), |
| 103 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), | 100 | Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), |
| 104 | } | 101 | } |
| 105 | 102 | ||
| 106 | poll_fn(|cx| unsafe { | 103 | poll_fn(|cx| { |
| 107 | self.io.with_mut(|s, _| match s.state() { | 104 | self.io.with_mut(|s, _| match s.state() { |
| 108 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), | 105 | tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), |
| 109 | tcp::State::Listen => unreachable!(), | 106 | tcp::State::Listen => unreachable!(), |
| @@ -121,14 +118,13 @@ impl<'a> TcpSocket<'a> { | |||
| 121 | where | 118 | where |
| 122 | T: Into<IpListenEndpoint>, | 119 | T: Into<IpListenEndpoint>, |
| 123 | { | 120 | { |
| 124 | // safety: not accessed reentrantly. | 121 | match self.io.with_mut(|s, _| s.listen(local_endpoint)) { |
| 125 | match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } { | ||
| 126 | Ok(()) => {} | 122 | Ok(()) => {} |
| 127 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), | 123 | Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), |
| 128 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), | 124 | Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), |
| 129 | } | 125 | } |
| 130 | 126 | ||
| 131 | poll_fn(|cx| unsafe { | 127 | poll_fn(|cx| { |
| 132 | self.io.with_mut(|s, _| match s.state() { | 128 | self.io.with_mut(|s, _| match s.state() { |
| 133 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { | 129 | tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { |
| 134 | s.register_send_waker(cx.waker()); | 130 | s.register_send_waker(cx.waker()); |
| @@ -149,51 +145,49 @@ impl<'a> TcpSocket<'a> { | |||
| 149 | } | 145 | } |
| 150 | 146 | ||
| 151 | pub fn set_timeout(&mut self, duration: Option<Duration>) { | 147 | pub fn set_timeout(&mut self, duration: Option<Duration>) { |
| 152 | unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } | 148 | self.io.with_mut(|s, _| s.set_timeout(duration)) |
| 153 | } | 149 | } |
| 154 | 150 | ||
| 155 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { | 151 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { |
| 156 | unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } | 152 | self.io.with_mut(|s, _| s.set_keep_alive(interval)) |
| 157 | } | 153 | } |
| 158 | 154 | ||
| 159 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { | 155 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { |
| 160 | unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } | 156 | self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) |
| 161 | } | 157 | } |
| 162 | 158 | ||
| 163 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { | 159 | pub fn local_endpoint(&self) -> Option<IpEndpoint> { |
| 164 | unsafe { self.io.with(|s, _| s.local_endpoint()) } | 160 | self.io.with(|s, _| s.local_endpoint()) |
| 165 | } | 161 | } |
| 166 | 162 | ||
| 167 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { | 163 | pub fn remote_endpoint(&self) -> Option<IpEndpoint> { |
| 168 | unsafe { self.io.with(|s, _| s.remote_endpoint()) } | 164 | self.io.with(|s, _| s.remote_endpoint()) |
| 169 | } | 165 | } |
| 170 | 166 | ||
| 171 | pub fn state(&self) -> tcp::State { | 167 | pub fn state(&self) -> tcp::State { |
| 172 | unsafe { self.io.with(|s, _| s.state()) } | 168 | self.io.with(|s, _| s.state()) |
| 173 | } | 169 | } |
| 174 | 170 | ||
| 175 | pub fn close(&mut self) { | 171 | pub fn close(&mut self) { |
| 176 | unsafe { self.io.with_mut(|s, _| s.close()) } | 172 | self.io.with_mut(|s, _| s.close()) |
| 177 | } | 173 | } |
| 178 | 174 | ||
| 179 | pub fn abort(&mut self) { | 175 | pub fn abort(&mut self) { |
| 180 | unsafe { self.io.with_mut(|s, _| s.abort()) } | 176 | self.io.with_mut(|s, _| s.abort()) |
| 181 | } | 177 | } |
| 182 | 178 | ||
| 183 | pub fn may_send(&self) -> bool { | 179 | pub fn may_send(&self) -> bool { |
| 184 | unsafe { self.io.with(|s, _| s.may_send()) } | 180 | self.io.with(|s, _| s.may_send()) |
| 185 | } | 181 | } |
| 186 | 182 | ||
| 187 | pub fn may_recv(&self) -> bool { | 183 | pub fn may_recv(&self) -> bool { |
| 188 | unsafe { self.io.with(|s, _| s.may_recv()) } | 184 | self.io.with(|s, _| s.may_recv()) |
| 189 | } | 185 | } |
| 190 | } | 186 | } |
| 191 | 187 | ||
| 192 | impl<'a> Drop for TcpSocket<'a> { | 188 | impl<'a> Drop for TcpSocket<'a> { |
| 193 | fn drop(&mut self) { | 189 | fn drop(&mut self) { |
| 194 | // safety: not accessed reentrantly. | 190 | self.io.stack.borrow_mut().sockets.remove(self.io.handle); |
| 195 | let s = unsafe { &mut *self.io.stack.get() }; | ||
| 196 | s.sockets.remove(self.io.handle); | ||
| 197 | } | 191 | } |
| 198 | } | 192 | } |
| 199 | 193 | ||
| @@ -201,21 +195,19 @@ impl<'a> Drop for TcpSocket<'a> { | |||
| 201 | 195 | ||
| 202 | #[derive(Copy, Clone)] | 196 | #[derive(Copy, Clone)] |
| 203 | struct TcpIo<'a> { | 197 | struct TcpIo<'a> { |
| 204 | stack: &'a UnsafeCell<SocketStack>, | 198 | stack: &'a RefCell<SocketStack>, |
| 205 | handle: SocketHandle, | 199 | handle: SocketHandle, |
| 206 | } | 200 | } |
| 207 | 201 | ||
| 208 | impl<'d> TcpIo<'d> { | 202 | impl<'d> TcpIo<'d> { |
| 209 | /// SAFETY: must not call reentrantly. | 203 | fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { |
| 210 | unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { | 204 | let s = &*self.stack.borrow(); |
| 211 | let s = &*self.stack.get(); | ||
| 212 | let socket = s.sockets.get::<tcp::Socket>(self.handle); | 205 | let socket = s.sockets.get::<tcp::Socket>(self.handle); |
| 213 | f(socket, &s.iface) | 206 | f(socket, &s.iface) |
| 214 | } | 207 | } |
| 215 | 208 | ||
| 216 | /// SAFETY: must not call reentrantly. | 209 | fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { |
| 217 | unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { | 210 | let s = &mut *self.stack.borrow_mut(); |
| 218 | let s = &mut *self.stack.get(); | ||
| 219 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); | 211 | let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); |
| 220 | let res = f(socket, &mut s.iface); | 212 | let res = f(socket, &mut s.iface); |
| 221 | s.waker.wake(); | 213 | s.waker.wake(); |
| @@ -223,7 +215,7 @@ impl<'d> TcpIo<'d> { | |||
| 223 | } | 215 | } |
| 224 | 216 | ||
| 225 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { | 217 | async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { |
| 226 | poll_fn(move |cx| unsafe { | 218 | poll_fn(move |cx| { |
| 227 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect | 219 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect |
| 228 | // from posix-like IO, so we have to tweak things here. | 220 | // from posix-like IO, so we have to tweak things here. |
| 229 | self.with_mut(|s, _| match s.recv_slice(buf) { | 221 | self.with_mut(|s, _| match s.recv_slice(buf) { |
| @@ -244,7 +236,7 @@ impl<'d> TcpIo<'d> { | |||
| 244 | } | 236 | } |
| 245 | 237 | ||
| 246 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { | 238 | async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { |
| 247 | poll_fn(move |cx| unsafe { | 239 | poll_fn(move |cx| { |
| 248 | self.with_mut(|s, _| match s.send_slice(buf) { | 240 | self.with_mut(|s, _| match s.send_slice(buf) { |
| 249 | // Not ready to send (no space in the tx buffer) | 241 | // Not ready to send (no space in the tx buffer) |
| 250 | Ok(0) => { | 242 | Ok(0) => { |
| @@ -332,6 +324,7 @@ mod embedded_io_impls { | |||
| 332 | 324 | ||
| 333 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] | 325 | #[cfg(all(feature = "unstable-traits", feature = "nightly"))] |
| 334 | pub mod client { | 326 | pub mod client { |
| 327 | use core::cell::UnsafeCell; | ||
| 335 | use core::mem::MaybeUninit; | 328 | use core::mem::MaybeUninit; |
| 336 | use core::ptr::NonNull; | 329 | use core::ptr::NonNull; |
| 337 | 330 | ||
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs index f2e33493c..4ddad77d4 100644 --- a/embassy-net/src/udp.rs +++ b/embassy-net/src/udp.rs | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | use core::cell::UnsafeCell; | 1 | use core::cell::RefCell; |
| 2 | use core::future::poll_fn; | 2 | use core::future::poll_fn; |
| 3 | use core::mem; | 3 | use core::mem; |
| 4 | use core::task::Poll; | 4 | use core::task::Poll; |
| @@ -27,7 +27,7 @@ pub enum Error { | |||
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | pub struct UdpSocket<'a> { | 29 | pub struct UdpSocket<'a> { |
| 30 | stack: &'a UnsafeCell<SocketStack>, | 30 | stack: &'a RefCell<SocketStack>, |
| 31 | handle: SocketHandle, | 31 | handle: SocketHandle, |
| 32 | } | 32 | } |
| 33 | 33 | ||
| @@ -39,8 +39,7 @@ impl<'a> UdpSocket<'a> { | |||
| 39 | tx_meta: &'a mut [PacketMetadata], | 39 | tx_meta: &'a mut [PacketMetadata], |
| 40 | tx_buffer: &'a mut [u8], | 40 | tx_buffer: &'a mut [u8], |
| 41 | ) -> Self { | 41 | ) -> Self { |
| 42 | // safety: not accessed reentrantly. | 42 | let s = &mut *stack.socket.borrow_mut(); |
| 43 | let s = unsafe { &mut *stack.socket.get() }; | ||
| 44 | 43 | ||
| 45 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; | 44 | let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; |
| 46 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; | 45 | let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; |
| @@ -63,30 +62,26 @@ impl<'a> UdpSocket<'a> { | |||
| 63 | { | 62 | { |
| 64 | let mut endpoint = endpoint.into(); | 63 | let mut endpoint = endpoint.into(); |
| 65 | 64 | ||
| 66 | // safety: not accessed reentrantly. | ||
| 67 | if endpoint.port == 0 { | 65 | if endpoint.port == 0 { |
| 68 | // If user didn't specify port allocate a dynamic port. | 66 | // If user didn't specify port allocate a dynamic port. |
| 69 | endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); | 67 | endpoint.port = self.stack.borrow_mut().get_local_port(); |
| 70 | } | 68 | } |
| 71 | 69 | ||
| 72 | // safety: not accessed reentrantly. | 70 | match self.with_mut(|s, _| s.bind(endpoint)) { |
| 73 | match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } { | ||
| 74 | Ok(()) => Ok(()), | 71 | Ok(()) => Ok(()), |
| 75 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), | 72 | Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), |
| 76 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), | 73 | Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), |
| 77 | } | 74 | } |
| 78 | } | 75 | } |
| 79 | 76 | ||
| 80 | /// SAFETY: must not call reentrantly. | 77 | fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { |
| 81 | unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { | 78 | let s = &*self.stack.borrow(); |
| 82 | let s = &*self.stack.get(); | ||
| 83 | let socket = s.sockets.get::<udp::Socket>(self.handle); | 79 | let socket = s.sockets.get::<udp::Socket>(self.handle); |
| 84 | f(socket, &s.iface) | 80 | f(socket, &s.iface) |
| 85 | } | 81 | } |
| 86 | 82 | ||
| 87 | /// SAFETY: must not call reentrantly. | 83 | fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { |
| 88 | unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { | 84 | let s = &mut *self.stack.borrow_mut(); |
| 89 | let s = &mut *self.stack.get(); | ||
| 90 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); | 85 | let socket = s.sockets.get_mut::<udp::Socket>(self.handle); |
| 91 | let res = f(socket, &mut s.iface); | 86 | let res = f(socket, &mut s.iface); |
| 92 | s.waker.wake(); | 87 | s.waker.wake(); |
| @@ -94,7 +89,7 @@ impl<'a> UdpSocket<'a> { | |||
| 94 | } | 89 | } |
| 95 | 90 | ||
| 96 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { | 91 | pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { |
| 97 | poll_fn(move |cx| unsafe { | 92 | poll_fn(move |cx| { |
| 98 | self.with_mut(|s, _| match s.recv_slice(buf) { | 93 | self.with_mut(|s, _| match s.recv_slice(buf) { |
| 99 | Ok(x) => Poll::Ready(Ok(x)), | 94 | Ok(x) => Poll::Ready(Ok(x)), |
| 100 | // No data ready | 95 | // No data ready |
| @@ -113,7 +108,7 @@ impl<'a> UdpSocket<'a> { | |||
| 113 | T: Into<IpEndpoint>, | 108 | T: Into<IpEndpoint>, |
| 114 | { | 109 | { |
| 115 | let remote_endpoint = remote_endpoint.into(); | 110 | let remote_endpoint = remote_endpoint.into(); |
| 116 | poll_fn(move |cx| unsafe { | 111 | poll_fn(move |cx| { |
| 117 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { | 112 | self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { |
| 118 | // Entire datagram has been sent | 113 | // Entire datagram has been sent |
| 119 | Ok(()) => Poll::Ready(Ok(())), | 114 | Ok(()) => Poll::Ready(Ok(())), |
| @@ -128,30 +123,28 @@ impl<'a> UdpSocket<'a> { | |||
| 128 | } | 123 | } |
| 129 | 124 | ||
| 130 | pub fn endpoint(&self) -> IpListenEndpoint { | 125 | pub fn endpoint(&self) -> IpListenEndpoint { |
| 131 | unsafe { self.with(|s, _| s.endpoint()) } | 126 | self.with(|s, _| s.endpoint()) |
| 132 | } | 127 | } |
| 133 | 128 | ||
| 134 | pub fn is_open(&self) -> bool { | 129 | pub fn is_open(&self) -> bool { |
| 135 | unsafe { self.with(|s, _| s.is_open()) } | 130 | self.with(|s, _| s.is_open()) |
| 136 | } | 131 | } |
| 137 | 132 | ||
| 138 | pub fn close(&mut self) { | 133 | pub fn close(&mut self) { |
| 139 | unsafe { self.with_mut(|s, _| s.close()) } | 134 | self.with_mut(|s, _| s.close()) |
| 140 | } | 135 | } |
| 141 | 136 | ||
| 142 | pub fn may_send(&self) -> bool { | 137 | pub fn may_send(&self) -> bool { |
| 143 | unsafe { self.with(|s, _| s.can_send()) } | 138 | self.with(|s, _| s.can_send()) |
| 144 | } | 139 | } |
| 145 | 140 | ||
| 146 | pub fn may_recv(&self) -> bool { | 141 | pub fn may_recv(&self) -> bool { |
| 147 | unsafe { self.with(|s, _| s.can_recv()) } | 142 | self.with(|s, _| s.can_recv()) |
| 148 | } | 143 | } |
| 149 | } | 144 | } |
| 150 | 145 | ||
| 151 | impl Drop for UdpSocket<'_> { | 146 | impl Drop for UdpSocket<'_> { |
| 152 | fn drop(&mut self) { | 147 | fn drop(&mut self) { |
| 153 | // safety: not accessed reentrantly. | 148 | self.stack.borrow_mut().sockets.remove(self.handle); |
| 154 | let s = unsafe { &mut *self.stack.get() }; | ||
| 155 | s.sockets.remove(self.handle); | ||
| 156 | } | 149 | } |
| 157 | } | 150 | } |
