summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2024-04-04 16:18:19 +0200
committerGitHub <noreply@github.com>2024-04-04 16:18:19 +0200
commitc04592ada120ce2129a614d093f9c199f6b12c4b (patch)
tree0f6de5ada2fb5af7ef587ff4ca53d316fa64dd68
parentc0fc75025cba765820fa1b95b896ba291591313e (diff)
parentedd71b600ff9303c92a6e701197a8296f8c25708 (diff)
Merge pull request #276 from TheNeikos/feature/add_keep_alive
Add keep alive
-rw-r--r--Cargo.lock7
-rw-r--r--Cargo.toml1
-rw-r--r--cloudmqtt-bin/src/bin/client.rs19
-rw-r--r--src/client/connect.rs82
-rw-r--r--src/client/state.rs37
-rw-r--r--src/keep_alive.rs1
6 files changed, 140 insertions, 7 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 696e83e..e08663d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -204,6 +204,7 @@ name = "cloudmqtt"
version = "0.5.0"
dependencies = [
"futures",
+ "futures-timer",
"mqtt-format",
"paste",
"stable_deref_trait",
@@ -327,6 +328,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
+name = "futures-timer"
+version = "3.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
+
+[[package]]
name = "futures-util"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 19041a3..0266690 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,7 @@ debug = ["winnow/debug"]
[dependencies]
futures = "0.3.30"
+futures-timer = "3.0.3"
mqtt-format = { version = "0.5.0", path = "mqtt-format", features = [
"yoke",
"mqttv5",
diff --git a/cloudmqtt-bin/src/bin/client.rs b/cloudmqtt-bin/src/bin/client.rs
index 7537cc3..43e9743 100644
--- a/cloudmqtt-bin/src/bin/client.rs
+++ b/cloudmqtt-bin/src/bin/client.rs
@@ -4,6 +4,8 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
+use std::time::Duration;
+
use clap::Parser;
use cloudmqtt::client::connect::MqttClientConnector;
use cloudmqtt::client::send::Publish;
@@ -47,7 +49,7 @@ async fn main() {
connection,
client_id,
cloudmqtt::client::connect::CleanStart::Yes,
- cloudmqtt::keep_alive::KeepAlive::Disabled,
+ cloudmqtt::keep_alive::KeepAlive::Seconds(5.try_into().unwrap()),
);
let client = MqttClient::new_with_default_handlers();
@@ -69,5 +71,20 @@ async fn main() {
client.ping().await.unwrap().response().await;
+ tokio::time::sleep(Duration::from_secs(3)).await;
+
+ client
+ .publish(Publish {
+ topic: "foo/bar".try_into().unwrap(),
+ qos: cloudmqtt::qos::QualityOfService::AtMostOnce,
+ retain: false,
+ payload: vec![123].try_into().unwrap(),
+ on_packet_recv: None,
+ })
+ .await
+ .unwrap();
+
+ tokio::time::sleep(Duration::from_secs(20)).await;
+
println!("Sent message! Bye");
}
diff --git a/src/client/connect.rs b/src/client/connect.rs
index ed668e9..8bb6f6a 100644
--- a/src/client/connect.rs
+++ b/src/client/connect.rs
@@ -4,6 +4,9 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
+use std::time::Duration;
+
+use futures::select;
use futures::FutureExt;
use futures::SinkExt;
use futures::StreamExt;
@@ -13,6 +16,7 @@ use tokio_util::codec::FramedWrite;
use super::MqttClient;
use crate::bytes::MqttBytes;
use crate::client::state::OutstandingPackets;
+use crate::client::state::TransportWriter;
use crate::client::ConnectState;
use crate::client::SessionState;
use crate::client_identifier::ProposedClientIdentifier;
@@ -213,6 +217,9 @@ impl MqttClient {
});
}
+ let (sender, heartbeat_receiver) = futures::channel::mpsc::channel(1);
+ let conn_write = TransportWriter::new(conn_write, sender);
+
let (conn_read_sender, conn_read_recv) = futures::channel::oneshot::channel();
let connect_client_state = ConnectState {
@@ -222,6 +229,15 @@ impl MqttClient {
retain_available: connack.properties.retain_available().map(|ra| ra.0),
maximum_packet_size: connack.properties.maximum_packet_size().map(|mps| mps.0),
topic_alias_maximum: connack.properties.topic_alias_maximum().map(|tam| tam.0),
+ keep_alive: connack
+ .properties
+ .server_keep_alive()
+ .map(|ska| {
+ std::num::NonZeroU16::try_from(ska.0)
+ .map(KeepAlive::Seconds)
+ .unwrap_or(KeepAlive::Disabled)
+ })
+ .unwrap_or(connector.keep_alive),
conn_write,
conn_read_recv,
next_packet_identifier: std::num::NonZeroU16::MIN,
@@ -257,6 +273,8 @@ impl MqttClient {
};
}
+ let keep_alive = connect_client_state.keep_alive;
+
inner.connection_state = Some(connect_client_state);
inner.session_state = Some(SessionState {
client_identifier,
@@ -267,11 +285,34 @@ impl MqttClient {
crate::packets::connack::ConnackPropertiesView::try_from(maybe_connack)
.expect("An already matched value suddenly changed?");
- let background_task = crate::client::receive::handle_background_receiving(
- inner_clone,
- conn_read,
- conn_read_sender,
- )
+ let background_task = async move {
+ let receiving_inner = inner_clone.clone();
+ let receiving = crate::client::receive::handle_background_receiving(
+ receiving_inner,
+ conn_read,
+ conn_read_sender,
+ );
+
+ let heartbeat_inner = inner_clone;
+
+ let heartbeat = if let KeepAlive::Seconds(time) = keep_alive {
+ handle_heartbeats(
+ heartbeat_receiver,
+ Duration::from_secs(time.get().into()),
+ heartbeat_inner,
+ )
+ .left_future()
+ } else {
+ tracing::info!(
+ "Keep Alive is disabled, will not send PingReq packets automatically"
+ );
+ futures::future::ok(()).right_future()
+ };
+
+ tokio::try_join!(receiving, heartbeat)
+ .map(drop)
+ .map_err(drop)
+ }
.boxed();
return Ok(Connected {
@@ -285,3 +326,34 @@ impl MqttClient {
todo!()
}
}
+
+async fn handle_heartbeats(
+ mut heartbeat_receiver: futures::channel::mpsc::Receiver<()>,
+ duration: Duration,
+ heartbeat_inner: std::sync::Arc<futures::lock::Mutex<super::InnerClient>>,
+) -> Result<(), ()> {
+ let mut timeout = futures_timer::Delay::new(duration).fuse();
+ loop {
+ select! {
+ heartbeat = heartbeat_receiver.next() => match heartbeat {
+ None => break,
+ Some(_) => {
+ timeout = futures_timer::Delay::new(duration).fuse();
+ },
+ },
+ _ = timeout => {
+ let mut inner = heartbeat_inner.lock().await;
+ let inner = &mut *inner;
+ let Some(conn_state) = inner.connection_state.as_mut() else {
+ todo!();
+ };
+
+ // We make sure that this won't deadlock in the send method
+ conn_state.conn_write.send(
+ mqtt_format::v5::packets::MqttPacket::Pingreq(mqtt_format::v5::packets::pingreq::MPingreq)
+ ).await.unwrap();
+ }
+ }
+ }
+ Ok(())
+}
diff --git a/src/client/state.rs b/src/client/state.rs
index 40f91d8..0c707d3 100644
--- a/src/client/state.rs
+++ b/src/client/state.rs
@@ -6,14 +6,48 @@
use std::num::NonZeroU16;
+use futures::SinkExt;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use crate::codecs::MqttPacketCodec;
+use crate::codecs::MqttPacketCodecError;
+use crate::keep_alive::KeepAlive;
use crate::packet_identifier::PacketIdentifier;
use crate::string::MqttString;
use crate::transport::MqttConnection;
+pub(super) struct TransportWriter {
+ conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
+ notify: futures::channel::mpsc::Sender<()>,
+}
+
+impl TransportWriter {
+ pub(super) fn new(
+ conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
+ notify: futures::channel::mpsc::Sender<()>,
+ ) -> Self {
+ Self { conn, notify }
+ }
+
+ pub(super) async fn send(
+ &mut self,
+ packet: mqtt_format::v5::packets::MqttPacket<'_>,
+ ) -> Result<(), MqttPacketCodecError> {
+ self.conn.send(packet).await?;
+ if let Err(e) = self.notify.try_send(()) {
+ if e.is_full() {
+ // This is fine, we are already notifying of a send
+ }
+ if e.is_disconnected() {
+ todo!("Could not send to heartbeat!?")
+ }
+ }
+
+ Ok(())
+ }
+}
+
pub(super) struct ConnectState {
pub(super) session_present: bool,
pub(super) receive_maximum: Option<NonZeroU16>,
@@ -21,13 +55,14 @@ pub(super) struct ConnectState {
pub(super) retain_available: Option<bool>,
pub(super) topic_alias_maximum: Option<u16>,
pub(super) maximum_packet_size: Option<u32>,
- pub(super) conn_write: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
+ pub(super) conn_write: TransportWriter,
pub(super) conn_read_recv: futures::channel::oneshot::Receiver<
FramedRead<tokio::io::ReadHalf<MqttConnection>, MqttPacketCodec>,
>,
pub(super) next_packet_identifier: std::num::NonZeroU16,
+ pub(crate) keep_alive: KeepAlive,
}
pub(super) struct SessionState {
diff --git a/src/keep_alive.rs b/src/keep_alive.rs
index 12cf6e9..a9cea33 100644
--- a/src/keep_alive.rs
+++ b/src/keep_alive.rs
@@ -7,6 +7,7 @@
use std::num::NonZeroU16;
use std::time::Duration;
+#[derive(Debug, Clone, Copy)]
pub enum KeepAlive {
Disabled,
Seconds(NonZeroU16),