diff options
author | Marcel Müller <neikos@neikos.email> | 2024-04-04 16:18:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 16:18:19 +0200 |
commit | c04592ada120ce2129a614d093f9c199f6b12c4b (patch) | |
tree | 0f6de5ada2fb5af7ef587ff4ca53d316fa64dd68 | |
parent | c0fc75025cba765820fa1b95b896ba291591313e (diff) | |
parent | edd71b600ff9303c92a6e701197a8296f8c25708 (diff) |
Merge pull request #276 from TheNeikos/feature/add_keep_alive
Add keep alive
-rw-r--r-- | Cargo.lock | 7 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | cloudmqtt-bin/src/bin/client.rs | 19 | ||||
-rw-r--r-- | src/client/connect.rs | 82 | ||||
-rw-r--r-- | src/client/state.rs | 37 | ||||
-rw-r--r-- | src/keep_alive.rs | 1 |
6 files changed, 140 insertions, 7 deletions
@@ -204,6 +204,7 @@ name = "cloudmqtt" version = "0.5.0" dependencies = [ "futures", + "futures-timer", "mqtt-format", "paste", "stable_deref_trait", @@ -327,6 +328,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + +[[package]] name = "futures-util" version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -20,6 +20,7 @@ debug = ["winnow/debug"] [dependencies] futures = "0.3.30" +futures-timer = "3.0.3" mqtt-format = { version = "0.5.0", path = "mqtt-format", features = [ "yoke", "mqttv5", diff --git a/cloudmqtt-bin/src/bin/client.rs b/cloudmqtt-bin/src/bin/client.rs index 7537cc3..43e9743 100644 --- a/cloudmqtt-bin/src/bin/client.rs +++ b/cloudmqtt-bin/src/bin/client.rs @@ -4,6 +4,8 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. // +use std::time::Duration; + use clap::Parser; use cloudmqtt::client::connect::MqttClientConnector; use cloudmqtt::client::send::Publish; @@ -47,7 +49,7 @@ async fn main() { connection, client_id, cloudmqtt::client::connect::CleanStart::Yes, - cloudmqtt::keep_alive::KeepAlive::Disabled, + cloudmqtt::keep_alive::KeepAlive::Seconds(5.try_into().unwrap()), ); let client = MqttClient::new_with_default_handlers(); @@ -69,5 +71,20 @@ async fn main() { client.ping().await.unwrap().response().await; + tokio::time::sleep(Duration::from_secs(3)).await; + + client + .publish(Publish { + topic: "foo/bar".try_into().unwrap(), + qos: cloudmqtt::qos::QualityOfService::AtMostOnce, + retain: false, + payload: vec![123].try_into().unwrap(), + on_packet_recv: None, + }) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_secs(20)).await; + println!("Sent message! Bye"); } diff --git a/src/client/connect.rs b/src/client/connect.rs index ed668e9..8bb6f6a 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -4,6 +4,9 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. // +use std::time::Duration; + +use futures::select; use futures::FutureExt; use futures::SinkExt; use futures::StreamExt; @@ -13,6 +16,7 @@ use tokio_util::codec::FramedWrite; use super::MqttClient; use crate::bytes::MqttBytes; use crate::client::state::OutstandingPackets; +use crate::client::state::TransportWriter; use crate::client::ConnectState; use crate::client::SessionState; use crate::client_identifier::ProposedClientIdentifier; @@ -213,6 +217,9 @@ impl MqttClient { }); } + let (sender, heartbeat_receiver) = futures::channel::mpsc::channel(1); + let conn_write = TransportWriter::new(conn_write, sender); + let (conn_read_sender, conn_read_recv) = futures::channel::oneshot::channel(); let connect_client_state = ConnectState { @@ -222,6 +229,15 @@ impl MqttClient { retain_available: connack.properties.retain_available().map(|ra| ra.0), maximum_packet_size: connack.properties.maximum_packet_size().map(|mps| mps.0), topic_alias_maximum: connack.properties.topic_alias_maximum().map(|tam| tam.0), + keep_alive: connack + .properties + .server_keep_alive() + .map(|ska| { + std::num::NonZeroU16::try_from(ska.0) + .map(KeepAlive::Seconds) + .unwrap_or(KeepAlive::Disabled) + }) + .unwrap_or(connector.keep_alive), conn_write, conn_read_recv, next_packet_identifier: std::num::NonZeroU16::MIN, @@ -257,6 +273,8 @@ impl MqttClient { }; } + let keep_alive = connect_client_state.keep_alive; + inner.connection_state = Some(connect_client_state); inner.session_state = Some(SessionState { client_identifier, @@ -267,11 +285,34 @@ impl MqttClient { crate::packets::connack::ConnackPropertiesView::try_from(maybe_connack) .expect("An already matched value suddenly changed?"); - let background_task = crate::client::receive::handle_background_receiving( - inner_clone, - conn_read, - conn_read_sender, - ) + let background_task = async move { + let receiving_inner = inner_clone.clone(); + let receiving = crate::client::receive::handle_background_receiving( + receiving_inner, + conn_read, + conn_read_sender, + ); + + let heartbeat_inner = inner_clone; + + let heartbeat = if let KeepAlive::Seconds(time) = keep_alive { + handle_heartbeats( + heartbeat_receiver, + Duration::from_secs(time.get().into()), + heartbeat_inner, + ) + .left_future() + } else { + tracing::info!( + "Keep Alive is disabled, will not send PingReq packets automatically" + ); + futures::future::ok(()).right_future() + }; + + tokio::try_join!(receiving, heartbeat) + .map(drop) + .map_err(drop) + } .boxed(); return Ok(Connected { @@ -285,3 +326,34 @@ impl MqttClient { todo!() } } + +async fn handle_heartbeats( + mut heartbeat_receiver: futures::channel::mpsc::Receiver<()>, + duration: Duration, + heartbeat_inner: std::sync::Arc<futures::lock::Mutex<super::InnerClient>>, +) -> Result<(), ()> { + let mut timeout = futures_timer::Delay::new(duration).fuse(); + loop { + select! { + heartbeat = heartbeat_receiver.next() => match heartbeat { + None => break, + Some(_) => { + timeout = futures_timer::Delay::new(duration).fuse(); + }, + }, + _ = timeout => { + let mut inner = heartbeat_inner.lock().await; + let inner = &mut *inner; + let Some(conn_state) = inner.connection_state.as_mut() else { + todo!(); + }; + + // We make sure that this won't deadlock in the send method + conn_state.conn_write.send( + mqtt_format::v5::packets::MqttPacket::Pingreq(mqtt_format::v5::packets::pingreq::MPingreq) + ).await.unwrap(); + } + } + } + Ok(()) +} diff --git a/src/client/state.rs b/src/client/state.rs index 40f91d8..0c707d3 100644 --- a/src/client/state.rs +++ b/src/client/state.rs @@ -6,14 +6,48 @@ use std::num::NonZeroU16; +use futures::SinkExt; use tokio_util::codec::FramedRead; use tokio_util::codec::FramedWrite; use crate::codecs::MqttPacketCodec; +use crate::codecs::MqttPacketCodecError; +use crate::keep_alive::KeepAlive; use crate::packet_identifier::PacketIdentifier; use crate::string::MqttString; use crate::transport::MqttConnection; +pub(super) struct TransportWriter { + conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>, + notify: futures::channel::mpsc::Sender<()>, +} + +impl TransportWriter { + pub(super) fn new( + conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>, + notify: futures::channel::mpsc::Sender<()>, + ) -> Self { + Self { conn, notify } + } + + pub(super) async fn send( + &mut self, + packet: mqtt_format::v5::packets::MqttPacket<'_>, + ) -> Result<(), MqttPacketCodecError> { + self.conn.send(packet).await?; + if let Err(e) = self.notify.try_send(()) { + if e.is_full() { + // This is fine, we are already notifying of a send + } + if e.is_disconnected() { + todo!("Could not send to heartbeat!?") + } + } + + Ok(()) + } +} + pub(super) struct ConnectState { pub(super) session_present: bool, pub(super) receive_maximum: Option<NonZeroU16>, @@ -21,13 +55,14 @@ pub(super) struct ConnectState { pub(super) retain_available: Option<bool>, pub(super) topic_alias_maximum: Option<u16>, pub(super) maximum_packet_size: Option<u32>, - pub(super) conn_write: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>, + pub(super) conn_write: TransportWriter, pub(super) conn_read_recv: futures::channel::oneshot::Receiver< FramedRead<tokio::io::ReadHalf<MqttConnection>, MqttPacketCodec>, >, pub(super) next_packet_identifier: std::num::NonZeroU16, + pub(crate) keep_alive: KeepAlive, } pub(super) struct SessionState { diff --git a/src/keep_alive.rs b/src/keep_alive.rs index 12cf6e9..a9cea33 100644 --- a/src/keep_alive.rs +++ b/src/keep_alive.rs @@ -7,6 +7,7 @@ use std::num::NonZeroU16; use std::time::Duration; +#[derive(Debug, Clone, Copy)] pub enum KeepAlive { Disabled, Seconds(NonZeroU16), |