aboutsummaryrefslogtreecommitdiff
path: root/src/client_channel.rs
blob: 5666a16ceb7a41d7898eee9766a3ae2d7c823518 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::{collections::HashMap, sync::Arc};

use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use tokio::{
    sync::{mpsc, oneshot},
    task::AbortHandle,
};

use crate::{
    Channel,
    protocol::{RpcCall, RpcMessage, RpcResponse},
};

const CLIENT_CHANNEL_BUFFER_SIZE: usize = 64;

struct ClientChannelMessage {
    service: String,
    method: String,
    arguments: Bytes,
    responder: oneshot::Sender<std::io::Result<Bytes>>,
}

struct ClientChannelInner {
    sender: mpsc::Sender<ClientChannelMessage>,
    abort_handle: AbortHandle,
}

impl Drop for ClientChannelInner {
    fn drop(&mut self) {
        self.abort_handle.abort();
    }
}

#[derive(Clone)]
pub struct ClientChannel(Arc<ClientChannelInner>);

impl ClientChannel {
    pub fn new<C: Channel>(channel: C) -> Self {
        let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_BUFFER_SIZE);
        let abort_handle = tokio::spawn(client_channel_loop(channel, rx)).abort_handle();
        Self(Arc::new(ClientChannelInner {
            sender: tx,
            abort_handle,
        }))
    }

    pub async fn call(
        &self,
        service: String,
        method: String,
        arguments: Bytes,
    ) -> std::io::Result<Bytes> {
        let (tx, rx) = oneshot::channel();
        self.0
            .sender
            .send(ClientChannelMessage {
                service,
                method,
                arguments,
                responder: tx,
            })
            .await
            .expect("client channel task should never shutdown while a client is alive");
        rx.await
            .expect("client channel task should never shutdown while a client is alive")
    }
}

async fn client_channel_loop<C: Channel>(
    mut channel: C,
    mut rx: mpsc::Receiver<ClientChannelMessage>,
) {
    enum Select {
        RpcMessage(RpcMessage),
        ClientChannelMessage(ClientChannelMessage),
    }

    let mut responders = HashMap::<u64, oneshot::Sender<std::io::Result<Bytes>>>::default();
    let mut rpc_call_id = 0;

    loop {
        let select = tokio::select! {
            Some(Ok(v)) = channel.next() => Select::RpcMessage(v),
            Some(v) = rx.recv() => Select::ClientChannelMessage(v),
        };

        match select {
            Select::RpcMessage(RpcMessage::Response(RpcResponse { id, value })) => {
                if let Some(responder) = responders.remove(&id) {
                    let _ = responder.send(Ok(value));
                }
            }
            Select::RpcMessage(_) => todo!(),
            Select::ClientChannelMessage(ClientChannelMessage {
                service,
                method,
                arguments,
                responder,
            }) => {
                let id = rpc_call_id;
                rpc_call_id += 1;

                let result = channel
                    .send(RpcMessage::Call(RpcCall {
                        id,
                        service,
                        method,
                        arguments,
                    }))
                    .await;

                match result {
                    Ok(()) => {
                        responders.insert(id, responder);
                    }
                    Err(err) => {
                        let _ = responder.send(Err(err));
                    }
                }
            }
        }
    }
}