summaryrefslogtreecommitdiffstats
path: root/src/client/receive.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/client/receive.rs')
-rw-r--r--src/client/receive.rs322
1 files changed, 322 insertions, 0 deletions
diff --git a/src/client/receive.rs b/src/client/receive.rs
new file mode 100644
index 0000000..e7d55c2
--- /dev/null
+++ b/src/client/receive.rs
@@ -0,0 +1,322 @@
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+//
+
+use std::num::NonZeroU16;
+use std::sync::Arc;
+
+use futures::lock::Mutex;
+use futures::SinkExt;
+use futures::StreamExt;
+use tokio_util::codec::FramedRead;
+use tracing::Instrument;
+use yoke::Yoke;
+
+use super::InnerClient;
+use crate::client::CallbackState;
+use crate::client::Id;
+use crate::codecs::MqttPacketCodec;
+use crate::packets::MqttPacket;
+use crate::packets::MqttWriter;
+use crate::packets::StableBytes;
+use crate::transport::MqttConnection;
+
+pub(super) async fn handle_background_receiving(
+ inner_clone: Arc<Mutex<InnerClient>>,
+ mut conn_read: FramedRead<tokio::io::ReadHalf<MqttConnection>, MqttPacketCodec>,
+ conn_read_sender: futures::channel::oneshot::Sender<
+ FramedRead<tokio::io::ReadHalf<MqttConnection>, MqttPacketCodec>,
+ >,
+) -> Result<(), ()> {
+ tracing::info!("Starting background task");
+ let inner: Arc<Mutex<InnerClient>> = inner_clone;
+
+ while let Some(next) = conn_read.next().await {
+ let process_span = tracing::debug_span!(
+ "Processing packet",
+ packet_kind = tracing::field::Empty,
+ packet_identifier = tracing::field::Empty
+ );
+ tracing::debug!(parent: &process_span, valid = next.is_ok(), "Received packet");
+ let packet = match next {
+ Ok(packet) => packet,
+ Err(e) => panic!("Received err: {e}"),
+ };
+ process_span.record(
+ "packet_kind",
+ tracing::field::debug(packet.get().get_kind()),
+ );
+
+ match packet.get() {
+ mqtt_format::v5::packets::MqttPacket::Auth(_) => todo!(),
+ mqtt_format::v5::packets::MqttPacket::Disconnect(_) => todo!(),
+ mqtt_format::v5::packets::MqttPacket::Pingreq(pingreq) => {
+ handle_pingreq(pingreq, &inner)
+ .instrument(process_span)
+ .await?
+ }
+ mqtt_format::v5::packets::MqttPacket::Pingresp(pingresp) => {
+ handle_pingresp(pingresp, &inner)
+ .instrument(process_span)
+ .await?
+ }
+ mqtt_format::v5::packets::MqttPacket::Puback(mpuback) => {
+ handle_puback(mpuback, &inner, &packet)
+ .instrument(process_span)
+ .await?
+ }
+ mqtt_format::v5::packets::MqttPacket::Pubrec(pubrec) => {
+ handle_pubrec(pubrec, &inner, &packet)
+ .instrument(process_span)
+ .await?
+ }
+ mqtt_format::v5::packets::MqttPacket::Pubcomp(pubcomp) => {
+ handle_pubcomp(pubcomp, &inner, &packet)
+ .instrument(process_span)
+ .await?
+ }
+ mqtt_format::v5::packets::MqttPacket::Publish(_) => todo!(),
+ mqtt_format::v5::packets::MqttPacket::Pubrel(_) => todo!(),
+ mqtt_format::v5::packets::MqttPacket::Suback(_) => todo!(),
+ mqtt_format::v5::packets::MqttPacket::Unsuback(_) => todo!(),
+
+ mqtt_format::v5::packets::MqttPacket::Connack(_)
+ | mqtt_format::v5::packets::MqttPacket::Connect(_)
+ | mqtt_format::v5::packets::MqttPacket::Subscribe(_)
+ | mqtt_format::v5::packets::MqttPacket::Unsubscribe(_) => {
+ todo!("Handle invalid packet")
+ }
+ }
+ }
+
+ tracing::debug!("Finished processing, returning reader");
+ if let Err(conn_read) = conn_read_sender.send(conn_read) {
+ tracing::error!("Failed to return reader");
+ todo!()
+ }
+
+ Ok(())
+}
+
+async fn handle_pingresp(
+ _pingresp: &mqtt_format::v5::packets::pingresp::MPingresp,
+ inner: &Arc<Mutex<InnerClient>>,
+) -> Result<(), ()> {
+ let mut inner = inner.lock().await;
+ let inner = &mut *inner;
+
+ if let Some(callback) = inner.outstanding_completions.get_mut(&Id::PingReq) {
+ match callback {
+ CallbackState::PingReq { on_pingresp } => {
+ if let Some(cb) = on_pingresp.pop_front() {
+ if cb.send(()).is_err() {
+ tracing::debug!(
+ "PingReq completion handler was dropped before receiving response"
+ )
+ }
+ } else {
+ tracing::warn!("Received an unwarranted PingResp from the server, continuing")
+ }
+ }
+ _ => todo!("Had non pingreq in pingreq callback state"),
+ }
+ }
+
+ Ok(())
+}
+
+async fn handle_pingreq(
+ _pingreq: &mqtt_format::v5::packets::pingreq::MPingreq,
+ inner: &Arc<Mutex<InnerClient>>,
+) -> Result<(), ()> {
+ let mut inner = inner.lock().await;
+ let inner = &mut *inner;
+ let Some(ref mut conn_state) = inner.connection_state else {
+ tracing::error!("No connection state found");
+ todo!()
+ };
+
+ let packet = mqtt_format::v5::packets::MqttPacket::Pingresp(
+ mqtt_format::v5::packets::pingresp::MPingresp,
+ );
+ conn_state.conn_write.send(packet).await.map_err(drop)?;
+
+ Ok(())
+}
+
+async fn handle_pubcomp(
+ pubcomp: &mqtt_format::v5::packets::pubcomp::MPubcomp<'_>,
+ inner: &Arc<Mutex<InnerClient>>,
+ packet: &MqttPacket,
+) -> Result<(), ()> {
+ match pubcomp.reason {
+ mqtt_format::v5::packets::pubcomp::PubcompReasonCode::Success => {
+ let mut inner = inner.lock().await;
+ let inner = &mut *inner;
+ let Some(ref mut session_state) = inner.session_state else {
+ tracing::error!("No session state found");
+ todo!()
+ };
+ let pident = NonZeroU16::try_from(pubcomp.packet_identifier.0)
+ .expect("zero PacketIdentifier not valid here");
+ tracing::Span::current().record("packet_identifier", pident);
+
+ if session_state
+ .outstanding_packets
+ .exists_outstanding_packet(pident)
+ {
+ session_state.outstanding_packets.remove_by_id(pident);
+ tracing::trace!("Removed packet id from outstanding packets");
+
+ if let Some(callback) = inner
+ .outstanding_completions
+ .get_mut(&Id::PacketIdentifier(pident))
+ {
+ match callback {
+ CallbackState::Qos2 { on_complete, .. } => {
+ if let Some(on_complete) = on_complete.take() {
+ if let Err(_) = on_complete.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
+ }
+ } else {
+ todo!("Invariant broken: Double on_complete for a single pid: {pident}")
+ }
+ }
+ _ => todo!(),
+ }
+ }
+ }
+ }
+ _ => todo!("Handle errors"),
+ }
+
+ Ok(())
+}
+
+async fn handle_puback(
+ mpuback: &mqtt_format::v5::packets::puback::MPuback<'_>,
+ inner: &Arc<Mutex<InnerClient>>,
+ packet: &MqttPacket,
+) -> Result<(), ()> {
+ match mpuback.reason {
+ mqtt_format::v5::packets::puback::PubackReasonCode::Success
+ | mqtt_format::v5::packets::puback::PubackReasonCode::NoMatchingSubscribers => {
+ let mut inner = inner.lock().await;
+ let inner = &mut *inner;
+ let Some(ref mut session_state) = inner.session_state else {
+ tracing::error!("No session state found");
+ todo!()
+ };
+
+ let pident = std::num::NonZeroU16::try_from(mpuback.packet_identifier.0)
+ .expect("Zero PacketIdentifier not valid here");
+ tracing::Span::current().record("packet_identifier", pident);
+
+ if session_state
+ .outstanding_packets
+ .exists_outstanding_packet(pident)
+ {
+ session_state.outstanding_packets.remove_by_id(pident);
+ tracing::trace!("Removed packet id from outstanding packets");
+
+ if let Some(callback) = inner
+ .outstanding_completions
+ .remove(&Id::PacketIdentifier(pident))
+ {
+ match callback {
+ CallbackState::Qos1 { on_acknowledge } => {
+ if let Err(_) = on_acknowledge.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
+ }
+ }
+ _ => todo!(),
+ }
+ }
+ } else {
+ tracing::error!("Packet id does not exist in outstanding packets");
+ todo!()
+ }
+
+ // TODO: Forward mpuback.properties etc to the user
+ }
+
+ _ => todo!("Handle errors"),
+ }
+
+ Ok(())
+}
+
+async fn handle_pubrec(
+ pubrec: &mqtt_format::v5::packets::pubrec::MPubrec<'_>,
+ inner: &Arc<Mutex<InnerClient>>,
+ packet: &MqttPacket,
+) -> Result<(), ()> {
+ match pubrec.reason {
+ mqtt_format::v5::packets::pubrec::PubrecReasonCode::Success => {
+ let mut inner = inner.lock().await;
+ let inner = &mut *inner;
+ let Some(ref mut session_state) = inner.session_state else {
+ tracing::error!("No session state found");
+ todo!()
+ };
+ let Some(ref mut conn_state) = inner.connection_state else {
+ tracing::error!("No session state found");
+ todo!()
+ };
+ let pident = NonZeroU16::try_from(pubrec.packet_identifier.0)
+ .expect("zero PacketIdentifier not valid here");
+ tracing::Span::current().record("packet_identifier", pident);
+
+ if session_state
+ .outstanding_packets
+ .exists_outstanding_packet(pident)
+ {
+ let pubrel = mqtt_format::v5::packets::MqttPacket::Pubrel(
+ mqtt_format::v5::packets::pubrel::MPubrel {
+ packet_identifier: pubrec.packet_identifier,
+ reason: mqtt_format::v5::packets::pubrel::PubrelReasonCode::Success,
+ properties: mqtt_format::v5::packets::pubrel::PubrelProperties::new(),
+ },
+ );
+
+ let mut bytes = tokio_util::bytes::BytesMut::new();
+ bytes.reserve(pubrel.binary_size() as usize);
+ pubrel.write(&mut MqttWriter(&mut bytes)).map_err(drop)?;
+ let pubrel_packet = MqttPacket {
+ packet: Yoke::try_attach_to_cart(StableBytes(bytes.freeze()), |bytes| {
+ mqtt_format::v5::packets::MqttPacket::parse_complete(bytes)
+ })
+ .unwrap(),
+ };
+ session_state
+ .outstanding_packets
+ .update_by_id(pident, pubrel_packet);
+ tracing::trace!("Update packet from outstanding packets");
+ conn_state.conn_write.send(pubrel).await.map_err(drop)?;
+
+ if let Some(callback) = inner
+ .outstanding_completions
+ .get_mut(&Id::PacketIdentifier(pident))
+ {
+ match callback {
+ CallbackState::Qos2 { on_receive, .. } => {
+ if let Some(on_receive) = on_receive.take() {
+ if let Err(_) = on_receive.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
+ }
+ } else {
+ todo!("Invariant broken: Double on_receive for a single pid: {pident}")
+ }
+ }
+ _ => todo!(),
+ }
+ }
+ }
+ }
+ _ => todo!("Handle errors"),
+ }
+
+ Ok(())
+}