aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/channel.rs162
-rw-r--r--src/client_channel.rs124
-rw-r--r--src/internal.rs2
-rw-r--r--src/lib.rs52
-rw-r--r--src/protocol.rs107
-rw-r--r--src/server.rs166
-rw-r--r--src/tcp.rs156
-rw-r--r--src/unix.rs160
8 files changed, 929 insertions, 0 deletions
diff --git a/src/channel.rs b/src/channel.rs
new file mode 100644
index 0000000..f80ca87
--- /dev/null
+++ b/src/channel.rs
@@ -0,0 +1,162 @@
1//! MPSC channel backed channel and listener implementations.
2//!
3//! This module provides an mpsc channel based listener/channel implementation that can be useful
4//! to for tests.
5//!
6//! The current implementation uses [`tokio`]'s unbounded mpsc channels.
7//!
8//! ```
9//! #[urpc::service]
10//! trait Hello {
11//! type Error = ();
12//!
13//! async fn hello(name: String) -> String;
14//! }
15//!
16//! struct HelloServer;
17//!
18//! impl Hello for HelloServer {
19//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> {
20//! Ok(format!("Hello, {name}!"))
21//! }
22//! }
23//!
24//! #[tokio::main]
25//! async fn main() -> Result<(), Box<dyn std::error::Error>>{
26//! let (dialer, listener) = urpc::channel::new();
27//!
28//! // spawn the server
29//! tokio::spawn(async move {
30//! urpc::Server::default()
31//! .with_listener(listener)
32//! .with_service(HelloServer.into_service())
33//! .serve()
34//! .await
35//! });
36//!
37//! // create a client
38//! let channel = urpc::ClientChannel::new(dialer.connect()?);
39//! let client = HelloClient::new(channel);
40//! let greeting = client.hello("World".into()).await.unwrap();
41//! assert_eq!(greeting, "Hello, World!");
42//! Ok(())
43//! }
44//! ```
45
46use futures::{Sink, Stream};
47use tokio::sync::mpsc;
48
49use crate::protocol::RpcMessage;
50
51pub struct ChannelListener {
52 receiver: mpsc::UnboundedReceiver<Channel>,
53}
54
55impl Stream for ChannelListener {
56 type Item = std::io::Result<Channel>;
57
58 fn poll_next(
59 self: std::pin::Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 ) -> std::task::Poll<Option<Self::Item>> {
62 match self.get_mut().receiver.poll_recv(cx) {
63 std::task::Poll::Ready(Some(c)) => std::task::Poll::Ready(Some(Ok(c))),
64 std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
65 std::task::Poll::Pending => std::task::Poll::Pending,
66 }
67 }
68}
69
70impl crate::Listener<Channel> for ChannelListener {}
71
72pub struct ChannelDialer {
73 sender: mpsc::UnboundedSender<Channel>,
74}
75
76impl ChannelDialer {
77 fn new() -> (Self, ChannelListener) {
78 let (sender, receiver) = mpsc::unbounded_channel();
79 (Self { sender }, ChannelListener { receiver })
80 }
81
82 pub fn connect(&self) -> std::io::Result<Channel> {
83 let (ch1, ch2) = Channel::new();
84 self.sender.send(ch1).expect("TODO: remove this");
85 Ok(ch2)
86 }
87}
88
89pub struct Channel {
90 sender: mpsc::UnboundedSender<RpcMessage>,
91 receiver: mpsc::UnboundedReceiver<RpcMessage>,
92}
93
94impl Channel {
95 fn new() -> (Self, Self) {
96 let (sender0, receiver0) = mpsc::unbounded_channel();
97 let (sender1, receiver1) = mpsc::unbounded_channel();
98 (
99 Self {
100 sender: sender0,
101 receiver: receiver1,
102 },
103 Self {
104 sender: sender1,
105 receiver: receiver0,
106 },
107 )
108 }
109}
110
111impl Stream for Channel {
112 type Item = std::io::Result<RpcMessage>;
113
114 fn poll_next(
115 self: std::pin::Pin<&mut Self>,
116 cx: &mut std::task::Context<'_>,
117 ) -> std::task::Poll<Option<Self::Item>> {
118 match self.get_mut().receiver.poll_recv(cx) {
119 std::task::Poll::Ready(Some(msg)) => std::task::Poll::Ready(Some(Ok(msg))),
120 std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
121 std::task::Poll::Pending => std::task::Poll::Pending,
122 }
123 }
124}
125
126impl Sink<RpcMessage> for Channel {
127 type Error = std::io::Error;
128
129 fn poll_ready(
130 self: std::pin::Pin<&mut Self>,
131 _cx: &mut std::task::Context<'_>,
132 ) -> std::task::Poll<Result<(), Self::Error>> {
133 return std::task::Poll::Ready(Ok(()));
134 }
135
136 fn start_send(self: std::pin::Pin<&mut Self>, item: RpcMessage) -> Result<(), Self::Error> {
137 match self.sender.send(item) {
138 Ok(()) => Ok(()),
139 Err(err) => Err(std::io::Error::other(err)),
140 }
141 }
142
143 fn poll_flush(
144 self: std::pin::Pin<&mut Self>,
145 _cx: &mut std::task::Context<'_>,
146 ) -> std::task::Poll<Result<(), Self::Error>> {
147 std::task::Poll::Ready(Ok(()))
148 }
149
150 fn poll_close(
151 self: std::pin::Pin<&mut Self>,
152 _cx: &mut std::task::Context<'_>,
153 ) -> std::task::Poll<Result<(), Self::Error>> {
154 std::task::Poll::Ready(Ok(()))
155 }
156}
157
158impl crate::Channel for Channel {}
159
160pub fn new() -> (ChannelDialer, ChannelListener) {
161 ChannelDialer::new()
162}
diff --git a/src/client_channel.rs b/src/client_channel.rs
new file mode 100644
index 0000000..5666a16
--- /dev/null
+++ b/src/client_channel.rs
@@ -0,0 +1,124 @@
1use std::{collections::HashMap, sync::Arc};
2
3use bytes::Bytes;
4use futures::{SinkExt, StreamExt};
5use tokio::{
6 sync::{mpsc, oneshot},
7 task::AbortHandle,
8};
9
10use crate::{
11 Channel,
12 protocol::{RpcCall, RpcMessage, RpcResponse},
13};
14
15const CLIENT_CHANNEL_BUFFER_SIZE: usize = 64;
16
17struct ClientChannelMessage {
18 service: String,
19 method: String,
20 arguments: Bytes,
21 responder: oneshot::Sender<std::io::Result<Bytes>>,
22}
23
24struct ClientChannelInner {
25 sender: mpsc::Sender<ClientChannelMessage>,
26 abort_handle: AbortHandle,
27}
28
29impl Drop for ClientChannelInner {
30 fn drop(&mut self) {
31 self.abort_handle.abort();
32 }
33}
34
35#[derive(Clone)]
36pub struct ClientChannel(Arc<ClientChannelInner>);
37
38impl ClientChannel {
39 pub fn new<C: Channel>(channel: C) -> Self {
40 let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_BUFFER_SIZE);
41 let abort_handle = tokio::spawn(client_channel_loop(channel, rx)).abort_handle();
42 Self(Arc::new(ClientChannelInner {
43 sender: tx,
44 abort_handle,
45 }))
46 }
47
48 pub async fn call(
49 &self,
50 service: String,
51 method: String,
52 arguments: Bytes,
53 ) -> std::io::Result<Bytes> {
54 let (tx, rx) = oneshot::channel();
55 self.0
56 .sender
57 .send(ClientChannelMessage {
58 service,
59 method,
60 arguments,
61 responder: tx,
62 })
63 .await
64 .expect("client channel task should never shutdown while a client is alive");
65 rx.await
66 .expect("client channel task should never shutdown while a client is alive")
67 }
68}
69
70async fn client_channel_loop<C: Channel>(
71 mut channel: C,
72 mut rx: mpsc::Receiver<ClientChannelMessage>,
73) {
74 enum Select {
75 RpcMessage(RpcMessage),
76 ClientChannelMessage(ClientChannelMessage),
77 }
78
79 let mut responders = HashMap::<u64, oneshot::Sender<std::io::Result<Bytes>>>::default();
80 let mut rpc_call_id = 0;
81
82 loop {
83 let select = tokio::select! {
84 Some(Ok(v)) = channel.next() => Select::RpcMessage(v),
85 Some(v) = rx.recv() => Select::ClientChannelMessage(v),
86 };
87
88 match select {
89 Select::RpcMessage(RpcMessage::Response(RpcResponse { id, value })) => {
90 if let Some(responder) = responders.remove(&id) {
91 let _ = responder.send(Ok(value));
92 }
93 }
94 Select::RpcMessage(_) => todo!(),
95 Select::ClientChannelMessage(ClientChannelMessage {
96 service,
97 method,
98 arguments,
99 responder,
100 }) => {
101 let id = rpc_call_id;
102 rpc_call_id += 1;
103
104 let result = channel
105 .send(RpcMessage::Call(RpcCall {
106 id,
107 service,
108 method,
109 arguments,
110 }))
111 .await;
112
113 match result {
114 Ok(()) => {
115 responders.insert(id, responder);
116 }
117 Err(err) => {
118 let _ = responder.send(Err(err));
119 }
120 }
121 }
122 }
123 }
124}
diff --git a/src/internal.rs b/src/internal.rs
new file mode 100644
index 0000000..a3c203a
--- /dev/null
+++ b/src/internal.rs
@@ -0,0 +1,2 @@
1pub use bincode;
2pub use bytes;
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..e65f659
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,52 @@
1#[doc(hidden)]
2pub mod internal;
3
4pub mod channel;
5pub mod protocol;
6pub mod tcp;
7pub mod unix;
8
9mod client_channel;
10mod server;
11
12pub use client_channel::ClientChannel;
13pub use server::Server;
14pub use urpc_macro::service;
15
16use protocol::RpcMessage;
17
18use std::pin::Pin;
19use std::future::Future;
20
21use bytes::Bytes;
22use futures::{Sink, Stream};
23
24#[derive(Debug, Default)]
25pub struct Context;
26
27pub trait Service: Send + Sync + 'static {
28 fn name() -> &'static str
29 where
30 Self: Sized;
31
32 fn call(
33 &self,
34 method: String,
35 arguments: Bytes,
36 ) -> Pin<Box<dyn Future<Output = std::io::Result<Bytes>> + Send + '_>>;
37}
38
39pub trait Channel:
40 Stream<Item = std::io::Result<RpcMessage>>
41 + Sink<RpcMessage, Error = std::io::Error>
42 + Send
43 + Unpin
44 + 'static
45{
46}
47
48pub trait Listener<C>: Stream<Item = std::io::Result<C>> + Send + Unpin + 'static
49where
50 C: Channel,
51{
52}
diff --git a/src/protocol.rs b/src/protocol.rs
new file mode 100644
index 0000000..baf886b
--- /dev/null
+++ b/src/protocol.rs
@@ -0,0 +1,107 @@
1//! Types used by the RPC protocol.
2use bytes::Bytes;
3use serde::{Deserialize, Serialize};
4use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct RpcCall {
8 pub id: u64,
9 pub service: String,
10 pub method: String,
11 pub arguments: Bytes,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct RpcCancel {
16 pub id: u64,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20pub struct RpcResponse {
21 pub id: u64,
22 pub value: Bytes,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub enum RpcMessage {
27 Call(RpcCall),
28 Response(RpcResponse),
29 Cancel(RpcCancel),
30}
31
32#[derive(Debug)]
33pub enum RpcError<E> {
34 Transport(std::io::Error),
35 Remote(E),
36}
37
38impl<E> std::fmt::Display for RpcError<E>
39where
40 E: std::fmt::Display,
41{
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 RpcError::Transport(error) => write!(f, "transport error: {error}"),
45 RpcError::Remote(error) => write!(f, "remote error: {error}"),
46 }
47 }
48}
49
50impl<E> std::error::Error for RpcError<E>
51where
52 E: std::error::Error,
53{
54 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
55 match self {
56 RpcError::Transport(error) => error.source(),
57 RpcError::Remote(error) => error.source(),
58 }
59 }
60}
61
62impl<E> From<std::io::Error> for RpcError<E> {
63 fn from(value: std::io::Error) -> Self {
64 Self::Transport(value)
65 }
66}
67
68#[derive(Default)]
69pub struct RpcMessageCodec(LengthDelimitedCodec);
70
71impl Encoder<RpcMessage> for RpcMessageCodec {
72 type Error = std::io::Error;
73
74 fn encode(
75 &mut self,
76 item: RpcMessage,
77 dst: &mut bytes::BytesMut,
78 ) -> std::result::Result<(), Self::Error> {
79 let encoded = bincode::serde::encode_to_vec(&item, bincode::config::standard())
80 .map_err(std::io::Error::other)?;
81 let encoded = Bytes::from(encoded);
82 self.0.encode(encoded, dst).map_err(std::io::Error::other)?;
83 Ok(())
84 }
85}
86
87impl Decoder for RpcMessageCodec {
88 type Item = RpcMessage;
89
90 type Error = std::io::Error;
91
92 fn decode(
93 &mut self,
94 src: &mut bytes::BytesMut,
95 ) -> std::result::Result<Option<Self::Item>, Self::Error> {
96 match self.0.decode(src) {
97 Ok(Some(frame)) => {
98 let (message, _) =
99 bincode::serde::decode_from_slice(&frame, bincode::config::standard())
100 .map_err(std::io::Error::other)?;
101 Ok(Some(message))
102 }
103 Ok(None) => Ok(None),
104 Err(err) => Err(std::io::Error::other(err)),
105 }
106 }
107}
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..414340e
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,166 @@
1use std::{collections::HashMap, sync::Arc};
2
3use futures::{SinkExt, StreamExt};
4use tokio::task::{AbortHandle, JoinSet};
5
6use crate::{
7 Channel, Listener, Service,
8 protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse},
9};
10
11#[derive(Clone)]
12struct Services(Arc<HashMap<String, Arc<dyn Service>>>);
13
14impl Services {
15 fn new(services: HashMap<String, Arc<dyn Service>>) -> Self {
16 Self(Arc::new(services))
17 }
18
19 fn get_service(&self, name: &str) -> std::io::Result<Arc<dyn Service>> {
20 match self.0.get(name) {
21 Some(service) => Ok(service.clone()),
22 None => Err(std::io::Error::new(
23 std::io::ErrorKind::NotFound,
24 "service not found",
25 )),
26 }
27 }
28}
29
30type ListenerSpawner =
31 Box<dyn FnOnce(&mut JoinSet<std::io::Result<()>>, Services) -> AbortHandle + Send + 'static>;
32
33#[derive(Debug, Default)]
34struct AbortHandles(Vec<AbortHandle>);
35
36impl Drop for AbortHandles {
37 fn drop(&mut self) {
38 for handle in &self.0 {
39 handle.abort();
40 }
41 }
42}
43
44impl AbortHandles {
45 pub fn push(&mut self, handle: AbortHandle) {
46 self.0.push(handle);
47 }
48}
49
50#[derive(Default)]
51pub struct Server {
52 services: HashMap<String, Arc<dyn Service>>,
53 listener_spawners: Vec<ListenerSpawner>,
54}
55
56impl Server {
57 pub fn with_service<T>(mut self, service: T) -> Self
58 where
59 T: Service,
60 {
61 let name = T::name();
62 let service = Arc::new(service);
63 self.services.insert(name.to_string(), service);
64 self
65 }
66
67 pub fn with_listener<L, C>(mut self, listener: L) -> Self
68 where
69 C: Channel,
70 L: Listener<C>,
71 {
72 self.listener_spawners
73 .push(Box::new(move |join_set, services| {
74 join_set.spawn(listener_loop(listener, services))
75 }));
76 self
77 }
78
79 pub async fn serve(self) -> std::io::Result<()> {
80 let services = Services::new(self.services);
81 let mut join_set = JoinSet::default();
82 let mut abort_handles = AbortHandles::default();
83 for spawner in self.listener_spawners {
84 let abort_handle = (spawner)(&mut join_set, services.clone());
85 abort_handles.push(abort_handle);
86 }
87 match join_set.join_next().await {
88 Some(Ok(Ok(()))) => Ok(()),
89 Some(Ok(Err(err))) => Err(err),
90 Some(Err(err)) => Err(std::io::Error::other(err)),
91 None => Ok(()),
92 }
93 }
94}
95
96async fn listener_loop<L, C>(mut listener: L, services: Services) -> std::io::Result<()>
97where
98 C: Channel,
99 L: Listener<C>,
100{
101 while let Some(result) = listener.next().await {
102 let channel = result?;
103 let services = services.clone();
104 tokio::spawn(channel_handler(channel, services));
105 }
106 Ok(())
107}
108
109async fn channel_handler<C: Channel>(mut channel: C, services: Services) -> std::io::Result<()> {
110 enum Select {
111 Empty,
112 Message(RpcMessage),
113 }
114
115 let (response_tx, mut response_rx) =
116 tokio::sync::mpsc::unbounded_channel::<std::io::Result<RpcResponse>>();
117 let mut requests: HashMap<u64, AbortHandle> = Default::default();
118 loop {
119 let select = tokio::select! {
120 reqopt = channel.next() => match reqopt {
121 Some(Ok(message)) => Select::Message(message),
122 Some(Err(err)) => return Err(err),
123 None => Select::Empty,
124 },
125 Some(response) = response_rx.recv() => match response {
126 Ok(response) => Select::Message(RpcMessage::Response(response)),
127 Err(err) => return Err(err),
128 }
129 };
130
131 match select {
132 Select::Empty => break,
133 Select::Message(message) => match message {
134 RpcMessage::Call(RpcCall {
135 id,
136 service,
137 method,
138 arguments,
139 }) => {
140 let response_tx = response_tx.clone();
141 let service = services.get_service(&service)?;
142 let handle = tokio::spawn(async move {
143 let response = match service.call(method, arguments).await {
144 Ok(value) => Ok(RpcResponse { id, value }),
145 Err(err) => Err(err),
146 };
147 let _ = response_tx.send(response);
148 })
149 .abort_handle();
150 requests.insert(id, handle);
151 }
152 RpcMessage::Cancel(RpcCancel { id }) => {
153 if let Some(handle) = requests.remove(&id) {
154 handle.abort();
155 }
156 }
157 RpcMessage::Response(response) => {
158 requests.remove(&response.id);
159 channel.send(RpcMessage::Response(response)).await?;
160 }
161 },
162 }
163 }
164
165 Ok(())
166}
diff --git a/src/tcp.rs b/src/tcp.rs
new file mode 100644
index 0000000..5684c41
--- /dev/null
+++ b/src/tcp.rs
@@ -0,0 +1,156 @@
1//! TCP backed channel and listener implementations.
2//!
3//! ```no_run
4//! #[urpc::service]
5//! trait Hello {
6//! type Error = ();
7//!
8//! async fn hello(name: String) -> String;
9//! }
10//!
11//! struct HelloServer;
12//!
13//! impl Hello for HelloServer {
14//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> {
15//! Ok(format!("Hello, {name}!"))
16//! }
17//! }
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>>{
21//! let listener = urpc::tcp::bind("0.0.0.0:3000").await?;
22//!
23//! // spawn the server
24//! tokio::spawn(async move {
25//! urpc::Server::default()
26//! .with_listener(listener)
27//! .with_service(HelloServer.into_service())
28//! .serve()
29//! .await
30//! });
31//!
32//! // create a client
33//! let channel = urpc::ClientChannel::new(urpc::tcp::connect("127.0.0.1:3000").await?);
34//! let client = HelloClient::new(channel);
35//! let greeting = client.hello("World".into()).await.unwrap();
36//! assert_eq!(greeting, "Hello, World!");
37//! Ok(())
38//! }
39//! ```
40use std::pin::Pin;
41
42use futures::{Sink, Stream};
43use tokio::{
44 net::{TcpListener, TcpStream, ToSocketAddrs},
45 sync::mpsc::Receiver,
46 task::AbortHandle,
47};
48use tokio_util::codec::Framed;
49
50use crate::{
51 Channel, Listener,
52 protocol::{RpcMessage, RpcMessageCodec},
53};
54
55pub struct TcpChannel(Framed<TcpStream, RpcMessageCodec>);
56
57impl TcpChannel {
58 fn new(stream: TcpStream) -> Self {
59 Self(Framed::new(stream, RpcMessageCodec::default()))
60 }
61
62 pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result<Self> {
63 let stream = TcpStream::connect(addrs).await?;
64 Ok(Self::new(stream))
65 }
66}
67
68impl Sink<RpcMessage> for TcpChannel {
69 type Error = std::io::Error;
70
71 fn poll_ready(
72 self: Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
75 Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx)
76 }
77
78 fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> {
79 Sink::start_send(Pin::new(&mut self.get_mut().0), item)
80 }
81
82 fn poll_flush(
83 self: Pin<&mut Self>,
84 cx: &mut std::task::Context<'_>,
85 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
86 Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx)
87 }
88
89 fn poll_close(
90 self: Pin<&mut Self>,
91 cx: &mut std::task::Context<'_>,
92 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
93 Sink::poll_close(Pin::new(&mut self.get_mut().0), cx)
94 }
95}
96
97impl Stream for TcpChannel {
98 type Item = std::io::Result<RpcMessage>;
99
100 fn poll_next(
101 self: Pin<&mut Self>,
102 cx: &mut std::task::Context<'_>,
103 ) -> std::task::Poll<Option<Self::Item>> {
104 Stream::poll_next(Pin::new(&mut self.get_mut().0), cx)
105 }
106}
107
108impl Channel for TcpChannel {}
109
110pub struct TcpChannelListener {
111 receiver: Receiver<TcpChannel>,
112 abort: AbortHandle,
113}
114
115impl Drop for TcpChannelListener {
116 fn drop(&mut self) {
117 self.abort.abort();
118 }
119}
120
121impl TcpChannelListener {
122 pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result<Self> {
123 let listener = TcpListener::bind(addrs).await?;
124 let (sender, receiver) = tokio::sync::mpsc::channel(8);
125 let abort = tokio::spawn(async move {
126 while let Ok((stream, _addr)) = listener.accept().await {
127 if sender.send(TcpChannel::new(stream)).await.is_err() {
128 break;
129 }
130 }
131 })
132 .abort_handle();
133 Ok(Self { receiver, abort })
134 }
135}
136
137impl Stream for TcpChannelListener {
138 type Item = std::io::Result<TcpChannel>;
139
140 fn poll_next(
141 self: Pin<&mut Self>,
142 cx: &mut std::task::Context<'_>,
143 ) -> std::task::Poll<Option<Self::Item>> {
144 self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok))
145 }
146}
147
148impl Listener<TcpChannel> for TcpChannelListener {}
149
150pub async fn bind(addrs: impl ToSocketAddrs) -> std::io::Result<TcpChannelListener> {
151 TcpChannelListener::bind(addrs).await
152}
153
154pub async fn connect(addrs: impl ToSocketAddrs) -> std::io::Result<TcpChannel> {
155 TcpChannel::connect(addrs).await
156}
diff --git a/src/unix.rs b/src/unix.rs
new file mode 100644
index 0000000..e241b64
--- /dev/null
+++ b/src/unix.rs
@@ -0,0 +1,160 @@
1//! UNIX Domain Socket backed channel and listener implementations.
2//!
3//! ```no_run
4//! #[urpc::service]
5//! trait Hello {
6//! type Error = ();
7//!
8//! async fn hello(name: String) -> String;
9//! }
10//!
11//! struct HelloServer;
12//!
13//! impl Hello for HelloServer {
14//! async fn hello(&self, _ctx: urpc::Context, name: String) -> Result<String, ()> {
15//! Ok(format!("Hello, {name}!"))
16//! }
17//! }
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>>{
21//! let listener = urpc::unix::bind("./hello.service").await?;
22//!
23//! // spawn the server
24//! tokio::spawn(async move {
25//! urpc::Server::default()
26//! .with_listener(listener)
27//! .with_service(HelloServer.into_service())
28//! .serve()
29//! .await
30//! });
31//!
32//! // create a client
33//! let channel = urpc::ClientChannel::new(urpc::unix::connect("./hello.service").await?);
34//! let client = HelloClient::new(channel);
35//! let greeting = client.hello("World".into()).await.unwrap();
36//! assert_eq!(greeting, "Hello, World!");
37//! Ok(())
38//! }
39//! ```
40use std::{path::Path, pin::Pin};
41
42use futures::{Sink, Stream};
43use tokio::{
44 net::{UnixListener, UnixStream},
45 sync::mpsc::Receiver,
46 task::AbortHandle,
47};
48use tokio_util::codec::Framed;
49
50use crate::{
51 Channel, Listener,
52 protocol::{RpcMessage, RpcMessageCodec},
53};
54
55pub struct UnixChannel(Framed<UnixStream, RpcMessageCodec>);
56
57impl UnixChannel {
58 fn new(stream: UnixStream) -> Self {
59 Self(Framed::new(stream, RpcMessageCodec::default()))
60 }
61
62 pub async fn connect(path: impl AsRef<Path>) -> std::io::Result<Self> {
63 let stream = UnixStream::connect(path).await?;
64 Ok(Self::new(stream))
65 }
66}
67
68impl Sink<RpcMessage> for UnixChannel {
69 type Error = std::io::Error;
70
71 fn poll_ready(
72 self: Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
75 Sink::poll_ready(Pin::new(&mut self.get_mut().0), cx)
76 }
77
78 fn start_send(self: Pin<&mut Self>, item: RpcMessage) -> std::result::Result<(), Self::Error> {
79 Sink::start_send(Pin::new(&mut self.get_mut().0), item)
80 }
81
82 fn poll_flush(
83 self: Pin<&mut Self>,
84 cx: &mut std::task::Context<'_>,
85 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
86 Sink::poll_flush(Pin::new(&mut self.get_mut().0), cx)
87 }
88
89 fn poll_close(
90 self: Pin<&mut Self>,
91 cx: &mut std::task::Context<'_>,
92 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
93 Sink::poll_close(Pin::new(&mut self.get_mut().0), cx)
94 }
95}
96
97impl Stream for UnixChannel {
98 type Item = std::io::Result<RpcMessage>;
99
100 fn poll_next(
101 self: Pin<&mut Self>,
102 cx: &mut std::task::Context<'_>,
103 ) -> std::task::Poll<Option<Self::Item>> {
104 Stream::poll_next(Pin::new(&mut self.get_mut().0), cx)
105 }
106}
107
108impl Channel for UnixChannel {}
109
110pub struct UnixChannelListener {
111 receiver: Receiver<UnixChannel>,
112 abort: AbortHandle,
113}
114
115impl Drop for UnixChannelListener {
116 fn drop(&mut self) {
117 self.abort.abort();
118 }
119}
120
121impl UnixChannelListener {
122 pub async fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
123 let path = path.as_ref();
124 if tokio::fs::try_exists(path).await? {
125 tokio::fs::remove_file(path).await?;
126 }
127 let listener = UnixListener::bind(path)?;
128 let (sender, receiver) = tokio::sync::mpsc::channel(8);
129 let abort = tokio::spawn(async move {
130 while let Ok((stream, _addr)) = listener.accept().await {
131 if sender.send(UnixChannel::new(stream)).await.is_err() {
132 break;
133 }
134 }
135 })
136 .abort_handle();
137 Ok(Self { receiver, abort })
138 }
139}
140
141impl Stream for UnixChannelListener {
142 type Item = std::io::Result<UnixChannel>;
143
144 fn poll_next(
145 self: Pin<&mut Self>,
146 cx: &mut std::task::Context<'_>,
147 ) -> std::task::Poll<Option<Self::Item>> {
148 self.get_mut().receiver.poll_recv(cx).map(|v| v.map(Ok))
149 }
150}
151
152impl Listener<UnixChannel> for UnixChannelListener {}
153
154pub async fn bind(path: impl AsRef<Path>) -> std::io::Result<UnixChannelListener> {
155 UnixChannelListener::bind(path).await
156}
157
158pub async fn connect(path: impl AsRef<Path>) -> std::io::Result<UnixChannel> {
159 UnixChannel::connect(path).await
160}