aboutsummaryrefslogtreecommitdiff
path: root/embassy-net/src/tcp_socket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-net/src/tcp_socket.rs')
-rw-r--r--embassy-net/src/tcp_socket.rs178
1 files changed, 178 insertions, 0 deletions
diff --git a/embassy-net/src/tcp_socket.rs b/embassy-net/src/tcp_socket.rs
new file mode 100644
index 000000000..7f4eb014c
--- /dev/null
+++ b/embassy-net/src/tcp_socket.rs
@@ -0,0 +1,178 @@
1use core::marker::PhantomData;
2use core::mem;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use embassy::io;
6use embassy::io::{AsyncBufRead, AsyncWrite};
7use smoltcp::socket::SocketHandle;
8use smoltcp::socket::TcpSocket as SyncTcpSocket;
9use smoltcp::socket::{TcpSocketBuffer, TcpState};
10use smoltcp::time::Duration;
11use smoltcp::wire::IpEndpoint;
12use smoltcp::{Error, Result};
13
14use super::stack::Stack;
15use crate::fmt::*;
16
17pub struct TcpSocket<'a> {
18 handle: SocketHandle,
19 ghost: PhantomData<&'a mut [u8]>,
20}
21
22impl<'a> Unpin for TcpSocket<'a> {}
23
24impl<'a> TcpSocket<'a> {
25 pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
26 let handle = Stack::with(|stack| {
27 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
28 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
29 stack.sockets.add(SyncTcpSocket::new(
30 TcpSocketBuffer::new(rx_buffer),
31 TcpSocketBuffer::new(tx_buffer),
32 ))
33 });
34
35 Self {
36 handle,
37 ghost: PhantomData,
38 }
39 }
40
41 pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<()>
42 where
43 T: Into<IpEndpoint>,
44 {
45 let local_port = Stack::with(|stack| stack.get_local_port());
46 self.with(|s| s.connect(remote_endpoint, local_port))?;
47
48 futures::future::poll_fn(|cx| {
49 self.with(|s| match s.state() {
50 TcpState::Closed | TcpState::TimeWait => Poll::Ready(Err(Error::Unaddressable)),
51 TcpState::Listen => Poll::Ready(Err(Error::Illegal)),
52 TcpState::SynSent | TcpState::SynReceived => {
53 s.register_send_waker(cx.waker());
54 Poll::Pending
55 }
56 _ => Poll::Ready(Ok(())),
57 })
58 })
59 .await
60 }
61
62 pub fn set_timeout(&mut self, duration: Option<Duration>) {
63 self.with(|s| s.set_timeout(duration))
64 }
65
66 pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
67 self.with(|s| s.set_keep_alive(interval))
68 }
69
70 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
71 self.with(|s| s.set_hop_limit(hop_limit))
72 }
73
74 pub fn local_endpoint(&self) -> IpEndpoint {
75 self.with(|s| s.local_endpoint())
76 }
77
78 pub fn remote_endpoint(&self) -> IpEndpoint {
79 self.with(|s| s.remote_endpoint())
80 }
81
82 pub fn state(&self) -> TcpState {
83 self.with(|s| s.state())
84 }
85
86 pub fn close(&mut self) {
87 self.with(|s| s.close())
88 }
89
90 pub fn abort(&mut self) {
91 self.with(|s| s.abort())
92 }
93
94 pub fn may_send(&self) -> bool {
95 self.with(|s| s.may_send())
96 }
97
98 pub fn may_recv(&self) -> bool {
99 self.with(|s| s.may_recv())
100 }
101
102 fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket) -> R) -> R {
103 Stack::with(|stack| {
104 let res = {
105 let mut s = stack.sockets.get::<SyncTcpSocket>(self.handle);
106 f(&mut *s)
107 };
108 stack.wake();
109 res
110 })
111 }
112}
113
114fn to_ioerr(e: Error) -> io::Error {
115 warn!("smoltcp err: {:?}", e);
116 // todo
117 io::Error::Other
118}
119
120impl<'a> Drop for TcpSocket<'a> {
121 fn drop(&mut self) {
122 Stack::with(|stack| {
123 stack.sockets.remove(self.handle);
124 })
125 }
126}
127
128impl<'a> AsyncBufRead for TcpSocket<'a> {
129 fn poll_fill_buf<'z>(
130 self: Pin<&'z mut Self>,
131 cx: &mut Context<'_>,
132 ) -> Poll<io::Result<&'z [u8]>> {
133 self.with(|socket| match socket.peek(1 << 30) {
134 // No data ready
135 Ok(buf) if buf.len() == 0 => {
136 socket.register_recv_waker(cx.waker());
137 Poll::Pending
138 }
139 // Data ready!
140 Ok(buf) => {
141 // Safety:
142 // - User can't touch the inner TcpSocket directly at all.
143 // - The socket itself won't touch these bytes until consume() is called, which
144 // requires the user to release this borrow.
145 let buf: &'z [u8] = unsafe { core::mem::transmute(&*buf) };
146 Poll::Ready(Ok(buf))
147 }
148 // EOF
149 Err(Error::Finished) => Poll::Ready(Ok(&[][..])),
150 // Error
151 Err(e) => Poll::Ready(Err(to_ioerr(e))),
152 })
153 }
154
155 fn consume(self: Pin<&mut Self>, amt: usize) {
156 self.with(|s| s.recv(|_| (amt, ()))).unwrap()
157 }
158}
159
160impl<'a> AsyncWrite for TcpSocket<'a> {
161 fn poll_write(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 buf: &[u8],
165 ) -> Poll<io::Result<usize>> {
166 self.with(|s| match s.send_slice(buf) {
167 // Not ready to send (no space in the tx buffer)
168 Ok(0) => {
169 s.register_send_waker(cx.waker());
170 Poll::Pending
171 }
172 // Some data sent
173 Ok(n) => Poll::Ready(Ok(n)),
174 // Error
175 Err(e) => Poll::Ready(Err(to_ioerr(e))),
176 })
177 }
178}