diff options
author | Marcel Müller <neikos@neikos.email> | 2024-04-04 12:34:44 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 12:34:44 +0200 |
commit | 04b96d6c1c43f90c3cb78cf6cc91c1c3771a4041 (patch) | |
tree | 990d77b90e1a241f4dc2f1d383e4d2b9a67af17c | |
parent | de260daf73add9c28ceedaef1345a250c241b1de (diff) | |
parent | 3e8d0ace867ff6894e16442e42febe31aa1f3cf8 (diff) |
Merge pull request #272 from matthiasbeyer/packetidentifier-type
Dedicated PacketIdentifier{,NonZero} type
-rw-r--r-- | src/client/receive.rs | 14 | ||||
-rw-r--r-- | src/client/send.rs | 27 | ||||
-rw-r--r-- | src/client/state.rs | 15 | ||||
-rw-r--r-- | src/lib.rs | 1 | ||||
-rw-r--r-- | src/packet_identifier.rs | 66 |
5 files changed, 99 insertions, 24 deletions
diff --git a/src/client/receive.rs b/src/client/receive.rs index 0231fd7..3ebf76b 100644 --- a/src/client/receive.rs +++ b/src/client/receive.rs @@ -4,7 +4,6 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. // -use std::num::NonZeroU16; use std::sync::Arc; use futures::lock::Mutex; @@ -16,6 +15,7 @@ use yoke::Yoke; use super::InnerClient; use crate::codecs::MqttPacketCodec; +use crate::packet_identifier::PacketIdentifierNonZero; use crate::packets::MqttPacket; use crate::packets::MqttWriter; use crate::packets::StableBytes; @@ -148,9 +148,9 @@ async fn handle_pubcomp( tracing::error!("No session state found"); todo!() }; - let pident = NonZeroU16::try_from(pubcomp.packet_identifier.0) + let pident = PacketIdentifierNonZero::try_from(pubcomp.packet_identifier) .expect("zero PacketIdentifier not valid here"); - tracing::Span::current().record("packet_identifier", pident); + tracing::Span::current().record("packet_identifier", tracing::field::display(pident)); if session_state .outstanding_packets @@ -189,9 +189,9 @@ async fn handle_puback( todo!() }; - let pident = std::num::NonZeroU16::try_from(mpuback.packet_identifier.0) + let pident = PacketIdentifierNonZero::try_from(mpuback.packet_identifier) .expect("Zero PacketIdentifier not valid here"); - tracing::Span::current().record("packet_identifier", pident); + tracing::Span::current().record("packet_identifier", tracing::field::display(pident)); if session_state .outstanding_packets @@ -236,9 +236,9 @@ async fn handle_pubrec( tracing::error!("No session state found"); todo!() }; - let pident = NonZeroU16::try_from(pubrec.packet_identifier.0) + let pident = PacketIdentifierNonZero::try_from(pubrec.packet_identifier) .expect("zero PacketIdentifier not valid here"); - tracing::Span::current().record("packet_identifier", pident); + tracing::Span::current().record("packet_identifier", tracing::field::display(pident)); if session_state .outstanding_packets diff --git a/src/client/send.rs b/src/client/send.rs index 052d0a3..91f9564 100644 --- a/src/client/send.rs +++ b/src/client/send.rs @@ -15,6 +15,7 @@ use tracing::Instrument; use super::state::OutstandingPackets; use super::MqttClient; +use crate::packet_identifier::PacketIdentifierNonZero; use crate::packets::MqttPacket; use crate::payload::MqttPayload; use crate::qos::QualityOfService; @@ -188,11 +189,11 @@ impl MqttClient { fn get_next_packet_ident( next_packet_ident: &mut std::num::NonZeroU16, outstanding_packets: &OutstandingPackets, -) -> Result<std::num::NonZeroU16, PacketIdentifierExhausted> { +) -> Result<PacketIdentifierNonZero, PacketIdentifierExhausted> { let start = *next_packet_ident; loop { - let next = *next_packet_ident; + let next = PacketIdentifierNonZero::from(*next_packet_ident); if !outstanding_packets.exists_outstanding_packet(next) { return Ok(next); @@ -229,9 +230,9 @@ pub(crate) enum Acknowledge { pub(crate) struct Callbacks { ping_req: VecDeque<futures::channel::oneshot::Sender<()>>, - qos1: HashMap<NonZeroU16, Qos1Callbacks>, - qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>, - qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>, + qos1: HashMap<PacketIdentifierNonZero, Qos1Callbacks>, + qos2_receive: HashMap<PacketIdentifierNonZero, Qos2ReceiveCallback>, + qos2_complete: HashMap<PacketIdentifierNonZero, Qos2CompleteCallback>, } impl Callbacks { @@ -248,13 +249,13 @@ impl Callbacks { self.ping_req.push_back(cb); } - pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) { + pub(crate) fn add_qos1(&mut self, id: PacketIdentifierNonZero, cb: Qos1Callbacks) { self.qos1.insert(id, cb); } pub(crate) fn add_qos2( &mut self, - id: NonZeroU16, + id: PacketIdentifierNonZero, rec: Qos2ReceiveCallback, comp: Qos2CompleteCallback, ) { @@ -266,15 +267,21 @@ impl Callbacks { self.ping_req.pop_front() } - pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> { + pub(crate) fn take_qos1(&mut self, id: PacketIdentifierNonZero) -> Option<Qos1Callbacks> { self.qos1.remove(&id) } - pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> { + pub(crate) fn take_qos2_receive( + &mut self, + id: PacketIdentifierNonZero, + ) -> Option<Qos2ReceiveCallback> { self.qos2_receive.remove(&id) } - pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> { + pub(crate) fn take_qos2_complete( + &mut self, + id: PacketIdentifierNonZero, + ) -> Option<Qos2CompleteCallback> { self.qos2_complete.remove(&id) } } diff --git a/src/client/state.rs b/src/client/state.rs index bbd9714..d44b182 100644 --- a/src/client/state.rs +++ b/src/client/state.rs @@ -10,6 +10,7 @@ use tokio_util::codec::FramedRead; use tokio_util::codec::FramedWrite; use crate::codecs::MqttPacketCodec; +use crate::packet_identifier::PacketIdentifierNonZero; use crate::string::MqttString; use crate::transport::MqttConnection; @@ -35,9 +36,9 @@ pub(super) struct SessionState { } pub(super) struct OutstandingPackets { - pub(super) packet_ident_order: Vec<std::num::NonZeroU16>, + pub(super) packet_ident_order: Vec<PacketIdentifierNonZero>, pub(super) outstanding_packets: - std::collections::BTreeMap<std::num::NonZeroU16, crate::packets::MqttPacket>, + std::collections::BTreeMap<PacketIdentifierNonZero, crate::packets::MqttPacket>, } impl OutstandingPackets { @@ -48,7 +49,7 @@ impl OutstandingPackets { } } - pub fn insert(&mut self, ident: std::num::NonZeroU16, packet: crate::packets::MqttPacket) { + pub fn insert(&mut self, ident: PacketIdentifierNonZero, packet: crate::packets::MqttPacket) { debug_assert_eq!( self.packet_ident_order.len(), self.outstanding_packets.len() @@ -62,7 +63,7 @@ impl OutstandingPackets { pub fn update_by_id( &mut self, - ident: std::num::NonZeroU16, + ident: PacketIdentifierNonZero, packet: crate::packets::MqttPacket, ) { debug_assert_eq!( @@ -75,19 +76,19 @@ impl OutstandingPackets { debug_assert!(removed.is_some()); } - pub fn exists_outstanding_packet(&self, ident: std::num::NonZeroU16) -> bool { + pub fn exists_outstanding_packet(&self, ident: PacketIdentifierNonZero) -> bool { self.outstanding_packets.contains_key(&ident) } pub fn iter_in_send_order( &self, - ) -> impl Iterator<Item = (std::num::NonZeroU16, &crate::packets::MqttPacket)> { + ) -> impl Iterator<Item = (PacketIdentifierNonZero, &crate::packets::MqttPacket)> { self.packet_ident_order .iter() .flat_map(|id| self.outstanding_packets.get(id).map(|p| (*id, p))) } - pub fn remove_by_id(&mut self, id: std::num::NonZeroU16) { + pub fn remove_by_id(&mut self, id: PacketIdentifierNonZero) { // Vec::retain() preserves order self.packet_ident_order.retain(|&elm| elm != id); self.outstanding_packets.remove(&id); @@ -12,6 +12,7 @@ pub mod client_identifier; mod codecs; mod error; pub mod keep_alive; +pub mod packet_identifier; pub mod packets; pub mod payload; mod properties; diff --git a/src/packet_identifier.rs b/src/packet_identifier.rs new file mode 100644 index 0000000..3b4c680 --- /dev/null +++ b/src/packet_identifier.rs @@ -0,0 +1,66 @@ +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. +// + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct PacketIdentifier(u16); + +impl std::fmt::Display for PacketIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl From<mqtt_format::v5::variable_header::PacketIdentifier> for PacketIdentifier { + fn from(value: mqtt_format::v5::variable_header::PacketIdentifier) -> Self { + Self(value.0) + } +} + +impl From<PacketIdentifier> for mqtt_format::v5::variable_header::PacketIdentifier { + fn from(value: PacketIdentifier) -> mqtt_format::v5::variable_header::PacketIdentifier { + mqtt_format::v5::variable_header::PacketIdentifier(value.0) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct PacketIdentifierNonZero(std::num::NonZeroU16); + +impl PacketIdentifierNonZero { + #[inline] + pub fn get(&self) -> u16 { + self.0.get() + } +} + +impl std::fmt::Display for PacketIdentifierNonZero { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl TryFrom<mqtt_format::v5::variable_header::PacketIdentifier> for PacketIdentifierNonZero { + type Error = (); // TODO + + fn try_from( + value: mqtt_format::v5::variable_header::PacketIdentifier, + ) -> Result<Self, Self::Error> { + std::num::NonZeroU16::try_from(value.0) + .map(Self) + .map_err(drop) // TODO + } +} + +impl From<PacketIdentifierNonZero> for mqtt_format::v5::variable_header::PacketIdentifier { + fn from(value: PacketIdentifierNonZero) -> mqtt_format::v5::variable_header::PacketIdentifier { + mqtt_format::v5::variable_header::PacketIdentifier(value.0.get()) + } +} + +impl From<std::num::NonZeroU16> for PacketIdentifierNonZero { + fn from(value: std::num::NonZeroU16) -> Self { + Self(value) + } +} |