aboutsummaryrefslogtreecommitdiff
path: root/examples/common/std_async_tcp.rs
blob: bd97fa965572706492cc514025f5f9eee420e1ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use std::{
    io::{Read, Write},
    net::{TcpStream, ToSocketAddrs},
    sync::{Arc, Mutex},
    thread::JoinHandle,
};

use embassy_sync::waitqueue::AtomicWaker;

pub struct AsyncTcp {
    write_handle: JoinHandle<()>,
    write_buffer: Arc<Mutex<Vec<u8>>>,
    read_buffer: Arc<Mutex<Vec<u8>>>,
    waker: Arc<AtomicWaker>,
}

impl AsyncTcp {
    pub fn connect(addr: impl ToSocketAddrs) -> Self {
        let stream = TcpStream::connect(addr).expect("failed to connect to remote");
        let mut read_stream = stream.try_clone().unwrap();
        let mut write_stream = stream;

        let read_buffer: Arc<Mutex<Vec<u8>>> = Default::default();
        let write_buffer: Arc<Mutex<Vec<u8>>> = Default::default();

        let waker = Arc::new(AtomicWaker::new());

        let write_handle = std::thread::spawn({
            let write_buffer = write_buffer.clone();
            move || {
                loop {
                    let buffer = {
                        let mut buffer = write_buffer.lock().unwrap();
                        std::mem::take(&mut *buffer)
                    };
                    if !buffer.is_empty() {
                        println!("writing {} bytes", buffer.len());
                        write_stream.write_all(&buffer).unwrap();
                        write_stream.flush().unwrap();
                    } else {
                        std::thread::park();
                    }
                }
            }
        });

        std::thread::spawn({
            let read_buffer = read_buffer.clone();
            let waker = waker.clone();
            move || {
                let mut scratch = [0u8; 1024];
                loop {
                    let n = read_stream.read(&mut scratch).unwrap();
                    if n == 0 {
                        panic!("EOF");
                    }

                    {
                        let mut buffer = read_buffer.lock().unwrap();
                        buffer.extend_from_slice(&scratch[..n]);
                        waker.wake();
                    }
                }
            }
        });

        Self {
            write_handle,
            write_buffer,
            read_buffer,
            waker,
        }
    }
}

impl embedded_io_async::ErrorType for AsyncTcp {
    type Error = std::io::Error;
}

impl embedded_io_async::Write for AsyncTcp {
    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
        {
            let mut buffer = self.write_buffer.lock().unwrap();
            buffer.extend_from_slice(buf);
        }
        self.write_handle.thread().unpark();
        Ok(buf.len())
    }
}

impl embedded_io_async::Read for AsyncTcp {
    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
        struct WaitForWaker<'a>(&'a AtomicWaker, bool);

        impl<'a> Future for WaitForWaker<'a> {
            type Output = ();

            fn poll(
                mut self: std::pin::Pin<&mut Self>,
                cx: &mut std::task::Context<'_>,
            ) -> std::task::Poll<Self::Output> {
                if self.1 {
                    std::task::Poll::Ready(())
                } else {
                    self.as_mut().1 = true;
                    self.0.register(cx.waker());
                    std::task::Poll::Pending
                }
            }
        }

        loop {
            {
                let mut buffer = self.read_buffer.lock().unwrap();
                if !buffer.is_empty() {
                    let copy_n = buf.len().min(buffer.len());
                    buf[..copy_n].copy_from_slice(&buffer[..copy_n]);
                    buffer.drain(..copy_n);
                    return Ok(copy_n);
                }
            }
            WaitForWaker(&self.waker, false).await
        }
    }
}