aboutsummaryrefslogtreecommitdiff
path: root/src/mqtt/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/mqtt/mod.rs')
-rw-r--r--src/mqtt/mod.rs485
1 files changed, 485 insertions, 0 deletions
diff --git a/src/mqtt/mod.rs b/src/mqtt/mod.rs
new file mode 100644
index 0000000..30e3a33
--- /dev/null
+++ b/src/mqtt/mod.rs
@@ -0,0 +1,485 @@
1mod connect_code;
2mod field;
3mod packet_id;
4mod protocol;
5mod qos;
6mod rx;
7mod transport;
8mod tx;
9mod varint;
10
11pub use connect_code::ConnectCode;
12use embedded_io_async::ReadExactError;
13pub use packet_id::PacketId;
14pub use qos::Qos;
15pub use transport::Transport;
16
17use self::{field::FieldBuffer, transport::TransportExt as _};
18
19const DEFAULT_CLIENT_RX_BUFFER_SIZE: usize = 512;
20const DEFAULT_CLIENT_TX_BUFFER_SIZE: usize = 512;
21
22pub enum Error<T: Transport> {
23 Transport(T::Error),
24 TransportEOF,
25 InsufficientBufferSpace,
26 ProtocolError(&'static str),
27 ConnectFailed(ConnectCode),
28}
29
30impl<T: Transport> core::fmt::Debug for Error<T> {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 match self {
33 Error::Transport(err) => f.debug_tuple("Transport").field(err).finish(),
34 Error::TransportEOF => f.write_str("TransportEOF"),
35 Error::InsufficientBufferSpace => f.write_str("InsufficientBufferSpace"),
36 Error::ProtocolError(msg) => f.debug_tuple("ProtocolError").field(msg).finish(),
37 Error::ConnectFailed(code) => f.debug_tuple("ConnectFailed").field(code).finish(),
38 }
39 }
40}
41
42impl<T: Transport> core::fmt::Display for Error<T> {
43 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44 match self {
45 Error::Transport(err) => write!(f, "transport error: {:?}", err),
46 Error::TransportEOF => write!(f, "unexpected end of transport stream"),
47 Error::InsufficientBufferSpace => {
48 write!(f, "insufficient buffer space to receive packet")
49 }
50 Error::ProtocolError(msg) => write!(f, "MQTT protocol error: {}", msg),
51 Error::ConnectFailed(code) => write!(f, "connection failed: {}", code),
52 }
53 }
54}
55
56impl<T: Transport> core::error::Error for Error<T>
57where
58 T::Error: core::error::Error + 'static,
59{
60 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
61 match self {
62 Error::Transport(err) => Some(err),
63 _ => None,
64 }
65 }
66}
67
68#[derive(Debug, Default)]
69pub struct ConnectParams<'a> {
70 pub will_topic: Option<&'a str>,
71 pub will_payload: Option<&'a [u8]>,
72 pub will_retain: bool,
73 pub username: Option<&'a str>,
74 pub password: Option<&'a [u8]>,
75 pub keepalive: Option<u16>,
76}
77
78#[derive(Debug, Default)]
79pub struct PublishParams {
80 pub qos: Qos,
81 pub retain: bool,
82}
83
84#[derive(Debug)]
85pub enum PublishData<'a> {
86 Inline(&'a [u8]),
87 Deferred(usize),
88}
89
90#[derive(Debug)]
91pub struct Publish<'a> {
92 pub topic: &'a str,
93 pub packet_id: Option<PacketId>,
94 pub qos: Qos,
95 pub retain: bool,
96 pub data_len: usize,
97}
98
99#[derive(Debug)]
100pub struct PublishAck {
101 pub packet_id: PacketId,
102}
103
104#[derive(Debug)]
105pub struct SubscribeAck {
106 pub packet_id: PacketId,
107 pub success: bool,
108}
109
110#[derive(Debug)]
111pub struct UnsubscribeAck {
112 pub packet_id: PacketId,
113}
114
115#[derive(Debug)]
116pub enum Packet<'a> {
117 Publish(Publish<'a>),
118 PublishAck(PublishAck),
119 SubscribeAck(SubscribeAck),
120 UnsubscribeAck(UnsubscribeAck),
121}
122
123pub struct ClientResources<
124 const RX: usize = DEFAULT_CLIENT_RX_BUFFER_SIZE,
125 const TX: usize = DEFAULT_CLIENT_TX_BUFFER_SIZE,
126> {
127 rx_buffer: [u8; RX],
128 tx_buffer: [u8; TX],
129}
130
131impl<const RX: usize, const TX: usize> Default for ClientResources<RX, TX> {
132 fn default() -> Self {
133 Self {
134 rx_buffer: [0u8; RX],
135 tx_buffer: [0u8; TX],
136 }
137 }
138}
139
140pub struct Client<'a, T> {
141 transport: T,
142 rx_buffer: &'a mut [u8],
143 rx_buffer_len: usize,
144 rx_buffer_skip: usize,
145 rx_buffer_data: usize,
146 tx_buffer: &'a mut [u8],
147 next_packet_id: u16,
148}
149
150impl<'a, T> Client<'a, T> {
151 pub fn new<const RX: usize, const TX: usize>(
152 resources: &'a mut ClientResources<RX, TX>,
153 transport: T,
154 ) -> Self {
155 Self {
156 transport,
157 rx_buffer: &mut resources.rx_buffer,
158 rx_buffer_len: 0,
159 rx_buffer_skip: 0,
160 rx_buffer_data: 0,
161 tx_buffer: &mut resources.tx_buffer,
162 next_packet_id: 1,
163 }
164 }
165}
166
167impl<'a, T> Client<'a, T>
168where
169 T: Transport,
170{
171 fn allocate_packet_id(&mut self) -> PacketId {
172 let packet_id = self.next_packet_id;
173 self.next_packet_id = self.next_packet_id.wrapping_add(1);
174 if self.next_packet_id == 0 {
175 self.next_packet_id = 1;
176 }
177 PacketId::from(packet_id)
178 }
179
180 pub async fn connect(&mut self, client_id: &str) -> Result<(), Error<T>> {
181 self.connect_with(client_id, Default::default()).await
182 }
183
184 pub async fn connect_with(
185 &mut self,
186 client_id: &str,
187 params: ConnectParams<'_>,
188 ) -> Result<(), Error<T>> {
189 let mut buffer = FieldBuffer::default();
190 tx::connect(
191 &mut buffer,
192 tx::Connect {
193 client_id,
194 clean_session: true,
195 username: params.username,
196 password: params.password,
197 will_topic: params.will_topic,
198 will_payload: params.will_payload,
199 will_retain: params.will_retain,
200 keepalive: None,
201 },
202 );
203 self.transport
204 .write_fields(&buffer)
205 .await
206 .map_err(Error::Transport)?;
207 self.transport.flush().await.map_err(Error::Transport)?;
208
209 // Wait for CONNACK response
210 match self.receive_inner().await? {
211 rx::Packet::ConnAck {
212 session_present,
213 code,
214 } => {
215 if code == ConnectCode::ConnectionAccepted {
216 Ok(())
217 } else {
218 Err(Error::ConnectFailed(code))
219 }
220 }
221 _ => Err(Error::ProtocolError(
222 "expected CONNACK packet after CONNECT",
223 )),
224 }
225 }
226
227 pub async fn publish(&mut self, topic: &str, data: &[u8]) -> Result<PacketId, Error<T>> {
228 self.publish_with(topic, data, Default::default()).await
229 }
230
231 pub async fn publish_with(
232 &mut self,
233 topic: &str,
234 data: &[u8],
235 params: PublishParams,
236 ) -> Result<PacketId, Error<T>> {
237 let packet_id = if params.qos.to_u8() > 0 {
238 Some(self.allocate_packet_id())
239 } else {
240 None
241 };
242
243 let mut buffer = FieldBuffer::default();
244 tx::publish(
245 &mut buffer,
246 tx::Publish {
247 topic,
248 payload: data,
249 qos: params.qos,
250 retain: params.retain,
251 dup: false,
252 packet_id,
253 },
254 );
255
256 self.transport
257 .write_fields(&buffer)
258 .await
259 .map_err(Error::Transport)?;
260 self.transport.flush().await.map_err(Error::Transport)?;
261
262 Ok(packet_id.unwrap_or(PacketId::from(0)))
263 }
264
265 pub async fn publish_ack(&mut self, packet_id: PacketId, qos: Qos) -> Result<(), Error<T>> {
266 let mut buffer = FieldBuffer::default();
267
268 match qos {
269 Qos::AtMostOnce => {
270 // QoS 0: No acknowledgment needed
271 return Ok(());
272 }
273 Qos::AtLeastOnce => {
274 // QoS 1: Send PUBACK
275 tx::puback(&mut buffer, packet_id);
276 }
277 Qos::ExactlyOnce => todo!("not implemented"),
278 }
279
280 self.transport
281 .write_fields(&buffer)
282 .await
283 .map_err(Error::Transport)?;
284 self.transport.flush().await.map_err(Error::Transport)?;
285
286 Ok(())
287 }
288
289 pub async fn subscribe(&mut self, topic: &str) -> Result<PacketId, Error<T>> {
290 self.subscribe_with(topic, Qos::AtMostOnce).await
291 }
292
293 pub async fn subscribe_with(&mut self, topic: &str, qos: Qos) -> Result<PacketId, Error<T>> {
294 let packet_id = self.allocate_packet_id();
295
296 let mut buffer = FieldBuffer::default();
297 tx::subscribe(
298 &mut buffer,
299 tx::Subscribe {
300 topic,
301 qos,
302 packet_id,
303 },
304 );
305
306 self.transport
307 .write_fields(&buffer)
308 .await
309 .map_err(Error::Transport)?;
310 self.transport.flush().await.map_err(Error::Transport)?;
311
312 Ok(packet_id)
313 }
314
315 pub async fn unsubscribe(&mut self, topic: &str) -> Result<PacketId, Error<T>> {
316 let packet_id = self.allocate_packet_id();
317
318 let mut buffer = FieldBuffer::default();
319 tx::unsubscribe(&mut buffer, tx::Unsubscribe { topic, packet_id });
320
321 self.transport
322 .write_fields(&buffer)
323 .await
324 .map_err(Error::Transport)?;
325 self.transport.flush().await.map_err(Error::Transport)?;
326
327 Ok(packet_id)
328 }
329
330 async fn receive_inner<'s>(&'s mut self) -> Result<rx::Packet<'s>, Error<T>> {
331 self.skip_if_required();
332 self.discard_data().await?;
333
334 loop {
335 let buf = &self.rx_buffer[..self.rx_buffer_len];
336 match rx::decode(buf) {
337 Ok(_) => {
338 // NOTE: stupid workaround for borrow checker, should not
339 // need to decode twice
340 let buf = &self.rx_buffer[..self.rx_buffer_len];
341 let (packet, n) = rx::decode(buf).unwrap();
342 self.rx_buffer_skip = n;
343 if let rx::Packet::Publish { data_len, .. } = &packet {
344 self.rx_buffer_data = *data_len;
345 }
346 return Ok(packet);
347 }
348 Err(err) => match err {
349 rx::Error::NeedMoreData => {
350 if self.rx_buffer.len() == self.rx_buffer_len {
351 return Err(Error::InsufficientBufferSpace);
352 }
353 }
354 rx::Error::InvalidPacket(msg) => return Err(Error::ProtocolError(msg)),
355 rx::Error::UnsupportedPacket { packet_type, .. } => {
356 return Err(Error::ProtocolError("unsupported packet type"));
357 }
358 rx::Error::UnknownPacket { packet_type, .. } => {
359 return Err(Error::ProtocolError("unknown packet type"));
360 }
361 },
362 }
363
364 self.fill_rx_buffer().await?;
365 }
366 }
367
368 pub async fn receive<'s>(&'s mut self) -> Result<Packet<'s>, Error<T>> {
369 match self.receive_inner().await? {
370 rx::Packet::ConnAck { .. } => {
371 return Err(Error::ProtocolError("unexpected CONNACK packet"));
372 }
373 rx::Packet::Publish {
374 topic,
375 packet_id,
376 qos,
377 retain,
378 dup: _dup,
379 data_len,
380 } => {
381 return Ok(Packet::Publish(Publish {
382 topic,
383 packet_id,
384 qos,
385 retain,
386 data_len,
387 }));
388 }
389 rx::Packet::PubAck { packet_id } => {
390 return Ok(Packet::PublishAck(PublishAck { packet_id }));
391 }
392 rx::Packet::SubscribeAck { packet_id, success } => {
393 return Ok(Packet::SubscribeAck(SubscribeAck { packet_id, success }));
394 }
395 rx::Packet::UnsubscribeAck { packet_id } => {
396 return Ok(Packet::UnsubscribeAck(UnsubscribeAck { packet_id }));
397 }
398 }
399 }
400
401 pub async fn receive_data(&mut self, buf: &mut [u8]) -> Result<(), Error<T>> {
402 self.skip_if_required();
403 if buf.len() != self.rx_buffer_data {
404 return Err(Error::InsufficientBufferSpace);
405 }
406
407 assert_eq!(self.rx_buffer_skip, 0);
408 let from_buffer = self.rx_buffer_data.min(self.rx_buffer_len);
409 let from_transport = self.rx_buffer_data.strict_sub(from_buffer);
410
411 buf[..from_buffer].copy_from_slice(&self.rx_buffer[..from_buffer]);
412 self.rx_buffer_len -= from_buffer;
413
414 if from_transport > 0 {
415 assert_eq!(self.rx_buffer_len, 0);
416 self.transport
417 .read_exact(&mut buf[from_buffer..])
418 .await
419 .map_err(|err| match err {
420 ReadExactError::UnexpectedEof => Error::<T>::TransportEOF,
421 ReadExactError::Other(e) => Error::Transport(e),
422 })?;
423 }
424 self.rx_buffer_data = 0;
425
426 Ok(())
427 }
428
429 pub async fn disconnect(&mut self) -> Result<(), Error<T>> {
430 let mut buffer = FieldBuffer::default();
431 tx::disconnect(&mut buffer);
432
433 self.transport
434 .write_fields(&buffer)
435 .await
436 .map_err(Error::Transport)?;
437 self.transport.flush().await.map_err(Error::Transport)?;
438
439 Ok(())
440 }
441
442 async fn fill_rx_buffer(&mut self) -> Result<(), Error<T>> {
443 let n = self
444 .transport
445 .read(&mut self.rx_buffer[self.rx_buffer_len..])
446 .await
447 .map_err(Error::Transport)?;
448 if n == 0 {
449 return Err(Error::TransportEOF);
450 }
451 self.rx_buffer_len += n;
452
453 Ok(())
454 }
455
456 fn skip_if_required(&mut self) {
457 assert!(self.rx_buffer_len >= self.rx_buffer_skip);
458 if self.rx_buffer_skip != 0 {
459 self.rx_buffer.copy_within(self.rx_buffer_skip.., 0);
460 self.rx_buffer_len = self.rx_buffer_len.strict_sub(self.rx_buffer_skip);
461 self.rx_buffer_skip = 0;
462 }
463 }
464
465 async fn discard_data(&mut self) -> Result<(), Error<T>> {
466 if self.rx_buffer_data == 0 {
467 return Ok(());
468 }
469
470 assert_eq!(self.rx_buffer_skip, 0);
471 while self.rx_buffer_data > 0 {
472 if self.rx_buffer_len <= self.rx_buffer_data {
473 self.rx_buffer_data -= self.rx_buffer_len;
474 self.rx_buffer_len = 0;
475 } else {
476 self.rx_buffer.copy_within(self.rx_buffer_data.., 0);
477 self.rx_buffer_len -= self.rx_buffer_data;
478 self.rx_buffer_data = 0;
479 }
480 self.fill_rx_buffer().await?;
481 }
482
483 Ok(())
484 }
485}