aboutsummaryrefslogtreecommitdiff
path: root/embassy-net
diff options
context:
space:
mode:
authorDario Nieuwenhuis <[email protected]>2022-12-03 00:56:16 +0100
committerDario Nieuwenhuis <[email protected]>2022-12-03 00:56:16 +0100
commit02abe00439ba873945bd6b60546a200b3da751f1 (patch)
tree62724bbe40f58380ce7ce67125e10348c5867adb /embassy-net
parentf109e73c6d7ef2ad93102b7c8223f5cef30ef36f (diff)
net: don't use UnsafeCell.
The "must not be called reentrantly" invariant is too "global" to maintain comfortably, and the cost of the RefCell is negligible, so this was a case of premature optimization.
Diffstat (limited to 'embassy-net')
-rw-r--r--embassy-net/src/stack.rs32
-rw-r--r--embassy-net/src/tcp.rs59
-rw-r--r--embassy-net/src/udp.rs41
3 files changed, 57 insertions, 75 deletions
diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs
index 3a7610758..631087405 100644
--- a/embassy-net/src/stack.rs
+++ b/embassy-net/src/stack.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::{poll_fn, Future}; 2use core::future::{poll_fn, Future};
3use core::task::{Context, Poll}; 3use core::task::{Context, Poll};
4 4
@@ -62,8 +62,8 @@ pub enum ConfigStrategy {
62} 62}
63 63
64pub struct Stack<D: Device> { 64pub struct Stack<D: Device> {
65 pub(crate) socket: UnsafeCell<SocketStack>, 65 pub(crate) socket: RefCell<SocketStack>,
66 inner: UnsafeCell<Inner<D>>, 66 inner: RefCell<Inner<D>>,
67} 67}
68 68
69struct Inner<D: Device> { 69struct Inner<D: Device> {
@@ -81,8 +81,6 @@ pub(crate) struct SocketStack {
81 next_local_port: u16, 81 next_local_port: u16,
82} 82}
83 83
84unsafe impl<D: Device> Send for Stack<D> {}
85
86impl<D: Device + 'static> Stack<D> { 84impl<D: Device + 'static> Stack<D> {
87 pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>( 85 pub fn new<const ADDR: usize, const SOCK: usize, const NEIGH: usize>(
88 device: D, 86 device: D,
@@ -143,40 +141,38 @@ impl<D: Device + 'static> Stack<D> {
143 } 141 }
144 142
145 Self { 143 Self {
146 socket: UnsafeCell::new(socket), 144 socket: RefCell::new(socket),
147 inner: UnsafeCell::new(inner), 145 inner: RefCell::new(inner),
148 } 146 }
149 } 147 }
150 148
151 /// SAFETY: must not call reentrantly. 149 fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R {
152 unsafe fn with<R>(&self, f: impl FnOnce(&SocketStack, &Inner<D>) -> R) -> R { 150 f(&*self.socket.borrow(), &*self.inner.borrow())
153 f(&*self.socket.get(), &*self.inner.get())
154 } 151 }
155 152
156 /// SAFETY: must not call reentrantly. 153 fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R {
157 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut SocketStack, &mut Inner<D>) -> R) -> R { 154 f(&mut *self.socket.borrow_mut(), &mut *self.inner.borrow_mut())
158 f(&mut *self.socket.get(), &mut *self.inner.get())
159 } 155 }
160 156
161 pub fn ethernet_address(&self) -> [u8; 6] { 157 pub fn ethernet_address(&self) -> [u8; 6] {
162 unsafe { self.with(|_s, i| i.device.device.ethernet_address()) } 158 self.with(|_s, i| i.device.device.ethernet_address())
163 } 159 }
164 160
165 pub fn is_link_up(&self) -> bool { 161 pub fn is_link_up(&self) -> bool {
166 unsafe { self.with(|_s, i| i.link_up) } 162 self.with(|_s, i| i.link_up)
167 } 163 }
168 164
169 pub fn is_config_up(&self) -> bool { 165 pub fn is_config_up(&self) -> bool {
170 unsafe { self.with(|_s, i| i.config.is_some()) } 166 self.with(|_s, i| i.config.is_some())
171 } 167 }
172 168
173 pub fn config(&self) -> Option<Config> { 169 pub fn config(&self) -> Option<Config> {
174 unsafe { self.with(|_s, i| i.config.clone()) } 170 self.with(|_s, i| i.config.clone())
175 } 171 }
176 172
177 pub async fn run(&self) -> ! { 173 pub async fn run(&self) -> ! {
178 poll_fn(|cx| { 174 poll_fn(|cx| {
179 unsafe { self.with_mut(|s, i| i.poll(cx, s)) } 175 self.with_mut(|s, i| i.poll(cx, s));
180 Poll::<()>::Pending 176 Poll::<()>::Pending
181 }) 177 })
182 .await; 178 .await;
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs
index 85d9e5ee1..60386535a 100644
--- a/embassy-net/src/tcp.rs
+++ b/embassy-net/src/tcp.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::poll_fn; 2use core::future::poll_fn;
3use core::mem; 3use core::mem;
4use core::task::Poll; 4use core::task::Poll;
@@ -68,8 +68,7 @@ impl<'a> TcpWriter<'a> {
68 68
69impl<'a> TcpSocket<'a> { 69impl<'a> TcpSocket<'a> {
70 pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { 70 pub fn new<D: Device>(stack: &'a Stack<D>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
71 // safety: not accessed reentrantly. 71 let s = &mut *stack.socket.borrow_mut();
72 let s = unsafe { &mut *stack.socket.get() };
73 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 72 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
74 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; 73 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
75 let handle = s.sockets.add(tcp::Socket::new( 74 let handle = s.sockets.add(tcp::Socket::new(
@@ -93,17 +92,15 @@ impl<'a> TcpSocket<'a> {
93 where 92 where
94 T: Into<IpEndpoint>, 93 T: Into<IpEndpoint>,
95 { 94 {
96 // safety: not accessed reentrantly. 95 let local_port = self.io.stack.borrow_mut().get_local_port();
97 let local_port = unsafe { &mut *self.io.stack.get() }.get_local_port();
98 96
99 // safety: not accessed reentrantly. 97 match { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } {
100 match unsafe { self.io.with_mut(|s, i| s.connect(i, remote_endpoint, local_port)) } {
101 Ok(()) => {} 98 Ok(()) => {}
102 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState), 99 Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState),
103 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute), 100 Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute),
104 } 101 }
105 102
106 poll_fn(|cx| unsafe { 103 poll_fn(|cx| {
107 self.io.with_mut(|s, _| match s.state() { 104 self.io.with_mut(|s, _| match s.state() {
108 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)), 105 tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)),
109 tcp::State::Listen => unreachable!(), 106 tcp::State::Listen => unreachable!(),
@@ -121,14 +118,13 @@ impl<'a> TcpSocket<'a> {
121 where 118 where
122 T: Into<IpListenEndpoint>, 119 T: Into<IpListenEndpoint>,
123 { 120 {
124 // safety: not accessed reentrantly. 121 match self.io.with_mut(|s, _| s.listen(local_endpoint)) {
125 match unsafe { self.io.with_mut(|s, _| s.listen(local_endpoint)) } {
126 Ok(()) => {} 122 Ok(()) => {}
127 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState), 123 Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState),
128 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort), 124 Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort),
129 } 125 }
130 126
131 poll_fn(|cx| unsafe { 127 poll_fn(|cx| {
132 self.io.with_mut(|s, _| match s.state() { 128 self.io.with_mut(|s, _| match s.state() {
133 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { 129 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
134 s.register_send_waker(cx.waker()); 130 s.register_send_waker(cx.waker());
@@ -149,51 +145,49 @@ impl<'a> TcpSocket<'a> {
149 } 145 }
150 146
151 pub fn set_timeout(&mut self, duration: Option<Duration>) { 147 pub fn set_timeout(&mut self, duration: Option<Duration>) {
152 unsafe { self.io.with_mut(|s, _| s.set_timeout(duration)) } 148 self.io.with_mut(|s, _| s.set_timeout(duration))
153 } 149 }
154 150
155 pub fn set_keep_alive(&mut self, interval: Option<Duration>) { 151 pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
156 unsafe { self.io.with_mut(|s, _| s.set_keep_alive(interval)) } 152 self.io.with_mut(|s, _| s.set_keep_alive(interval))
157 } 153 }
158 154
159 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { 155 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
160 unsafe { self.io.with_mut(|s, _| s.set_hop_limit(hop_limit)) } 156 self.io.with_mut(|s, _| s.set_hop_limit(hop_limit))
161 } 157 }
162 158
163 pub fn local_endpoint(&self) -> Option<IpEndpoint> { 159 pub fn local_endpoint(&self) -> Option<IpEndpoint> {
164 unsafe { self.io.with(|s, _| s.local_endpoint()) } 160 self.io.with(|s, _| s.local_endpoint())
165 } 161 }
166 162
167 pub fn remote_endpoint(&self) -> Option<IpEndpoint> { 163 pub fn remote_endpoint(&self) -> Option<IpEndpoint> {
168 unsafe { self.io.with(|s, _| s.remote_endpoint()) } 164 self.io.with(|s, _| s.remote_endpoint())
169 } 165 }
170 166
171 pub fn state(&self) -> tcp::State { 167 pub fn state(&self) -> tcp::State {
172 unsafe { self.io.with(|s, _| s.state()) } 168 self.io.with(|s, _| s.state())
173 } 169 }
174 170
175 pub fn close(&mut self) { 171 pub fn close(&mut self) {
176 unsafe { self.io.with_mut(|s, _| s.close()) } 172 self.io.with_mut(|s, _| s.close())
177 } 173 }
178 174
179 pub fn abort(&mut self) { 175 pub fn abort(&mut self) {
180 unsafe { self.io.with_mut(|s, _| s.abort()) } 176 self.io.with_mut(|s, _| s.abort())
181 } 177 }
182 178
183 pub fn may_send(&self) -> bool { 179 pub fn may_send(&self) -> bool {
184 unsafe { self.io.with(|s, _| s.may_send()) } 180 self.io.with(|s, _| s.may_send())
185 } 181 }
186 182
187 pub fn may_recv(&self) -> bool { 183 pub fn may_recv(&self) -> bool {
188 unsafe { self.io.with(|s, _| s.may_recv()) } 184 self.io.with(|s, _| s.may_recv())
189 } 185 }
190} 186}
191 187
192impl<'a> Drop for TcpSocket<'a> { 188impl<'a> Drop for TcpSocket<'a> {
193 fn drop(&mut self) { 189 fn drop(&mut self) {
194 // safety: not accessed reentrantly. 190 self.io.stack.borrow_mut().sockets.remove(self.io.handle);
195 let s = unsafe { &mut *self.io.stack.get() };
196 s.sockets.remove(self.io.handle);
197 } 191 }
198} 192}
199 193
@@ -201,21 +195,19 @@ impl<'a> Drop for TcpSocket<'a> {
201 195
202#[derive(Copy, Clone)] 196#[derive(Copy, Clone)]
203struct TcpIo<'a> { 197struct TcpIo<'a> {
204 stack: &'a UnsafeCell<SocketStack>, 198 stack: &'a RefCell<SocketStack>,
205 handle: SocketHandle, 199 handle: SocketHandle,
206} 200}
207 201
208impl<'d> TcpIo<'d> { 202impl<'d> TcpIo<'d> {
209 /// SAFETY: must not call reentrantly. 203 fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R {
210 unsafe fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R { 204 let s = &*self.stack.borrow();
211 let s = &*self.stack.get();
212 let socket = s.sockets.get::<tcp::Socket>(self.handle); 205 let socket = s.sockets.get::<tcp::Socket>(self.handle);
213 f(socket, &s.iface) 206 f(socket, &s.iface)
214 } 207 }
215 208
216 /// SAFETY: must not call reentrantly. 209 fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R {
217 unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R { 210 let s = &mut *self.stack.borrow_mut();
218 let s = &mut *self.stack.get();
219 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle); 211 let socket = s.sockets.get_mut::<tcp::Socket>(self.handle);
220 let res = f(socket, &mut s.iface); 212 let res = f(socket, &mut s.iface);
221 s.waker.wake(); 213 s.waker.wake();
@@ -223,7 +215,7 @@ impl<'d> TcpIo<'d> {
223 } 215 }
224 216
225 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { 217 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
226 poll_fn(move |cx| unsafe { 218 poll_fn(move |cx| {
227 // CAUTION: smoltcp semantics around EOF are different to what you'd expect 219 // CAUTION: smoltcp semantics around EOF are different to what you'd expect
228 // from posix-like IO, so we have to tweak things here. 220 // from posix-like IO, so we have to tweak things here.
229 self.with_mut(|s, _| match s.recv_slice(buf) { 221 self.with_mut(|s, _| match s.recv_slice(buf) {
@@ -244,7 +236,7 @@ impl<'d> TcpIo<'d> {
244 } 236 }
245 237
246 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> { 238 async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
247 poll_fn(move |cx| unsafe { 239 poll_fn(move |cx| {
248 self.with_mut(|s, _| match s.send_slice(buf) { 240 self.with_mut(|s, _| match s.send_slice(buf) {
249 // Not ready to send (no space in the tx buffer) 241 // Not ready to send (no space in the tx buffer)
250 Ok(0) => { 242 Ok(0) => {
@@ -332,6 +324,7 @@ mod embedded_io_impls {
332 324
333#[cfg(all(feature = "unstable-traits", feature = "nightly"))] 325#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
334pub mod client { 326pub mod client {
327 use core::cell::UnsafeCell;
335 use core::mem::MaybeUninit; 328 use core::mem::MaybeUninit;
336 use core::ptr::NonNull; 329 use core::ptr::NonNull;
337 330
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
index f2e33493c..4ddad77d4 100644
--- a/embassy-net/src/udp.rs
+++ b/embassy-net/src/udp.rs
@@ -1,4 +1,4 @@
1use core::cell::UnsafeCell; 1use core::cell::RefCell;
2use core::future::poll_fn; 2use core::future::poll_fn;
3use core::mem; 3use core::mem;
4use core::task::Poll; 4use core::task::Poll;
@@ -27,7 +27,7 @@ pub enum Error {
27} 27}
28 28
29pub struct UdpSocket<'a> { 29pub struct UdpSocket<'a> {
30 stack: &'a UnsafeCell<SocketStack>, 30 stack: &'a RefCell<SocketStack>,
31 handle: SocketHandle, 31 handle: SocketHandle,
32} 32}
33 33
@@ -39,8 +39,7 @@ impl<'a> UdpSocket<'a> {
39 tx_meta: &'a mut [PacketMetadata], 39 tx_meta: &'a mut [PacketMetadata],
40 tx_buffer: &'a mut [u8], 40 tx_buffer: &'a mut [u8],
41 ) -> Self { 41 ) -> Self {
42 // safety: not accessed reentrantly. 42 let s = &mut *stack.socket.borrow_mut();
43 let s = unsafe { &mut *stack.socket.get() };
44 43
45 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) }; 44 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
46 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; 45 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
@@ -63,30 +62,26 @@ impl<'a> UdpSocket<'a> {
63 { 62 {
64 let mut endpoint = endpoint.into(); 63 let mut endpoint = endpoint.into();
65 64
66 // safety: not accessed reentrantly.
67 if endpoint.port == 0 { 65 if endpoint.port == 0 {
68 // If user didn't specify port allocate a dynamic port. 66 // If user didn't specify port allocate a dynamic port.
69 endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); 67 endpoint.port = self.stack.borrow_mut().get_local_port();
70 } 68 }
71 69
72 // safety: not accessed reentrantly. 70 match self.with_mut(|s, _| s.bind(endpoint)) {
73 match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } {
74 Ok(()) => Ok(()), 71 Ok(()) => Ok(()),
75 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), 72 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
76 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), 73 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
77 } 74 }
78 } 75 }
79 76
80 /// SAFETY: must not call reentrantly. 77 fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
81 unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { 78 let s = &*self.stack.borrow();
82 let s = &*self.stack.get();
83 let socket = s.sockets.get::<udp::Socket>(self.handle); 79 let socket = s.sockets.get::<udp::Socket>(self.handle);
84 f(socket, &s.iface) 80 f(socket, &s.iface)
85 } 81 }
86 82
87 /// SAFETY: must not call reentrantly. 83 fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
88 unsafe fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { 84 let s = &mut *self.stack.borrow_mut();
89 let s = &mut *self.stack.get();
90 let socket = s.sockets.get_mut::<udp::Socket>(self.handle); 85 let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
91 let res = f(socket, &mut s.iface); 86 let res = f(socket, &mut s.iface);
92 s.waker.wake(); 87 s.waker.wake();
@@ -94,7 +89,7 @@ impl<'a> UdpSocket<'a> {
94 } 89 }
95 90
96 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { 91 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
97 poll_fn(move |cx| unsafe { 92 poll_fn(move |cx| {
98 self.with_mut(|s, _| match s.recv_slice(buf) { 93 self.with_mut(|s, _| match s.recv_slice(buf) {
99 Ok(x) => Poll::Ready(Ok(x)), 94 Ok(x) => Poll::Ready(Ok(x)),
100 // No data ready 95 // No data ready
@@ -113,7 +108,7 @@ impl<'a> UdpSocket<'a> {
113 T: Into<IpEndpoint>, 108 T: Into<IpEndpoint>,
114 { 109 {
115 let remote_endpoint = remote_endpoint.into(); 110 let remote_endpoint = remote_endpoint.into();
116 poll_fn(move |cx| unsafe { 111 poll_fn(move |cx| {
117 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { 112 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
118 // Entire datagram has been sent 113 // Entire datagram has been sent
119 Ok(()) => Poll::Ready(Ok(())), 114 Ok(()) => Poll::Ready(Ok(())),
@@ -128,30 +123,28 @@ impl<'a> UdpSocket<'a> {
128 } 123 }
129 124
130 pub fn endpoint(&self) -> IpListenEndpoint { 125 pub fn endpoint(&self) -> IpListenEndpoint {
131 unsafe { self.with(|s, _| s.endpoint()) } 126 self.with(|s, _| s.endpoint())
132 } 127 }
133 128
134 pub fn is_open(&self) -> bool { 129 pub fn is_open(&self) -> bool {
135 unsafe { self.with(|s, _| s.is_open()) } 130 self.with(|s, _| s.is_open())
136 } 131 }
137 132
138 pub fn close(&mut self) { 133 pub fn close(&mut self) {
139 unsafe { self.with_mut(|s, _| s.close()) } 134 self.with_mut(|s, _| s.close())
140 } 135 }
141 136
142 pub fn may_send(&self) -> bool { 137 pub fn may_send(&self) -> bool {
143 unsafe { self.with(|s, _| s.can_send()) } 138 self.with(|s, _| s.can_send())
144 } 139 }
145 140
146 pub fn may_recv(&self) -> bool { 141 pub fn may_recv(&self) -> bool {
147 unsafe { self.with(|s, _| s.can_recv()) } 142 self.with(|s, _| s.can_recv())
148 } 143 }
149} 144}
150 145
151impl Drop for UdpSocket<'_> { 146impl Drop for UdpSocket<'_> {
152 fn drop(&mut self) { 147 fn drop(&mut self) {
153 // safety: not accessed reentrantly. 148 self.stack.borrow_mut().sockets.remove(self.handle);
154 let s = unsafe { &mut *self.stack.get() };
155 s.sockets.remove(self.handle);
156 } 149 }
157} 150}