diff options
| author | Dario Nieuwenhuis <[email protected]> | 2022-05-13 22:15:27 +0200 |
|---|---|---|
| committer | Dario Nieuwenhuis <[email protected]> | 2022-05-19 06:14:05 +0200 |
| commit | 0b2f43c391f1f1f6525484c369a24e72610dd8ef (patch) | |
| tree | eb5b942e76d6b88d1ce3fed211ab281bb132398b | |
| parent | 240bef8c9f97998bd4639c714077b7aa59f5e7bc (diff) | |
net: add split() to tcpsocket
| -rw-r--r-- | embassy-net/src/tcp/io_impl.rs | 68 | ||||
| -rw-r--r-- | embassy-net/src/tcp/mod.rs | 90 |
2 files changed, 129 insertions, 29 deletions
diff --git a/embassy-net/src/tcp/io_impl.rs b/embassy-net/src/tcp/io_impl.rs index 155733497..b30c920b8 100644 --- a/embassy-net/src/tcp/io_impl.rs +++ b/embassy-net/src/tcp/io_impl.rs | |||
| @@ -2,7 +2,7 @@ use core::future::Future; | |||
| 2 | use core::task::Poll; | 2 | use core::task::Poll; |
| 3 | use futures::future::poll_fn; | 3 | use futures::future::poll_fn; |
| 4 | 4 | ||
| 5 | use super::{Error, TcpSocket}; | 5 | use super::{with_socket, Error, TcpReader, TcpSocket, TcpWriter}; |
| 6 | 6 | ||
| 7 | impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { | 7 | impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { |
| 8 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | 8 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> |
| @@ -13,7 +13,7 @@ impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { | |||
| 13 | poll_fn(move |cx| { | 13 | poll_fn(move |cx| { |
| 14 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect | 14 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect |
| 15 | // from posix-like IO, so we have to tweak things here. | 15 | // from posix-like IO, so we have to tweak things here. |
| 16 | self.with(|s, _| match s.recv_slice(buf) { | 16 | with_socket(self.handle, |s, _| match s.recv_slice(buf) { |
| 17 | // No data ready | 17 | // No data ready |
| 18 | Ok(0) => { | 18 | Ok(0) => { |
| 19 | s.register_recv_waker(cx.waker()); | 19 | s.register_recv_waker(cx.waker()); |
| @@ -39,7 +39,69 @@ impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { | |||
| 39 | 39 | ||
| 40 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { | 40 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { |
| 41 | poll_fn(move |cx| { | 41 | poll_fn(move |cx| { |
| 42 | self.with(|s, _| match s.send_slice(buf) { | 42 | with_socket(self.handle, |s, _| match s.send_slice(buf) { |
| 43 | // Not ready to send (no space in the tx buffer) | ||
| 44 | Ok(0) => { | ||
| 45 | s.register_send_waker(cx.waker()); | ||
| 46 | Poll::Pending | ||
| 47 | } | ||
| 48 | // Some data sent | ||
| 49 | Ok(n) => Poll::Ready(Ok(n)), | ||
| 50 | // Connection reset. TODO: this can also be timeouts etc, investigate. | ||
| 51 | Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), | ||
| 52 | // smoltcp returns no errors other than the above. | ||
| 53 | Err(_) => unreachable!(), | ||
| 54 | }) | ||
| 55 | }) | ||
| 56 | } | ||
| 57 | |||
| 58 | type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> | ||
| 59 | where | ||
| 60 | Self: 'a; | ||
| 61 | |||
| 62 | fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { | ||
| 63 | poll_fn(move |_| { | ||
| 64 | Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? | ||
| 65 | }) | ||
| 66 | } | ||
| 67 | } | ||
| 68 | |||
| 69 | impl<'d> embedded_io::asynch::Read for TcpReader<'d> { | ||
| 70 | type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | ||
| 71 | where | ||
| 72 | Self: 'a; | ||
| 73 | |||
| 74 | fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { | ||
| 75 | poll_fn(move |cx| { | ||
| 76 | // CAUTION: smoltcp semantics around EOF are different to what you'd expect | ||
| 77 | // from posix-like IO, so we have to tweak things here. | ||
| 78 | with_socket(self.handle, |s, _| match s.recv_slice(buf) { | ||
| 79 | // No data ready | ||
| 80 | Ok(0) => { | ||
| 81 | s.register_recv_waker(cx.waker()); | ||
| 82 | Poll::Pending | ||
| 83 | } | ||
| 84 | // Data ready! | ||
| 85 | Ok(n) => Poll::Ready(Ok(n)), | ||
| 86 | // EOF | ||
| 87 | Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), | ||
| 88 | // Connection reset. TODO: this can also be timeouts etc, investigate. | ||
| 89 | Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), | ||
| 90 | // smoltcp returns no errors other than the above. | ||
| 91 | Err(_) => unreachable!(), | ||
| 92 | }) | ||
| 93 | }) | ||
| 94 | } | ||
| 95 | } | ||
| 96 | |||
| 97 | impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { | ||
| 98 | type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> | ||
| 99 | where | ||
| 100 | Self: 'a; | ||
| 101 | |||
| 102 | fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { | ||
| 103 | poll_fn(move |cx| { | ||
| 104 | with_socket(self.handle, |s, _| match s.send_slice(buf) { | ||
| 43 | // Not ready to send (no space in the tx buffer) | 105 | // Not ready to send (no space in the tx buffer) |
| 44 | Ok(0) => { | 106 | Ok(0) => { |
| 45 | s.register_send_waker(cx.waker()); | 107 | s.register_send_waker(cx.waker()); |
diff --git a/embassy-net/src/tcp/mod.rs b/embassy-net/src/tcp/mod.rs index 3bfd4c7b6..425e6acbc 100644 --- a/embassy-net/src/tcp/mod.rs +++ b/embassy-net/src/tcp/mod.rs | |||
| @@ -49,6 +49,20 @@ pub struct TcpSocket<'a> { | |||
| 49 | 49 | ||
| 50 | impl<'a> Unpin for TcpSocket<'a> {} | 50 | impl<'a> Unpin for TcpSocket<'a> {} |
| 51 | 51 | ||
| 52 | pub struct TcpReader<'a> { | ||
| 53 | handle: SocketHandle, | ||
| 54 | ghost: PhantomData<&'a mut [u8]>, | ||
| 55 | } | ||
| 56 | |||
| 57 | impl<'a> Unpin for TcpReader<'a> {} | ||
| 58 | |||
| 59 | pub struct TcpWriter<'a> { | ||
| 60 | handle: SocketHandle, | ||
| 61 | ghost: PhantomData<&'a mut [u8]>, | ||
| 62 | } | ||
| 63 | |||
| 64 | impl<'a> Unpin for TcpWriter<'a> {} | ||
| 65 | |||
| 52 | impl<'a> TcpSocket<'a> { | 66 | impl<'a> TcpSocket<'a> { |
| 53 | pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { | 67 | pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { |
| 54 | let handle = Stack::with(|stack| { | 68 | let handle = Stack::with(|stack| { |
| @@ -66,12 +80,27 @@ impl<'a> TcpSocket<'a> { | |||
| 66 | } | 80 | } |
| 67 | } | 81 | } |
| 68 | 82 | ||
| 83 | pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { | ||
| 84 | ( | ||
| 85 | TcpReader { | ||
| 86 | handle: self.handle, | ||
| 87 | ghost: PhantomData, | ||
| 88 | }, | ||
| 89 | TcpWriter { | ||
| 90 | handle: self.handle, | ||
| 91 | ghost: PhantomData, | ||
| 92 | }, | ||
| 93 | ) | ||
| 94 | } | ||
| 95 | |||
| 69 | pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> | 96 | pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> |
| 70 | where | 97 | where |
| 71 | T: Into<IpEndpoint>, | 98 | T: Into<IpEndpoint>, |
| 72 | { | 99 | { |
| 73 | let local_port = Stack::with(|stack| stack.get_local_port()); | 100 | let local_port = Stack::with(|stack| stack.get_local_port()); |
| 74 | match self.with(|s, cx| s.connect(cx, remote_endpoint, local_port)) { | 101 | match with_socket(self.handle, |s, cx| { |
| 102 | s.connect(cx, remote_endpoint, local_port) | ||
| 103 | }) { | ||
| 75 | Ok(()) => {} | 104 | Ok(()) => {} |
| 76 | Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), | 105 | Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), |
| 77 | Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), | 106 | Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), |
| @@ -80,7 +109,7 @@ impl<'a> TcpSocket<'a> { | |||
| 80 | } | 109 | } |
| 81 | 110 | ||
| 82 | futures::future::poll_fn(|cx| { | 111 | futures::future::poll_fn(|cx| { |
| 83 | self.with(|s, _| match s.state() { | 112 | with_socket(self.handle, |s, _| match s.state() { |
| 84 | TcpState::Closed | TcpState::TimeWait => { | 113 | TcpState::Closed | TcpState::TimeWait => { |
| 85 | Poll::Ready(Err(ConnectError::ConnectionReset)) | 114 | Poll::Ready(Err(ConnectError::ConnectionReset)) |
| 86 | } | 115 | } |
| @@ -99,7 +128,7 @@ impl<'a> TcpSocket<'a> { | |||
| 99 | where | 128 | where |
| 100 | T: Into<IpEndpoint>, | 129 | T: Into<IpEndpoint>, |
| 101 | { | 130 | { |
| 102 | match self.with(|s, _| s.listen(local_endpoint)) { | 131 | match with_socket(self.handle, |s, _| s.listen(local_endpoint)) { |
| 103 | Ok(()) => {} | 132 | Ok(()) => {} |
| 104 | Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), | 133 | Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), |
| 105 | Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), | 134 | Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), |
| @@ -108,7 +137,7 @@ impl<'a> TcpSocket<'a> { | |||
| 108 | } | 137 | } |
| 109 | 138 | ||
| 110 | futures::future::poll_fn(|cx| { | 139 | futures::future::poll_fn(|cx| { |
| 111 | self.with(|s, _| match s.state() { | 140 | with_socket(self.handle, |s, _| match s.state() { |
| 112 | TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { | 141 | TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { |
| 113 | s.register_send_waker(cx.waker()); | 142 | s.register_send_waker(cx.waker()); |
| 114 | Poll::Pending | 143 | Poll::Pending |
| @@ -120,57 +149,58 @@ impl<'a> TcpSocket<'a> { | |||
| 120 | } | 149 | } |
| 121 | 150 | ||
| 122 | pub fn set_timeout(&mut self, duration: Option<Duration>) { | 151 | pub fn set_timeout(&mut self, duration: Option<Duration>) { |
| 123 | self.with(|s, _| s.set_timeout(duration)) | 152 | with_socket(self.handle, |s, _| s.set_timeout(duration)) |
| 124 | } | 153 | } |
| 125 | 154 | ||
| 126 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { | 155 | pub fn set_keep_alive(&mut self, interval: Option<Duration>) { |
| 127 | self.with(|s, _| s.set_keep_alive(interval)) | 156 | with_socket(self.handle, |s, _| s.set_keep_alive(interval)) |
| 128 | } | 157 | } |
| 129 | 158 | ||
| 130 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { | 159 | pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { |
| 131 | self.with(|s, _| s.set_hop_limit(hop_limit)) | 160 | with_socket(self.handle, |s, _| s.set_hop_limit(hop_limit)) |
| 132 | } | 161 | } |
| 133 | 162 | ||
| 134 | pub fn local_endpoint(&self) -> IpEndpoint { | 163 | pub fn local_endpoint(&self) -> IpEndpoint { |
| 135 | self.with(|s, _| s.local_endpoint()) | 164 | with_socket(self.handle, |s, _| s.local_endpoint()) |
| 136 | } | 165 | } |
| 137 | 166 | ||
| 138 | pub fn remote_endpoint(&self) -> IpEndpoint { | 167 | pub fn remote_endpoint(&self) -> IpEndpoint { |
| 139 | self.with(|s, _| s.remote_endpoint()) | 168 | with_socket(self.handle, |s, _| s.remote_endpoint()) |
| 140 | } | 169 | } |
| 141 | 170 | ||
| 142 | pub fn state(&self) -> TcpState { | 171 | pub fn state(&self) -> TcpState { |
| 143 | self.with(|s, _| s.state()) | 172 | with_socket(self.handle, |s, _| s.state()) |
| 144 | } | 173 | } |
| 145 | 174 | ||
| 146 | pub fn close(&mut self) { | 175 | pub fn close(&mut self) { |
| 147 | self.with(|s, _| s.close()) | 176 | with_socket(self.handle, |s, _| s.close()) |
| 148 | } | 177 | } |
| 149 | 178 | ||
| 150 | pub fn abort(&mut self) { | 179 | pub fn abort(&mut self) { |
| 151 | self.with(|s, _| s.abort()) | 180 | with_socket(self.handle, |s, _| s.abort()) |
| 152 | } | 181 | } |
| 153 | 182 | ||
| 154 | pub fn may_send(&self) -> bool { | 183 | pub fn may_send(&self) -> bool { |
| 155 | self.with(|s, _| s.may_send()) | 184 | with_socket(self.handle, |s, _| s.may_send()) |
| 156 | } | 185 | } |
| 157 | 186 | ||
| 158 | pub fn may_recv(&self) -> bool { | 187 | pub fn may_recv(&self) -> bool { |
| 159 | self.with(|s, _| s.may_recv()) | 188 | with_socket(self.handle, |s, _| s.may_recv()) |
| 160 | } | 189 | } |
| 190 | } | ||
| 161 | 191 | ||
| 162 | fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R) -> R { | 192 | fn with_socket<R>( |
| 163 | Stack::with(|stack| { | 193 | handle: SocketHandle, |
| 164 | let res = { | 194 | f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R, |
| 165 | let (s, cx) = stack | 195 | ) -> R { |
| 166 | .iface | 196 | Stack::with(|stack| { |
| 167 | .get_socket_and_context::<SyncTcpSocket>(self.handle); | 197 | let res = { |
| 168 | f(s, cx) | 198 | let (s, cx) = stack.iface.get_socket_and_context::<SyncTcpSocket>(handle); |
| 169 | }; | 199 | f(s, cx) |
| 170 | stack.wake(); | 200 | }; |
| 171 | res | 201 | stack.wake(); |
| 172 | }) | 202 | res |
| 173 | } | 203 | }) |
| 174 | } | 204 | } |
| 175 | 205 | ||
| 176 | impl<'a> Drop for TcpSocket<'a> { | 206 | impl<'a> Drop for TcpSocket<'a> { |
| @@ -190,3 +220,11 @@ impl embedded_io::Error for Error { | |||
| 190 | impl<'d> embedded_io::Io for TcpSocket<'d> { | 220 | impl<'d> embedded_io::Io for TcpSocket<'d> { |
| 191 | type Error = Error; | 221 | type Error = Error; |
| 192 | } | 222 | } |
| 223 | |||
| 224 | impl<'d> embedded_io::Io for TcpReader<'d> { | ||
| 225 | type Error = Error; | ||
| 226 | } | ||
| 227 | |||
| 228 | impl<'d> embedded_io::Io for TcpWriter<'d> { | ||
| 229 | type Error = Error; | ||
| 230 | } | ||
