diff options
author | Marcel Müller <neikos@neikos.email> | 2024-04-04 12:28:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 12:28:14 +0200 |
commit | de260daf73add9c28ceedaef1345a250c241b1de (patch) | |
tree | 385c71654cf90501c1ef59cccee443dcb7dded87 | |
parent | ae07b4c04a27be8e21e0da36a45e80ed6b9ce498 (diff) | |
parent | 8dccfc038dce577f6e7ff4e17fb86debfd97b3cd (diff) |
Merge pull request #273 from TheNeikos/feature/expand_client
Refactor callbacks to remove hashmap and enum matching
-rw-r--r-- | src/client/mod.rs | 7 | ||||
-rw-r--r-- | src/client/receive.rs | 75 | ||||
-rw-r--r-- | src/client/send.rs | 104 |
3 files changed, 89 insertions, 97 deletions
diff --git a/src/client/mod.rs b/src/client/mod.rs index 12175bb..72780b3 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -14,9 +14,8 @@ use std::sync::Arc; use futures::lock::Mutex; use self::send::Acknowledge; -use self::send::CallbackState; +use self::send::Callbacks; use self::send::ClientHandlers; -use self::send::Id; use self::state::ConnectState; use self::state::SessionState; @@ -24,7 +23,7 @@ struct InnerClient { connection_state: Option<ConnectState>, session_state: Option<SessionState>, default_handlers: ClientHandlers, - outstanding_completions: std::collections::HashMap<Id, CallbackState>, + outstanding_callbacks: Callbacks, } pub struct MqttClient { @@ -41,7 +40,7 @@ impl MqttClient { on_packet_recv: Box::new(|_| ()), handle_acknowledge: Box::new(|_| Acknowledge::Yes), }, - outstanding_completions: std::collections::HashMap::new(), + outstanding_callbacks: Callbacks::new(), })), } } diff --git a/src/client/receive.rs b/src/client/receive.rs index e7d55c2..0231fd7 100644 --- a/src/client/receive.rs +++ b/src/client/receive.rs @@ -15,8 +15,6 @@ use tracing::Instrument; use yoke::Yoke; use super::InnerClient; -use crate::client::CallbackState; -use crate::client::Id; use crate::codecs::MqttPacketCodec; use crate::packets::MqttPacket; use crate::packets::MqttWriter; @@ -107,21 +105,12 @@ async fn handle_pingresp( let mut inner = inner.lock().await; let inner = &mut *inner; - if let Some(callback) = inner.outstanding_completions.get_mut(&Id::PingReq) { - match callback { - CallbackState::PingReq { on_pingresp } => { - if let Some(cb) = on_pingresp.pop_front() { - if cb.send(()).is_err() { - tracing::debug!( - "PingReq completion handler was dropped before receiving response" - ) - } - } else { - tracing::warn!("Received an unwarranted PingResp from the server, continuing") - } - } - _ => todo!("Had non pingreq in pingreq callback state"), + if let Some(cb) = inner.outstanding_callbacks.take_ping_req() { + if cb.send(()).is_err() { + tracing::debug!("PingReq completion handler was dropped before receiving response") } + } else { + tracing::warn!("Received an unwarranted PingResp from the server, continuing") } Ok(()) @@ -170,22 +159,12 @@ async fn handle_pubcomp( session_state.outstanding_packets.remove_by_id(pident); tracing::trace!("Removed packet id from outstanding packets"); - if let Some(callback) = inner - .outstanding_completions - .get_mut(&Id::PacketIdentifier(pident)) - { - match callback { - CallbackState::Qos2 { on_complete, .. } => { - if let Some(on_complete) = on_complete.take() { - if let Err(_) = on_complete.send(packet.clone()) { - tracing::trace!("Could not send ack, receiver was dropped.") - } - } else { - todo!("Invariant broken: Double on_complete for a single pid: {pident}") - } - } - _ => todo!(), + if let Some(callback) = inner.outstanding_callbacks.take_qos2_complete(pident) { + if let Err(_) = callback.on_complete.send(packet.clone()) { + tracing::trace!("Could not send ack, receiver was dropped.") } + } else { + todo!("Invariant broken: Received on_complete for unknown packet") } } } @@ -221,17 +200,9 @@ async fn handle_puback( session_state.outstanding_packets.remove_by_id(pident); tracing::trace!("Removed packet id from outstanding packets"); - if let Some(callback) = inner - .outstanding_completions - .remove(&Id::PacketIdentifier(pident)) - { - match callback { - CallbackState::Qos1 { on_acknowledge } => { - if let Err(_) = on_acknowledge.send(packet.clone()) { - tracing::trace!("Could not send ack, receiver was dropped.") - } - } - _ => todo!(), + if let Some(callback) = inner.outstanding_callbacks.take_qos1(pident) { + if let Err(_) = callback.on_acknowledge.send(packet.clone()) { + tracing::trace!("Could not send ack, receiver was dropped.") } } } else { @@ -296,22 +267,12 @@ async fn handle_pubrec( tracing::trace!("Update packet from outstanding packets"); conn_state.conn_write.send(pubrel).await.map_err(drop)?; - if let Some(callback) = inner - .outstanding_completions - .get_mut(&Id::PacketIdentifier(pident)) - { - match callback { - CallbackState::Qos2 { on_receive, .. } => { - if let Some(on_receive) = on_receive.take() { - if let Err(_) = on_receive.send(packet.clone()) { - tracing::trace!("Could not send ack, receiver was dropped.") - } - } else { - todo!("Invariant broken: Double on_receive for a single pid: {pident}") - } - } - _ => todo!(), + if let Some(callback) = inner.outstanding_callbacks.take_qos2_receive(pident) { + if let Err(_) = callback.on_receive.send(packet.clone()) { + tracing::trace!("Could not send ack, receiver was dropped.") } + } else { + todo!("Invariant broken: Receive PubRec for unawaited {pident}") } } } diff --git a/src/client/send.rs b/src/client/send.rs index 0223551..052d0a3 100644 --- a/src/client/send.rs +++ b/src/client/send.rs @@ -4,6 +4,7 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. // +use std::collections::HashMap; use std::collections::VecDeque; use std::num::NonZeroU16; @@ -104,21 +105,18 @@ impl MqttClient { QualityOfService::AtMostOnce => unreachable!(), QualityOfService::AtLeastOnce => { let (on_acknowledge, recv) = futures::channel::oneshot::channel(); - inner.outstanding_completions.insert( - Id::PacketIdentifier(pi), - CallbackState::Qos1 { on_acknowledge }, - ); + inner + .outstanding_callbacks + .add_qos1(pi, Qos1Callbacks { on_acknowledge }); published_recv = PublishedReceiver::Once(PublishedQos1 { recv }); } QualityOfService::ExactlyOnce => { let (on_receive, recv) = futures::channel::oneshot::channel(); let (on_complete, comp_recv) = futures::channel::oneshot::channel(); - inner.outstanding_completions.insert( - Id::PacketIdentifier(pi), - CallbackState::Qos2 { - on_receive: Some(on_receive), - on_complete: Some(on_complete), - }, + inner.outstanding_callbacks.add_qos2( + pi, + Qos2ReceiveCallback { on_receive }, + Qos2CompleteCallback { on_complete }, ); published_recv = PublishedReceiver::Twice(PublishedQos2Received { recv, comp_recv }); @@ -229,23 +227,67 @@ pub(crate) enum Acknowledge { YesWithProps {}, } -#[derive(Debug, Hash, PartialEq, Eq)] -pub(crate) enum Id { - PingReq, - PacketIdentifier(NonZeroU16), +pub(crate) struct Callbacks { + ping_req: VecDeque<futures::channel::oneshot::Sender<()>>, + qos1: HashMap<NonZeroU16, Qos1Callbacks>, + qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>, + qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>, +} + +impl Callbacks { + pub(crate) fn new() -> Callbacks { + Callbacks { + ping_req: Default::default(), + qos1: HashMap::default(), + qos2_receive: HashMap::default(), + qos2_complete: HashMap::default(), + } + } + + pub(crate) fn add_ping_req(&mut self, cb: futures::channel::oneshot::Sender<()>) { + self.ping_req.push_back(cb); + } + + pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) { + self.qos1.insert(id, cb); + } + + pub(crate) fn add_qos2( + &mut self, + id: NonZeroU16, + rec: Qos2ReceiveCallback, + comp: Qos2CompleteCallback, + ) { + self.qos2_receive.insert(id, rec); + self.qos2_complete.insert(id, comp); + } + + pub(crate) fn take_ping_req(&mut self) -> Option<futures::channel::oneshot::Sender<()>> { + self.ping_req.pop_front() + } + + pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> { + self.qos1.remove(&id) + } + + pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> { + self.qos2_receive.remove(&id) + } + + pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> { + self.qos2_complete.remove(&id) + } +} + +pub(crate) struct Qos1Callbacks { + pub(crate) on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, } -pub(crate) enum CallbackState { - PingReq { - on_pingresp: VecDeque<futures::channel::oneshot::Sender<()>>, - }, - Qos1 { - on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, - }, - Qos2 { - on_receive: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>, - on_complete: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>, - }, +pub(crate) struct Qos2ReceiveCallback { + pub(crate) on_receive: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, +} +pub(crate) struct Qos2CompleteCallback { + pub(crate) on_complete: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, } pub struct Publish { @@ -365,17 +407,7 @@ impl MqttClient { let (sender, recv) = futures::channel::oneshot::channel(); - let cbs = inner - .outstanding_completions - .entry(Id::PingReq) - .or_insert_with(|| CallbackState::PingReq { - on_pingresp: Default::default(), - }); - - match cbs { - CallbackState::PingReq { on_pingresp } => on_pingresp.push_back(sender), - _ => unreachable!("Had a non-pingreq in a pingreq response"), - } + inner.outstanding_callbacks.add_ping_req(sender); conn_state.conn_write.send(packet).await.map_err(drop)?; |