From 19087016a64685b678f73b50b851169306ea6230 Mon Sep 17 00:00:00 2001 From: obabec Date: Thu, 17 Mar 2022 09:16:59 +0100 Subject: [PATCH] Fix read (#9) * Add receive function --- mqtt/src/client/client_v5.rs | 75 ++++++++--- .../tests/unit/utils/buffer_writer_unit.rs | 118 ++++++++++++++++++ mqtt/src/utils/buffer_writer.rs | 26 ++++ mqtt/tests/integration_test_single.rs | 1 + 4 files changed, 202 insertions(+), 18 deletions(-) diff --git a/mqtt/src/client/client_v5.rs b/mqtt/src/client/client_v5.rs index 34052ab..2a623c8 100644 --- a/mqtt/src/client/client_v5.rs +++ b/mqtt/src/client/client_v5.rs @@ -42,7 +42,9 @@ use crate::utils::types::BufferError; use heapless::Vec; use rand_core::RngCore; +use crate::encoding::variable_byte_integer::{VariableByteInteger, VariableByteIntegerDecoder, VariableByteIntegerEncoder}; use crate::network::NetworkError::Connection; +use crate::utils::buffer_writer::BuffWriter; pub struct MqttClientV5<'a, T, const MAX_PROPERTIES: usize> { connection: Option, @@ -106,13 +108,15 @@ where //connack let reason: Result = { trace!("Waiting for connack"); - conn.receive(self.recv_buffer).await?; + + let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? }; + let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new(); - if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { + if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) { if err == BufferError::PacketTypeMismatch { let mut disc = DisconnectPacket::<'b, MAX_PROPERTIES>::new(); if disc - .decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) + .decode(&mut BuffReader::new(self.buffer, read)) .is_ok() { error!("Client was disconnected with reason: "); @@ -196,9 +200,10 @@ where { let reason: Result<[u16; 2], BufferError> = { trace!("Waiting for ack"); - conn.receive(self.recv_buffer).await?; + let read = receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await?; + trace!("[PUBACK] Received packet with len"); let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new(); - if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) + if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) { Err(err) } else { @@ -252,10 +257,10 @@ where conn.send(&self.buffer[0..len.unwrap()]).await?; let reason: Result, BufferError> = { - conn.receive(self.recv_buffer).await?; + let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? }; let mut packet = SubackPacket::<'b, TOPICS, MAX_PROPERTIES>::new(); - if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { + if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) { Err(err) } else { Ok(packet.reason_codes) @@ -302,10 +307,10 @@ where conn.send(&self.buffer[0..len.unwrap()]).await?; let reason: Result = { - conn.receive(self.recv_buffer).await?; + let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? }; let mut packet = SubackPacket::<'b, 5, MAX_PROPERTIES>::new(); - if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { + if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) { Err(err) } else { Ok(*packet.reason_codes.get(0).unwrap()) @@ -330,15 +335,16 @@ where return Err(ReasonCode::NetworkError); } let mut conn = self.connection.as_mut().unwrap(); - conn.receive(self.recv_buffer).await?; + let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? }; let mut packet = PublishPacket::<'b, 5>::new(); - if let Err(err) = - packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) + if let Err(err) = { + packet.decode(&mut BuffReader::new(self.buffer, read)) + } + { if err == BufferError::PacketTypeMismatch { let mut disc = DisconnectPacket::<'b, 5>::new(); - if disc - .decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) + if disc.decode(&mut BuffReader::new(self.buffer, read)) .is_ok() { error!("Client was disconnected with reason: "); @@ -356,12 +362,12 @@ where puback.packet_identifier = packet.packet_identifier; puback.reason_code = 0x00; { - let len = puback.encode(self.buffer, self.buffer_len); + let len = { puback.encode(self.recv_buffer, self.recv_buffer_len) }; if let Err(err) = len { error!("[DECODE ERR]: {}", err); return Err(ReasonCode::BuffError); } - conn.send(&self.buffer[0..len.unwrap()]).await?; + conn.send(&self.recv_buffer[0..len.unwrap()]).await?; } } @@ -385,9 +391,9 @@ where conn.send(&self.buffer[0..len.unwrap()]).await?; - conn.receive(self.recv_buffer).await?; + let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? }; let mut packet = PingrespPacket::new(); - if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { + if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) { error!("[DECODE ERR]: {}", err); return Err(ReasonCode::BuffError); } else { @@ -395,3 +401,36 @@ where } } } + +async fn receive_packet<'c, T:NetworkConnection>(buffer: & mut [u8],buffer_len: usize, recv_buffer: & mut [u8], conn: &'c mut T) -> Result { + let mut target_len = 0; + let mut rem_len: VariableByteInteger = [0; 4]; + let mut rem_len_len: usize = 0; + let mut complete_len: bool = false; + let mut writer = BuffWriter::new(buffer, buffer_len); + loop { + let len: usize = conn.receive(recv_buffer).await?; + if len > 0 { + writer.insert_ref(len, &recv_buffer); + + if writer.position >= 1 && target_len == 0 { + let tmp_rem_len = writer.get_rem_len(); + if tmp_rem_len.is_err() { + continue; + } + rem_len = tmp_rem_len.unwrap(); + rem_len_len = VariableByteIntegerEncoder::len(rem_len); + if let Ok(res) = VariableByteIntegerDecoder::decode(rem_len) { + target_len = res as usize; + } else { + return Err(ReasonCode::BuffError); + } + } + + if target_len != 0 && (target_len + rem_len_len + 1) >= writer.position { + trace!("Just read packet with len {}", (target_len + rem_len_len + 1)); + return Ok(target_len + rem_len_len + 1); + } + } + } +} \ No newline at end of file diff --git a/mqtt/src/tests/unit/utils/buffer_writer_unit.rs b/mqtt/src/tests/unit/utils/buffer_writer_unit.rs index e60c7d3..58ac4cd 100644 --- a/mqtt/src/tests/unit/utils/buffer_writer_unit.rs +++ b/mqtt/src/tests/unit/utils/buffer_writer_unit.rs @@ -27,6 +27,8 @@ use crate::utils::buffer_writer::BuffWriter; use crate::utils::types::{BinaryData, BufferError, EncodedString, StringPair, TopicFilter}; use heapless::Vec; +use tokio_test::{assert_err, assert_ok}; +use crate::encoding::variable_byte_integer::VariableByteInteger; #[test] fn buffer_write_ref() { @@ -380,3 +382,119 @@ fn buffer_write_filters_oob() { assert!(test_write.is_err()); assert_eq!(test_write.unwrap_err(), BufferError::InsufficientBufferSize) } + +#[test] +fn buffer_get_rem_len_one() { + static BUFFER: [u8; 5] = [0x82, 0x02, 0x03, 0x85, 0x84]; + static REF: VariableByteInteger = [0x02, 0x00, 0x00, 0x00]; + let mut res_buffer: [u8; 5] = [0; 5]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5); + let test_write = writer.insert_ref(5, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_two() { + static BUFFER: [u8; 5] = [0x82, 0x82, 0x03, 0x85, 0x84]; + static REF: VariableByteInteger = [0x82, 0x03, 0x00, 0x00]; + let mut res_buffer: [u8; 5] = [0; 5]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5); + let test_write = writer.insert_ref(5, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + + +#[test] +fn buffer_get_rem_len_three() { + static BUFFER: [u8; 5] = [0x82, 0x82, 0x83, 0x05, 0x84]; + static REF: VariableByteInteger = [0x82, 0x83, 0x05, 0x00]; + let mut res_buffer: [u8; 5] = [0; 5]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5); + let test_write = writer.insert_ref(5, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_all() { + static BUFFER: [u8; 5] = [0x82, 0x82, 0x83, 0x85, 0x04]; + static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x04]; + let mut res_buffer: [u8; 5] = [0; 5]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5); + let test_write = writer.insert_ref(5, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_over() { + static BUFFER: [u8; 6] = [0x82, 0x82, 0x83, 0x85, 0x84, 0x34]; + static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x84]; + let mut res_buffer: [u8; 6] = [0; 6]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6); + let test_write = writer.insert_ref(6, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_zero_end() { + static BUFFER: [u8; 6] = [0x82, 0x82, 0x83, 0x85, 0x04, 0x34]; + static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x04]; + let mut res_buffer: [u8; 6] = [0; 6]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6); + let test_write = writer.insert_ref(6, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_zero() { + static BUFFER: [u8; 6] = [0x82, 0x00, 0x83, 0x85, 0x04, 0x34]; + static REF: VariableByteInteger = [0x00, 0x00, 0x00, 0x00]; + let mut res_buffer: [u8; 6] = [0; 6]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6); + let test_write = writer.insert_ref(6, &BUFFER); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_ok!(rm_len); + assert_eq!(rm_len.unwrap(), REF); +} + +#[test] +fn buffer_get_rem_len_cont() { + static BUFFER: [u8; 6] = [0x82, 0x00, 0x83, 0x85, 0x04, 0x34]; + static REF: VariableByteInteger = [0x00, 0x00, 0x00, 0x00]; + let mut res_buffer: [u8; 6] = [0; 6]; + + let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6); + let mut test_write = writer.insert_ref(2, &[0x82, 0x81]); + let rm_len = writer.get_rem_len(); + assert_ok!(test_write); + assert_err!(rm_len); + test_write = writer.insert_ref(2, &[0x82, 0x01]); + let rm_len_sec = writer.get_rem_len(); + assert_ok!(rm_len_sec); + assert_eq!(rm_len_sec.unwrap(), [0x81, 0x82, 0x01, 0x00]); +} \ No newline at end of file diff --git a/mqtt/src/utils/buffer_writer.rs b/mqtt/src/utils/buffer_writer.rs index 6b74faf..5dab43d 100644 --- a/mqtt/src/utils/buffer_writer.rs +++ b/mqtt/src/utils/buffer_writer.rs @@ -48,6 +48,32 @@ impl<'a> BuffWriter<'a> { self.position = self.position + increment; } + pub fn get_n_byte(& mut self, n: usize) -> u8 { + if self.position >= n { + return self.buffer[n] + } + return 0 + } + + pub fn get_rem_len(& mut self) -> Result { + let mut max = if self.position >= 5 {4} else {self.position - 1}; + let mut i = 1; + let mut len: VariableByteInteger = [0; 4]; + loop { + len[i - 1] = self.buffer[i]; + if len[i - 1] & 0x80 == 0 { + return Ok(len); + } + if len[i - 1] & 0x80 != 0 && i == max && i != 4 { + return Err(()); + } + if i == max { + return Ok(len); + } + i = i + 1; + } + } + pub fn insert_ref(&mut self, len: usize, array: &[u8]) -> Result<(), BufferError> { let mut x: usize = 0; if self.position + len > self.len { diff --git a/mqtt/tests/integration_test_single.rs b/mqtt/tests/integration_test_single.rs index 6810572..e8faa90 100644 --- a/mqtt/tests/integration_test_single.rs +++ b/mqtt/tests/integration_test_single.rs @@ -74,6 +74,7 @@ async fn publish_core<'b>( info!("[Publisher] Sending new message {} to topic {}", MSG, topic); result = { client.send_message(topic, MSG).await }; + info!("[PUBLISHER] sent"); assert_ok!(result); info!("[Publisher] Disconnecting!");