From f319d7ab5278a3cfb43d38875d81c28cc2dce1e1 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Wed, 16 Jul 2025 10:46:41 +0100 Subject: Initial commit - extracted urpc from monorepo --- src/channel.rs | 162 ++++++++++++++++++++++++++++++++++++++++++++++++ src/client_channel.rs | 124 +++++++++++++++++++++++++++++++++++++ src/internal.rs | 2 + src/lib.rs | 52 ++++++++++++++++ src/protocol.rs | 107 ++++++++++++++++++++++++++++++++ src/server.rs | 166 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/tcp.rs | 156 +++++++++++++++++++++++++++++++++++++++++++++++ src/unix.rs | 160 ++++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 929 insertions(+) create mode 100644 src/channel.rs create mode 100644 src/client_channel.rs create mode 100644 src/internal.rs create mode 100644 src/lib.rs create mode 100644 src/protocol.rs create mode 100644 src/server.rs create mode 100644 src/tcp.rs create mode 100644 src/unix.rs (limited to 'src') diff --git a/src/channel.rs b/src/channel.rs new file mode 100644 index 0000000..f80ca87 --- /dev/null +++ b/src/channel.rs @@ -0,0 +1,162 @@ +//! MPSC channel backed channel and listener implementations. +//! +//! This module provides an mpsc channel based listener/channel implementation that can be useful +//! to for tests. +//! +//! The current implementation uses [`tokio`]'s unbounded mpsc channels. +//! +//! ``` +//! #[urpc::service] +//! trait Hello { +//! type Error = (); +//! +//! async fn hello(name: String) -> String; +//! } +//! +//! struct HelloServer; +//! +//! impl Hello for HelloServer { +//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result { +//! Ok(format!("Hello, {name}!")) +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box>{ +//! let (dialer, listener) = urpc::channel::new(); +//! +//! // spawn the server +//! tokio::spawn(async move { +//! urpc::Server::default() +//! .with_listener(listener) +//! .with_service(HelloServer.into_service()) +//! .serve() +//! .await +//! }); +//! +//! // create a client +//! let channel = urpc::ClientChannel::new(dialer.connect()?); +//! let client = HelloClient::new(channel); +//! let greeting = client.hello("World".into()).await.unwrap(); +//! assert_eq!(greeting, "Hello, World!"); +//! Ok(()) +//! } +//! ``` + +use futures::{Sink, Stream}; +use tokio::sync::mpsc; + +use crate::protocol::RpcMessage; + +pub struct ChannelListener { + receiver: mpsc::UnboundedReceiver, +} + +impl Stream for ChannelListener { + type Item = std::io::Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut().receiver.poll_recv(cx) { + std::task::Poll::Ready(Some(c)) => std::task::Poll::Ready(Some(Ok(c))), + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +impl crate::Listener for ChannelListener {} + +pub struct ChannelDialer { + sender: mpsc::UnboundedSender, +} + +impl ChannelDialer { + fn new() -> (Self, ChannelListener) { + let (sender, receiver) = mpsc::unbounded_channel(); + (Self { sender }, ChannelListener { receiver }) + } + + pub fn connect(&self) -> std::io::Result { + let (ch1, ch2) = Channel::new(); + self.sender.send(ch1).expect("TODO: remove this"); + Ok(ch2) + } +} + +pub struct Channel { + sender: mpsc::UnboundedSender, + receiver: mpsc::UnboundedReceiver, +} + +impl Channel { + fn new() -> (Self, Self) { + let (sender0, receiver0) = mpsc::unbounded_channel(); + let (sender1, receiver1) = mpsc::unbounded_channel(); + ( + Self { + sender: sender0, + receiver: receiver1, + }, + Self { + sender: sender1, + receiver: receiver0, + }, + ) + } +} + +impl Stream for Channel { + type Item = std::io::Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut().receiver.poll_recv(cx) { + std::task::Poll::Ready(Some(msg)) => std::task::Poll::Ready(Some(Ok(msg))), + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +impl Sink for Channel { + type Error = std::io::Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + return std::task::Poll::Ready(Ok(())); + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: RpcMessage) -> Result<(), Self::Error> { + match self.sender.send(item) { + Ok(()) => Ok(()), + Err(err) => Err(std::io::Error::other(err)), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +impl crate::Channel for Channel {} + +pub fn new() -> (ChannelDialer, ChannelListener) { + ChannelDialer::new() +} diff --git a/src/client_channel.rs b/src/client_channel.rs new file mode 100644 index 0000000..5666a16 --- /dev/null +++ b/src/client_channel.rs @@ -0,0 +1,124 @@ +use std::{collections::HashMap, sync::Arc}; + +use bytes::Bytes; +use futures::{SinkExt, StreamExt}; +use tokio::{ + sync::{mpsc, oneshot}, + task::AbortHandle, +}; + +use crate::{ + Channel, + protocol::{RpcCall, RpcMessage, RpcResponse}, +}; + +const CLIENT_CHANNEL_BUFFER_SIZE: usize = 64; + +struct ClientChannelMessage { + service: String, + method: String, + arguments: Bytes, + responder: oneshot::Sender>, +} + +struct ClientChannelInner { + sender: mpsc::Sender, + abort_handle: AbortHandle, +} + +impl Drop for ClientChannelInner { + fn drop(&mut self) { + self.abort_handle.abort(); + } +} + +#[derive(Clone)] +pub struct ClientChannel(Arc); + +impl ClientChannel { + pub fn new(channel: C) -> Self { + let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_BUFFER_SIZE); + let abort_handle = tokio::spawn(client_channel_loop(channel, rx)).abort_handle(); + Self(Arc::new(ClientChannelInner { + sender: tx, + abort_handle, + })) + } + + pub async fn call( + &self, + service: String, + method: String, + arguments: Bytes, + ) -> std::io::Result { + let (tx, rx) = oneshot::channel(); + self.0 + .sender + .send(ClientChannelMessage { + service, + method, + arguments, + responder: tx, + }) + .await + .expect("client channel task should never shutdown while a client is alive"); + rx.await + .expect("client channel task should never shutdown while a client is alive") + } +} + +async fn client_channel_loop( + mut channel: C, + mut rx: mpsc::Receiver, +) { + enum Select { + RpcMessage(RpcMessage), + ClientChannelMessage(ClientChannelMessage), + } + + let mut responders = HashMap::>>::default(); + let mut rpc_call_id = 0; + + loop { + let select = tokio::select! { + Some(Ok(v)) = channel.next() => Select::RpcMessage(v), + Some(v) = rx.recv() => Select::ClientChannelMessage(v), + }; + + match select { + Select::RpcMessage(RpcMessage::Response(RpcResponse { id, value })) => { + if let Some(responder) = responders.remove(&id) { + let _ = responder.send(Ok(value)); + } + } + Select::RpcMessage(_) => todo!(), + Select::ClientChannelMessage(ClientChannelMessage { + service, + method, + arguments, + responder, + }) => { + let id = rpc_call_id; + rpc_call_id += 1; + + let result = channel + .send(RpcMessage::Call(RpcCall { + id, + service, + method, + arguments, + })) + .await; + + match result { + Ok(()) => { + responders.insert(id, responder); + } + Err(err) => { + let _ = responder.send(Err(err)); + } + } + } + } + } +} diff --git a/src/internal.rs b/src/internal.rs new file mode 100644 index 0000000..a3c203a --- /dev/null +++ b/src/internal.rs @@ -0,0 +1,2 @@ +pub use bincode; +pub use bytes; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..e65f659 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,52 @@ +#[doc(hidden)] +pub mod internal; + +pub mod channel; +pub mod protocol; +pub mod tcp; +pub mod unix; + +mod client_channel; +mod server; + +pub use client_channel::ClientChannel; +pub use server::Server; +pub use urpc_macro::service; + +use protocol::RpcMessage; + +use std::pin::Pin; +use std::future::Future; + +use bytes::Bytes; +use futures::{Sink, Stream}; + +#[derive(Debug, Default)] +pub struct Context; + +pub trait Service: Send + Sync + 'static { + fn name() -> &'static str + where + Self: Sized; + + fn call( + &self, + method: String, + arguments: Bytes, + ) -> Pin> + Send + '_>>; +} + +pub trait Channel: + Stream> + + Sink + + Send + + Unpin + + 'static +{ +} + +pub trait Listener: Stream> + Send + Unpin + 'static +where + C: Channel, +{ +} diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..baf886b --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,107 @@ +//! Types used by the RPC protocol. +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RpcCall { + pub id: u64, + pub service: String, + pub method: String, + pub arguments: Bytes, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RpcCancel { + pub id: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RpcResponse { + pub id: u64, + pub value: Bytes, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RpcMessage { + Call(RpcCall), + Response(RpcResponse), + Cancel(RpcCancel), +} + +#[derive(Debug)] +pub enum RpcError { + Transport(std::io::Error), + Remote(E), +} + +impl std::fmt::Display for RpcError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RpcError::Transport(error) => write!(f, "transport error: {error}"), + RpcError::Remote(error) => write!(f, "remote error: {error}"), + } + } +} + +impl std::error::Error for RpcError +where + E: std::error::Error, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RpcError::Transport(error) => error.source(), + RpcError::Remote(error) => error.source(), + } + } +} + +impl From for RpcError { + fn from(value: std::io::Error) -> Self { + Self::Transport(value) + } +} + +#[derive(Default)] +pub struct RpcMessageCodec(LengthDelimitedCodec); + +impl Encoder for RpcMessageCodec { + type Error = std::io::Error; + + fn encode( + &mut self, + item: RpcMessage, + dst: &mut bytes::BytesMut, + ) -> std::result::Result<(), Self::Error> { + let encoded = bincode::serde::encode_to_vec(&item, bincode::config::standard()) + .map_err(std::io::Error::other)?; + let encoded = Bytes::from(encoded); + self.0.encode(encoded, dst).map_err(std::io::Error::other)?; + Ok(()) + } +} + +impl Decoder for RpcMessageCodec { + type Item = RpcMessage; + + type Error = std::io::Error; + + fn decode( + &mut self, + src: &mut bytes::BytesMut, + ) -> std::result::Result, Self::Error> { + match self.0.decode(src) { + Ok(Some(frame)) => { + let (message, _) = + bincode::serde::decode_from_slice(&frame, bincode::config::standard()) + .map_err(std::io::Error::other)?; + Ok(Some(message)) + } + Ok(None) => Ok(None), + Err(err) => Err(std::io::Error::other(err)), + } + } +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..414340e --- /dev/null +++ b/src/server.rs @@ -0,0 +1,166 @@ +use std::{collections::HashMap, sync::Arc}; + +use futures::{SinkExt, StreamExt}; +use tokio::task::{AbortHandle, JoinSet}; + +use crate::{ + Channel, Listener, Service, + protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse}, +}; + +#[derive(Clone)] +struct Services(Arc>>); + +impl Services { + fn new(services: HashMap>) -> Self { + Self(Arc::new(services)) + } + + fn get_service(&self, name: &str) -> std::io::Result> { + match self.0.get(name) { + Some(service) => Ok(service.clone()), + None => Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "service not found", + )), + } + } +} + +type ListenerSpawner = + Box>, Services) -> AbortHandle + Send + 'static>; + +#[derive(Debug, Default)] +struct AbortHandles(Vec); + +impl Drop for AbortHandles { + fn drop(&mut self) { + for handle in &self.0 { + handle.abort(); + } + } +} + +impl AbortHandles { + pub fn push(&mut self, handle: AbortHandle) { + self.0.push(handle); + } +} + +#[derive(Default)] +pub struct Server { + services: HashMap>, + listener_spawners: Vec, +} + +impl Server { + pub fn with_service(mut self, service: T) -> Self + where + T: Service, + { + let name = T::name(); + let service = Arc::new(service); + self.services.insert(name.to_string(), service); + self + } + + pub fn with_listener(mut self, listener: L) -> Self + where + C: Channel, + L: Listener, + { + self.listener_spawners + .push(Box::new(move |join_set, services| { + join_set.spawn(listener_loop(listener, services)) + })); + self + } + + pub async fn serve(self) -> std::io::Result<()> { + let services = Services::new(self.services); + let mut join_set = JoinSet::default(); + let mut abort_handles = AbortHandles::default(); + for spawner in self.listener_spawners { + let abort_handle = (spawner)(&mut join_set, services.clone()); + abort_handles.push(abort_handle); + } + match join_set.join_next().await { + Some(Ok(Ok(()))) => Ok(()), + Some(Ok(Err(err))) => Err(err), + Some(Err(err)) => Err(std::io::Error::other(err)), + None => Ok(()), + } + } +} + +async fn listener_loop(mut listener: L, services: Services) -> std::io::Result<()> +where + C: Channel, + L: Listener, +{ + while let Some(result) = listener.next().await { + let channel = result?; + let services = services.clone(); + tokio::spawn(channel_handler(channel, services)); + } + Ok(()) +} + +async fn channel_handler(mut channel: C, services: Services) -> std::io::Result<()> { + enum Select { + Empty, + Message(RpcMessage), + } + + let (response_tx, mut response_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + let mut requests: HashMap = Default::default(); + loop { + let select = tokio::select! { + reqopt = channel.next() => match reqopt { + Some(Ok(message)) => Select::Message(message), + Some(Err(err)) => return Err(err), + None => Select::Empty, + }, + Some(response) = response_rx.recv() => match response { + Ok(response) => Select::Message(RpcMessage::Response(response)), + Err(err) => return Err(err), + } + }; + + match select { + Select::Empty => break, + Select::Message(message) => match message { + RpcMessage::Call(RpcCall { + id, + service, + method, + arguments, + }) => { + let response_tx = response_tx.clone(); + let service = services.get_service(&service)?; + let handle = tokio::spawn(async move { + let response = match service.call(method, arguments).await { + Ok(value) => Ok(RpcResponse { id, value }), + Err(err) => Err(err), + }; + let _ = response_tx.send(response); + }) + .abort_handle(); + requests.insert(id, handle); + } + RpcMessage::Cancel(RpcCancel { id }) => { + if let Some(handle) = requests.remove(&id) { + handle.abort(); + } + } + RpcMessage::Response(response) => { + requests.remove(&response.id); + channel.send(RpcMessage::Response(response)).await?; + } + }, + } + } + + Ok(()) +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..5684c41 --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,156 @@ +//! TCP backed channel and listener implementations. +//! +//! ```no_run +//! #[urpc::service] +//! trait Hello { +//! type Error = (); +//! +//! async fn hello(name: String) -> String; +//! } +//! +//! struct HelloServer; +//! +//! impl Hello for HelloServer { +//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result { +//! Ok(format!("Hello, {name}!")) +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box>{ +//! let listener = urpc::tcp::bind("0.0.0.0:3000").await?; +//! +//! // spawn the server +//! tokio::spawn(async move { +//! urpc::Server::default() +//! .with_listener(listener) +//! .with_service(HelloServer.into_service()) +//! .serve() +//! .await +//! }); +//! +//! // create a client +//! let channel = urpc::ClientChannel::new(urpc::tcp::connect("127.0.0.1:3000").await?); +//! let client = HelloClient::new(channel); +//! let greeting = client.hello("World".into()).await.unwrap(); +//! assert_eq!(greeting, "Hello, World!"); +//! Ok(()) +//! } +//! ``` +use std::pin::Pin; + +use futures::{Sink, Stream}; +use tokio::{ + net::{TcpListener, TcpStream, ToSocketAddrs}, + sync::mpsc::Receiver, + task::AbortHandle, +}; +use tokio_util::codec::Framed; + +use crate::{ + Channel, Listener, + protocol::{RpcMessage, RpcMessageCodec}, +}; + +pub struct TcpChannel(Framed); + +impl TcpChannel { + fn new(stream: TcpStream) -> Self { + Self(Framed::new(stream, RpcMessageCodec::default())) + } + + pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result { + let stream = TcpStream::connect(addrs).await?; + Ok(Self::new(stream)) + } +} + +impl Sink for TcpChannel { + type Error = std::io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx) + } + + fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> { + Sink::start_send(Pin::new(&mut self.get_mut().0), item) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_close(Pin::new(&mut self.get_mut().0), cx) + } +} + +impl Stream for TcpChannel { + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Stream::poll_next(Pin::new(&mut self.get_mut().0), cx) + } +} + +impl Channel for TcpChannel {} + +pub struct TcpChannelListener { + receiver: Receiver, + abort: AbortHandle, +} + +impl Drop for TcpChannelListener { + fn drop(&mut self) { + self.abort.abort(); + } +} + +impl TcpChannelListener { + pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result { + let listener = TcpListener::bind(addrs).await?; + let (sender, receiver) = tokio::sync::mpsc::channel(8); + let abort = tokio::spawn(async move { + while let Ok((stream, _addr)) = listener.accept().await { + if sender.send(TcpChannel::new(stream)).await.is_err() { + break; + } + } + }) + .abort_handle(); + Ok(Self { receiver, abort }) + } +} + +impl Stream for TcpChannelListener { + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok)) + } +} + +impl Listener for TcpChannelListener {} + +pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result { + TcpChannelListener::bind(addrs).await +} + +pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result { + TcpChannel::connect(addrs).await +} diff --git a/src/unix.rs b/src/unix.rs new file mode 100644 index 0000000..e241b64 --- /dev/null +++ b/src/unix.rs @@ -0,0 +1,160 @@ +//! UNIX Domain Socket backed channel and listener implementations. +//! +//! ```no_run +//! #[urpc::service] +//! trait Hello { +//! type Error = (); +//! +//! async fn hello(name: String) -> String; +//! } +//! +//! struct HelloServer; +//! +//! impl Hello for HelloServer { +//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result { +//! Ok(format!("Hello, {name}!")) +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box>{ +//! let listener = urpc::unix::bind("./hello.service").await?; +//! +//! // spawn the server +//! tokio::spawn(async move { +//! urpc::Server::default() +//! .with_listener(listener) +//! .with_service(HelloServer.into_service()) +//! .serve() +//! .await +//! }); +//! +//! // create a client +//! let channel = urpc::ClientChannel::new(urpc::unix::connect("./hello.service").await?); +//! let client = HelloClient::new(channel); +//! let greeting = client.hello("World".into()).await.unwrap(); +//! assert_eq!(greeting, "Hello, World!"); +//! Ok(()) +//! } +//! ``` +use std::{path::Path, pin::Pin}; + +use futures::{Sink, Stream}; +use tokio::{ + net::{UnixListener, UnixStream}, + sync::mpsc::Receiver, + task::AbortHandle, +}; +use tokio_util::codec::Framed; + +use crate::{ + Channel, Listener, + protocol::{RpcMessage, RpcMessageCodec}, +}; + +pub struct UnixChannel(Framed); + +impl UnixChannel { + fn new(stream: UnixStream) -> Self { + Self(Framed::new(stream, RpcMessageCodec::default())) + } + + pub async fn connect(path: impl AsRef) -> std::io::Result { + let stream = UnixStream::connect(path).await?; + Ok(Self::new(stream)) + } +} + +impl Sink for UnixChannel { + type Error = std::io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx) + } + + fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> { + Sink::start_send(Pin::new(&mut self.get_mut().0), item) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Sink::poll_close(Pin::new(&mut self.get_mut().0), cx) + } +} + +impl Stream for UnixChannel { + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Stream::poll_next(Pin::new(&mut self.get_mut().0), cx) + } +} + +impl Channel for UnixChannel {} + +pub struct UnixChannelListener { + receiver: Receiver, + abort: AbortHandle, +} + +impl Drop for UnixChannelListener { + fn drop(&mut self) { + self.abort.abort(); + } +} + +impl UnixChannelListener { + pub async fn bind(path: impl AsRef) -> std::io::Result { + let path = path.as_ref(); + if tokio::fs::try_exists(path).await? { + tokio::fs::remove_file(path).await?; + } + let listener = UnixListener::bind(path)?; + let (sender, receiver) = tokio::sync::mpsc::channel(8); + let abort = tokio::spawn(async move { + while let Ok((stream, _addr)) = listener.accept().await { + if sender.send(UnixChannel::new(stream)).await.is_err() { + break; + } + } + }) + .abort_handle(); + Ok(Self { receiver, abort }) + } +} + +impl Stream for UnixChannelListener { + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok)) + } +} + +impl Listener for UnixChannelListener {} + +pub async fn bind(path: impl AsRef) -> std::io::Result { + UnixChannelListener::bind(path).await +} + +pub async fn connect(path: impl AsRef) -> std::io::Result { + UnixChannel::connect(path).await +} -- cgit