From d224254d514508d46bd0fcd9cef17126c2e546a4 Mon Sep 17 00:00:00 2001 From: Matous Hybl Date: Mon, 26 Sep 2022 10:37:17 +0200 Subject: [PATCH] Improve send_message, recv_message API, add support for encoding last will, fix decoding of QOS (#24) * recv_message now returns topic as well as data * send_message accepts byte slice instead of &str * Add last will encoding * Fix QOS decoding --- src/client/client.rs | 20 +++++++++++++------ src/client/client_config.rs | 33 ++++++++++++++++++++++++++++++++ src/packet/v5/connect_packet.rs | 13 ++++++++++++- src/packet/v5/publish_packet.rs | 4 ++-- tests/integration_test_single.rs | 14 +++++++------- tests/load_test.rs | 4 ++-- 6 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index c3548fd..338f011 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -98,6 +98,14 @@ where if self.config.password_flag { connect.add_password(&self.config.password) } + if self.config.will_flag { + connect.add_will( + &self.config.will_topic, + &self.config.will_payload, + self.config.will_retain, + ) + } + connect.add_client_id(&self.config.client_id); connect.encode(self.buffer, self.buffer_len) }; @@ -191,7 +199,7 @@ where async fn send_message_v5<'b>( &'b mut self, topic_name: &'b str, - message: &'b str, + message: &'b [u8], ) -> Result<(), ReasonCode> { if self.connection.is_none() { return Err(ReasonCode::NetworkError); @@ -204,7 +212,7 @@ where packet.add_topic_name(topic_name); packet.add_qos(self.config.qos); packet.add_identifier(identifier); - packet.add_message(message.as_bytes()); + packet.add_message(message); packet.encode(self.buffer, self.buffer_len) }; @@ -255,7 +263,7 @@ where pub async fn send_message<'b>( &'b mut self, topic_name: &'b str, - message: &'b str, + message: &'b [u8], ) -> Result<(), ReasonCode> { match self.config.mqtt_version { MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion), @@ -451,7 +459,7 @@ where } } - async fn receive_message_v5<'b>(&'b mut self) -> Result<&'b [u8], ReasonCode> { + async fn receive_message_v5<'b>(&'b mut self) -> Result<(&'b str, &'b [u8]), ReasonCode> { if self.connection.is_none() { return Err(ReasonCode::NetworkError); } @@ -487,13 +495,13 @@ where } } - return Ok(packet.message.unwrap()); + return Ok((packet.topic_name.string, packet.message.unwrap())); } /// Method allows client receive a message. The work of this method strictly depends on the /// network implementation passed in the `ClientConfig`. It expects the PUBLISH packet /// from the broker. - pub async fn receive_message<'b>(&'b mut self) -> Result<&'b [u8], ReasonCode> { + pub async fn receive_message<'b>(&'b mut self) -> Result<(&'b str, &'b [u8]), ReasonCode> { match self.config.mqtt_version { MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion), MqttVersion::MQTTv5 => self.receive_message_v5().await, diff --git a/src/client/client_config.rs b/src/client/client_config.rs index 304a5f8..7b77455 100644 --- a/src/client/client_config.rs +++ b/src/client/client_config.rs @@ -56,6 +56,11 @@ pub struct ClientConfig<'a, const MAX_PROPERTIES: usize, T: RngCore> { pub max_packet_size: u32, pub mqtt_version: MqttVersion, pub rng: T, + pub will_flag: bool, + pub will_topic: EncodedString<'a>, + pub will_payload: BinaryData<'a>, + pub will_retain: bool, + pub client_id: EncodedString<'a>, } impl<'a, const MAX_PROPERTIES: usize, T: RngCore> ClientConfig<'a, MAX_PROPERTIES, T> { @@ -71,6 +76,11 @@ impl<'a, const MAX_PROPERTIES: usize, T: RngCore> ClientConfig<'a, MAX_PROPERTIE max_packet_size: 265_000, mqtt_version: version, rng, + will_flag: false, + will_topic: EncodedString::new(), + will_payload: BinaryData::new(), + will_retain: false, + client_id: EncodedString::new(), } } @@ -78,6 +88,21 @@ impl<'a, const MAX_PROPERTIES: usize, T: RngCore> ClientConfig<'a, MAX_PROPERTIE self.qos = qos; } + pub fn add_will(&mut self, topic: &'a str, payload: &'a [u8], retain: bool) { + let mut topic_s = EncodedString::new(); + topic_s.string = topic; + topic_s.len = topic.len() as u16; + + let mut payload_d = BinaryData::new(); + payload_d.bin = payload; + payload_d.len = payload.len() as u16; + + self.will_flag = true; + self.will_retain = retain; + self.will_topic = topic_s; + self.will_payload = payload_d; + } + /// Method adds the username array and also sets the username flag so client /// will use it for the authentication pub fn add_username(&mut self, username: &'a str) { @@ -113,4 +138,12 @@ impl<'a, const MAX_PROPERTIES: usize, T: RngCore> ClientConfig<'a, MAX_PROPERTIE } return 0; } + + pub fn add_client_id(&mut self, client_id: &'a str) { + let mut client_id_s = EncodedString::new(); + client_id_s.string = client_id; + client_id_s.len = client_id.len() as u16; + + self.client_id = client_id_s + } } diff --git a/src/packet/v5/connect_packet.rs b/src/packet/v5/connect_packet.rs index e29182e..babec24 100644 --- a/src/packet/v5/connect_packet.rs +++ b/src/packet/v5/connect_packet.rs @@ -98,6 +98,15 @@ impl<'a, const MAX_PROPERTIES: usize, const MAX_WILL_PROPERTIES: usize> self.connect_flags = self.connect_flags | 0x40; } + pub fn add_will(&mut self, topic: &EncodedString<'a>, payload: &BinaryData<'a>, retain: bool) { + self.will_topic = topic.clone(); + self.will_payload = payload.clone(); + self.connect_flags |= 0x04; + if retain { + self.connect_flags |= 0x20; + } + } + pub fn add_client_id(&mut self, id: &EncodedString<'a>) { self.client_id = (*id).clone(); } @@ -145,7 +154,9 @@ impl<'a, const MAX_PROPERTIES: usize, const MAX_WILL_PROPERTIES: usize> Packet<' + wil_prop_len_len as u32 + self.will_property_len as u32 + self.will_topic.len as u32 - + self.will_payload.len as u32; + + 2 + + self.will_payload.len as u32 + + 2; } if (self.connect_flags & 0x80) != 0 { rm_ln = rm_ln + self.username.len as u32 + 2; diff --git a/src/packet/v5/publish_packet.rs b/src/packet/v5/publish_packet.rs index 7dc503d..4b15d3a 100644 --- a/src/packet/v5/publish_packet.rs +++ b/src/packet/v5/publish_packet.rs @@ -116,7 +116,7 @@ impl<'a, const MAX_PROPERTIES: usize> Packet<'a> for PublishPacket<'a, MAX_PROPE rm_ln = rm_ln + property_len_len as u32 + msg_len + self.topic_name.len as u32 + 2; buff_writer.write_u8(self.fixed_header)?; - let qos = self.fixed_header & 0x03; + let qos = self.fixed_header & 0x06; if qos != 0 { rm_ln = rm_ln + 2; } @@ -140,7 +140,7 @@ impl<'a, const MAX_PROPERTIES: usize> Packet<'a> for PublishPacket<'a, MAX_PROPE return Err(BufferError::PacketTypeMismatch); } self.topic_name = buff_reader.read_string()?; - let qos = self.fixed_header & 0x03; + let qos = self.fixed_header & 0x06; if qos != 0 { // Decode only for QoS 1 / 2 self.packet_identifier = buff_reader.read_u16()?; diff --git a/tests/integration_test_single.rs b/tests/integration_test_single.rs index 3303f54..3c05cc3 100644 --- a/tests/integration_test_single.rs +++ b/tests/integration_test_single.rs @@ -80,7 +80,7 @@ async fn publish_core<'b>( "[Publisher] Sending new message {} to topic {}", message, topic ); - result = client.send_message(topic, message).await; + result = client.send_message(topic, message.as_bytes()).await; info!("[PUBLISHER] sent"); if err == true { assert_err!(result); @@ -174,7 +174,7 @@ async fn receive_core<'b>( info!("[Receiver] Waiting for new message!"); let msg = client.receive_message().await; assert_ok!(msg); - let act_message = String::from_utf8_lossy(msg?); + let act_message = String::from_utf8_lossy(msg?.1); info!("[Receiver] Got new message: {}", act_message); assert_eq!(act_message, MSG); @@ -208,14 +208,14 @@ async fn receive_core_multiple<'b, const TOPICS: usize>( { let msg = client.receive_message().await; assert_ok!(msg); - let act_message = String::from_utf8_lossy(msg?); + let act_message = String::from_utf8_lossy(msg?.1); info!("[Receiver] Got new message: {}", act_message); assert_eq!(act_message, MSG); } { let msg_sec = client.receive_message().await; assert_ok!(msg_sec); - let act_message_second = String::from_utf8_lossy(msg_sec?); + let act_message_second = String::from_utf8_lossy(msg_sec?.1); info!("[Receiver] Got new message: {}", act_message_second); assert_eq!(act_message_second, MSG); } @@ -365,14 +365,14 @@ async fn receive_multiple_second_unsub( { let msg = { client.receive_message().await }; assert_ok!(msg); - let act_message = String::from_utf8_lossy(msg?); + let act_message = String::from_utf8_lossy(msg?.1); info!("[Receiver] Got new message: {}", act_message); assert_eq!(act_message, msg_t1); } { let msg_sec = { client.receive_message().await }; assert_ok!(msg_sec); - let act_message_second = String::from_utf8_lossy(msg_sec?); + let act_message_second = String::from_utf8_lossy(msg_sec?.1); info!("[Receiver] Got new message: {}", act_message_second); assert_eq!(act_message_second, msg_t2); } @@ -386,7 +386,7 @@ async fn receive_multiple_second_unsub( { let msg = { client.receive_message().await }; assert_ok!(msg); - let act_message = String::from_utf8_lossy(msg?); + let act_message = String::from_utf8_lossy(msg?.1); info!("[Receiver] Got new message: {}", act_message); assert_eq!(act_message, msg_t1); } diff --git a/tests/load_test.rs b/tests/load_test.rs index ebf6253..50ee926 100644 --- a/tests/load_test.rs +++ b/tests/load_test.rs @@ -78,7 +78,7 @@ async fn publish_core<'b>( info!("[Publisher] Sending new message {} to topic {}", MSG, topic); let mut count = 0; loop { - result = client.send_message(topic, MSG).await; + result = client.send_message(topic, MSG.as_bytes()).await; info!("[PUBLISHER] sent {}", count); assert_ok!(result); count = count + 1; @@ -145,7 +145,7 @@ async fn receive_core<'b>( loop { let msg = client.receive_message().await; assert_ok!(msg); - let act_message = String::from_utf8_lossy(msg?); + let act_message = String::from_utf8_lossy(msg?.1); info!("[Receiver] Got new {}. message: {}", count, act_message); assert_eq!(act_message, MSG); count = count + 1;