use std::{collections::HashMap, sync::Arc}; use futures::{SinkExt, StreamExt}; use tokio::task::{AbortHandle, JoinSet}; use crate::{ Channel, Listener, Service, protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse}, }; #[derive(Clone)] struct Services(Arc>>); impl Services { fn new(services: HashMap>) -> Self { Self(Arc::new(services)) } fn get_service(&self, name: &str) -> std::io::Result> { match self.0.get(name) { Some(service) => Ok(service.clone()), None => Err(std::io::Error::new( std::io::ErrorKind::NotFound, "service not found", )), } } } type ListenerSpawner = Box>, Services) -> AbortHandle + Send + 'static>; #[derive(Debug, Default)] struct AbortHandles(Vec); impl Drop for AbortHandles { fn drop(&mut self) { for handle in &self.0 { handle.abort(); } } } impl AbortHandles { pub fn push(&mut self, handle: AbortHandle) { self.0.push(handle); } } #[derive(Default)] pub struct Server { services: HashMap>, listener_spawners: Vec, } impl Server { pub fn with_service(mut self, service: T) -> Self where T: Service, { let name = T::name(); let service = Arc::new(service); self.services.insert(name.to_string(), service); self } pub fn with_listener(mut self, listener: L) -> Self where C: Channel, L: Listener, { self.listener_spawners .push(Box::new(move |join_set, services| { join_set.spawn(listener_loop(listener, services)) })); self } pub async fn serve(self) -> std::io::Result<()> { let services = Services::new(self.services); let mut join_set = JoinSet::default(); let mut abort_handles = AbortHandles::default(); for spawner in self.listener_spawners { let abort_handle = (spawner)(&mut join_set, services.clone()); abort_handles.push(abort_handle); } match join_set.join_next().await { Some(Ok(Ok(()))) => Ok(()), Some(Ok(Err(err))) => Err(err), Some(Err(err)) => Err(std::io::Error::other(err)), None => Ok(()), } } } async fn listener_loop(mut listener: L, services: Services) -> std::io::Result<()> where C: Channel, L: Listener, { while let Some(result) = listener.next().await { let channel = result?; let services = services.clone(); tokio::spawn(channel_handler(channel, services)); } Ok(()) } async fn channel_handler(mut channel: C, services: Services) -> std::io::Result<()> { enum Select { Empty, Message(RpcMessage), } let (response_tx, mut response_rx) = tokio::sync::mpsc::unbounded_channel::>(); let mut requests: HashMap = Default::default(); loop { let select = tokio::select! { reqopt = channel.next() => match reqopt { Some(Ok(message)) => Select::Message(message), Some(Err(err)) => return Err(err), None => Select::Empty, }, Some(response) = response_rx.recv() => match response { Ok(response) => Select::Message(RpcMessage::Response(response)), Err(err) => return Err(err), } }; match select { Select::Empty => break, Select::Message(message) => match message { RpcMessage::Call(RpcCall { id, service, method, arguments, }) => { let response_tx = response_tx.clone(); let service = services.get_service(&service)?; let handle = tokio::spawn(async move { let response = match service.call(method, arguments).await { Ok(value) => Ok(RpcResponse { id, value }), Err(err) => Err(err), }; let _ = response_tx.send(response); }) .abort_handle(); requests.insert(id, handle); } RpcMessage::Cancel(RpcCancel { id }) => { if let Some(handle) = requests.remove(&id) { handle.abort(); } } RpcMessage::Response(response) => { requests.remove(&response.id); channel.send(RpcMessage::Response(response)).await?; } }, } } Ok(()) }