aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2022-12-06 23:29:15 +0000
committerGitHub <[email protected]>2022-12-06 23:29:15 +0000
commit94010d33620bc83b613596c5201e39bd251271e3 (patch)
tree04cd3ccc6b628b20a76fd8892e0094b2aba06e68
parent40f0272dd0007616c1c92b5fb51fb723a3d47d30 (diff)
parentf7fe0c1441843b04fa17ba0fe94f8c8d4f851882 (diff)
Merge #1100
1100: net: remove unsafe, update smoltcp. r=Dirbaio a=Dirbaio bors r+ Co-authored-by: Dario Nieuwenhuis <[email protected]>
-rw-r--r--embassy-net/Cargo.toml2
-rw-r--r--embassy-net/src/device.rs14
-rw-r--r--embassy-net/src/stack.rs53
-rw-r--r--embassy-net/src/tcp.rs62
-rw-r--r--embassy-net/src/udp.rs41
5 files changed, 73 insertions, 99 deletions
diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml
index 86d4aa105..ac338843d 100644
--- a/embassy-net/Cargo.toml
+++ b/embassy-net/Cargo.toml
@@ -57,7 +57,7 @@ embedded-nal-async = { version = "0.3.0", optional = true }
57[dependencies.smoltcp] 57[dependencies.smoltcp]
58version = "0.8.0" 58version = "0.8.0"
59git = "https://github.com/smoltcp-rs/smoltcp" 59git = "https://github.com/smoltcp-rs/smoltcp"
60rev = "ed0cf16750a42f30e31fcaf5347915592924b1e3" 60rev = "b7a7c4b1c56e8d4c2524c1e3a056c745a13cc09f"
61default-features = false 61default-features = false
62features = [ 62features = [
63 "proto-ipv4", 63 "proto-ipv4",
diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs
index c183bd58a..4bdfd7720 100644
--- a/embassy-net/src/device.rs
+++ b/embassy-net/src/device.rs
@@ -12,8 +12,6 @@ pub enum LinkState {
12 Up, 12 Up,
13} 13}
14 14
15// 'static required due to the "fake GAT" in smoltcp::phy::Device.
16// https://github.com/smoltcp-rs/smoltcp/pull/572
17pub trait Device { 15pub trait Device {
18 fn is_transmit_ready(&mut self) -> bool; 16 fn is_transmit_ready(&mut self) -> bool;
19 fn transmit(&mut self, pkt: PacketBuf); 17 fn transmit(&mut self, pkt: PacketBuf);
@@ -25,7 +23,7 @@ pub trait Device {
25 fn ethernet_address(&self) -> [u8; 6]; 23 fn ethernet_address(&self) -> [u8; 6];
26} 24}
27 25
28impl<T: ?Sized + Device> Device for &'static mut T { 26impl<T: ?Sized + Device> Device for &mut T {
29 fn is_transmit_ready(&mut self) -> bool { 27 fn is_transmit_ready(&mut self) -> bool {
30 T::is_transmit_ready(self) 28 T::is_transmit_ready(self)
31 } 29 }
@@ -63,11 +61,11 @@ impl<D: Device> DeviceAdapter<D> {
63 } 61 }
64} 62}
65 63
66impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> { 64impl<D: Device> SmolDevice for DeviceAdapter<D> {
67 type RxToken = RxToken; 65 type RxToken<'a> = RxToken where Self: 'a;
68 type TxToken = TxToken<'a, D>; 66 type TxToken<'a> = TxToken<'a, D> where Self: 'a;
69 67
70 fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { 68 fn receive(&mut self) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
71 let tx_pkt = PacketBox::new(Packet::new())?; 69 let tx_pkt = PacketBox::new(Packet::new())?;
72 let rx_pkt = self.device.receive()?; 70 let rx_pkt = self.device.receive()?;
73 let rx_token = RxToken { pkt: rx_pkt }; 71 let rx_token = RxToken { pkt: rx_pkt };
@@ -80,7 +78,7 @@ impl<'a, D: Device + 'static> SmolDevice<'a> for DeviceAdapter<D> {
80 } 78 }
81 79
82 /// Construct a transmit token. 80 /// Construct a transmit token.
83 fn transmit(&'a mut self) -> Option<Self::TxToken> { 81 fn transmit(&mut self) -> Option<Self::TxToken<'_>> {
84 if !self.device.is_transmit_ready() { 82 if !self.device.is_transmit_ready() {
85 return None; 83 return None;
86 } 84 }
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs
index 3a7610758..5c4fb0442 100644
--- a/embassy-net/src/stack.rs
+++ b/embassy-net/src/stack.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::{poll_fn, Future}; 2use core::future::{poll_fn, Future};
3use core::task::{Context, Poll}; 3use core::task::{Context, Poll};
4 4
@@ -62,8 +62,8 @@ pub enum ConfigStrategy {
62} 62}
63 63
64pub struct Stack<D: Device> { 64pub struct Stack<D: Device> {
65 pub(crate) socket: UnsafeCell<SocketStack>, 65 pub(crate) socket: RefCell<SocketStack>,
66 inner: UnsafeCell<Inner<D>>, 66 inner: RefCell<Inner<D>>,
67} 67}
68 68
69struct Inner<D: Device> { 69struct Inner<D: Device> {
@@ -81,8 +81,6 @@ pub(crate) struct SocketStack {
81 next_local_port: u16, 81 next_local_port: u16,
82} 82}
83 83
84unsafe impl<D: Device> Send for Stack<D> {}
85
86impl<D: Device + 'static> Stack<D> { 84impl<D: Device + 'static> Stack<D> {
87 pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( 85 pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>(
88 device: D, 86 device: D,
@@ -143,40 +141,38 @@ impl<D: Device + 'static> Stack<D> {
143 } 141 }
144 142
145 Self { 143 Self {
146 socket: UnsafeCell::new(socket), 144 socket: RefCell::new(socket),
147 inner: UnsafeCell::new(inner), 145 inner: RefCell::new(inner),
148 } 146 }
149 } 147 }
150 148
151 /// SAFETY: must not call reentrantly. 149 fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R {
152 unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { 150 f(&*self.socket.borrow(), &*self.inner.borrow())
153 f(&*self.socket.get(), &*self.inner.get())
154 } 151 }
155 152
156 /// SAFETY: must not call reentrantly. 153 fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R {
157 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { 154 f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut())
158 f(&mut *self.socket.get(), &mut *self.inner.get())
159 } 155 }
160 156
161 pub fn ethernet_address(&self) -> [u8; 6] { 157 pub fn ethernet_address(&self) -> [u8; 6] {
162 unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } 158 self.with(|_s, i| i.device.device.ethernet_address())
163 } 159 }
164 160
165 pub fn is_link_up(&self) -> bool { 161 pub fn is_link_up(&self) -> bool {
166 unsafe { self.with(|_s, i| i.link_up) } 162 self.with(|_s, i| i.link_up)
167 } 163 }
168 164
169 pub fn is_config_up(&self) -> bool { 165 pub fn is_config_up(&self) -> bool {
170 unsafe { self.with(|_s, i| i.config.is_some()) } 166 self.with(|_s, i| i.config.is_some())
171 } 167 }
172 168
173 pub fn config(&self) -> Option<Config> { 169 pub fn config(&self) -> Option<Config> {
174 unsafe { self.with(|_s, i| i.config.clone()) } 170 self.with(|_s, i| i.config.clone())
175 } 171 }
176 172
177 pub async fn run(&self) -> ! { 173 pub async fn run(&self) -> ! {
178 poll_fn(|cx| { 174 poll_fn(|cx| {
179 unsafe { self.with_mut(|s, i| i.poll(cx, s)) } 175 self.with_mut(|s, i| i.poll(cx, s));
180 Poll::<()>::Pending 176 Poll::<()>::Pending
181 }) 177 })
182 .await; 178 .await;
@@ -270,21 +266,12 @@ impl<D: Device + 'static> Inner<D> {
270 None => {} 266 None => {}
271 Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s), 267 Some(dhcpv4::Event::Deconfigured) => self.unapply_config(s),
272 Some(dhcpv4::Event::Configured(config)) => { 268 Some(dhcpv4::Event::Configured(config)) => {
273 let mut dns_servers = Vec::new(); 269 let config = Config {
274 for s in &config.dns_servers { 270 address: config.address,
275 if let Some(addr) = s { 271 gateway: config.router,
276 dns_servers.push(addr.clone()).unwrap(); 272 dns_servers: config.dns_servers,
277 } 273 };
278 } 274 self.apply_config(s, config)
279
280 self.apply_config(
281 s,
282 Config {
283 address: config.address,
284 gateway: config.router,
285 dns_servers,
286 },
287 )
288 } 275 }
289 } 276 }
290 } else if old_link_up { 277 } else if old_link_up {
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs
index 85d9e5ee1..73cf2d4e4 100644
--- a/embassy-net/src/tcp.rs
+++ b/embassy-net/src/tcp.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::poll_fn; 2use core::future::poll_fn;
3use core::mem; 3use core::mem;
4use core::task::Poll; 4use core::task::Poll;
@@ -68,8 +68,7 @@ impl<'a> TcpWriter<'a> {
68 68
69impl<'a> TcpSocket<'a> { 69impl<'a> TcpSocket<'a> {
70 pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { 70 pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
71 // safety: not accessed reentrantly. 71 let s = &mut *stack.socket.borrow_mut();
72 let s = unsafe { &mut *stack.socket.get() };
73 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 72 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
74 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; 73 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
75 let handle = s.sockets.add(tcp::Socket::new( 74 let handle = s.sockets.add(tcp::Socket::new(
@@ -93,17 +92,18 @@ impl<'a> TcpSocket<'a> {
93 where 92 where
94 T: Into<IpEndpoint>, 93 T: Into<IpEndpoint>,
95 { 94 {
96 // safety: not accessed reentrantly. 95 let local_port = self.io.stack.borrow_mut().get_local_port();
97 let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port();
98 96
99 // safety: not accessed reentrantly. 97 match {
100 match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } { 98 self.io
99 .with_mut(|s, i| s.connect(i.context(), remote_endpoint, local_port))
100 } {
101 Ok(()) => {} 101 Ok(()) => {}
102 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), 102 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState),
103 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), 103 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute),
104 } 104 }
105 105
106 poll_fn(|cx| unsafe { 106 poll_fn(|cx| {
107 self.io.with_mut(|s, _| match s.state() { 107 self.io.with_mut(|s, _| match s.state() {
108 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), 108 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)),
109 tcp::State::Listen => unreachable!(), 109 tcp::State::Listen => unreachable!(),
@@ -121,14 +121,13 @@ impl<'a> TcpSocket<'a> {
121 where 121 where
122 T: Into<IpListenEndpoint>, 122 T: Into<IpListenEndpoint>,
123 { 123 {
124 // safety: not accessed reentrantly. 124 match self.io.with_mut(|s, _| s.listen(local_endpoint)) {
125 match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } {
126 Ok(()) => {} 125 Ok(()) => {}
127 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), 126 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState),
128 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), 127 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort),
129 } 128 }
130 129
131 poll_fn(|cx| unsafe { 130 poll_fn(|cx| {
132 self.io.with_mut(|s, _| match s.state() { 131 self.io.with_mut(|s, _| match s.state() {
133 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { 132 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
134 s.register_send_waker(cx.waker()); 133 s.register_send_waker(cx.waker());
@@ -149,51 +148,49 @@ impl<'a> TcpSocket<'a> {
149 } 148 }
150 149
151 pub fn set_timeout(&mut self, duration: Option<Duration>) { 150 pub fn set_timeout(&mut self, duration: Option<Duration>) {
152 unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } 151 self.io.with_mut(|s, _| s.set_timeout(duration))
153 } 152 }
154 153
155 pub fn set_keep_alive(&mut self, interval: Option<Duration>) { 154 pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
156 unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } 155 self.io.with_mut(|s, _| s.set_keep_alive(interval))
157 } 156 }
158 157
159 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { 158 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
160 unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } 159 self.io.with_mut(|s, _| s.set_hop_limit(hop_limit))
161 } 160 }
162 161
163 pub fn local_endpoint(&self) -> Option<IpEndpoint> { 162 pub fn local_endpoint(&self) -> Option<IpEndpoint> {
164 unsafe { self.io.with(|s, _| s.local_endpoint()) } 163 self.io.with(|s, _| s.local_endpoint())
165 } 164 }
166 165
167 pub fn remote_endpoint(&self) -> Option<IpEndpoint> { 166 pub fn remote_endpoint(&self) -> Option<IpEndpoint> {
168 unsafe { self.io.with(|s, _| s.remote_endpoint()) } 167 self.io.with(|s, _| s.remote_endpoint())
169 } 168 }
170 169
171 pub fn state(&self) -> tcp::State { 170 pub fn state(&self) -> tcp::State {
172 unsafe { self.io.with(|s, _| s.state()) } 171 self.io.with(|s, _| s.state())
173 } 172 }
174 173
175 pub fn close(&mut self) { 174 pub fn close(&mut self) {
176 unsafe { self.io.with_mut(|s, _| s.close()) } 175 self.io.with_mut(|s, _| s.close())
177 } 176 }
178 177
179 pub fn abort(&mut self) { 178 pub fn abort(&mut self) {
180 unsafe { self.io.with_mut(|s, _| s.abort()) } 179 self.io.with_mut(|s, _| s.abort())
181 } 180 }
182 181
183 pub fn may_send(&self) -> bool { 182 pub fn may_send(&self) -> bool {
184 unsafe { self.io.with(|s, _| s.may_send()) } 183 self.io.with(|s, _| s.may_send())
185 } 184 }
186 185
187 pub fn may_recv(&self) -> bool { 186 pub fn may_recv(&self) -> bool {
188 unsafe { self.io.with(|s, _| s.may_recv()) } 187 self.io.with(|s, _| s.may_recv())
189 } 188 }
190} 189}
191 190
192impl<'a> Drop for TcpSocket<'a> { 191impl<'a> Drop for TcpSocket<'a> {
193 fn drop(&mut self) { 192 fn drop(&mut self) {
194 // safety: not accessed reentrantly. 193 self.io.stack.borrow_mut().sockets.remove(self.io.handle);
195 let s = unsafe { &mut *self.io.stack.get() };
196 s.sockets.remove(self.io.handle);
197 } 194 }
198} 195}
199 196
@@ -201,21 +198,19 @@ impl<'a> Drop for TcpSocket<'a> {
201 198
202#[derive(Copy, Clone)] 199#[derive(Copy, Clone)]
203struct TcpIo<'a> { 200struct TcpIo<'a> {
204 stack: &'a UnsafeCell<SocketStack>, 201 stack: &'a RefCell<SocketStack>,
205 handle: SocketHandle, 202 handle: SocketHandle,
206} 203}
207 204
208impl<'d> TcpIo<'d> { 205impl<'d> TcpIo<'d> {
209 /// SAFETY: must not call reentrantly. 206 fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R {
210 unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { 207 let s = &*self.stack.borrow();
211 let s = &*self.stack.get();
212 let socket = s.sockets.get::<tcp::Socket>(self.handle); 208 let socket = s.sockets.get::<tcp::Socket>(self.handle);
213 f(socket, &s.iface) 209 f(socket, &s.iface)
214 } 210 }
215 211
216 /// SAFETY: must not call reentrantly. 212 fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R {
217 unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { 213 let s = &mut *self.stack.borrow_mut();
218 let s = &mut *self.stack.get();
219 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); 214 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle);
220 let res = f(socket, &mut s.iface); 215 let res = f(socket, &mut s.iface);
221 s.waker.wake(); 216 s.waker.wake();
@@ -223,7 +218,7 @@ impl<'d> TcpIo<'d> {
223 } 218 }
224 219
225 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { 220 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
226 poll_fn(move |cx| unsafe { 221 poll_fn(move |cx| {
227 // CAUTION: smoltcp semantics around EOF are different to what you'd expect 222 // CAUTION: smoltcp semantics around EOF are different to what you'd expect
228 // from posix-like IO, so we have to tweak things here. 223 // from posix-like IO, so we have to tweak things here.
229 self.with_mut(|s, _| match s.recv_slice(buf) { 224 self.with_mut(|s, _| match s.recv_slice(buf) {
@@ -244,7 +239,7 @@ impl<'d> TcpIo<'d> {
244 } 239 }
245 240
246 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { 241 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
247 poll_fn(move |cx| unsafe { 242 poll_fn(move |cx| {
248 self.with_mut(|s, _| match s.send_slice(buf) { 243 self.with_mut(|s, _| match s.send_slice(buf) {
249 // Not ready to send (no space in the tx buffer) 244 // Not ready to send (no space in the tx buffer)
250 Ok(0) => { 245 Ok(0) => {
@@ -332,6 +327,7 @@ mod embedded_io_impls {
332 327
333#[cfg(all(feature = "unstable-traits", feature = "nightly"))] 328#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
334pub mod client { 329pub mod client {
330 use core::cell::UnsafeCell;
335 use core::mem::MaybeUninit; 331 use core::mem::MaybeUninit;
336 use core::ptr::NonNull; 332 use core::ptr::NonNull;
337 333
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
index f2e33493c..4ddad77d4 100644
--- a/embassy-net/src/udp.rs
+++ b/embassy-net/src/udp.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::poll_fn; 2use core::future::poll_fn;
3use core::mem; 3use core::mem;
4use core::task::Poll; 4use core::task::Poll;
@@ -27,7 +27,7 @@ pub enum Error {
27} 27}
28 28
29pub struct UdpSocket<'a> { 29pub struct UdpSocket<'a> {
30 stack: &'a UnsafeCell<SocketStack>, 30 stack: &'a RefCell<SocketStack>,
31 handle: SocketHandle, 31 handle: SocketHandle,
32} 32}
33 33
@@ -39,8 +39,7 @@ impl<'a> UdpSocket<'a> {
39 tx_meta: &'a mut [PacketMetadata], 39 tx_meta: &'a mut [PacketMetadata],
40 tx_buffer: &'a mut [u8], 40 tx_buffer: &'a mut [u8],
41 ) -> Self { 41 ) -> Self {
42 // safety: not accessed reentrantly. 42 let s = &mut *stack.socket.borrow_mut();
43 let s = unsafe { &mut *stack.socket.get() };
44 43
45 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; 44 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
46 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 45 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
@@ -63,30 +62,26 @@ impl<'a> UdpSocket<'a> {
63 { 62 {
64 let mut endpoint = endpoint.into(); 63 let mut endpoint = endpoint.into();
65 64
66 // safety: not accessed reentrantly.
67 if endpoint.port == 0 { 65 if endpoint.port == 0 {
68 // If user didn't specify port allocate a dynamic port. 66 // If user didn't specify port allocate a dynamic port.
69 endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); 67 endpoint.port = self.stack.borrow_mut().get_local_port();
70 } 68 }
71 69
72 // safety: not accessed reentrantly. 70 match self.with_mut(|s, _| s.bind(endpoint)) {
73 match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } {
74 Ok(()) => Ok(()), 71 Ok(()) => Ok(()),
75 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), 72 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
76 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), 73 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
77 } 74 }
78 } 75 }
79 76
80 /// SAFETY: must not call reentrantly. 77 fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
81 unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { 78 let s = &*self.stack.borrow();
82 let s = &*self.stack.get();
83 let socket = s.sockets.get::<udp::Socket>(self.handle); 79 let socket = s.sockets.get::<udp::Socket>(self.handle);
84 f(socket, &s.iface) 80 f(socket, &s.iface)
85 } 81 }
86 82
87 /// SAFETY: must not call reentrantly. 83 fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
88 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { 84 let s = &mut *self.stack.borrow_mut();
89 let s = &mut *self.stack.get();
90 let socket = s.sockets.get_mut::<udp::Socket>(self.handle); 85 let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
91 let res = f(socket, &mut s.iface); 86 let res = f(socket, &mut s.iface);
92 s.waker.wake(); 87 s.waker.wake();
@@ -94,7 +89,7 @@ impl<'a> UdpSocket<'a> {
94 } 89 }
95 90
96 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { 91 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
97 poll_fn(move |cx| unsafe { 92 poll_fn(move |cx| {
98 self.with_mut(|s, _| match s.recv_slice(buf) { 93 self.with_mut(|s, _| match s.recv_slice(buf) {
99 Ok(x) => Poll::Ready(Ok(x)), 94 Ok(x) => Poll::Ready(Ok(x)),
100 // No data ready 95 // No data ready
@@ -113,7 +108,7 @@ impl<'a> UdpSocket<'a> {
113 T: Into<IpEndpoint>, 108 T: Into<IpEndpoint>,
114 { 109 {
115 let remote_endpoint = remote_endpoint.into(); 110 let remote_endpoint = remote_endpoint.into();
116 poll_fn(move |cx| unsafe { 111 poll_fn(move |cx| {
117 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { 112 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
118 // Entire datagram has been sent 113 // Entire datagram has been sent
119 Ok(()) => Poll::Ready(Ok(())), 114 Ok(()) => Poll::Ready(Ok(())),
@@ -128,30 +123,28 @@ impl<'a> UdpSocket<'a> {
128 } 123 }
129 124
130 pub fn endpoint(&self) -> IpListenEndpoint { 125 pub fn endpoint(&self) -> IpListenEndpoint {
131 unsafe { self.with(|s, _| s.endpoint()) } 126 self.with(|s, _| s.endpoint())
132 } 127 }
133 128
134 pub fn is_open(&self) -> bool { 129 pub fn is_open(&self) -> bool {
135 unsafe { self.with(|s, _| s.is_open()) } 130 self.with(|s, _| s.is_open())
136 } 131 }
137 132
138 pub fn close(&mut self) { 133 pub fn close(&mut self) {
139 unsafe { self.with_mut(|s, _| s.close()) } 134 self.with_mut(|s, _| s.close())
140 } 135 }
141 136
142 pub fn may_send(&self) -> bool { 137 pub fn may_send(&self) -> bool {
143 unsafe { self.with(|s, _| s.can_send()) } 138 self.with(|s, _| s.can_send())
144 } 139 }
145 140
146 pub fn may_recv(&self) -> bool { 141 pub fn may_recv(&self) -> bool {
147 unsafe { self.with(|s, _| s.can_recv()) } 142 self.with(|s, _| s.can_recv())
148 } 143 }
149} 144}
150 145
151impl Drop for UdpSocket<'_> { 146impl Drop for UdpSocket<'_> {
152 fn drop(&mut self) { 147 fn drop(&mut self) {
153 // safety: not accessed reentrantly. 148 self.stack.borrow_mut().sockets.remove(self.handle);
154 let s = unsafe { &mut *self.stack.get() };
155 s.sockets.remove(self.handle);
156 } 149 }
157} 150}