diff options
| author | diogo464 <[email protected]> | 2025-07-16 10:46:41 +0100 |
|---|---|---|
| committer | diogo464 <[email protected]> | 2025-07-16 10:46:41 +0100 |
| commit | f319d7ab5278a3cfb43d38875d81c28cc2dce1e1 (patch) | |
| tree | cb161fd990643e267bbc373fb09ccd7b689a23b5 /src | |
Initial commit - extracted urpc from monorepo
Diffstat (limited to 'src')
| -rw-r--r-- | src/channel.rs | 162 | ||||
| -rw-r--r-- | src/client_channel.rs | 124 | ||||
| -rw-r--r-- | src/internal.rs | 2 | ||||
| -rw-r--r-- | src/lib.rs | 52 | ||||
| -rw-r--r-- | src/protocol.rs | 107 | ||||
| -rw-r--r-- | src/server.rs | 166 | ||||
| -rw-r--r-- | src/tcp.rs | 156 | ||||
| -rw-r--r-- | src/unix.rs | 160 |
8 files changed, 929 insertions, 0 deletions
diff --git a/src/channel.rs b/src/channel.rs new file mode 100644 index 0000000..f80ca87 --- /dev/null +++ b/src/channel.rs | |||
| @@ -0,0 +1,162 @@ | |||
| 1 | //! MPSC channel backed channel and listener implementations. | ||
| 2 | //! | ||
| 3 | //! This module provides an mpsc channel based listener/channel implementation that can be useful | ||
| 4 | //! to for tests. | ||
| 5 | //! | ||
| 6 | //! The current implementation uses [`tokio`]'s unbounded mpsc channels. | ||
| 7 | //! | ||
| 8 | //! ``` | ||
| 9 | //! #[urpc::service] | ||
| 10 | //! trait Hello { | ||
| 11 | //! type Error = (); | ||
| 12 | //! | ||
| 13 | //! async fn hello(name: String) -> String; | ||
| 14 | //! } | ||
| 15 | //! | ||
| 16 | //! struct HelloServer; | ||
| 17 | //! | ||
| 18 | //! impl Hello for HelloServer { | ||
| 19 | //! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> { | ||
| 20 | //! Ok(format!("Hello, {name}!")) | ||
| 21 | //! } | ||
| 22 | //! } | ||
| 23 | //! | ||
| 24 | //! #[tokio::main] | ||
| 25 | //! async fn main() -> Result<(), Box<dyn std::error::Error>>{ | ||
| 26 | //! let (dialer, listener) = urpc::channel::new(); | ||
| 27 | //! | ||
| 28 | //! // spawn the server | ||
| 29 | //! tokio::spawn(async move { | ||
| 30 | //! urpc::Server::default() | ||
| 31 | //! .with_listener(listener) | ||
| 32 | //! .with_service(HelloServer.into_service()) | ||
| 33 | //! .serve() | ||
| 34 | //! .await | ||
| 35 | //! }); | ||
| 36 | //! | ||
| 37 | //! // create a client | ||
| 38 | //! let channel = urpc::ClientChannel::new(dialer.connect()?); | ||
| 39 | //! let client = HelloClient::new(channel); | ||
| 40 | //! let greeting = client.hello("World".into()).await.unwrap(); | ||
| 41 | //! assert_eq!(greeting, "Hello, World!"); | ||
| 42 | //! Ok(()) | ||
| 43 | //! } | ||
| 44 | //! ``` | ||
| 45 | |||
| 46 | use futures::{Sink, Stream}; | ||
| 47 | use tokio::sync::mpsc; | ||
| 48 | |||
| 49 | use crate::protocol::RpcMessage; | ||
| 50 | |||
| 51 | pub struct ChannelListener { | ||
| 52 | receiver: mpsc::UnboundedReceiver<Channel>, | ||
| 53 | } | ||
| 54 | |||
| 55 | impl Stream for ChannelListener { | ||
| 56 | type Item = std::io::Result<Channel>; | ||
| 57 | |||
| 58 | fn poll_next( | ||
| 59 | self: std::pin::Pin<&mut Self>, | ||
| 60 | cx: &mut std::task::Context<'_>, | ||
| 61 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 62 | match self.get_mut().receiver.poll_recv(cx) { | ||
| 63 | std::task::Poll::Ready(Some(c)) => std::task::Poll::Ready(Some(Ok(c))), | ||
| 64 | std::task::Poll::Ready(None) => std::task::Poll::Ready(None), | ||
| 65 | std::task::Poll::Pending => std::task::Poll::Pending, | ||
| 66 | } | ||
| 67 | } | ||
| 68 | } | ||
| 69 | |||
| 70 | impl crate::Listener<Channel> for ChannelListener {} | ||
| 71 | |||
| 72 | pub struct ChannelDialer { | ||
| 73 | sender: mpsc::UnboundedSender<Channel>, | ||
| 74 | } | ||
| 75 | |||
| 76 | impl ChannelDialer { | ||
| 77 | fn new() -> (Self, ChannelListener) { | ||
| 78 | let (sender, receiver) = mpsc::unbounded_channel(); | ||
| 79 | (Self { sender }, ChannelListener { receiver }) | ||
| 80 | } | ||
| 81 | |||
| 82 | pub fn connect(&self) -> std::io::Result<Channel> { | ||
| 83 | let (ch1, ch2) = Channel::new(); | ||
| 84 | self.sender.send(ch1).expect("TODO: remove this"); | ||
| 85 | Ok(ch2) | ||
| 86 | } | ||
| 87 | } | ||
| 88 | |||
| 89 | pub struct Channel { | ||
| 90 | sender: mpsc::UnboundedSender<RpcMessage>, | ||
| 91 | receiver: mpsc::UnboundedReceiver<RpcMessage>, | ||
| 92 | } | ||
| 93 | |||
| 94 | impl Channel { | ||
| 95 | fn new() -> (Self, Self) { | ||
| 96 | let (sender0, receiver0) = mpsc::unbounded_channel(); | ||
| 97 | let (sender1, receiver1) = mpsc::unbounded_channel(); | ||
| 98 | ( | ||
| 99 | Self { | ||
| 100 | sender: sender0, | ||
| 101 | receiver: receiver1, | ||
| 102 | }, | ||
| 103 | Self { | ||
| 104 | sender: sender1, | ||
| 105 | receiver: receiver0, | ||
| 106 | }, | ||
| 107 | ) | ||
| 108 | } | ||
| 109 | } | ||
| 110 | |||
| 111 | impl Stream for Channel { | ||
| 112 | type Item = std::io::Result<RpcMessage>; | ||
| 113 | |||
| 114 | fn poll_next( | ||
| 115 | self: std::pin::Pin<&mut Self>, | ||
| 116 | cx: &mut std::task::Context<'_>, | ||
| 117 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 118 | match self.get_mut().receiver.poll_recv(cx) { | ||
| 119 | std::task::Poll::Ready(Some(msg)) => std::task::Poll::Ready(Some(Ok(msg))), | ||
| 120 | std::task::Poll::Ready(None) => std::task::Poll::Ready(None), | ||
| 121 | std::task::Poll::Pending => std::task::Poll::Pending, | ||
| 122 | } | ||
| 123 | } | ||
| 124 | } | ||
| 125 | |||
| 126 | impl Sink<RpcMessage> for Channel { | ||
| 127 | type Error = std::io::Error; | ||
| 128 | |||
| 129 | fn poll_ready( | ||
| 130 | self: std::pin::Pin<&mut Self>, | ||
| 131 | _cx: &mut std::task::Context<'_>, | ||
| 132 | ) -> std::task::Poll<Result<(), Self::Error>> { | ||
| 133 | return std::task::Poll::Ready(Ok(())); | ||
| 134 | } | ||
| 135 | |||
| 136 | fn start_send(self: std::pin::Pin<&mut Self>, item: RpcMessage) -> Result<(), Self::Error> { | ||
| 137 | match self.sender.send(item) { | ||
| 138 | Ok(()) => Ok(()), | ||
| 139 | Err(err) => Err(std::io::Error::other(err)), | ||
| 140 | } | ||
| 141 | } | ||
| 142 | |||
| 143 | fn poll_flush( | ||
| 144 | self: std::pin::Pin<&mut Self>, | ||
| 145 | _cx: &mut std::task::Context<'_>, | ||
| 146 | ) -> std::task::Poll<Result<(), Self::Error>> { | ||
| 147 | std::task::Poll::Ready(Ok(())) | ||
| 148 | } | ||
| 149 | |||
| 150 | fn poll_close( | ||
| 151 | self: std::pin::Pin<&mut Self>, | ||
| 152 | _cx: &mut std::task::Context<'_>, | ||
| 153 | ) -> std::task::Poll<Result<(), Self::Error>> { | ||
| 154 | std::task::Poll::Ready(Ok(())) | ||
| 155 | } | ||
| 156 | } | ||
| 157 | |||
| 158 | impl crate::Channel for Channel {} | ||
| 159 | |||
| 160 | pub fn new() -> (ChannelDialer, ChannelListener) { | ||
| 161 | ChannelDialer::new() | ||
| 162 | } | ||
diff --git a/src/client_channel.rs b/src/client_channel.rs new file mode 100644 index 0000000..5666a16 --- /dev/null +++ b/src/client_channel.rs | |||
| @@ -0,0 +1,124 @@ | |||
| 1 | use std::{collections::HashMap, sync::Arc}; | ||
| 2 | |||
| 3 | use bytes::Bytes; | ||
| 4 | use futures::{SinkExt, StreamExt}; | ||
| 5 | use tokio::{ | ||
| 6 | sync::{mpsc, oneshot}, | ||
| 7 | task::AbortHandle, | ||
| 8 | }; | ||
| 9 | |||
| 10 | use crate::{ | ||
| 11 | Channel, | ||
| 12 | protocol::{RpcCall, RpcMessage, RpcResponse}, | ||
| 13 | }; | ||
| 14 | |||
| 15 | const CLIENT_CHANNEL_BUFFER_SIZE: usize = 64; | ||
| 16 | |||
| 17 | struct ClientChannelMessage { | ||
| 18 | service: String, | ||
| 19 | method: String, | ||
| 20 | arguments: Bytes, | ||
| 21 | responder: oneshot::Sender<std::io::Result<Bytes>>, | ||
| 22 | } | ||
| 23 | |||
| 24 | struct ClientChannelInner { | ||
| 25 | sender: mpsc::Sender<ClientChannelMessage>, | ||
| 26 | abort_handle: AbortHandle, | ||
| 27 | } | ||
| 28 | |||
| 29 | impl Drop for ClientChannelInner { | ||
| 30 | fn drop(&mut self) { | ||
| 31 | self.abort_handle.abort(); | ||
| 32 | } | ||
| 33 | } | ||
| 34 | |||
| 35 | #[derive(Clone)] | ||
| 36 | pub struct ClientChannel(Arc<ClientChannelInner>); | ||
| 37 | |||
| 38 | impl ClientChannel { | ||
| 39 | pub fn new<C: Channel>(channel: C) -> Self { | ||
| 40 | let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_BUFFER_SIZE); | ||
| 41 | let abort_handle = tokio::spawn(client_channel_loop(channel, rx)).abort_handle(); | ||
| 42 | Self(Arc::new(ClientChannelInner { | ||
| 43 | sender: tx, | ||
| 44 | abort_handle, | ||
| 45 | })) | ||
| 46 | } | ||
| 47 | |||
| 48 | pub async fn call( | ||
| 49 | &self, | ||
| 50 | service: String, | ||
| 51 | method: String, | ||
| 52 | arguments: Bytes, | ||
| 53 | ) -> std::io::Result<Bytes> { | ||
| 54 | let (tx, rx) = oneshot::channel(); | ||
| 55 | self.0 | ||
| 56 | .sender | ||
| 57 | .send(ClientChannelMessage { | ||
| 58 | service, | ||
| 59 | method, | ||
| 60 | arguments, | ||
| 61 | responder: tx, | ||
| 62 | }) | ||
| 63 | .await | ||
| 64 | .expect("client channel task should never shutdown while a client is alive"); | ||
| 65 | rx.await | ||
| 66 | .expect("client channel task should never shutdown while a client is alive") | ||
| 67 | } | ||
| 68 | } | ||
| 69 | |||
| 70 | async fn client_channel_loop<C: Channel>( | ||
| 71 | mut channel: C, | ||
| 72 | mut rx: mpsc::Receiver<ClientChannelMessage>, | ||
| 73 | ) { | ||
| 74 | enum Select { | ||
| 75 | RpcMessage(RpcMessage), | ||
| 76 | ClientChannelMessage(ClientChannelMessage), | ||
| 77 | } | ||
| 78 | |||
| 79 | let mut responders = HashMap::<u64, oneshot::Sender<std::io::Result<Bytes>>>::default(); | ||
| 80 | let mut rpc_call_id = 0; | ||
| 81 | |||
| 82 | loop { | ||
| 83 | let select = tokio::select! { | ||
| 84 | Some(Ok(v)) = channel.next() => Select::RpcMessage(v), | ||
| 85 | Some(v) = rx.recv() => Select::ClientChannelMessage(v), | ||
| 86 | }; | ||
| 87 | |||
| 88 | match select { | ||
| 89 | Select::RpcMessage(RpcMessage::Response(RpcResponse { id, value })) => { | ||
| 90 | if let Some(responder) = responders.remove(&id) { | ||
| 91 | let _ = responder.send(Ok(value)); | ||
| 92 | } | ||
| 93 | } | ||
| 94 | Select::RpcMessage(_) => todo!(), | ||
| 95 | Select::ClientChannelMessage(ClientChannelMessage { | ||
| 96 | service, | ||
| 97 | method, | ||
| 98 | arguments, | ||
| 99 | responder, | ||
| 100 | }) => { | ||
| 101 | let id = rpc_call_id; | ||
| 102 | rpc_call_id += 1; | ||
| 103 | |||
| 104 | let result = channel | ||
| 105 | .send(RpcMessage::Call(RpcCall { | ||
| 106 | id, | ||
| 107 | service, | ||
| 108 | method, | ||
| 109 | arguments, | ||
| 110 | })) | ||
| 111 | .await; | ||
| 112 | |||
| 113 | match result { | ||
| 114 | Ok(()) => { | ||
| 115 | responders.insert(id, responder); | ||
| 116 | } | ||
| 117 | Err(err) => { | ||
| 118 | let _ = responder.send(Err(err)); | ||
| 119 | } | ||
| 120 | } | ||
| 121 | } | ||
| 122 | } | ||
| 123 | } | ||
| 124 | } | ||
diff --git a/src/internal.rs b/src/internal.rs new file mode 100644 index 0000000..a3c203a --- /dev/null +++ b/src/internal.rs | |||
| @@ -0,0 +1,2 @@ | |||
| 1 | pub use bincode; | ||
| 2 | pub use bytes; | ||
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..e65f659 --- /dev/null +++ b/src/lib.rs | |||
| @@ -0,0 +1,52 @@ | |||
| 1 | #[doc(hidden)] | ||
| 2 | pub mod internal; | ||
| 3 | |||
| 4 | pub mod channel; | ||
| 5 | pub mod protocol; | ||
| 6 | pub mod tcp; | ||
| 7 | pub mod unix; | ||
| 8 | |||
| 9 | mod client_channel; | ||
| 10 | mod server; | ||
| 11 | |||
| 12 | pub use client_channel::ClientChannel; | ||
| 13 | pub use server::Server; | ||
| 14 | pub use urpc_macro::service; | ||
| 15 | |||
| 16 | use protocol::RpcMessage; | ||
| 17 | |||
| 18 | use std::pin::Pin; | ||
| 19 | use std::future::Future; | ||
| 20 | |||
| 21 | use bytes::Bytes; | ||
| 22 | use futures::{Sink, Stream}; | ||
| 23 | |||
| 24 | #[derive(Debug, Default)] | ||
| 25 | pub struct Context; | ||
| 26 | |||
| 27 | pub trait Service: Send + Sync + 'static { | ||
| 28 | fn name() -> &'static str | ||
| 29 | where | ||
| 30 | Self: Sized; | ||
| 31 | |||
| 32 | fn call( | ||
| 33 | &self, | ||
| 34 | method: String, | ||
| 35 | arguments: Bytes, | ||
| 36 | ) -> Pin<Box<dyn Future<Output = std::io::Result<Bytes>> + Send + '_>>; | ||
| 37 | } | ||
| 38 | |||
| 39 | pub trait Channel: | ||
| 40 | Stream<Item = std::io::Result<RpcMessage>> | ||
| 41 | + Sink<RpcMessage, Error = std::io::Error> | ||
| 42 | + Send | ||
| 43 | + Unpin | ||
| 44 | + 'static | ||
| 45 | { | ||
| 46 | } | ||
| 47 | |||
| 48 | pub trait Listener<C>: Stream<Item = std::io::Result<C>> + Send + Unpin + 'static | ||
| 49 | where | ||
| 50 | C: Channel, | ||
| 51 | { | ||
| 52 | } | ||
diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..baf886b --- /dev/null +++ b/src/protocol.rs | |||
| @@ -0,0 +1,107 @@ | |||
| 1 | //! Types used by the RPC protocol. | ||
| 2 | use bytes::Bytes; | ||
| 3 | use serde::{Deserialize, Serialize}; | ||
| 4 | use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; | ||
| 5 | |||
| 6 | #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] | ||
| 7 | pub struct RpcCall { | ||
| 8 | pub id: u64, | ||
| 9 | pub service: String, | ||
| 10 | pub method: String, | ||
| 11 | pub arguments: Bytes, | ||
| 12 | } | ||
| 13 | |||
| 14 | #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] | ||
| 15 | pub struct RpcCancel { | ||
| 16 | pub id: u64, | ||
| 17 | } | ||
| 18 | |||
| 19 | #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] | ||
| 20 | pub struct RpcResponse { | ||
| 21 | pub id: u64, | ||
| 22 | pub value: Bytes, | ||
| 23 | } | ||
| 24 | |||
| 25 | #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] | ||
| 26 | pub enum RpcMessage { | ||
| 27 | Call(RpcCall), | ||
| 28 | Response(RpcResponse), | ||
| 29 | Cancel(RpcCancel), | ||
| 30 | } | ||
| 31 | |||
| 32 | #[derive(Debug)] | ||
| 33 | pub enum RpcError<E> { | ||
| 34 | Transport(std::io::Error), | ||
| 35 | Remote(E), | ||
| 36 | } | ||
| 37 | |||
| 38 | impl<E> std::fmt::Display for RpcError<E> | ||
| 39 | where | ||
| 40 | E: std::fmt::Display, | ||
| 41 | { | ||
| 42 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| 43 | match self { | ||
| 44 | RpcError::Transport(error) => write!(f, "transport error: {error}"), | ||
| 45 | RpcError::Remote(error) => write!(f, "remote error: {error}"), | ||
| 46 | } | ||
| 47 | } | ||
| 48 | } | ||
| 49 | |||
| 50 | impl<E> std::error::Error for RpcError<E> | ||
| 51 | where | ||
| 52 | E: std::error::Error, | ||
| 53 | { | ||
| 54 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { | ||
| 55 | match self { | ||
| 56 | RpcError::Transport(error) => error.source(), | ||
| 57 | RpcError::Remote(error) => error.source(), | ||
| 58 | } | ||
| 59 | } | ||
| 60 | } | ||
| 61 | |||
| 62 | impl<E> From<std::io::Error> for RpcError<E> { | ||
| 63 | fn from(value: std::io::Error) -> Self { | ||
| 64 | Self::Transport(value) | ||
| 65 | } | ||
| 66 | } | ||
| 67 | |||
| 68 | #[derive(Default)] | ||
| 69 | pub struct RpcMessageCodec(LengthDelimitedCodec); | ||
| 70 | |||
| 71 | impl Encoder<RpcMessage> for RpcMessageCodec { | ||
| 72 | type Error = std::io::Error; | ||
| 73 | |||
| 74 | fn encode( | ||
| 75 | &mut self, | ||
| 76 | item: RpcMessage, | ||
| 77 | dst: &mut bytes::BytesMut, | ||
| 78 | ) -> std::result::Result<(), Self::Error> { | ||
| 79 | let encoded = bincode::serde::encode_to_vec(&item, bincode::config::standard()) | ||
| 80 | .map_err(std::io::Error::other)?; | ||
| 81 | let encoded = Bytes::from(encoded); | ||
| 82 | self.0.encode(encoded, dst).map_err(std::io::Error::other)?; | ||
| 83 | Ok(()) | ||
| 84 | } | ||
| 85 | } | ||
| 86 | |||
| 87 | impl Decoder for RpcMessageCodec { | ||
| 88 | type Item = RpcMessage; | ||
| 89 | |||
| 90 | type Error = std::io::Error; | ||
| 91 | |||
| 92 | fn decode( | ||
| 93 | &mut self, | ||
| 94 | src: &mut bytes::BytesMut, | ||
| 95 | ) -> std::result::Result<Option<Self::Item>, Self::Error> { | ||
| 96 | match self.0.decode(src) { | ||
| 97 | Ok(Some(frame)) => { | ||
| 98 | let (message, _) = | ||
| 99 | bincode::serde::decode_from_slice(&frame, bincode::config::standard()) | ||
| 100 | .map_err(std::io::Error::other)?; | ||
| 101 | Ok(Some(message)) | ||
| 102 | } | ||
| 103 | Ok(None) => Ok(None), | ||
| 104 | Err(err) => Err(std::io::Error::other(err)), | ||
| 105 | } | ||
| 106 | } | ||
| 107 | } | ||
diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..414340e --- /dev/null +++ b/src/server.rs | |||
| @@ -0,0 +1,166 @@ | |||
| 1 | use std::{collections::HashMap, sync::Arc}; | ||
| 2 | |||
| 3 | use futures::{SinkExt, StreamExt}; | ||
| 4 | use tokio::task::{AbortHandle, JoinSet}; | ||
| 5 | |||
| 6 | use crate::{ | ||
| 7 | Channel, Listener, Service, | ||
| 8 | protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse}, | ||
| 9 | }; | ||
| 10 | |||
| 11 | #[derive(Clone)] | ||
| 12 | struct Services(Arc<HashMap<String, Arc<dyn Service>>>); | ||
| 13 | |||
| 14 | impl Services { | ||
| 15 | fn new(services: HashMap<String, Arc<dyn Service>>) -> Self { | ||
| 16 | Self(Arc::new(services)) | ||
| 17 | } | ||
| 18 | |||
| 19 | fn get_service(&self, name: &str) -> std::io::Result<Arc<dyn Service>> { | ||
| 20 | match self.0.get(name) { | ||
| 21 | Some(service) => Ok(service.clone()), | ||
| 22 | None => Err(std::io::Error::new( | ||
| 23 | std::io::ErrorKind::NotFound, | ||
| 24 | "service not found", | ||
| 25 | )), | ||
| 26 | } | ||
| 27 | } | ||
| 28 | } | ||
| 29 | |||
| 30 | type ListenerSpawner = | ||
| 31 | Box<dyn FnOnce(&mut JoinSet<std::io::Result<()>>, Services) -> AbortHandle + Send + 'static>; | ||
| 32 | |||
| 33 | #[derive(Debug, Default)] | ||
| 34 | struct AbortHandles(Vec<AbortHandle>); | ||
| 35 | |||
| 36 | impl Drop for AbortHandles { | ||
| 37 | fn drop(&mut self) { | ||
| 38 | for handle in &self.0 { | ||
| 39 | handle.abort(); | ||
| 40 | } | ||
| 41 | } | ||
| 42 | } | ||
| 43 | |||
| 44 | impl AbortHandles { | ||
| 45 | pub fn push(&mut self, handle: AbortHandle) { | ||
| 46 | self.0.push(handle); | ||
| 47 | } | ||
| 48 | } | ||
| 49 | |||
| 50 | #[derive(Default)] | ||
| 51 | pub struct Server { | ||
| 52 | services: HashMap<String, Arc<dyn Service>>, | ||
| 53 | listener_spawners: Vec<ListenerSpawner>, | ||
| 54 | } | ||
| 55 | |||
| 56 | impl Server { | ||
| 57 | pub fn with_service<T>(mut self, service: T) -> Self | ||
| 58 | where | ||
| 59 | T: Service, | ||
| 60 | { | ||
| 61 | let name = T::name(); | ||
| 62 | let service = Arc::new(service); | ||
| 63 | self.services.insert(name.to_string(), service); | ||
| 64 | self | ||
| 65 | } | ||
| 66 | |||
| 67 | pub fn with_listener<L, C>(mut self, listener: L) -> Self | ||
| 68 | where | ||
| 69 | C: Channel, | ||
| 70 | L: Listener<C>, | ||
| 71 | { | ||
| 72 | self.listener_spawners | ||
| 73 | .push(Box::new(move |join_set, services| { | ||
| 74 | join_set.spawn(listener_loop(listener, services)) | ||
| 75 | })); | ||
| 76 | self | ||
| 77 | } | ||
| 78 | |||
| 79 | pub async fn serve(self) -> std::io::Result<()> { | ||
| 80 | let services = Services::new(self.services); | ||
| 81 | let mut join_set = JoinSet::default(); | ||
| 82 | let mut abort_handles = AbortHandles::default(); | ||
| 83 | for spawner in self.listener_spawners { | ||
| 84 | let abort_handle = (spawner)(&mut join_set, services.clone()); | ||
| 85 | abort_handles.push(abort_handle); | ||
| 86 | } | ||
| 87 | match join_set.join_next().await { | ||
| 88 | Some(Ok(Ok(()))) => Ok(()), | ||
| 89 | Some(Ok(Err(err))) => Err(err), | ||
| 90 | Some(Err(err)) => Err(std::io::Error::other(err)), | ||
| 91 | None => Ok(()), | ||
| 92 | } | ||
| 93 | } | ||
| 94 | } | ||
| 95 | |||
| 96 | async fn listener_loop<L, C>(mut listener: L, services: Services) -> std::io::Result<()> | ||
| 97 | where | ||
| 98 | C: Channel, | ||
| 99 | L: Listener<C>, | ||
| 100 | { | ||
| 101 | while let Some(result) = listener.next().await { | ||
| 102 | let channel = result?; | ||
| 103 | let services = services.clone(); | ||
| 104 | tokio::spawn(channel_handler(channel, services)); | ||
| 105 | } | ||
| 106 | Ok(()) | ||
| 107 | } | ||
| 108 | |||
| 109 | async fn channel_handler<C: Channel>(mut channel: C, services: Services) -> std::io::Result<()> { | ||
| 110 | enum Select { | ||
| 111 | Empty, | ||
| 112 | Message(RpcMessage), | ||
| 113 | } | ||
| 114 | |||
| 115 | let (response_tx, mut response_rx) = | ||
| 116 | tokio::sync::mpsc::unbounded_channel::<std::io::Result<RpcResponse>>(); | ||
| 117 | let mut requests: HashMap<u64, AbortHandle> = Default::default(); | ||
| 118 | loop { | ||
| 119 | let select = tokio::select! { | ||
| 120 | reqopt = channel.next() => match reqopt { | ||
| 121 | Some(Ok(message)) => Select::Message(message), | ||
| 122 | Some(Err(err)) => return Err(err), | ||
| 123 | None => Select::Empty, | ||
| 124 | }, | ||
| 125 | Some(response) = response_rx.recv() => match response { | ||
| 126 | Ok(response) => Select::Message(RpcMessage::Response(response)), | ||
| 127 | Err(err) => return Err(err), | ||
| 128 | } | ||
| 129 | }; | ||
| 130 | |||
| 131 | match select { | ||
| 132 | Select::Empty => break, | ||
| 133 | Select::Message(message) => match message { | ||
| 134 | RpcMessage::Call(RpcCall { | ||
| 135 | id, | ||
| 136 | service, | ||
| 137 | method, | ||
| 138 | arguments, | ||
| 139 | }) => { | ||
| 140 | let response_tx = response_tx.clone(); | ||
| 141 | let service = services.get_service(&service)?; | ||
| 142 | let handle = tokio::spawn(async move { | ||
| 143 | let response = match service.call(method, arguments).await { | ||
| 144 | Ok(value) => Ok(RpcResponse { id, value }), | ||
| 145 | Err(err) => Err(err), | ||
| 146 | }; | ||
| 147 | let _ = response_tx.send(response); | ||
| 148 | }) | ||
| 149 | .abort_handle(); | ||
| 150 | requests.insert(id, handle); | ||
| 151 | } | ||
| 152 | RpcMessage::Cancel(RpcCancel { id }) => { | ||
| 153 | if let Some(handle) = requests.remove(&id) { | ||
| 154 | handle.abort(); | ||
| 155 | } | ||
| 156 | } | ||
| 157 | RpcMessage::Response(response) => { | ||
| 158 | requests.remove(&response.id); | ||
| 159 | channel.send(RpcMessage::Response(response)).await?; | ||
| 160 | } | ||
| 161 | }, | ||
| 162 | } | ||
| 163 | } | ||
| 164 | |||
| 165 | Ok(()) | ||
| 166 | } | ||
diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..5684c41 --- /dev/null +++ b/src/tcp.rs | |||
| @@ -0,0 +1,156 @@ | |||
| 1 | //! TCP backed channel and listener implementations. | ||
| 2 | //! | ||
| 3 | //! ```no_run | ||
| 4 | //! #[urpc::service] | ||
| 5 | //! trait Hello { | ||
| 6 | //! type Error = (); | ||
| 7 | //! | ||
| 8 | //! async fn hello(name: String) -> String; | ||
| 9 | //! } | ||
| 10 | //! | ||
| 11 | //! struct HelloServer; | ||
| 12 | //! | ||
| 13 | //! impl Hello for HelloServer { | ||
| 14 | //! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> { | ||
| 15 | //! Ok(format!("Hello, {name}!")) | ||
| 16 | //! } | ||
| 17 | //! } | ||
| 18 | //! | ||
| 19 | //! #[tokio::main] | ||
| 20 | //! async fn main() -> Result<(), Box<dyn std::error::Error>>{ | ||
| 21 | //! let listener = urpc::tcp::bind("0.0.0.0:3000").await?; | ||
| 22 | //! | ||
| 23 | //! // spawn the server | ||
| 24 | //! tokio::spawn(async move { | ||
| 25 | //! urpc::Server::default() | ||
| 26 | //! .with_listener(listener) | ||
| 27 | //! .with_service(HelloServer.into_service()) | ||
| 28 | //! .serve() | ||
| 29 | //! .await | ||
| 30 | //! }); | ||
| 31 | //! | ||
| 32 | //! // create a client | ||
| 33 | //! let channel = urpc::ClientChannel::new(urpc::tcp::connect("127.0.0.1:3000").await?); | ||
| 34 | //! let client = HelloClient::new(channel); | ||
| 35 | //! let greeting = client.hello("World".into()).await.unwrap(); | ||
| 36 | //! assert_eq!(greeting, "Hello, World!"); | ||
| 37 | //! Ok(()) | ||
| 38 | //! } | ||
| 39 | //! ``` | ||
| 40 | use std::pin::Pin; | ||
| 41 | |||
| 42 | use futures::{Sink, Stream}; | ||
| 43 | use tokio::{ | ||
| 44 | net::{TcpListener, TcpStream, ToSocketAddrs}, | ||
| 45 | sync::mpsc::Receiver, | ||
| 46 | task::AbortHandle, | ||
| 47 | }; | ||
| 48 | use tokio_util::codec::Framed; | ||
| 49 | |||
| 50 | use crate::{ | ||
| 51 | Channel, Listener, | ||
| 52 | protocol::{RpcMessage, RpcMessageCodec}, | ||
| 53 | }; | ||
| 54 | |||
| 55 | pub struct TcpChannel(Framed<TcpStream, RpcMessageCodec>); | ||
| 56 | |||
| 57 | impl TcpChannel { | ||
| 58 | fn new(stream: TcpStream) -> Self { | ||
| 59 | Self(Framed::new(stream, RpcMessageCodec::default())) | ||
| 60 | } | ||
| 61 | |||
| 62 | pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result<Self> { | ||
| 63 | let stream = TcpStream::connect(addrs).await?; | ||
| 64 | Ok(Self::new(stream)) | ||
| 65 | } | ||
| 66 | } | ||
| 67 | |||
| 68 | impl Sink<RpcMessage> for TcpChannel { | ||
| 69 | type Error = std::io::Error; | ||
| 70 | |||
| 71 | fn poll_ready( | ||
| 72 | self: Pin<&mut Self>, | ||
| 73 | cx: &mut std::task::Context<'_>, | ||
| 74 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 75 | Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx) | ||
| 76 | } | ||
| 77 | |||
| 78 | fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> { | ||
| 79 | Sink::start_send(Pin::new(&mut self.get_mut().0), item) | ||
| 80 | } | ||
| 81 | |||
| 82 | fn poll_flush( | ||
| 83 | self: Pin<&mut Self>, | ||
| 84 | cx: &mut std::task::Context<'_>, | ||
| 85 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 86 | Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx) | ||
| 87 | } | ||
| 88 | |||
| 89 | fn poll_close( | ||
| 90 | self: Pin<&mut Self>, | ||
| 91 | cx: &mut std::task::Context<'_>, | ||
| 92 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 93 | Sink::poll_close(Pin::new(&mut self.get_mut().0), cx) | ||
| 94 | } | ||
| 95 | } | ||
| 96 | |||
| 97 | impl Stream for TcpChannel { | ||
| 98 | type Item = std::io::Result<RpcMessage>; | ||
| 99 | |||
| 100 | fn poll_next( | ||
| 101 | self: Pin<&mut Self>, | ||
| 102 | cx: &mut std::task::Context<'_>, | ||
| 103 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 104 | Stream::poll_next(Pin::new(&mut self.get_mut().0), cx) | ||
| 105 | } | ||
| 106 | } | ||
| 107 | |||
| 108 | impl Channel for TcpChannel {} | ||
| 109 | |||
| 110 | pub struct TcpChannelListener { | ||
| 111 | receiver: Receiver<TcpChannel>, | ||
| 112 | abort: AbortHandle, | ||
| 113 | } | ||
| 114 | |||
| 115 | impl Drop for TcpChannelListener { | ||
| 116 | fn drop(&mut self) { | ||
| 117 | self.abort.abort(); | ||
| 118 | } | ||
| 119 | } | ||
| 120 | |||
| 121 | impl TcpChannelListener { | ||
| 122 | pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result<Self> { | ||
| 123 | let listener = TcpListener::bind(addrs).await?; | ||
| 124 | let (sender, receiver) = tokio::sync::mpsc::channel(8); | ||
| 125 | let abort = tokio::spawn(async move { | ||
| 126 | while let Ok((stream, _addr)) = listener.accept().await { | ||
| 127 | if sender.send(TcpChannel::new(stream)).await.is_err() { | ||
| 128 | break; | ||
| 129 | } | ||
| 130 | } | ||
| 131 | }) | ||
| 132 | .abort_handle(); | ||
| 133 | Ok(Self { receiver, abort }) | ||
| 134 | } | ||
| 135 | } | ||
| 136 | |||
| 137 | impl Stream for TcpChannelListener { | ||
| 138 | type Item = std::io::Result<TcpChannel>; | ||
| 139 | |||
| 140 | fn poll_next( | ||
| 141 | self: Pin<&mut Self>, | ||
| 142 | cx: &mut std::task::Context<'_>, | ||
| 143 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 144 | self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok)) | ||
| 145 | } | ||
| 146 | } | ||
| 147 | |||
| 148 | impl Listener<TcpChannel> for TcpChannelListener {} | ||
| 149 | |||
| 150 | pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result<TcpChannelListener> { | ||
| 151 | TcpChannelListener::bind(addrs).await | ||
| 152 | } | ||
| 153 | |||
| 154 | pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result<TcpChannel> { | ||
| 155 | TcpChannel::connect(addrs).await | ||
| 156 | } | ||
diff --git a/src/unix.rs b/src/unix.rs new file mode 100644 index 0000000..e241b64 --- /dev/null +++ b/src/unix.rs | |||
| @@ -0,0 +1,160 @@ | |||
| 1 | //! UNIX Domain Socket backed channel and listener implementations. | ||
| 2 | //! | ||
| 3 | //! ```no_run | ||
| 4 | //! #[urpc::service] | ||
| 5 | //! trait Hello { | ||
| 6 | //! type Error = (); | ||
| 7 | //! | ||
| 8 | //! async fn hello(name: String) -> String; | ||
| 9 | //! } | ||
| 10 | //! | ||
| 11 | //! struct HelloServer; | ||
| 12 | //! | ||
| 13 | //! impl Hello for HelloServer { | ||
| 14 | //! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> { | ||
| 15 | //! Ok(format!("Hello, {name}!")) | ||
| 16 | //! } | ||
| 17 | //! } | ||
| 18 | //! | ||
| 19 | //! #[tokio::main] | ||
| 20 | //! async fn main() -> Result<(), Box<dyn std::error::Error>>{ | ||
| 21 | //! let listener = urpc::unix::bind("./hello.service").await?; | ||
| 22 | //! | ||
| 23 | //! // spawn the server | ||
| 24 | //! tokio::spawn(async move { | ||
| 25 | //! urpc::Server::default() | ||
| 26 | //! .with_listener(listener) | ||
| 27 | //! .with_service(HelloServer.into_service()) | ||
| 28 | //! .serve() | ||
| 29 | //! .await | ||
| 30 | //! }); | ||
| 31 | //! | ||
| 32 | //! // create a client | ||
| 33 | //! let channel = urpc::ClientChannel::new(urpc::unix::connect("./hello.service").await?); | ||
| 34 | //! let client = HelloClient::new(channel); | ||
| 35 | //! let greeting = client.hello("World".into()).await.unwrap(); | ||
| 36 | //! assert_eq!(greeting, "Hello, World!"); | ||
| 37 | //! Ok(()) | ||
| 38 | //! } | ||
| 39 | //! ``` | ||
| 40 | use std::{path::Path, pin::Pin}; | ||
| 41 | |||
| 42 | use futures::{Sink, Stream}; | ||
| 43 | use tokio::{ | ||
| 44 | net::{UnixListener, UnixStream}, | ||
| 45 | sync::mpsc::Receiver, | ||
| 46 | task::AbortHandle, | ||
| 47 | }; | ||
| 48 | use tokio_util::codec::Framed; | ||
| 49 | |||
| 50 | use crate::{ | ||
| 51 | Channel, Listener, | ||
| 52 | protocol::{RpcMessage, RpcMessageCodec}, | ||
| 53 | }; | ||
| 54 | |||
| 55 | pub struct UnixChannel(Framed<UnixStream, RpcMessageCodec>); | ||
| 56 | |||
| 57 | impl UnixChannel { | ||
| 58 | fn new(stream: UnixStream) -> Self { | ||
| 59 | Self(Framed::new(stream, RpcMessageCodec::default())) | ||
| 60 | } | ||
| 61 | |||
| 62 | pub async fn connect(path: impl AsRef<Path>) -> std::io::Result<Self> { | ||
| 63 | let stream = UnixStream::connect(path).await?; | ||
| 64 | Ok(Self::new(stream)) | ||
| 65 | } | ||
| 66 | } | ||
| 67 | |||
| 68 | impl Sink<RpcMessage> for UnixChannel { | ||
| 69 | type Error = std::io::Error; | ||
| 70 | |||
| 71 | fn poll_ready( | ||
| 72 | self: Pin<&mut Self>, | ||
| 73 | cx: &mut std::task::Context<'_>, | ||
| 74 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 75 | Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx) | ||
| 76 | } | ||
| 77 | |||
| 78 | fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> { | ||
| 79 | Sink::start_send(Pin::new(&mut self.get_mut().0), item) | ||
| 80 | } | ||
| 81 | |||
| 82 | fn poll_flush( | ||
| 83 | self: Pin<&mut Self>, | ||
| 84 | cx: &mut std::task::Context<'_>, | ||
| 85 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 86 | Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx) | ||
| 87 | } | ||
| 88 | |||
| 89 | fn poll_close( | ||
| 90 | self: Pin<&mut Self>, | ||
| 91 | cx: &mut std::task::Context<'_>, | ||
| 92 | ) -> std::task::Poll<std::result::Result<(), Self::Error>> { | ||
| 93 | Sink::poll_close(Pin::new(&mut self.get_mut().0), cx) | ||
| 94 | } | ||
| 95 | } | ||
| 96 | |||
| 97 | impl Stream for UnixChannel { | ||
| 98 | type Item = std::io::Result<RpcMessage>; | ||
| 99 | |||
| 100 | fn poll_next( | ||
| 101 | self: Pin<&mut Self>, | ||
| 102 | cx: &mut std::task::Context<'_>, | ||
| 103 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 104 | Stream::poll_next(Pin::new(&mut self.get_mut().0), cx) | ||
| 105 | } | ||
| 106 | } | ||
| 107 | |||
| 108 | impl Channel for UnixChannel {} | ||
| 109 | |||
| 110 | pub struct UnixChannelListener { | ||
| 111 | receiver: Receiver<UnixChannel>, | ||
| 112 | abort: AbortHandle, | ||
| 113 | } | ||
| 114 | |||
| 115 | impl Drop for UnixChannelListener { | ||
| 116 | fn drop(&mut self) { | ||
| 117 | self.abort.abort(); | ||
| 118 | } | ||
| 119 | } | ||
| 120 | |||
| 121 | impl UnixChannelListener { | ||
| 122 | pub async fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> { | ||
| 123 | let path = path.as_ref(); | ||
| 124 | if tokio::fs::try_exists(path).await? { | ||
| 125 | tokio::fs::remove_file(path).await?; | ||
| 126 | } | ||
| 127 | let listener = UnixListener::bind(path)?; | ||
| 128 | let (sender, receiver) = tokio::sync::mpsc::channel(8); | ||
| 129 | let abort = tokio::spawn(async move { | ||
| 130 | while let Ok((stream, _addr)) = listener.accept().await { | ||
| 131 | if sender.send(UnixChannel::new(stream)).await.is_err() { | ||
| 132 | break; | ||
| 133 | } | ||
| 134 | } | ||
| 135 | }) | ||
| 136 | .abort_handle(); | ||
| 137 | Ok(Self { receiver, abort }) | ||
| 138 | } | ||
| 139 | } | ||
| 140 | |||
| 141 | impl Stream for UnixChannelListener { | ||
| 142 | type Item = std::io::Result<UnixChannel>; | ||
| 143 | |||
| 144 | fn poll_next( | ||
| 145 | self: Pin<&mut Self>, | ||
| 146 | cx: &mut std::task::Context<'_>, | ||
| 147 | ) -> std::task::Poll<Option<Self::Item>> { | ||
| 148 | self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok)) | ||
| 149 | } | ||
| 150 | } | ||
| 151 | |||
| 152 | impl Listener<UnixChannel> for UnixChannelListener {} | ||
| 153 | |||
| 154 | pub async fn bind(path: impl AsRef<Path>) -> std::io::Result<UnixChannelListener> { | ||
| 155 | UnixChannelListener::bind(path).await | ||
| 156 | } | ||
| 157 | |||
| 158 | pub async fn connect(path: impl AsRef<Path>) -> std::io::Result<UnixChannel> { | ||
| 159 | UnixChannel::connect(path).await | ||
| 160 | } | ||
