aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authordiogo464 <[email protected]>2025-12-09 22:30:42 +0000
committerdiogo464 <[email protected]>2025-12-09 22:30:42 +0000
commita5845673cf052b606f722be10d48c5d963958050 (patch)
treee21bf5848163d07fce4bf8e3d7474bfeed5d1aff /src
parent6bb6d358f39c31b5486621b49da463f97226fea5 (diff)
moved embedded-mqtt crate to a module
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs14
-rw-r--r--src/mqtt/connect_code.rs58
-rw-r--r--src/mqtt/field.rs89
-rw-r--r--src/mqtt/mod.rs485
-rw-r--r--src/mqtt/packet_id.rs14
-rw-r--r--src/mqtt/protocol.rs65
-rw-r--r--src/mqtt/qos.rs64
-rw-r--r--src/mqtt/rx.rs241
-rw-r--r--src/mqtt/transport.rs39
-rw-r--r--src/mqtt/tx.rs203
-rw-r--r--src/mqtt/varint.rs69
11 files changed, 1335 insertions, 6 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 3e91272..714e186 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -28,6 +28,8 @@ use heapless::{
28}; 28};
29use serde::Serialize; 29use serde::Serialize;
30 30
31mod mqtt;
32
31pub mod log; 33pub mod log;
32pub use log::Format; 34pub use log::Format;
33 35
@@ -215,7 +217,7 @@ pub struct DeviceResources {
215 waker: AtomicWaker, 217 waker: AtomicWaker,
216 entities: [RefCell<Option<EntityData>>; Self::ENTITY_LIMIT], 218 entities: [RefCell<Option<EntityData>>; Self::ENTITY_LIMIT],
217 219
218 mqtt_resources: embedded_mqtt::ClientResources, 220 mqtt_resources: mqtt::ClientResources,
219 publish_buffer: Vec<u8, 2048>, 221 publish_buffer: Vec<u8, 2048>,
220 subscribe_buffer: Vec<u8, 128>, 222 subscribe_buffer: Vec<u8, 128>,
221 discovery_buffer: Vec<u8, 2048>, 223 discovery_buffer: Vec<u8, 2048>,
@@ -422,7 +424,7 @@ pub struct Device<'a> {
422 waker: &'a AtomicWaker, 424 waker: &'a AtomicWaker,
423 entities: &'a [RefCell<Option<EntityData>>], 425 entities: &'a [RefCell<Option<EntityData>>],
424 426
425 mqtt_resources: &'a mut embedded_mqtt::ClientResources, 427 mqtt_resources: &'a mut mqtt::ClientResources,
426 publish_buffer: &'a mut VecView<u8>, 428 publish_buffer: &'a mut VecView<u8>,
427 subscribe_buffer: &'a mut VecView<u8>, 429 subscribe_buffer: &'a mut VecView<u8>,
428 discovery_buffer: &'a mut VecView<u8>, 430 discovery_buffer: &'a mut VecView<u8>,
@@ -585,8 +587,8 @@ pub async fn run<T: Transport>(device: &mut Device<'_>, transport: &mut T) -> Re
585 .expect("device availability buffer too small"); 587 .expect("device availability buffer too small");
586 let availability_topic = device.availability_topic_buffer.as_str(); 588 let availability_topic = device.availability_topic_buffer.as_str();
587 589
588 let mut client = embedded_mqtt::Client::new(device.mqtt_resources, transport); 590 let mut client = mqtt::Client::new(device.mqtt_resources, transport);
589 let connect_params = embedded_mqtt::ConnectParams { 591 let connect_params = mqtt::ConnectParams {
590 will_topic: Some(availability_topic), 592 will_topic: Some(availability_topic),
591 will_payload: Some(NOT_AVAILABLE_PAYLOAD.as_bytes()), 593 will_payload: Some(NOT_AVAILABLE_PAYLOAD.as_bytes()),
592 will_retain: true, 594 will_retain: true,
@@ -744,7 +746,7 @@ pub async fn run<T: Transport>(device: &mut Device<'_>, transport: &mut T) -> Re
744 client.publish_with( 746 client.publish_with(
745 availability_topic, 747 availability_topic,
746 AVAILABLE_PAYLOAD.as_bytes(), 748 AVAILABLE_PAYLOAD.as_bytes(),
747 embedded_mqtt::PublishParams { 749 mqtt::PublishParams {
748 retain: true, 750 retain: true,
749 ..Default::default() 751 ..Default::default()
750 }, 752 },
@@ -858,7 +860,7 @@ pub async fn run<T: Transport>(device: &mut Device<'_>, transport: &mut T) -> Re
858 .await 860 .await
859 { 861 {
860 Ok(embassy_futures::select::Either::First(packet)) => match packet { 862 Ok(embassy_futures::select::Either::First(packet)) => match packet {
861 Ok(embedded_mqtt::Packet::Publish(publish)) => publish, 863 Ok(mqtt::Packet::Publish(publish)) => publish,
862 Err(err) => { 864 Err(err) => {
863 crate::log::error!( 865 crate::log::error!(
864 "mqtt receive failed with: {:?}", 866 "mqtt receive failed with: {:?}",
diff --git a/src/mqtt/connect_code.rs b/src/mqtt/connect_code.rs
new file mode 100644
index 0000000..570ce0f
--- /dev/null
+++ b/src/mqtt/connect_code.rs
@@ -0,0 +1,58 @@
1use super::protocol;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum ConnectCode {
5 ConnectionAccepted,
6 UnacceptableProtocolVersion,
7 IdentifierRejected,
8 ServerUnavailable,
9 BadUsernamePassword,
10 NotAuthorized,
11 Unknown(u8),
12}
13
14impl core::fmt::Display for ConnectCode {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 match self {
17 ConnectCode::ConnectionAccepted => write!(f, "Connection Accepted"),
18 ConnectCode::UnacceptableProtocolVersion => write!(f, "Unacceptable Protocol Version"),
19 ConnectCode::IdentifierRejected => write!(f, "Identifier Rejected"),
20 ConnectCode::ServerUnavailable => write!(f, "Server Unavailable"),
21 ConnectCode::BadUsernamePassword => write!(f, "Bad Username or Password"),
22 ConnectCode::NotAuthorized => write!(f, "Not Authorized"),
23 ConnectCode::Unknown(code) => write!(f, "Unknown({})", code),
24 }
25 }
26}
27
28impl From<u8> for ConnectCode {
29 fn from(value: u8) -> Self {
30 match value {
31 protocol::CONNACK_CODE_ACCEPTED => ConnectCode::ConnectionAccepted,
32 protocol::CONNACK_CODE_UNACCEPTABLE_PROTOCOL_VERSION => {
33 ConnectCode::UnacceptableProtocolVersion
34 }
35 protocol::CONNACK_CODE_IDENTIFIER_REJECTED => ConnectCode::IdentifierRejected,
36 protocol::CONNACK_CODE_SERVER_UNAVAILABLE => ConnectCode::ServerUnavailable,
37 protocol::CONNACK_CODE_BAD_USERNAME_PASSWORD => ConnectCode::BadUsernamePassword,
38 protocol::CONNACK_CODE_NOT_AUTHORIZED => ConnectCode::NotAuthorized,
39 code => ConnectCode::Unknown(code),
40 }
41 }
42}
43
44impl From<ConnectCode> for u8 {
45 fn from(value: ConnectCode) -> Self {
46 match value {
47 ConnectCode::ConnectionAccepted => protocol::CONNACK_CODE_ACCEPTED,
48 ConnectCode::UnacceptableProtocolVersion => {
49 protocol::CONNACK_CODE_UNACCEPTABLE_PROTOCOL_VERSION
50 }
51 ConnectCode::IdentifierRejected => protocol::CONNACK_CODE_IDENTIFIER_REJECTED,
52 ConnectCode::ServerUnavailable => protocol::CONNACK_CODE_SERVER_UNAVAILABLE,
53 ConnectCode::BadUsernamePassword => protocol::CONNACK_CODE_BAD_USERNAME_PASSWORD,
54 ConnectCode::NotAuthorized => protocol::CONNACK_CODE_NOT_AUTHORIZED,
55 ConnectCode::Unknown(code) => code,
56 }
57 }
58}
diff --git a/src/mqtt/field.rs b/src/mqtt/field.rs
new file mode 100644
index 0000000..9e67e63
--- /dev/null
+++ b/src/mqtt/field.rs
@@ -0,0 +1,89 @@
1use core::{mem::MaybeUninit, ops::Deref};
2
3use super::varint;
4
5const DEFAULT_FIELD_BUFFER_CAP: usize = 32;
6
7pub enum Field<'a> {
8 U8(u8),
9 U16(u16),
10 VarInt(u32),
11 Buffer(&'a [u8]),
12 LenPrefixedBuffer(&'a [u8]),
13 LenPrefixedString(&'a str),
14}
15
16pub struct FieldBuffer<'a, const N: usize = DEFAULT_FIELD_BUFFER_CAP> {
17 data: [MaybeUninit<Field<'a>>; N],
18 len: usize,
19}
20
21impl<'a, const N: usize> Default for FieldBuffer<'a, N> {
22 fn default() -> Self {
23 Self {
24 data: [const { MaybeUninit::uninit() }; N],
25 len: 0,
26 }
27 }
28}
29
30impl<'a, const N: usize> FieldBuffer<'a, N> {
31 pub fn clear(&mut self) {
32 self.len = 0;
33 }
34
35 pub fn push(&mut self, field: Field<'a>) {
36 assert!(self.len < N, "field buffer lenght limit exceeded");
37 self.data[self.len].write(field);
38 self.len += 1;
39 }
40
41 pub fn set(&mut self, n: usize, field: Field<'a>) {
42 assert!(self.len > n);
43 self.data[n].write(field);
44 }
45
46 pub fn as_slice<'s>(&'s self) -> &'s [Field<'a>] {
47 unsafe {
48 core::mem::transmute::<&'s [MaybeUninit<Field<'a>>], &'s [Field<'a>]>(
49 &self.data[..self.len],
50 )
51 }
52 }
53}
54
55impl<'a, const N: usize> AsRef<[Field<'a>]> for FieldBuffer<'a, N> {
56 fn as_ref(&self) -> &[Field<'a>] {
57 self.as_slice()
58 }
59}
60
61impl<'a, const N: usize> Deref for FieldBuffer<'a, N> {
62 type Target = [Field<'a>];
63
64 fn deref(&self) -> &Self::Target {
65 self.as_slice()
66 }
67}
68
69pub fn field_size(field: &Field) -> usize {
70 match field {
71 Field::U8(_) => 1,
72 Field::U16(_) => 2,
73 Field::VarInt(v) => {
74 let (_, n) = varint::encode(*v);
75 n
76 }
77 Field::Buffer(v) => v.len(),
78 Field::LenPrefixedBuffer(v) => v.len().strict_add(2),
79 Field::LenPrefixedString(v) => v.len().strict_add(2),
80 }
81}
82
83pub fn fields_size(fields: &[Field]) -> usize {
84 let mut total_size = 0usize;
85 for field in fields {
86 total_size = total_size.strict_add(field_size(field));
87 }
88 total_size
89}
diff --git a/src/mqtt/mod.rs b/src/mqtt/mod.rs
new file mode 100644
index 0000000..30e3a33
--- /dev/null
+++ b/src/mqtt/mod.rs
@@ -0,0 +1,485 @@
1mod connect_code;
2mod field;
3mod packet_id;
4mod protocol;
5mod qos;
6mod rx;
7mod transport;
8mod tx;
9mod varint;
10
11pub use connect_code::ConnectCode;
12use embedded_io_async::ReadExactError;
13pub use packet_id::PacketId;
14pub use qos::Qos;
15pub use transport::Transport;
16
17use self::{field::FieldBuffer, transport::TransportExt as _};
18
19const DEFAULT_CLIENT_RX_BUFFER_SIZE: usize = 512;
20const DEFAULT_CLIENT_TX_BUFFER_SIZE: usize = 512;
21
22pub enum Error<T: Transport> {
23 Transport(T::Error),
24 TransportEOF,
25 InsufficientBufferSpace,
26 ProtocolError(&'static str),
27 ConnectFailed(ConnectCode),
28}
29
30impl<T: Transport> core::fmt::Debug for Error<T> {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 match self {
33 Error::Transport(err) => f.debug_tuple("Transport").field(err).finish(),
34 Error::TransportEOF => f.write_str("TransportEOF"),
35 Error::InsufficientBufferSpace => f.write_str("InsufficientBufferSpace"),
36 Error::ProtocolError(msg) => f.debug_tuple("ProtocolError").field(msg).finish(),
37 Error::ConnectFailed(code) => f.debug_tuple("ConnectFailed").field(code).finish(),
38 }
39 }
40}
41
42impl<T: Transport> core::fmt::Display for Error<T> {
43 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44 match self {
45 Error::Transport(err) => write!(f, "transport error: {:?}", err),
46 Error::TransportEOF => write!(f, "unexpected end of transport stream"),
47 Error::InsufficientBufferSpace => {
48 write!(f, "insufficient buffer space to receive packet")
49 }
50 Error::ProtocolError(msg) => write!(f, "MQTT protocol error: {}", msg),
51 Error::ConnectFailed(code) => write!(f, "connection failed: {}", code),
52 }
53 }
54}
55
56impl<T: Transport> core::error::Error for Error<T>
57where
58 T::Error: core::error::Error + 'static,
59{
60 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
61 match self {
62 Error::Transport(err) => Some(err),
63 _ => None,
64 }
65 }
66}
67
68#[derive(Debug, Default)]
69pub struct ConnectParams<'a> {
70 pub will_topic: Option<&'a str>,
71 pub will_payload: Option<&'a [u8]>,
72 pub will_retain: bool,
73 pub username: Option<&'a str>,
74 pub password: Option<&'a [u8]>,
75 pub keepalive: Option<u16>,
76}
77
78#[derive(Debug, Default)]
79pub struct PublishParams {
80 pub qos: Qos,
81 pub retain: bool,
82}
83
84#[derive(Debug)]
85pub enum PublishData<'a> {
86 Inline(&'a [u8]),
87 Deferred(usize),
88}
89
90#[derive(Debug)]
91pub struct Publish<'a> {
92 pub topic: &'a str,
93 pub packet_id: Option<PacketId>,
94 pub qos: Qos,
95 pub retain: bool,
96 pub data_len: usize,
97}
98
99#[derive(Debug)]
100pub struct PublishAck {
101 pub packet_id: PacketId,
102}
103
104#[derive(Debug)]
105pub struct SubscribeAck {
106 pub packet_id: PacketId,
107 pub success: bool,
108}
109
110#[derive(Debug)]
111pub struct UnsubscribeAck {
112 pub packet_id: PacketId,
113}
114
115#[derive(Debug)]
116pub enum Packet<'a> {
117 Publish(Publish<'a>),
118 PublishAck(PublishAck),
119 SubscribeAck(SubscribeAck),
120 UnsubscribeAck(UnsubscribeAck),
121}
122
123pub struct ClientResources<
124 const RX: usize = DEFAULT_CLIENT_RX_BUFFER_SIZE,
125 const TX: usize = DEFAULT_CLIENT_TX_BUFFER_SIZE,
126> {
127 rx_buffer: [u8; RX],
128 tx_buffer: [u8; TX],
129}
130
131impl<const RX: usize, const TX: usize> Default for ClientResources<RX, TX> {
132 fn default() -> Self {
133 Self {
134 rx_buffer: [0u8; RX],
135 tx_buffer: [0u8; TX],
136 }
137 }
138}
139
140pub struct Client<'a, T> {
141 transport: T,
142 rx_buffer: &'a mut [u8],
143 rx_buffer_len: usize,
144 rx_buffer_skip: usize,
145 rx_buffer_data: usize,
146 tx_buffer: &'a mut [u8],
147 next_packet_id: u16,
148}
149
150impl<'a, T> Client<'a, T> {
151 pub fn new<const RX: usize, const TX: usize>(
152 resources: &'a mut ClientResources<RX, TX>,
153 transport: T,
154 ) -> Self {
155 Self {
156 transport,
157 rx_buffer: &mut resources.rx_buffer,
158 rx_buffer_len: 0,
159 rx_buffer_skip: 0,
160 rx_buffer_data: 0,
161 tx_buffer: &mut resources.tx_buffer,
162 next_packet_id: 1,
163 }
164 }
165}
166
167impl<'a, T> Client<'a, T>
168where
169 T: Transport,
170{
171 fn allocate_packet_id(&mut self) -> PacketId {
172 let packet_id = self.next_packet_id;
173 self.next_packet_id = self.next_packet_id.wrapping_add(1);
174 if self.next_packet_id == 0 {
175 self.next_packet_id = 1;
176 }
177 PacketId::from(packet_id)
178 }
179
180 pub async fn connect(&mut self, client_id: &str) -> Result<(), Error<T>> {
181 self.connect_with(client_id, Default::default()).await
182 }
183
184 pub async fn connect_with(
185 &mut self,
186 client_id: &str,
187 params: ConnectParams<'_>,
188 ) -> Result<(), Error<T>> {
189 let mut buffer = FieldBuffer::default();
190 tx::connect(
191 &mut buffer,
192 tx::Connect {
193 client_id,
194 clean_session: true,
195 username: params.username,
196 password: params.password,
197 will_topic: params.will_topic,
198 will_payload: params.will_payload,
199 will_retain: params.will_retain,
200 keepalive: None,
201 },
202 );
203 self.transport
204 .write_fields(&buffer)
205 .await
206 .map_err(Error::Transport)?;
207 self.transport.flush().await.map_err(Error::Transport)?;
208
209 // Wait for CONNACK response
210 match self.receive_inner().await? {
211 rx::Packet::ConnAck {
212 session_present,
213 code,
214 } => {
215 if code == ConnectCode::ConnectionAccepted {
216 Ok(())
217 } else {
218 Err(Error::ConnectFailed(code))
219 }
220 }
221 _ => Err(Error::ProtocolError(
222 "expected CONNACK packet after CONNECT",
223 )),
224 }
225 }
226
227 pub async fn publish(&mut self, topic: &str, data: &[u8]) -> Result<PacketId, Error<T>> {
228 self.publish_with(topic, data, Default::default()).await
229 }
230
231 pub async fn publish_with(
232 &mut self,
233 topic: &str,
234 data: &[u8],
235 params: PublishParams,
236 ) -> Result<PacketId, Error<T>> {
237 let packet_id = if params.qos.to_u8() > 0 {
238 Some(self.allocate_packet_id())
239 } else {
240 None
241 };
242
243 let mut buffer = FieldBuffer::default();
244 tx::publish(
245 &mut buffer,
246 tx::Publish {
247 topic,
248 payload: data,
249 qos: params.qos,
250 retain: params.retain,
251 dup: false,
252 packet_id,
253 },
254 );
255
256 self.transport
257 .write_fields(&buffer)
258 .await
259 .map_err(Error::Transport)?;
260 self.transport.flush().await.map_err(Error::Transport)?;
261
262 Ok(packet_id.unwrap_or(PacketId::from(0)))
263 }
264
265 pub async fn publish_ack(&mut self, packet_id: PacketId, qos: Qos) -> Result<(), Error<T>> {
266 let mut buffer = FieldBuffer::default();
267
268 match qos {
269 Qos::AtMostOnce => {
270 // QoS 0: No acknowledgment needed
271 return Ok(());
272 }
273 Qos::AtLeastOnce => {
274 // QoS 1: Send PUBACK
275 tx::puback(&mut buffer, packet_id);
276 }
277 Qos::ExactlyOnce => todo!("not implemented"),
278 }
279
280 self.transport
281 .write_fields(&buffer)
282 .await
283 .map_err(Error::Transport)?;
284 self.transport.flush().await.map_err(Error::Transport)?;
285
286 Ok(())
287 }
288
289 pub async fn subscribe(&mut self, topic: &str) -> Result<PacketId, Error<T>> {
290 self.subscribe_with(topic, Qos::AtMostOnce).await
291 }
292
293 pub async fn subscribe_with(&mut self, topic: &str, qos: Qos) -> Result<PacketId, Error<T>> {
294 let packet_id = self.allocate_packet_id();
295
296 let mut buffer = FieldBuffer::default();
297 tx::subscribe(
298 &mut buffer,
299 tx::Subscribe {
300 topic,
301 qos,
302 packet_id,
303 },
304 );
305
306 self.transport
307 .write_fields(&buffer)
308 .await
309 .map_err(Error::Transport)?;
310 self.transport.flush().await.map_err(Error::Transport)?;
311
312 Ok(packet_id)
313 }
314
315 pub async fn unsubscribe(&mut self, topic: &str) -> Result<PacketId, Error<T>> {
316 let packet_id = self.allocate_packet_id();
317
318 let mut buffer = FieldBuffer::default();
319 tx::unsubscribe(&mut buffer, tx::Unsubscribe { topic, packet_id });
320
321 self.transport
322 .write_fields(&buffer)
323 .await
324 .map_err(Error::Transport)?;
325 self.transport.flush().await.map_err(Error::Transport)?;
326
327 Ok(packet_id)
328 }
329
330 async fn receive_inner<'s>(&'s mut self) -> Result<rx::Packet<'s>, Error<T>> {
331 self.skip_if_required();
332 self.discard_data().await?;
333
334 loop {
335 let buf = &self.rx_buffer[..self.rx_buffer_len];
336 match rx::decode(buf) {
337 Ok(_) => {
338 // NOTE: stupid workaround for borrow checker, should not
339 // need to decode twice
340 let buf = &self.rx_buffer[..self.rx_buffer_len];
341 let (packet, n) = rx::decode(buf).unwrap();
342 self.rx_buffer_skip = n;
343 if let rx::Packet::Publish { data_len, .. } = &packet {
344 self.rx_buffer_data = *data_len;
345 }
346 return Ok(packet);
347 }
348 Err(err) => match err {
349 rx::Error::NeedMoreData => {
350 if self.rx_buffer.len() == self.rx_buffer_len {
351 return Err(Error::InsufficientBufferSpace);
352 }
353 }
354 rx::Error::InvalidPacket(msg) => return Err(Error::ProtocolError(msg)),
355 rx::Error::UnsupportedPacket { packet_type, .. } => {
356 return Err(Error::ProtocolError("unsupported packet type"));
357 }
358 rx::Error::UnknownPacket { packet_type, .. } => {
359 return Err(Error::ProtocolError("unknown packet type"));
360 }
361 },
362 }
363
364 self.fill_rx_buffer().await?;
365 }
366 }
367
368 pub async fn receive<'s>(&'s mut self) -> Result<Packet<'s>, Error<T>> {
369 match self.receive_inner().await? {
370 rx::Packet::ConnAck { .. } => {
371 return Err(Error::ProtocolError("unexpected CONNACK packet"));
372 }
373 rx::Packet::Publish {
374 topic,
375 packet_id,
376 qos,
377 retain,
378 dup: _dup,
379 data_len,
380 } => {
381 return Ok(Packet::Publish(Publish {
382 topic,
383 packet_id,
384 qos,
385 retain,
386 data_len,
387 }));
388 }
389 rx::Packet::PubAck { packet_id } => {
390 return Ok(Packet::PublishAck(PublishAck { packet_id }));
391 }
392 rx::Packet::SubscribeAck { packet_id, success } => {
393 return Ok(Packet::SubscribeAck(SubscribeAck { packet_id, success }));
394 }
395 rx::Packet::UnsubscribeAck { packet_id } => {
396 return Ok(Packet::UnsubscribeAck(UnsubscribeAck { packet_id }));
397 }
398 }
399 }
400
401 pub async fn receive_data(&mut self, buf: &mut [u8]) -> Result<(), Error<T>> {
402 self.skip_if_required();
403 if buf.len() != self.rx_buffer_data {
404 return Err(Error::InsufficientBufferSpace);
405 }
406
407 assert_eq!(self.rx_buffer_skip, 0);
408 let from_buffer = self.rx_buffer_data.min(self.rx_buffer_len);
409 let from_transport = self.rx_buffer_data.strict_sub(from_buffer);
410
411 buf[..from_buffer].copy_from_slice(&self.rx_buffer[..from_buffer]);
412 self.rx_buffer_len -= from_buffer;
413
414 if from_transport > 0 {
415 assert_eq!(self.rx_buffer_len, 0);
416 self.transport
417 .read_exact(&mut buf[from_buffer..])
418 .await
419 .map_err(|err| match err {
420 ReadExactError::UnexpectedEof => Error::<T>::TransportEOF,
421 ReadExactError::Other(e) => Error::Transport(e),
422 })?;
423 }
424 self.rx_buffer_data = 0;
425
426 Ok(())
427 }
428
429 pub async fn disconnect(&mut self) -> Result<(), Error<T>> {
430 let mut buffer = FieldBuffer::default();
431 tx::disconnect(&mut buffer);
432
433 self.transport
434 .write_fields(&buffer)
435 .await
436 .map_err(Error::Transport)?;
437 self.transport.flush().await.map_err(Error::Transport)?;
438
439 Ok(())
440 }
441
442 async fn fill_rx_buffer(&mut self) -> Result<(), Error<T>> {
443 let n = self
444 .transport
445 .read(&mut self.rx_buffer[self.rx_buffer_len..])
446 .await
447 .map_err(Error::Transport)?;
448 if n == 0 {
449 return Err(Error::TransportEOF);
450 }
451 self.rx_buffer_len += n;
452
453 Ok(())
454 }
455
456 fn skip_if_required(&mut self) {
457 assert!(self.rx_buffer_len >= self.rx_buffer_skip);
458 if self.rx_buffer_skip != 0 {
459 self.rx_buffer.copy_within(self.rx_buffer_skip.., 0);
460 self.rx_buffer_len = self.rx_buffer_len.strict_sub(self.rx_buffer_skip);
461 self.rx_buffer_skip = 0;
462 }
463 }
464
465 async fn discard_data(&mut self) -> Result<(), Error<T>> {
466 if self.rx_buffer_data == 0 {
467 return Ok(());
468 }
469
470 assert_eq!(self.rx_buffer_skip, 0);
471 while self.rx_buffer_data > 0 {
472 if self.rx_buffer_len <= self.rx_buffer_data {
473 self.rx_buffer_data -= self.rx_buffer_len;
474 self.rx_buffer_len = 0;
475 } else {
476 self.rx_buffer.copy_within(self.rx_buffer_data.., 0);
477 self.rx_buffer_len -= self.rx_buffer_data;
478 self.rx_buffer_data = 0;
479 }
480 self.fill_rx_buffer().await?;
481 }
482
483 Ok(())
484 }
485}
diff --git a/src/mqtt/packet_id.rs b/src/mqtt/packet_id.rs
new file mode 100644
index 0000000..4e0158f
--- /dev/null
+++ b/src/mqtt/packet_id.rs
@@ -0,0 +1,14 @@
1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2pub struct PacketId(u16);
3
4impl From<u16> for PacketId {
5 fn from(value: u16) -> Self {
6 Self(value)
7 }
8}
9
10impl From<PacketId> for u16 {
11 fn from(value: PacketId) -> Self {
12 value.0
13 }
14}
diff --git a/src/mqtt/protocol.rs b/src/mqtt/protocol.rs
new file mode 100644
index 0000000..bf77d78
--- /dev/null
+++ b/src/mqtt/protocol.rs
@@ -0,0 +1,65 @@
1pub const PACKET_TYPE_CONNECT: u8 = 1;
2pub const PACKET_TYPE_CONNACK: u8 = 2;
3pub const PACKET_TYPE_PUBLISH: u8 = 3;
4pub const PACKET_TYPE_PUBACK: u8 = 4;
5pub const PACKET_TYPE_PUBREC: u8 = 5;
6pub const PACKET_TYPE_PUBREL: u8 = 6;
7pub const PACKET_TYPE_PUBCOMP: u8 = 7;
8pub const PACKET_TYPE_SUBSCRIBE: u8 = 8;
9pub const PACKET_TYPE_SUBACK: u8 = 9;
10pub const PACKET_TYPE_UNSUBSCRIBE: u8 = 10;
11pub const PACKET_TYPE_UNSUBACK: u8 = 11;
12pub const PACKET_TYPE_PINGREQ: u8 = 12;
13pub const PACKET_TYPE_PINGRESP: u8 = 13;
14pub const PACKET_TYPE_DISCONNECT: u8 = 14;
15
16pub const PROTOCOL_NAME: &str = "MQTT";
17
18pub const PROTOCOL_LEVEL_3_1_1: u8 = 0x04;
19pub const PROTOCOL_LEVEL_5_0_0: u8 = 0x05;
20
21pub const CONNECT_FLAG_USERNAME: u8 = 1 << 7;
22pub const CONNECT_FLAG_PASSWORD: u8 = 1 << 6;
23pub const CONNECT_FLAG_WILL_RETAIN: u8 = 1 << 5;
24pub const CONNECT_FLAG_WILL_FLAG: u8 = 1 << 2;
25pub const CONNECT_FLAG_CLEAN_SESSION: u8 = 1 << 1;
26
27pub const SUBSCRIBE_HEADER_FLAGS: u8 = 0x02;
28pub const UNSUBSCRIBE_HEADER_FLAGS: u8 = 0x02;
29pub const PUBREL_HEADER_FLAGS: u8 = 0x02;
30
31pub const CONNACK_CODE_ACCEPTED: u8 = 0;
32pub const CONNACK_CODE_UNACCEPTABLE_PROTOCOL_VERSION: u8 = 1;
33pub const CONNACK_CODE_IDENTIFIER_REJECTED: u8 = 2;
34pub const CONNACK_CODE_SERVER_UNAVAILABLE: u8 = 3;
35pub const CONNACK_CODE_BAD_USERNAME_PASSWORD: u8 = 4;
36pub const CONNACK_CODE_NOT_AUTHORIZED: u8 = 5;
37
38pub const CONNACK_FLAG_SESSION_PRESENT: u8 = 0x01;
39pub const CONNACK_FLAG_RESERVED: u8 = 0xFE;
40
41pub const SUBACK_FAILURE: u8 = 0x80;
42
43pub const PUBLISH_FLAG_RETAIN: u8 = 0x01;
44pub const PUBLISH_FLAG_QOS_MASK: u8 = 0x06;
45pub const PUBLISH_FLAG_QOS_SHIFT: u8 = 1;
46pub const PUBLISH_FLAG_DUP: u8 = 0x08;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub struct HeaderControl {
50 pub packet_type: u8,
51 pub packet_flags: u8,
52}
53
54pub fn create_header_control(packet_type: u8, flags: u8) -> u8 {
55 assert!(packet_type & 0xF0 == 0);
56 assert!(flags & 0xF0 == 0);
57 packet_type << 4 | flags
58}
59
60pub fn split_header_control(control: u8) -> HeaderControl {
61 HeaderControl {
62 packet_type: control >> 4,
63 packet_flags: control & 0x0F,
64 }
65}
diff --git a/src/mqtt/qos.rs b/src/mqtt/qos.rs
new file mode 100644
index 0000000..0d464b4
--- /dev/null
+++ b/src/mqtt/qos.rs
@@ -0,0 +1,64 @@
1#[derive(Debug)]
2pub struct InvalidQos(u8);
3
4impl core::fmt::Display for InvalidQos {
5 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
6 write!(f, "invalid QoS value: '{}'", self.0)
7 }
8}
9
10impl core::error::Error for InvalidQos {}
11
12#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum Qos {
14 #[default]
15 AtMostOnce,
16 AtLeastOnce,
17 ExactlyOnce,
18}
19
20impl core::fmt::Display for Qos {
21 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22 f.write_str(match self {
23 Qos::AtMostOnce => "QoS::AtMostOnce",
24 Qos::AtLeastOnce => "Qos::AtLeastOnce",
25 Qos::ExactlyOnce => "Qos::ExactlyOnce",
26 })
27 }
28}
29
30impl Qos {
31 pub fn to_u8(&self) -> u8 {
32 match self {
33 Qos::AtMostOnce => 0,
34 Qos::AtLeastOnce => 1,
35 Qos::ExactlyOnce => 2,
36 }
37 }
38
39 pub fn from_u8(v: u8) -> Option<Self> {
40 match v {
41 0 => Some(Self::AtMostOnce),
42 1 => Some(Self::AtLeastOnce),
43 2 => Some(Self::ExactlyOnce),
44 _ => None,
45 }
46 }
47}
48
49impl TryFrom<u8> for Qos {
50 type Error = InvalidQos;
51
52 fn try_from(value: u8) -> Result<Self, Self::Error> {
53 match Self::from_u8(value) {
54 Some(v) => Ok(v),
55 None => Err(InvalidQos(value)),
56 }
57 }
58}
59
60impl From<Qos> for u8 {
61 fn from(value: Qos) -> Self {
62 value.to_u8()
63 }
64}
diff --git a/src/mqtt/rx.rs b/src/mqtt/rx.rs
new file mode 100644
index 0000000..e81171d
--- /dev/null
+++ b/src/mqtt/rx.rs
@@ -0,0 +1,241 @@
1use super::{ConnectCode, PacketId, Qos, protocol, varint};
2
3#[derive(Debug)]
4pub enum Error {
5 NeedMoreData,
6 InvalidPacket(&'static str),
7 UnsupportedPacket { packet_type: u8, packet_len: u32 },
8 UnknownPacket { packet_type: u8, packet_len: u32 },
9}
10
11impl core::fmt::Display for Error {
12 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13 match self {
14 Error::NeedMoreData => f.write_str("need more data"),
15 Error::InvalidPacket(msg) => write!(f, "invalid packet: {}", msg),
16 Error::UnsupportedPacket {
17 packet_type,
18 packet_len,
19 } => write!(
20 f,
21 "unsupported packet type {} with length {}",
22 packet_type, packet_len
23 ),
24 Error::UnknownPacket {
25 packet_type,
26 packet_len,
27 } => write!(
28 f,
29 "unknown packet type {} with length {}",
30 packet_type, packet_len
31 ),
32 }
33 }
34}
35
36impl From<varint::Error> for Error {
37 fn from(value: varint::Error) -> Self {
38 match value {
39 varint::Error::NeedMoreData => Self::NeedMoreData,
40 varint::Error::InvalidVarInt => Self::InvalidPacket("invalid variable integer encoding"),
41 }
42 }
43}
44
45pub enum Packet<'a> {
46 ConnAck {
47 session_present: bool,
48 code: ConnectCode,
49 },
50 Publish {
51 topic: &'a str,
52 packet_id: Option<PacketId>,
53 qos: Qos,
54 retain: bool,
55 dup: bool,
56 data_len: usize,
57 },
58 PubAck {
59 packet_id: PacketId,
60 },
61 SubscribeAck {
62 packet_id: PacketId,
63 success: bool,
64 },
65 UnsubscribeAck {
66 packet_id: PacketId,
67 },
68}
69
70pub fn decode<'a>(buf: &'a [u8]) -> Result<(Packet<'a>, usize), Error> {
71 let mut reader = Reader::new(buf);
72 let protocol::HeaderControl {
73 packet_type,
74 packet_flags,
75 } = protocol::split_header_control(reader.read_u8()?);
76 let packet_len = reader.read_varint()?;
77
78 let packet = match packet_type {
79 protocol::PACKET_TYPE_CONNACK => {
80 let flags = reader.read_u8()?;
81 let code = ConnectCode::from(reader.read_u8()?);
82 let session_present = flags & protocol::CONNACK_FLAG_SESSION_PRESENT != 0;
83 if flags & protocol::CONNACK_FLAG_RESERVED != 0 {
84 return Err(Error::InvalidPacket("CONNACK reserved flags must be zero"));
85 }
86 Packet::ConnAck {
87 session_present,
88 code,
89 }
90 }
91 protocol::PACKET_TYPE_PUBLISH => {
92 // Extract flags from the fixed header
93 let retain = (packet_flags & protocol::PUBLISH_FLAG_RETAIN) != 0;
94 let qos_value = (packet_flags & protocol::PUBLISH_FLAG_QOS_MASK) >> protocol::PUBLISH_FLAG_QOS_SHIFT;
95 let qos = Qos::from_u8(qos_value).ok_or(Error::InvalidPacket("PUBLISH has invalid QoS value"))?;
96 let dup = (packet_flags & protocol::PUBLISH_FLAG_DUP) != 0;
97
98 // Track position after fixed header to calculate data length
99 let variable_header_start = reader.num_read();
100
101 // Read topic name
102 let topic = reader.read_len_prefix_str()?;
103
104 // Read packet ID if QoS > 0
105 let packet_id = if qos.to_u8() > 0 {
106 Some(PacketId::from(reader.read_u16()?))
107 } else {
108 None
109 };
110
111 // Calculate payload length without reading it
112 let variable_header_len = reader.num_read() - variable_header_start;
113 let data_len = (packet_len as usize)
114 .checked_sub(variable_header_len)
115 .ok_or(Error::InvalidPacket("PUBLISH remaining length is too short for headers"))?;
116
117 Packet::Publish {
118 topic,
119 packet_id,
120 qos,
121 retain,
122 dup,
123 data_len,
124 }
125 }
126 protocol::PACKET_TYPE_PUBACK => {
127 if packet_flags != 0 {
128 return Err(Error::InvalidPacket("PUBACK flags must be zero"));
129 }
130 if packet_len != 2 {
131 return Err(Error::InvalidPacket("PUBACK remaining length must be 2"));
132 }
133 let packet_id = PacketId::from(reader.read_u16()?);
134 Packet::PubAck { packet_id }
135 }
136 protocol::PACKET_TYPE_SUBACK => {
137 if packet_flags != 0 {
138 return Err(Error::InvalidPacket("SUBACK flags must be zero"));
139 }
140 if packet_len < 3 {
141 // Minimum: 2 bytes packet ID + 1 byte return code
142 return Err(Error::InvalidPacket("SUBACK remaining length must be at least 3"));
143 }
144 let packet_id = PacketId::from(reader.read_u16()?);
145 let return_code = reader.read_u8()?;
146 let success = return_code != protocol::SUBACK_FAILURE;
147 Packet::SubscribeAck { packet_id, success }
148 }
149 protocol::PACKET_TYPE_UNSUBACK => {
150 if packet_flags != 0 {
151 return Err(Error::InvalidPacket("UNSUBACK flags must be zero"));
152 }
153 if packet_len != 2 {
154 return Err(Error::InvalidPacket("UNSUBACK remaining length must be 2"));
155 }
156 let packet_id = PacketId::from(reader.read_u16()?);
157 Packet::UnsubscribeAck { packet_id }
158 }
159 protocol::PACKET_TYPE_CONNECT
160 | protocol::PACKET_TYPE_PUBREC
161 | protocol::PACKET_TYPE_PUBREL
162 | protocol::PACKET_TYPE_PUBCOMP
163 | protocol::PACKET_TYPE_DISCONNECT
164 | protocol::PACKET_TYPE_SUBSCRIBE
165 | protocol::PACKET_TYPE_UNSUBSCRIBE
166 | protocol::PACKET_TYPE_PINGREQ
167 | protocol::PACKET_TYPE_PINGRESP => {
168 return Err(Error::UnsupportedPacket {
169 packet_type,
170 packet_len,
171 });
172 }
173 _ => {
174 return Err(Error::UnknownPacket {
175 packet_type,
176 packet_len,
177 });
178 }
179 };
180
181 Ok((packet, reader.num_read()))
182}
183
184struct Reader<'a> {
185 buf: &'a [u8],
186 off: usize,
187}
188
189impl<'a> Reader<'a> {
190 fn new(buf: &'a [u8]) -> Self {
191 Self { buf, off: 0 }
192 }
193
194 fn remain(&self) -> usize {
195 self.buf.len() - self.off
196 }
197
198 fn remain_slice(&self) -> &'a [u8] {
199 &self.buf[self.off..]
200 }
201
202 fn num_read(&self) -> usize {
203 self.off
204 }
205
206 fn read_buf(&mut self, n: usize) -> Result<&'a [u8], Error> {
207 if self.remain() < n {
208 return Err(Error::NeedMoreData);
209 }
210 let v = &self.buf[self.off..self.off + n];
211 self.off += n;
212 Ok(v)
213 }
214
215 fn read_u8(&mut self) -> Result<u8, Error> {
216 let v = self.read_buf(1)?;
217 Ok(v[0])
218 }
219
220 fn read_u16(&mut self) -> Result<u16, Error> {
221 let v = self.read_buf(2)?;
222 Ok(u16::from_be_bytes([v[0], v[1]]))
223 }
224
225 fn read_len_prefix_buf(&mut self) -> Result<&'a [u8], Error> {
226 let l = self.read_u16()?;
227 let v = self.read_buf(usize::from(l))?;
228 Ok(v)
229 }
230
231 fn read_len_prefix_str(&mut self) -> Result<&'a str, Error> {
232 let v = self.read_len_prefix_buf()?;
233 Ok(str::from_utf8(v).unwrap())
234 }
235
236 fn read_varint(&mut self) -> Result<u32, Error> {
237 let (value, consumed) = varint::decode(self.remain_slice())?;
238 self.off += consumed;
239 Ok(value)
240 }
241}
diff --git a/src/mqtt/transport.rs b/src/mqtt/transport.rs
new file mode 100644
index 0000000..780b0ba
--- /dev/null
+++ b/src/mqtt/transport.rs
@@ -0,0 +1,39 @@
1use super::{field::Field, varint};
2
3pub trait Transport: embedded_io_async::Read + embedded_io_async::Write {}
4
5impl<T> Transport for T where T: embedded_io_async::Read + embedded_io_async::Write {}
6
7pub(crate) trait TransportExt: Transport {
8 async fn write_fields(&mut self, fields: &[Field]) -> Result<(), Self::Error>;
9}
10
11impl<T> TransportExt for T
12where
13 T: Transport,
14{
15 async fn write_fields(&mut self, fields: &[Field<'_>]) -> Result<(), Self::Error> {
16 for field in fields {
17 match field {
18 Field::U8(v) => self.write_all(&[*v]).await?,
19 Field::U16(v) => self.write_all(&u16::to_be_bytes(*v)).await?,
20 Field::VarInt(v) => {
21 let (v_buf, v_len) = varint::encode(*v);
22 self.write_all(&v_buf[..v_len]).await?;
23 }
24 Field::Buffer(v) => self.write_all(v).await?,
25 Field::LenPrefixedBuffer(v) => {
26 self.write_all(&u16::to_be_bytes(u16::try_from(v.len()).unwrap()))
27 .await?;
28 self.write_all(v).await?;
29 }
30 Field::LenPrefixedString(v) => {
31 self.write_all(&u16::to_be_bytes(u16::try_from(v.len()).unwrap()))
32 .await?;
33 self.write_all(v.as_bytes()).await?;
34 }
35 }
36 }
37 Ok(())
38 }
39}
diff --git a/src/mqtt/tx.rs b/src/mqtt/tx.rs
new file mode 100644
index 0000000..cdf1c75
--- /dev/null
+++ b/src/mqtt/tx.rs
@@ -0,0 +1,203 @@
1use super::{
2 PacketId,
3 field::{self, Field, FieldBuffer},
4 protocol,
5 qos::Qos,
6};
7
8pub struct Connect<'a> {
9 pub client_id: &'a str,
10 pub clean_session: bool,
11 pub username: Option<&'a str>,
12 pub password: Option<&'a [u8]>,
13 pub will_topic: Option<&'a str>,
14 pub will_payload: Option<&'a [u8]>,
15 pub will_retain: bool,
16 pub keepalive: Option<u16>,
17}
18
19pub fn connect<'a>(buffer: &mut FieldBuffer<'a>, connect: Connect<'a>) {
20 let mut flags = 0;
21 if connect.clean_session {
22 flags |= protocol::CONNECT_FLAG_CLEAN_SESSION;
23 }
24 if connect.username.is_some() {
25 flags |= protocol::CONNECT_FLAG_USERNAME;
26 }
27 if connect.password.is_some() {
28 flags |= protocol::CONNECT_FLAG_PASSWORD;
29 }
30 if connect.will_topic.is_some() {
31 flags |= protocol::CONNECT_FLAG_WILL_FLAG;
32 }
33 if connect.will_retain {
34 flags |= protocol::CONNECT_FLAG_WILL_RETAIN;
35 }
36
37 buffer.push(Field::U8(protocol::create_header_control(
38 protocol::PACKET_TYPE_CONNECT,
39 0,
40 )));
41 buffer.push(Field::VarInt(0));
42
43 buffer.push(Field::LenPrefixedString(protocol::PROTOCOL_NAME));
44 buffer.push(Field::U8(protocol::PROTOCOL_LEVEL_3_1_1));
45 buffer.push(Field::U8(flags));
46 buffer.push(Field::U16(connect.keepalive.unwrap_or(0)));
47 buffer.push(Field::LenPrefixedString(connect.client_id));
48 if let Some(will_topic) = connect.will_topic {
49 buffer.push(Field::LenPrefixedString(will_topic));
50 buffer.push(Field::LenPrefixedBuffer(
51 connect.will_payload.unwrap_or(&[]),
52 ));
53 }
54 if let Some(username) = connect.username {
55 buffer.push(Field::LenPrefixedString(username));
56 }
57 if let Some(password) = connect.password {
58 buffer.push(Field::LenPrefixedBuffer(password));
59 }
60
61 let message_size = field::fields_size(&buffer.as_slice()[2..]);
62 buffer.set(1, Field::VarInt(u32::try_from(message_size).unwrap()));
63}
64
65pub struct Publish<'a> {
66 pub topic: &'a str,
67 pub payload: &'a [u8],
68 pub qos: Qos,
69 pub retain: bool,
70 pub dup: bool,
71 pub packet_id: Option<PacketId>,
72}
73
74pub fn publish<'a>(buffer: &mut FieldBuffer<'a>, publish: Publish<'a>) {
75 let mut flags = 0u8;
76
77 // Set QoS bits (bits 1-2)
78 flags |= (publish.qos.to_u8() & 0x03) << 1;
79
80 // Set RETAIN flag (bit 0)
81 if publish.retain {
82 flags |= 0x01;
83 }
84
85 // Set DUP flag (bit 3)
86 if publish.dup {
87 flags |= 0x08;
88 }
89
90 buffer.push(Field::U8(protocol::create_header_control(
91 protocol::PACKET_TYPE_PUBLISH,
92 flags,
93 )));
94 buffer.push(Field::VarInt(0));
95
96 buffer.push(Field::LenPrefixedString(publish.topic));
97
98 // Packet ID is only present for QoS 1 and 2
99 if publish.qos.to_u8() > 0 {
100 // TODO: turn this into a warning
101 let packet_id = publish.packet_id.expect("packet_id required for QoS > 0");
102 buffer.push(Field::U16(packet_id.into()));
103 }
104
105 buffer.push(Field::Buffer(publish.payload));
106
107 let message_size = field::fields_size(&buffer.as_slice()[2..]);
108 buffer.set(1, Field::VarInt(u32::try_from(message_size).unwrap()));
109}
110
111pub struct Subscribe<'a> {
112 pub topic: &'a str,
113 pub qos: Qos,
114 pub packet_id: PacketId,
115}
116
117pub fn subscribe<'a>(buffer: &mut FieldBuffer<'a>, subscribe: Subscribe<'a>) {
118 // SUBSCRIBE packets have fixed header flags (reserved bits)
119 buffer.push(Field::U8(protocol::create_header_control(
120 protocol::PACKET_TYPE_SUBSCRIBE,
121 protocol::SUBSCRIBE_HEADER_FLAGS,
122 )));
123 buffer.push(Field::VarInt(0));
124
125 // Variable header: packet identifier
126 buffer.push(Field::U16(subscribe.packet_id.into()));
127
128 // Payload: topic filter + QoS
129 buffer.push(Field::LenPrefixedString(subscribe.topic));
130 buffer.push(Field::U8(subscribe.qos.to_u8()));
131
132 let message_size = field::fields_size(&buffer.as_slice()[2..]);
133 buffer.set(1, Field::VarInt(u32::try_from(message_size).unwrap()));
134}
135
136pub struct Unsubscribe<'a> {
137 pub topic: &'a str,
138 pub packet_id: PacketId,
139}
140
141pub fn unsubscribe<'a>(buffer: &mut FieldBuffer<'a>, unsubscribe: Unsubscribe<'a>) {
142 // UNSUBSCRIBE packets have fixed header flags (reserved bits)
143 buffer.push(Field::U8(protocol::create_header_control(
144 protocol::PACKET_TYPE_UNSUBSCRIBE,
145 protocol::UNSUBSCRIBE_HEADER_FLAGS,
146 )));
147 buffer.push(Field::VarInt(0));
148
149 // Variable header: packet identifier
150 buffer.push(Field::U16(unsubscribe.packet_id.into()));
151
152 // Payload: topic filter (no QoS)
153 buffer.push(Field::LenPrefixedString(unsubscribe.topic));
154
155 let message_size = field::fields_size(&buffer.as_slice()[2..]);
156 buffer.set(1, Field::VarInt(u32::try_from(message_size).unwrap()));
157}
158
159pub fn disconnect(buffer: &mut FieldBuffer) {
160 // DISCONNECT has no variable header or payload
161 buffer.push(Field::U8(protocol::create_header_control(
162 protocol::PACKET_TYPE_DISCONNECT,
163 0,
164 )));
165 buffer.push(Field::VarInt(0));
166}
167
168pub fn puback(buffer: &mut FieldBuffer, packet_id: PacketId) {
169 buffer.push(Field::U8(protocol::create_header_control(
170 protocol::PACKET_TYPE_PUBACK,
171 0,
172 )));
173 buffer.push(Field::VarInt(2)); // Remaining length is always 2 (packet ID)
174 buffer.push(Field::U16(packet_id.into()));
175}
176
177pub fn pubrec(buffer: &mut FieldBuffer, packet_id: PacketId) {
178 buffer.push(Field::U8(protocol::create_header_control(
179 protocol::PACKET_TYPE_PUBREC,
180 0,
181 )));
182 buffer.push(Field::VarInt(2)); // Remaining length is always 2 (packet ID)
183 buffer.push(Field::U16(packet_id.into()));
184}
185
186pub fn pubrel(buffer: &mut FieldBuffer, packet_id: PacketId) {
187 buffer.push(Field::U8(protocol::create_header_control(
188 protocol::PACKET_TYPE_PUBREL,
189 protocol::PUBREL_HEADER_FLAGS,
190 )));
191 buffer.push(Field::VarInt(2)); // Remaining length is always 2 (packet ID)
192 buffer.push(Field::U16(packet_id.into()));
193}
194
195pub fn pubcomp(buffer: &mut FieldBuffer, packet_id: PacketId) {
196 buffer.push(Field::U8(protocol::create_header_control(
197 protocol::PACKET_TYPE_PUBCOMP,
198 0,
199 )));
200 buffer.push(Field::VarInt(2)); // Remaining length is always 2 (packet ID)
201 buffer.push(Field::U16(packet_id.into()));
202}
203
diff --git a/src/mqtt/varint.rs b/src/mqtt/varint.rs
new file mode 100644
index 0000000..63bdd06
--- /dev/null
+++ b/src/mqtt/varint.rs
@@ -0,0 +1,69 @@
1#[derive(Debug)]
2pub enum Error {
3 NeedMoreData,
4 InvalidVarInt,
5}
6
7impl core::fmt::Display for Error {
8 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
9 match self {
10 Error::NeedMoreData => f.write_str("NeedMoreData"),
11 Error::InvalidVarInt => f.write_str("InvalidVarInt"),
12 }
13 }
14}
15
16impl core::error::Error for Error {}
17
18pub fn encode(mut v: u32) -> ([u8; 4], usize) {
19 let mut encoded = [0u8; 4];
20 let mut count = 0;
21
22 loop {
23 let mut byte = (v % 128) as u8;
24 v /= 128;
25
26 if v > 0 {
27 byte |= 0x80; // Set continuation bit
28 }
29
30 encoded[count] = byte;
31 count += 1;
32
33 if v == 0 {
34 break;
35 }
36 }
37
38 (encoded, count)
39}
40
41pub fn decode(buf: &[u8]) -> Result<(u32, usize), Error> {
42 let mut value = 0u32;
43
44 let v = buf.get(0).ok_or(Error::NeedMoreData)?;
45 value |= ((v & 0x7F) as u32) << 0;
46 if v & 0x80 == 0 {
47 return Ok((value, 1));
48 }
49
50 let v = buf.get(1).ok_or(Error::NeedMoreData)?;
51 value |= ((v & 0x7F) as u32) << 7;
52 if v & 0x80 == 0 {
53 return Ok((value, 2));
54 }
55
56 let v = buf.get(2).ok_or(Error::NeedMoreData)?;
57 value |= ((v & 0x7F) as u32) << 14;
58 if v & 0x80 == 0 {
59 return Ok((value, 3));
60 }
61
62 let v = buf.get(3).ok_or(Error::NeedMoreData)?;
63 value |= ((v & 0x7F) as u32) << 21;
64 if v & 0x80 != 0 {
65 return Err(Error::InvalidVarInt);
66 }
67
68 Ok((value, 4))
69}