use std::{collections::HashMap, sync::Arc}; use bytes::Bytes; use futures::{SinkExt, StreamExt}; use tokio::{ sync::{mpsc, oneshot}, task::AbortHandle, }; use crate::{ Channel, protocol::{RpcCall, RpcMessage, RpcResponse}, }; const CLIENT_CHANNEL_BUFFER_SIZE: usize = 64; struct ClientChannelMessage { service: String, method: String, arguments: Bytes, responder: oneshot::Sender>, } struct ClientChannelInner { sender: mpsc::Sender, abort_handle: AbortHandle, } impl Drop for ClientChannelInner { fn drop(&mut self) { self.abort_handle.abort(); } } #[derive(Clone)] pub struct ClientChannel(Arc); impl ClientChannel { pub fn new(channel: C) -> Self { let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_BUFFER_SIZE); let abort_handle = tokio::spawn(client_channel_loop(channel, rx)).abort_handle(); Self(Arc::new(ClientChannelInner { sender: tx, abort_handle, })) } pub async fn call( &self, service: String, method: String, arguments: Bytes, ) -> std::io::Result { let (tx, rx) = oneshot::channel(); self.0 .sender .send(ClientChannelMessage { service, method, arguments, responder: tx, }) .await .expect("client channel task should never shutdown while a client is alive"); rx.await .expect("client channel task should never shutdown while a client is alive") } } async fn client_channel_loop( mut channel: C, mut rx: mpsc::Receiver, ) { enum Select { RpcMessage(RpcMessage), ClientChannelMessage(ClientChannelMessage), } let mut responders = HashMap::>>::default(); let mut rpc_call_id = 0; loop { let select = tokio::select! { Some(Ok(v)) = channel.next() => Select::RpcMessage(v), Some(v) = rx.recv() => Select::ClientChannelMessage(v), }; match select { Select::RpcMessage(RpcMessage::Response(RpcResponse { id, value })) => { if let Some(responder) = responders.remove(&id) { let _ = responder.send(Ok(value)); } } Select::RpcMessage(_) => todo!(), Select::ClientChannelMessage(ClientChannelMessage { service, method, arguments, responder, }) => { let id = rpc_call_id; rpc_call_id += 1; let result = channel .send(RpcMessage::Call(RpcCall { id, service, method, arguments, })) .await; match result { Ok(()) => { responders.insert(id, responder); } Err(err) => { let _ = responder.send(Err(err)); } } } } } }