aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDario Nieuwenhuis <[email protected]>2022-05-13 22:15:27 +0200
committerDario Nieuwenhuis <[email protected]>2022-05-19 06:14:05 +0200
commit0b2f43c391f1f1f6525484c369a24e72610dd8ef (patch)
treeeb5b942e76d6b88d1ce3fed211ab281bb132398b
parent240bef8c9f97998bd4639c714077b7aa59f5e7bc (diff)
net: add split() to tcpsocket
-rw-r--r--embassy-net/src/tcp/io_impl.rs68
-rw-r--r--embassy-net/src/tcp/mod.rs90
2 files changed, 129 insertions, 29 deletions
diff --git a/embassy-net/src/tcp/io_impl.rs b/embassy-net/src/tcp/io_impl.rs
index 155733497..b30c920b8 100644
--- a/embassy-net/src/tcp/io_impl.rs
+++ b/embassy-net/src/tcp/io_impl.rs
@@ -2,7 +2,7 @@ use core::future::Future;
2use core::task::Poll; 2use core::task::Poll;
3use futures::future::poll_fn; 3use futures::future::poll_fn;
4 4
5use super::{Error, TcpSocket}; 5use super::{with_socket, Error, TcpReader, TcpSocket, TcpWriter};
6 6
7impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { 7impl<'d> embedded_io::asynch::Read for TcpSocket<'d> {
8 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> 8 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>>
@@ -13,7 +13,7 @@ impl<'d> embedded_io::asynch::Read for TcpSocket<'d> {
13 poll_fn(move |cx| { 13 poll_fn(move |cx| {
14 // CAUTION: smoltcp semantics around EOF are different to what you'd expect 14 // CAUTION: smoltcp semantics around EOF are different to what you'd expect
15 // from posix-like IO, so we have to tweak things here. 15 // from posix-like IO, so we have to tweak things here.
16 self.with(|s, _| match s.recv_slice(buf) { 16 with_socket(self.handle, |s, _| match s.recv_slice(buf) {
17 // No data ready 17 // No data ready
18 Ok(0) => { 18 Ok(0) => {
19 s.register_recv_waker(cx.waker()); 19 s.register_recv_waker(cx.waker());
@@ -39,7 +39,69 @@ impl<'d> embedded_io::asynch::Write for TcpSocket<'d> {
39 39
40 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { 40 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> {
41 poll_fn(move |cx| { 41 poll_fn(move |cx| {
42 self.with(|s, _| match s.send_slice(buf) { 42 with_socket(self.handle, |s, _| match s.send_slice(buf) {
43 // Not ready to send (no space in the tx buffer)
44 Ok(0) => {
45 s.register_send_waker(cx.waker());
46 Poll::Pending
47 }
48 // Some data sent
49 Ok(n) => Poll::Ready(Ok(n)),
50 // Connection reset. TODO: this can also be timeouts etc, investigate.
51 Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)),
52 // smoltcp returns no errors other than the above.
53 Err(_) => unreachable!(),
54 })
55 })
56 }
57
58 type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>>
59 where
60 Self: 'a;
61
62 fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> {
63 poll_fn(move |_| {
64 Poll::Ready(Ok(())) // TODO: Is there a better implementation for this?
65 })
66 }
67}
68
69impl<'d> embedded_io::asynch::Read for TcpReader<'d> {
70 type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>>
71 where
72 Self: 'a;
73
74 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
75 poll_fn(move |cx| {
76 // CAUTION: smoltcp semantics around EOF are different to what you'd expect
77 // from posix-like IO, so we have to tweak things here.
78 with_socket(self.handle, |s, _| match s.recv_slice(buf) {
79 // No data ready
80 Ok(0) => {
81 s.register_recv_waker(cx.waker());
82 Poll::Pending
83 }
84 // Data ready!
85 Ok(n) => Poll::Ready(Ok(n)),
86 // EOF
87 Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)),
88 // Connection reset. TODO: this can also be timeouts etc, investigate.
89 Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)),
90 // smoltcp returns no errors other than the above.
91 Err(_) => unreachable!(),
92 })
93 })
94 }
95}
96
97impl<'d> embedded_io::asynch::Write for TcpWriter<'d> {
98 type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>>
99 where
100 Self: 'a;
101
102 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> {
103 poll_fn(move |cx| {
104 with_socket(self.handle, |s, _| match s.send_slice(buf) {
43 // Not ready to send (no space in the tx buffer) 105 // Not ready to send (no space in the tx buffer)
44 Ok(0) => { 106 Ok(0) => {
45 s.register_send_waker(cx.waker()); 107 s.register_send_waker(cx.waker());
diff --git a/embassy-net/src/tcp/mod.rs b/embassy-net/src/tcp/mod.rs
index 3bfd4c7b6..425e6acbc 100644
--- a/embassy-net/src/tcp/mod.rs
+++ b/embassy-net/src/tcp/mod.rs
@@ -49,6 +49,20 @@ pub struct TcpSocket<'a> {
49 49
50impl<'a> Unpin for TcpSocket<'a> {} 50impl<'a> Unpin for TcpSocket<'a> {}
51 51
52pub struct TcpReader<'a> {
53 handle: SocketHandle,
54 ghost: PhantomData<&'a mut [u8]>,
55}
56
57impl<'a> Unpin for TcpReader<'a> {}
58
59pub struct TcpWriter<'a> {
60 handle: SocketHandle,
61 ghost: PhantomData<&'a mut [u8]>,
62}
63
64impl<'a> Unpin for TcpWriter<'a> {}
65
52impl<'a> TcpSocket<'a> { 66impl<'a> TcpSocket<'a> {
53 pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { 67 pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
54 let handle = Stack::with(|stack| { 68 let handle = Stack::with(|stack| {
@@ -66,12 +80,27 @@ impl<'a> TcpSocket<'a> {
66 } 80 }
67 } 81 }
68 82
83 pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) {
84 (
85 TcpReader {
86 handle: self.handle,
87 ghost: PhantomData,
88 },
89 TcpWriter {
90 handle: self.handle,
91 ghost: PhantomData,
92 },
93 )
94 }
95
69 pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError> 96 pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError>
70 where 97 where
71 T: Into<IpEndpoint>, 98 T: Into<IpEndpoint>,
72 { 99 {
73 let local_port = Stack::with(|stack| stack.get_local_port()); 100 let local_port = Stack::with(|stack| stack.get_local_port());
74 match self.with(|s, cx| s.connect(cx, remote_endpoint, local_port)) { 101 match with_socket(self.handle, |s, cx| {
102 s.connect(cx, remote_endpoint, local_port)
103 }) {
75 Ok(()) => {} 104 Ok(()) => {}
76 Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), 105 Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState),
77 Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), 106 Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute),
@@ -80,7 +109,7 @@ impl<'a> TcpSocket<'a> {
80 } 109 }
81 110
82 futures::future::poll_fn(|cx| { 111 futures::future::poll_fn(|cx| {
83 self.with(|s, _| match s.state() { 112 with_socket(self.handle, |s, _| match s.state() {
84 TcpState::Closed | TcpState::TimeWait => { 113 TcpState::Closed | TcpState::TimeWait => {
85 Poll::Ready(Err(ConnectError::ConnectionReset)) 114 Poll::Ready(Err(ConnectError::ConnectionReset))
86 } 115 }
@@ -99,7 +128,7 @@ impl<'a> TcpSocket<'a> {
99 where 128 where
100 T: Into<IpEndpoint>, 129 T: Into<IpEndpoint>,
101 { 130 {
102 match self.with(|s, _| s.listen(local_endpoint)) { 131 match with_socket(self.handle, |s, _| s.listen(local_endpoint)) {
103 Ok(()) => {} 132 Ok(()) => {}
104 Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), 133 Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState),
105 Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), 134 Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort),
@@ -108,7 +137,7 @@ impl<'a> TcpSocket<'a> {
108 } 137 }
109 138
110 futures::future::poll_fn(|cx| { 139 futures::future::poll_fn(|cx| {
111 self.with(|s, _| match s.state() { 140 with_socket(self.handle, |s, _| match s.state() {
112 TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { 141 TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => {
113 s.register_send_waker(cx.waker()); 142 s.register_send_waker(cx.waker());
114 Poll::Pending 143 Poll::Pending
@@ -120,57 +149,58 @@ impl<'a> TcpSocket<'a> {
120 } 149 }
121 150
122 pub fn set_timeout(&mut self, duration: Option<Duration>) { 151 pub fn set_timeout(&mut self, duration: Option<Duration>) {
123 self.with(|s, _| s.set_timeout(duration)) 152 with_socket(self.handle, |s, _| s.set_timeout(duration))
124 } 153 }
125 154
126 pub fn set_keep_alive(&mut self, interval: Option<Duration>) { 155 pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
127 self.with(|s, _| s.set_keep_alive(interval)) 156 with_socket(self.handle, |s, _| s.set_keep_alive(interval))
128 } 157 }
129 158
130 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { 159 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
131 self.with(|s, _| s.set_hop_limit(hop_limit)) 160 with_socket(self.handle, |s, _| s.set_hop_limit(hop_limit))
132 } 161 }
133 162
134 pub fn local_endpoint(&self) -> IpEndpoint { 163 pub fn local_endpoint(&self) -> IpEndpoint {
135 self.with(|s, _| s.local_endpoint()) 164 with_socket(self.handle, |s, _| s.local_endpoint())
136 } 165 }
137 166
138 pub fn remote_endpoint(&self) -> IpEndpoint { 167 pub fn remote_endpoint(&self) -> IpEndpoint {
139 self.with(|s, _| s.remote_endpoint()) 168 with_socket(self.handle, |s, _| s.remote_endpoint())
140 } 169 }
141 170
142 pub fn state(&self) -> TcpState { 171 pub fn state(&self) -> TcpState {
143 self.with(|s, _| s.state()) 172 with_socket(self.handle, |s, _| s.state())
144 } 173 }
145 174
146 pub fn close(&mut self) { 175 pub fn close(&mut self) {
147 self.with(|s, _| s.close()) 176 with_socket(self.handle, |s, _| s.close())
148 } 177 }
149 178
150 pub fn abort(&mut self) { 179 pub fn abort(&mut self) {
151 self.with(|s, _| s.abort()) 180 with_socket(self.handle, |s, _| s.abort())
152 } 181 }
153 182
154 pub fn may_send(&self) -> bool { 183 pub fn may_send(&self) -> bool {
155 self.with(|s, _| s.may_send()) 184 with_socket(self.handle, |s, _| s.may_send())
156 } 185 }
157 186
158 pub fn may_recv(&self) -> bool { 187 pub fn may_recv(&self) -> bool {
159 self.with(|s, _| s.may_recv()) 188 with_socket(self.handle, |s, _| s.may_recv())
160 } 189 }
190}
161 191
162 fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R) -> R { 192fn with_socket<R>(
163 Stack::with(|stack| { 193 handle: SocketHandle,
164 let res = { 194 f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R,
165 let (s, cx) = stack 195) -> R {
166 .iface 196 Stack::with(|stack| {
167 .get_socket_and_context::<SyncTcpSocket>(self.handle); 197 let res = {
168 f(s, cx) 198 let (s, cx) = stack.iface.get_socket_and_context::<SyncTcpSocket>(handle);
169 }; 199 f(s, cx)
170 stack.wake(); 200 };
171 res 201 stack.wake();
172 }) 202 res
173 } 203 })
174} 204}
175 205
176impl<'a> Drop for TcpSocket<'a> { 206impl<'a> Drop for TcpSocket<'a> {
@@ -190,3 +220,11 @@ impl embedded_io::Error for Error {
190impl<'d> embedded_io::Io for TcpSocket<'d> { 220impl<'d> embedded_io::Io for TcpSocket<'d> {
191 type Error = Error; 221 type Error = Error;
192} 222}
223
224impl<'d> embedded_io::Io for TcpReader<'d> {
225 type Error = Error;
226}
227
228impl<'d> embedded_io::Io for TcpWriter<'d> {
229 type Error = Error;
230}