aboutsummaryrefslogtreecommitdiff
path: root/src/server.rs
diff options
context:
space:
mode:
authordiogo464 <[email protected]>2025-07-16 10:46:41 +0100
committerdiogo464 <[email protected]>2025-07-16 10:46:41 +0100
commitf319d7ab5278a3cfb43d38875d81c28cc2dce1e1 (patch)
treecb161fd990643e267bbc373fb09ccd7b689a23b5 /src/server.rs
Initial commit - extracted urpc from monorepo
Diffstat (limited to 'src/server.rs')
-rw-r--r--src/server.rs166
1 files changed, 166 insertions, 0 deletions
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..414340e
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,166 @@
1use std::{collections::HashMap, sync::Arc};
2
3use futures::{SinkExt, StreamExt};
4use tokio::task::{AbortHandle, JoinSet};
5
6use crate::{
7 Channel, Listener, Service,
8 protocol::{RpcCall, RpcCancel, RpcMessage, RpcResponse},
9};
10
11#[derive(Clone)]
12struct Services(Arc<HashMap<String, Arc<dyn Service>>>);
13
14impl Services {
15 fn new(services: HashMap<String, Arc<dyn Service>>) -> Self {
16 Self(Arc::new(services))
17 }
18
19 fn get_service(&self, name: &str) -> std::io::Result<Arc<dyn Service>> {
20 match self.0.get(name) {
21 Some(service) => Ok(service.clone()),
22 None => Err(std::io::Error::new(
23 std::io::ErrorKind::NotFound,
24 "service not found",
25 )),
26 }
27 }
28}
29
30type ListenerSpawner =
31 Box<dyn FnOnce(&mut JoinSet<std::io::Result<()>>, Services) -> AbortHandle + Send + 'static>;
32
33#[derive(Debug, Default)]
34struct AbortHandles(Vec<AbortHandle>);
35
36impl Drop for AbortHandles {
37 fn drop(&mut self) {
38 for handle in &self.0 {
39 handle.abort();
40 }
41 }
42}
43
44impl AbortHandles {
45 pub fn push(&mut self, handle: AbortHandle) {
46 self.0.push(handle);
47 }
48}
49
50#[derive(Default)]
51pub struct Server {
52 services: HashMap<String, Arc<dyn Service>>,
53 listener_spawners: Vec<ListenerSpawner>,
54}
55
56impl Server {
57 pub fn with_service<T>(mut self, service: T) -> Self
58 where
59 T: Service,
60 {
61 let name = T::name();
62 let service = Arc::new(service);
63 self.services.insert(name.to_string(), service);
64 self
65 }
66
67 pub fn with_listener<L, C>(mut self, listener: L) -> Self
68 where
69 C: Channel,
70 L: Listener<C>,
71 {
72 self.listener_spawners
73 .push(Box::new(move |join_set, services| {
74 join_set.spawn(listener_loop(listener, services))
75 }));
76 self
77 }
78
79 pub async fn serve(self) -> std::io::Result<()> {
80 let services = Services::new(self.services);
81 let mut join_set = JoinSet::default();
82 let mut abort_handles = AbortHandles::default();
83 for spawner in self.listener_spawners {
84 let abort_handle = (spawner)(&mut join_set, services.clone());
85 abort_handles.push(abort_handle);
86 }
87 match join_set.join_next().await {
88 Some(Ok(Ok(()))) => Ok(()),
89 Some(Ok(Err(err))) => Err(err),
90 Some(Err(err)) => Err(std::io::Error::other(err)),
91 None => Ok(()),
92 }
93 }
94}
95
96async fn listener_loop<L, C>(mut listener: L, services: Services) -> std::io::Result<()>
97where
98 C: Channel,
99 L: Listener<C>,
100{
101 while let Some(result) = listener.next().await {
102 let channel = result?;
103 let services = services.clone();
104 tokio::spawn(channel_handler(channel, services));
105 }
106 Ok(())
107}
108
109async fn channel_handler<C: Channel>(mut channel: C, services: Services) -> std::io::Result<()> {
110 enum Select {
111 Empty,
112 Message(RpcMessage),
113 }
114
115 let (response_tx, mut response_rx) =
116 tokio::sync::mpsc::unbounded_channel::<std::io::Result<RpcResponse>>();
117 let mut requests: HashMap<u64, AbortHandle> = Default::default();
118 loop {
119 let select = tokio::select! {
120 reqopt = channel.next() => match reqopt {
121 Some(Ok(message)) => Select::Message(message),
122 Some(Err(err)) => return Err(err),
123 None => Select::Empty,
124 },
125 Some(response) = response_rx.recv() => match response {
126 Ok(response) => Select::Message(RpcMessage::Response(response)),
127 Err(err) => return Err(err),
128 }
129 };
130
131 match select {
132 Select::Empty => break,
133 Select::Message(message) => match message {
134 RpcMessage::Call(RpcCall {
135 id,
136 service,
137 method,
138 arguments,
139 }) => {
140 let response_tx = response_tx.clone();
141 let service = services.get_service(&service)?;
142 let handle = tokio::spawn(async move {
143 let response = match service.call(method, arguments).await {
144 Ok(value) => Ok(RpcResponse { id, value }),
145 Err(err) => Err(err),
146 };
147 let _ = response_tx.send(response);
148 })
149 .abort_handle();
150 requests.insert(id, handle);
151 }
152 RpcMessage::Cancel(RpcCancel { id }) => {
153 if let Some(handle) = requests.remove(&id) {
154 handle.abort();
155 }
156 }
157 RpcMessage::Response(response) => {
158 requests.remove(&response.id);
159 channel.send(RpcMessage::Response(response)).await?;
160 }
161 },
162 }
163 }
164
165 Ok(())
166}