Fix read (#9)

* Add receive function
This commit is contained in:
obabec 2022-03-17 09:16:59 +01:00 committed by GitHub
parent b58e8318b6
commit 19087016a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 202 additions and 18 deletions

View File

@ -42,7 +42,9 @@ use crate::utils::types::BufferError;
use heapless::Vec; use heapless::Vec;
use rand_core::RngCore; use rand_core::RngCore;
use crate::encoding::variable_byte_integer::{VariableByteInteger, VariableByteIntegerDecoder, VariableByteIntegerEncoder};
use crate::network::NetworkError::Connection; use crate::network::NetworkError::Connection;
use crate::utils::buffer_writer::BuffWriter;
pub struct MqttClientV5<'a, T, const MAX_PROPERTIES: usize> { pub struct MqttClientV5<'a, T, const MAX_PROPERTIES: usize> {
connection: Option<T>, connection: Option<T>,
@ -106,13 +108,15 @@ where
//connack //connack
let reason: Result<u8, BufferError> = { let reason: Result<u8, BufferError> = {
trace!("Waiting for connack"); trace!("Waiting for connack");
conn.receive(self.recv_buffer).await?;
let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new(); let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new();
if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
if err == BufferError::PacketTypeMismatch { if err == BufferError::PacketTypeMismatch {
let mut disc = DisconnectPacket::<'b, MAX_PROPERTIES>::new(); let mut disc = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
if disc if disc
.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) .decode(&mut BuffReader::new(self.buffer, read))
.is_ok() .is_ok()
{ {
error!("Client was disconnected with reason: "); error!("Client was disconnected with reason: ");
@ -196,9 +200,10 @@ where
{ {
let reason: Result<[u16; 2], BufferError> = { let reason: Result<[u16; 2], BufferError> = {
trace!("Waiting for ack"); trace!("Waiting for ack");
conn.receive(self.recv_buffer).await?; let read = receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await?;
trace!("[PUBACK] Received packet with len");
let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new(); let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new();
if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read))
{ {
Err(err) Err(err)
} else { } else {
@ -252,10 +257,10 @@ where
conn.send(&self.buffer[0..len.unwrap()]).await?; conn.send(&self.buffer[0..len.unwrap()]).await?;
let reason: Result<Vec<u8, TOPICS>, BufferError> = { let reason: Result<Vec<u8, TOPICS>, BufferError> = {
conn.receive(self.recv_buffer).await?; let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
let mut packet = SubackPacket::<'b, TOPICS, MAX_PROPERTIES>::new(); let mut packet = SubackPacket::<'b, TOPICS, MAX_PROPERTIES>::new();
if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
Err(err) Err(err)
} else { } else {
Ok(packet.reason_codes) Ok(packet.reason_codes)
@ -302,10 +307,10 @@ where
conn.send(&self.buffer[0..len.unwrap()]).await?; conn.send(&self.buffer[0..len.unwrap()]).await?;
let reason: Result<u8, BufferError> = { let reason: Result<u8, BufferError> = {
conn.receive(self.recv_buffer).await?; let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
let mut packet = SubackPacket::<'b, 5, MAX_PROPERTIES>::new(); let mut packet = SubackPacket::<'b, 5, MAX_PROPERTIES>::new();
if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
Err(err) Err(err)
} else { } else {
Ok(*packet.reason_codes.get(0).unwrap()) Ok(*packet.reason_codes.get(0).unwrap())
@ -330,15 +335,16 @@ where
return Err(ReasonCode::NetworkError); return Err(ReasonCode::NetworkError);
} }
let mut conn = self.connection.as_mut().unwrap(); let mut conn = self.connection.as_mut().unwrap();
conn.receive(self.recv_buffer).await?; let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
let mut packet = PublishPacket::<'b, 5>::new(); let mut packet = PublishPacket::<'b, 5>::new();
if let Err(err) = if let Err(err) = {
packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) packet.decode(&mut BuffReader::new(self.buffer, read))
}
{ {
if err == BufferError::PacketTypeMismatch { if err == BufferError::PacketTypeMismatch {
let mut disc = DisconnectPacket::<'b, 5>::new(); let mut disc = DisconnectPacket::<'b, 5>::new();
if disc if disc.decode(&mut BuffReader::new(self.buffer, read))
.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len))
.is_ok() .is_ok()
{ {
error!("Client was disconnected with reason: "); error!("Client was disconnected with reason: ");
@ -356,12 +362,12 @@ where
puback.packet_identifier = packet.packet_identifier; puback.packet_identifier = packet.packet_identifier;
puback.reason_code = 0x00; puback.reason_code = 0x00;
{ {
let len = puback.encode(self.buffer, self.buffer_len); let len = { puback.encode(self.recv_buffer, self.recv_buffer_len) };
if let Err(err) = len { if let Err(err) = len {
error!("[DECODE ERR]: {}", err); error!("[DECODE ERR]: {}", err);
return Err(ReasonCode::BuffError); return Err(ReasonCode::BuffError);
} }
conn.send(&self.buffer[0..len.unwrap()]).await?; conn.send(&self.recv_buffer[0..len.unwrap()]).await?;
} }
} }
@ -385,9 +391,9 @@ where
conn.send(&self.buffer[0..len.unwrap()]).await?; conn.send(&self.buffer[0..len.unwrap()]).await?;
conn.receive(self.recv_buffer).await?; let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
let mut packet = PingrespPacket::new(); let mut packet = PingrespPacket::new();
if let Err(err) = packet.decode(&mut BuffReader::new(self.recv_buffer, self.recv_buffer_len)) { if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
error!("[DECODE ERR]: {}", err); error!("[DECODE ERR]: {}", err);
return Err(ReasonCode::BuffError); return Err(ReasonCode::BuffError);
} else { } else {
@ -395,3 +401,36 @@ where
} }
} }
} }
async fn receive_packet<'c, T:NetworkConnection>(buffer: & mut [u8],buffer_len: usize, recv_buffer: & mut [u8], conn: &'c mut T) -> Result<usize, ReasonCode> {
let mut target_len = 0;
let mut rem_len: VariableByteInteger = [0; 4];
let mut rem_len_len: usize = 0;
let mut complete_len: bool = false;
let mut writer = BuffWriter::new(buffer, buffer_len);
loop {
let len: usize = conn.receive(recv_buffer).await?;
if len > 0 {
writer.insert_ref(len, &recv_buffer);
if writer.position >= 1 && target_len == 0 {
let tmp_rem_len = writer.get_rem_len();
if tmp_rem_len.is_err() {
continue;
}
rem_len = tmp_rem_len.unwrap();
rem_len_len = VariableByteIntegerEncoder::len(rem_len);
if let Ok(res) = VariableByteIntegerDecoder::decode(rem_len) {
target_len = res as usize;
} else {
return Err(ReasonCode::BuffError);
}
}
if target_len != 0 && (target_len + rem_len_len + 1) >= writer.position {
trace!("Just read packet with len {}", (target_len + rem_len_len + 1));
return Ok(target_len + rem_len_len + 1);
}
}
}
}

View File

@ -27,6 +27,8 @@ use crate::utils::buffer_writer::BuffWriter;
use crate::utils::types::{BinaryData, BufferError, EncodedString, StringPair, TopicFilter}; use crate::utils::types::{BinaryData, BufferError, EncodedString, StringPair, TopicFilter};
use heapless::Vec; use heapless::Vec;
use tokio_test::{assert_err, assert_ok};
use crate::encoding::variable_byte_integer::VariableByteInteger;
#[test] #[test]
fn buffer_write_ref() { fn buffer_write_ref() {
@ -380,3 +382,119 @@ fn buffer_write_filters_oob() {
assert!(test_write.is_err()); assert!(test_write.is_err());
assert_eq!(test_write.unwrap_err(), BufferError::InsufficientBufferSize) assert_eq!(test_write.unwrap_err(), BufferError::InsufficientBufferSize)
} }
#[test]
fn buffer_get_rem_len_one() {
static BUFFER: [u8; 5] = [0x82, 0x02, 0x03, 0x85, 0x84];
static REF: VariableByteInteger = [0x02, 0x00, 0x00, 0x00];
let mut res_buffer: [u8; 5] = [0; 5];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5);
let test_write = writer.insert_ref(5, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_two() {
static BUFFER: [u8; 5] = [0x82, 0x82, 0x03, 0x85, 0x84];
static REF: VariableByteInteger = [0x82, 0x03, 0x00, 0x00];
let mut res_buffer: [u8; 5] = [0; 5];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5);
let test_write = writer.insert_ref(5, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_three() {
static BUFFER: [u8; 5] = [0x82, 0x82, 0x83, 0x05, 0x84];
static REF: VariableByteInteger = [0x82, 0x83, 0x05, 0x00];
let mut res_buffer: [u8; 5] = [0; 5];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5);
let test_write = writer.insert_ref(5, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_all() {
static BUFFER: [u8; 5] = [0x82, 0x82, 0x83, 0x85, 0x04];
static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x04];
let mut res_buffer: [u8; 5] = [0; 5];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 5);
let test_write = writer.insert_ref(5, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_over() {
static BUFFER: [u8; 6] = [0x82, 0x82, 0x83, 0x85, 0x84, 0x34];
static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x84];
let mut res_buffer: [u8; 6] = [0; 6];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6);
let test_write = writer.insert_ref(6, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_zero_end() {
static BUFFER: [u8; 6] = [0x82, 0x82, 0x83, 0x85, 0x04, 0x34];
static REF: VariableByteInteger = [0x82, 0x83, 0x85, 0x04];
let mut res_buffer: [u8; 6] = [0; 6];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6);
let test_write = writer.insert_ref(6, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_zero() {
static BUFFER: [u8; 6] = [0x82, 0x00, 0x83, 0x85, 0x04, 0x34];
static REF: VariableByteInteger = [0x00, 0x00, 0x00, 0x00];
let mut res_buffer: [u8; 6] = [0; 6];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6);
let test_write = writer.insert_ref(6, &BUFFER);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_ok!(rm_len);
assert_eq!(rm_len.unwrap(), REF);
}
#[test]
fn buffer_get_rem_len_cont() {
static BUFFER: [u8; 6] = [0x82, 0x00, 0x83, 0x85, 0x04, 0x34];
static REF: VariableByteInteger = [0x00, 0x00, 0x00, 0x00];
let mut res_buffer: [u8; 6] = [0; 6];
let mut writer: BuffWriter = BuffWriter::new(&mut res_buffer, 6);
let mut test_write = writer.insert_ref(2, &[0x82, 0x81]);
let rm_len = writer.get_rem_len();
assert_ok!(test_write);
assert_err!(rm_len);
test_write = writer.insert_ref(2, &[0x82, 0x01]);
let rm_len_sec = writer.get_rem_len();
assert_ok!(rm_len_sec);
assert_eq!(rm_len_sec.unwrap(), [0x81, 0x82, 0x01, 0x00]);
}

View File

@ -48,6 +48,32 @@ impl<'a> BuffWriter<'a> {
self.position = self.position + increment; self.position = self.position + increment;
} }
pub fn get_n_byte(& mut self, n: usize) -> u8 {
if self.position >= n {
return self.buffer[n]
}
return 0
}
pub fn get_rem_len(& mut self) -> Result<VariableByteInteger, ()> {
let mut max = if self.position >= 5 {4} else {self.position - 1};
let mut i = 1;
let mut len: VariableByteInteger = [0; 4];
loop {
len[i - 1] = self.buffer[i];
if len[i - 1] & 0x80 == 0 {
return Ok(len);
}
if len[i - 1] & 0x80 != 0 && i == max && i != 4 {
return Err(());
}
if i == max {
return Ok(len);
}
i = i + 1;
}
}
pub fn insert_ref(&mut self, len: usize, array: &[u8]) -> Result<(), BufferError> { pub fn insert_ref(&mut self, len: usize, array: &[u8]) -> Result<(), BufferError> {
let mut x: usize = 0; let mut x: usize = 0;
if self.position + len > self.len { if self.position + len > self.len {

View File

@ -74,6 +74,7 @@ async fn publish_core<'b>(
info!("[Publisher] Sending new message {} to topic {}", MSG, topic); info!("[Publisher] Sending new message {} to topic {}", MSG, topic);
result = { client.send_message(topic, MSG).await }; result = { client.send_message(topic, MSG).await };
info!("[PUBLISHER] sent");
assert_ok!(result); assert_ok!(result);
info!("[Publisher] Disconnecting!"); info!("[Publisher] Disconnecting!");