From f319d7ab5278a3cfb43d38875d81c28cc2dce1e1 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Wed, 16 Jul 2025 10:46:41 +0100 Subject: Initial commit - extracted urpc from monorepo --- src/server.rs | 166 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/server.rs (limited to 'src/server.rs') diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..414340e --- /dev/null +++ b/src/server.rs @@ -0,0 +1,166 @@ +use std::{collections::HashMap, sync::Arc}; + +use futures::{SinkExt, StreamExt}; +use tokio::task::{AbortHandle, JoinSet}; + +use crate::{ + Channel, Listener, Service, + protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse}, +}; + +#[derive(Clone)] +struct Services(Arc>>); + +impl Services { + fn new(services: HashMap>) -> Self { + Self(Arc::new(services)) + } + + fn get_service(&self, name: &str) -> std::io::Result> { + match self.0.get(name) { + Some(service) => Ok(service.clone()), + None => Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "service not found", + )), + } + } +} + +type ListenerSpawner = + Box>, Services) -> AbortHandle + Send + 'static>; + +#[derive(Debug, Default)] +struct AbortHandles(Vec); + +impl Drop for AbortHandles { + fn drop(&mut self) { + for handle in &self.0 { + handle.abort(); + } + } +} + +impl AbortHandles { + pub fn push(&mut self, handle: AbortHandle) { + self.0.push(handle); + } +} + +#[derive(Default)] +pub struct Server { + services: HashMap>, + listener_spawners: Vec, +} + +impl Server { + pub fn with_service(mut self, service: T) -> Self + where + T: Service, + { + let name = T::name(); + let service = Arc::new(service); + self.services.insert(name.to_string(), service); + self + } + + pub fn with_listener(mut self, listener: L) -> Self + where + C: Channel, + L: Listener, + { + self.listener_spawners + .push(Box::new(move |join_set, services| { + join_set.spawn(listener_loop(listener, services)) + })); + self + } + + pub async fn serve(self) -> std::io::Result<()> { + let services = Services::new(self.services); + let mut join_set = JoinSet::default(); + let mut abort_handles = AbortHandles::default(); + for spawner in self.listener_spawners { + let abort_handle = (spawner)(&mut join_set, services.clone()); + abort_handles.push(abort_handle); + } + match join_set.join_next().await { + Some(Ok(Ok(()))) => Ok(()), + Some(Ok(Err(err))) => Err(err), + Some(Err(err)) => Err(std::io::Error::other(err)), + None => Ok(()), + } + } +} + +async fn listener_loop(mut listener: L, services: Services) -> std::io::Result<()> +where + C: Channel, + L: Listener, +{ + while let Some(result) = listener.next().await { + let channel = result?; + let services = services.clone(); + tokio::spawn(channel_handler(channel, services)); + } + Ok(()) +} + +async fn channel_handler(mut channel: C, services: Services) -> std::io::Result<()> { + enum Select { + Empty, + Message(RpcMessage), + } + + let (response_tx, mut response_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + let mut requests: HashMap = Default::default(); + loop { + let select = tokio::select! { + reqopt = channel.next() => match reqopt { + Some(Ok(message)) => Select::Message(message), + Some(Err(err)) => return Err(err), + None => Select::Empty, + }, + Some(response) = response_rx.recv() => match response { + Ok(response) => Select::Message(RpcMessage::Response(response)), + Err(err) => return Err(err), + } + }; + + match select { + Select::Empty => break, + Select::Message(message) => match message { + RpcMessage::Call(RpcCall { + id, + service, + method, + arguments, + }) => { + let response_tx = response_tx.clone(); + let service = services.get_service(&service)?; + let handle = tokio::spawn(async move { + let response = match service.call(method, arguments).await { + Ok(value) => Ok(RpcResponse { id, value }), + Err(err) => Err(err), + }; + let _ = response_tx.send(response); + }) + .abort_handle(); + requests.insert(id, handle); + } + RpcMessage::Cancel(RpcCancel { id }) => { + if let Some(handle) = requests.remove(&id) { + handle.abort(); + } + } + RpcMessage::Response(response) => { + requests.remove(&response.id); + channel.send(RpcMessage::Response(response)).await?; + } + }, + } + } + + Ok(()) +} -- cgit