summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2024-04-04 12:34:44 +0200
committerGitHub <noreply@github.com>2024-04-04 12:34:44 +0200
commit04b96d6c1c43f90c3cb78cf6cc91c1c3771a4041 (patch)
tree990d77b90e1a241f4dc2f1d383e4d2b9a67af17c
parentde260daf73add9c28ceedaef1345a250c241b1de (diff)
parent3e8d0ace867ff6894e16442e42febe31aa1f3cf8 (diff)
Merge pull request #272 from matthiasbeyer/packetidentifier-type
Dedicated PacketIdentifier{,NonZero} type
-rw-r--r--src/client/receive.rs14
-rw-r--r--src/client/send.rs27
-rw-r--r--src/client/state.rs15
-rw-r--r--src/lib.rs1
-rw-r--r--src/packet_identifier.rs66
5 files changed, 99 insertions, 24 deletions
diff --git a/src/client/receive.rs b/src/client/receive.rs
index 0231fd7..3ebf76b 100644
--- a/src/client/receive.rs
+++ b/src/client/receive.rs
@@ -4,7 +4,6 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
-use std::num::NonZeroU16;
use std::sync::Arc;
use futures::lock::Mutex;
@@ -16,6 +15,7 @@ use yoke::Yoke;
use super::InnerClient;
use crate::codecs::MqttPacketCodec;
+use crate::packet_identifier::PacketIdentifierNonZero;
use crate::packets::MqttPacket;
use crate::packets::MqttWriter;
use crate::packets::StableBytes;
@@ -148,9 +148,9 @@ async fn handle_pubcomp(
tracing::error!("No session state found");
todo!()
};
- let pident = NonZeroU16::try_from(pubcomp.packet_identifier.0)
+ let pident = PacketIdentifierNonZero::try_from(pubcomp.packet_identifier)
.expect("zero PacketIdentifier not valid here");
- tracing::Span::current().record("packet_identifier", pident);
+ tracing::Span::current().record("packet_identifier", tracing::field::display(pident));
if session_state
.outstanding_packets
@@ -189,9 +189,9 @@ async fn handle_puback(
todo!()
};
- let pident = std::num::NonZeroU16::try_from(mpuback.packet_identifier.0)
+ let pident = PacketIdentifierNonZero::try_from(mpuback.packet_identifier)
.expect("Zero PacketIdentifier not valid here");
- tracing::Span::current().record("packet_identifier", pident);
+ tracing::Span::current().record("packet_identifier", tracing::field::display(pident));
if session_state
.outstanding_packets
@@ -236,9 +236,9 @@ async fn handle_pubrec(
tracing::error!("No session state found");
todo!()
};
- let pident = NonZeroU16::try_from(pubrec.packet_identifier.0)
+ let pident = PacketIdentifierNonZero::try_from(pubrec.packet_identifier)
.expect("zero PacketIdentifier not valid here");
- tracing::Span::current().record("packet_identifier", pident);
+ tracing::Span::current().record("packet_identifier", tracing::field::display(pident));
if session_state
.outstanding_packets
diff --git a/src/client/send.rs b/src/client/send.rs
index 052d0a3..91f9564 100644
--- a/src/client/send.rs
+++ b/src/client/send.rs
@@ -15,6 +15,7 @@ use tracing::Instrument;
use super::state::OutstandingPackets;
use super::MqttClient;
+use crate::packet_identifier::PacketIdentifierNonZero;
use crate::packets::MqttPacket;
use crate::payload::MqttPayload;
use crate::qos::QualityOfService;
@@ -188,11 +189,11 @@ impl MqttClient {
fn get_next_packet_ident(
next_packet_ident: &mut std::num::NonZeroU16,
outstanding_packets: &OutstandingPackets,
-) -> Result<std::num::NonZeroU16, PacketIdentifierExhausted> {
+) -> Result<PacketIdentifierNonZero, PacketIdentifierExhausted> {
let start = *next_packet_ident;
loop {
- let next = *next_packet_ident;
+ let next = PacketIdentifierNonZero::from(*next_packet_ident);
if !outstanding_packets.exists_outstanding_packet(next) {
return Ok(next);
@@ -229,9 +230,9 @@ pub(crate) enum Acknowledge {
pub(crate) struct Callbacks {
ping_req: VecDeque<futures::channel::oneshot::Sender<()>>,
- qos1: HashMap<NonZeroU16, Qos1Callbacks>,
- qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>,
- qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>,
+ qos1: HashMap<PacketIdentifierNonZero, Qos1Callbacks>,
+ qos2_receive: HashMap<PacketIdentifierNonZero, Qos2ReceiveCallback>,
+ qos2_complete: HashMap<PacketIdentifierNonZero, Qos2CompleteCallback>,
}
impl Callbacks {
@@ -248,13 +249,13 @@ impl Callbacks {
self.ping_req.push_back(cb);
}
- pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) {
+ pub(crate) fn add_qos1(&mut self, id: PacketIdentifierNonZero, cb: Qos1Callbacks) {
self.qos1.insert(id, cb);
}
pub(crate) fn add_qos2(
&mut self,
- id: NonZeroU16,
+ id: PacketIdentifierNonZero,
rec: Qos2ReceiveCallback,
comp: Qos2CompleteCallback,
) {
@@ -266,15 +267,21 @@ impl Callbacks {
self.ping_req.pop_front()
}
- pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> {
+ pub(crate) fn take_qos1(&mut self, id: PacketIdentifierNonZero) -> Option<Qos1Callbacks> {
self.qos1.remove(&id)
}
- pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> {
+ pub(crate) fn take_qos2_receive(
+ &mut self,
+ id: PacketIdentifierNonZero,
+ ) -> Option<Qos2ReceiveCallback> {
self.qos2_receive.remove(&id)
}
- pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> {
+ pub(crate) fn take_qos2_complete(
+ &mut self,
+ id: PacketIdentifierNonZero,
+ ) -> Option<Qos2CompleteCallback> {
self.qos2_complete.remove(&id)
}
}
diff --git a/src/client/state.rs b/src/client/state.rs
index bbd9714..d44b182 100644
--- a/src/client/state.rs
+++ b/src/client/state.rs
@@ -10,6 +10,7 @@ use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use crate::codecs::MqttPacketCodec;
+use crate::packet_identifier::PacketIdentifierNonZero;
use crate::string::MqttString;
use crate::transport::MqttConnection;
@@ -35,9 +36,9 @@ pub(super) struct SessionState {
}
pub(super) struct OutstandingPackets {
- pub(super) packet_ident_order: Vec<std::num::NonZeroU16>,
+ pub(super) packet_ident_order: Vec<PacketIdentifierNonZero>,
pub(super) outstanding_packets:
- std::collections::BTreeMap<std::num::NonZeroU16, crate::packets::MqttPacket>,
+ std::collections::BTreeMap<PacketIdentifierNonZero, crate::packets::MqttPacket>,
}
impl OutstandingPackets {
@@ -48,7 +49,7 @@ impl OutstandingPackets {
}
}
- pub fn insert(&mut self, ident: std::num::NonZeroU16, packet: crate::packets::MqttPacket) {
+ pub fn insert(&mut self, ident: PacketIdentifierNonZero, packet: crate::packets::MqttPacket) {
debug_assert_eq!(
self.packet_ident_order.len(),
self.outstanding_packets.len()
@@ -62,7 +63,7 @@ impl OutstandingPackets {
pub fn update_by_id(
&mut self,
- ident: std::num::NonZeroU16,
+ ident: PacketIdentifierNonZero,
packet: crate::packets::MqttPacket,
) {
debug_assert_eq!(
@@ -75,19 +76,19 @@ impl OutstandingPackets {
debug_assert!(removed.is_some());
}
- pub fn exists_outstanding_packet(&self, ident: std::num::NonZeroU16) -> bool {
+ pub fn exists_outstanding_packet(&self, ident: PacketIdentifierNonZero) -> bool {
self.outstanding_packets.contains_key(&ident)
}
pub fn iter_in_send_order(
&self,
- ) -> impl Iterator<Item = (std::num::NonZeroU16, &crate::packets::MqttPacket)> {
+ ) -> impl Iterator<Item = (PacketIdentifierNonZero, &crate::packets::MqttPacket)> {
self.packet_ident_order
.iter()
.flat_map(|id| self.outstanding_packets.get(id).map(|p| (*id, p)))
}
- pub fn remove_by_id(&mut self, id: std::num::NonZeroU16) {
+ pub fn remove_by_id(&mut self, id: PacketIdentifierNonZero) {
// Vec::retain() preserves order
self.packet_ident_order.retain(|&elm| elm != id);
self.outstanding_packets.remove(&id);
diff --git a/src/lib.rs b/src/lib.rs
index 929827c..92fcdc2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,6 +12,7 @@ pub mod client_identifier;
mod codecs;
mod error;
pub mod keep_alive;
+pub mod packet_identifier;
pub mod packets;
pub mod payload;
mod properties;
diff --git a/src/packet_identifier.rs b/src/packet_identifier.rs
new file mode 100644
index 0000000..3b4c680
--- /dev/null
+++ b/src/packet_identifier.rs
@@ -0,0 +1,66 @@
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+//
+
+#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
+pub struct PacketIdentifier(u16);
+
+impl std::fmt::Display for PacketIdentifier {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl From<mqtt_format::v5::variable_header::PacketIdentifier> for PacketIdentifier {
+ fn from(value: mqtt_format::v5::variable_header::PacketIdentifier) -> Self {
+ Self(value.0)
+ }
+}
+
+impl From<PacketIdentifier> for mqtt_format::v5::variable_header::PacketIdentifier {
+ fn from(value: PacketIdentifier) -> mqtt_format::v5::variable_header::PacketIdentifier {
+ mqtt_format::v5::variable_header::PacketIdentifier(value.0)
+ }
+}
+
+#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
+pub struct PacketIdentifierNonZero(std::num::NonZeroU16);
+
+impl PacketIdentifierNonZero {
+ #[inline]
+ pub fn get(&self) -> u16 {
+ self.0.get()
+ }
+}
+
+impl std::fmt::Display for PacketIdentifierNonZero {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl TryFrom<mqtt_format::v5::variable_header::PacketIdentifier> for PacketIdentifierNonZero {
+ type Error = (); // TODO
+
+ fn try_from(
+ value: mqtt_format::v5::variable_header::PacketIdentifier,
+ ) -> Result<Self, Self::Error> {
+ std::num::NonZeroU16::try_from(value.0)
+ .map(Self)
+ .map_err(drop) // TODO
+ }
+}
+
+impl From<PacketIdentifierNonZero> for mqtt_format::v5::variable_header::PacketIdentifier {
+ fn from(value: PacketIdentifierNonZero) -> mqtt_format::v5::variable_header::PacketIdentifier {
+ mqtt_format::v5::variable_header::PacketIdentifier(value.0.get())
+ }
+}
+
+impl From<std::num::NonZeroU16> for PacketIdentifierNonZero {
+ fn from(value: std::num::NonZeroU16) -> Self {
+ Self(value)
+ }
+}