Improved error handling and Updater now functions as a wrapper around the FirmwareUpdater

This commit is contained in:
Dreaded_X 2023-09-15 23:00:14 +02:00
parent c279e52e2d
commit 7763269b06
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4

View File

@ -1,14 +1,20 @@
#![no_std] #![no_std]
#![no_main] #![no_main]
#![feature(type_alias_impl_trait)] #![feature(type_alias_impl_trait)]
#![feature(never_type)]
use defmt::*; use defmt::*;
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater}; use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterError};
use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack}; use embassy_net::{
dns::{self, DnsQueryType},
driver::Driver,
tcp::{ConnectError, TcpSocket},
Stack,
};
use embassy_time::{Duration, Timer}; use embassy_time::{Duration, Timer};
use embedded_io_async::{Read, Write}; use embedded_io_async::{Read, ReadExactError, Write};
use embedded_storage::nor_flash::NorFlash; use embedded_storage::nor_flash::{NorFlash, NorFlashError};
use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext, TlsError};
use heapless::Vec; use heapless::Vec;
use nourl::Url; use nourl::Url;
use rand_core::{CryptoRng, RngCore}; use rand_core::{CryptoRng, RngCore};
@ -18,7 +24,7 @@ use reqwless::{
}; };
use rust_mqtt::{ use rust_mqtt::{
client::{client::MqttClient, client_config::ClientConfig}, client::{client::MqttClient, client_config::ClientConfig},
packet::v5::publish_packet::QualityOfService, packet::v5::{publish_packet::QualityOfService, reason_codes::ReasonCode},
}; };
use serde::Serialize; use serde::Serialize;
use static_cell::make_static; use static_cell::make_static;
@ -36,29 +42,109 @@ enum Status<'a> {
} }
impl Status<'_> { impl Status<'_> {
fn vec(&self) -> Vec<u8, 1024> { fn json(&self) -> Vec<u8, 128> {
serde_json_core::to_vec(self) serde_json_core::to_vec(self)
.expect("The buffer should be large enough to contain all the data") .expect("This buffers size should be large enough to contain the serialized status")
}
}
#[derive(Debug)]
pub enum Error<FE: NorFlashError> {
Mqtt(ReasonCode),
Dns(dns::Error),
Connect(ConnectError),
Tls(TlsError),
Reqwless(reqwless::Error),
FirmwareUpdater(FirmwareUpdaterError),
FlashError(FE),
UnexpectedEof,
}
impl<FE: NorFlashError> From<ReasonCode> for Error<FE> {
fn from(error: ReasonCode) -> Self {
Self::Mqtt(error)
}
}
impl<FE: NorFlashError> From<dns::Error> for Error<FE> {
fn from(error: dns::Error) -> Self {
Self::Dns(error)
}
}
impl<FE: NorFlashError> From<ConnectError> for Error<FE> {
fn from(error: ConnectError) -> Self {
Self::Connect(error)
}
}
impl<FE: NorFlashError> From<TlsError> for Error<FE> {
fn from(error: TlsError) -> Self {
Self::Tls(error)
}
}
impl<FE: NorFlashError> From<reqwless::Error> for Error<FE> {
fn from(error: reqwless::Error) -> Self {
Self::Reqwless(error)
}
}
impl<FE: NorFlashError> From<FirmwareUpdaterError> for Error<FE> {
fn from(error: FirmwareUpdaterError) -> Self {
Self::FirmwareUpdater(error)
}
}
impl<FE: NorFlashError> From<ReadExactError<reqwless::Error>> for Error<FE> {
fn from(error: ReadExactError<reqwless::Error>) -> Self {
match error {
ReadExactError::UnexpectedEof => Self::UnexpectedEof,
ReadExactError::Other(error) => Self::Reqwless(error),
}
}
}
impl<FE: NorFlashError + defmt::Format> Format for Error<FE> {
fn format(&self, f: Formatter) {
match self {
Error::Mqtt(error) => defmt::write!(f, "Mqtt: {}", error),
Error::Dns(error) => defmt::write!(f, "Dns: {}", error),
Error::Connect(error) => defmt::write!(f, "Connect: {}", error),
Error::Tls(error) => defmt::write!(f, "Tls: {}", error),
Error::Reqwless(error) => defmt::write!(f, "Reqwless: {}", error),
Error::FirmwareUpdater(error) => defmt::write!(f, "FirmwareUpdater: {}", error),
Error::FlashError(error) => defmt::write!(f, "FlashError: {:?}", error),
Error::UnexpectedEof => defmt::write!(f, "UnexpectedEof"),
}
} }
} }
// TODO: Make this the owner of the blocking firmware updater // TODO: Make this the owner of the blocking firmware updater
// TODO: When fixed, use the async firmware updater // TODO: When fixed, use the async firmware updater
pub struct Updater { pub struct Updater<'a, DFU, STATE>
where
DFU: NorFlash,
STATE: NorFlash,
{
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
topic_status: &'static str, topic_status: &'static str,
topic_update: &'static str, topic_update: &'static str,
version: &'static str, version: &'static str,
public_key: &'static [u8], public_key: &'static [u8],
} }
impl Updater { impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> {
pub fn new( pub fn new(
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
topic_status: &'static str, topic_status: &'static str,
topic_update: &'static str, topic_update: &'static str,
version: &'static str, version: &'static str,
public_key: &'static [u8], public_key: &'static [u8],
) -> Self { ) -> Self {
Self { Self {
updater,
topic_status, topic_status,
topic_update, topic_update,
version, version,
@ -70,42 +156,44 @@ impl Updater {
&self, &self,
config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>, config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>,
) { ) {
let msg = make_static!(Status::Disconnected.vec()); let msg = make_static!(Status::Disconnected.json());
config.add_will(self.topic_status, msg, true); config.add_will(self.topic_status, msg, true);
} }
pub async fn ready<const MAX_PROPERTIES: usize>( pub async fn ready<const MAX_PROPERTIES: usize>(
&self, &mut self,
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
) { ) -> Result<(), Error<DFU::Error>> {
let status = Status::Connected { let status = Status::Connected {
version: self.version, version: self.version,
} }
.vec(); .json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, true) .send_message(self.topic_status, &status, QualityOfService::QoS1, true)
.await .await?;
.unwrap();
client.subscribe_to_topic(self.topic_update).await.unwrap(); client.subscribe_to_topic(self.topic_update).await?;
self.updater.mark_booted()?;
Ok(())
} }
pub async fn update<const MAX_PROPERTIES: usize>( pub async fn update<const MAX_PROPERTIES: usize>(
&self, &mut self,
stack: &'static Stack<impl Driver>, stack: &'static Stack<impl Driver>,
updater: &mut BlockingFirmwareUpdater<'_, impl NorFlash, impl NorFlash>,
rng: &mut (impl RngCore + CryptoRng), rng: &mut (impl RngCore + CryptoRng),
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
url: Url<'_>, url: Url<'_>,
) { ) -> Result<!, Error<DFU::Error>> {
info!("Preparing for OTA..."); info!("Preparing for OTA...");
let status = Status::PreparingUpdate.vec(); let status = Status::PreparingUpdate.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
let ip = stack.dns_query(url.host(), DnsQueryType::A).await.unwrap()[0]; let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0];
let mut rx_buffer = [0; 1024]; let mut rx_buffer = [0; 1024];
let mut tx_buffer = [0; 1024]; let mut tx_buffer = [0; 1024];
@ -114,53 +202,44 @@ impl Updater {
let addr = (ip, url.port_or_default()); let addr = (ip, url.port_or_default());
debug!("Addr: {}", addr); debug!("Addr: {}", addr);
socket.connect(addr).await.unwrap(); socket.connect(addr).await?;
let mut read_record_buffer = [0; 16384 * 2]; let mut read_record_buffer = [0; 16384 * 2];
let mut write_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384];
let mut tls: TlsConnection<TcpSocket, Aes128GcmSha256> = let mut tls: TlsConnection<TcpSocket, Aes128GcmSha256> =
TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer);
tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng)) tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng))
.await .await?;
.unwrap();
debug!("Path: {}", url.path()); debug!("Path: {}", url.path());
Request::get(url.path()) Request::get(url.path())
.host(url.host()) .host(url.host())
.build() .build()
.write(&mut tls) .write(&mut tls)
.await .await?;
.unwrap();
let mut headers = [0; 1024]; let mut headers = [0; 1024];
let resp = Response::read(&mut tls, Method::GET, &mut headers) let resp = Response::read(&mut tls, Method::GET, &mut headers).await?;
.await
.unwrap();
let mut body = resp.body().reader(); let mut body = resp.body().reader();
debug!("Erasing flash..."); debug!("Erasing flash...");
let status = Status::Erasing.vec(); let status = Status::Erasing.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
let writer = updater let writer = self.updater.prepare_update()?;
.prepare_update()
.map_err(|e| warn!("E: {:?}", Debug2Format(&e)))
.unwrap();
debug!("Writing..."); debug!("Writing...");
let status = Status::Writing { progress: 0 }.vec(); let status = Status::Writing { progress: 0 }.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
// The first 64 bytes of the file contain the signature // The first 64 bytes of the file contain the signature
let mut signature = [0; 64]; let mut signature = [0; 64];
body.read_exact(&mut signature).await.unwrap(); body.read_exact(&mut signature).await?;
trace!("Signature: {:?}", signature); trace!("Signature: {:?}", signature);
@ -171,41 +250,37 @@ impl Updater {
break; break;
} }
debug!("Writing chunk: {}", read); debug!("Writing chunk: {}", read);
writer.write(size, &buffer.0[..read]).unwrap(); writer.write(size, &buffer.0[..read]).map_err(Error::FlashError)?;
size += read as u32; size += read as u32;
let status = Status::Writing { progress: size }.vec(); let status = Status::Writing { progress: size }.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
} }
debug!("Total size: {}", size); debug!("Total size: {}", size);
let status = Status::Verifying.vec(); let status = Status::Verifying.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
updater self.updater
.verify_and_mark_updated(self.public_key, &signature, size) .verify_and_mark_updated(self.public_key, &signature, size)?;
.unwrap();
// Update mqtt message should be send using retain // Update mqtt message should be send using retain
// TODO: Clear the message // TODO: Clear the message
let status = Status::UpdateComplete.vec(); let status = Status::UpdateComplete.json();
client client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false) .send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await .await?;
.unwrap();
client.disconnect().await.unwrap(); client.disconnect().await?;
info!("Restarting in 5 seconds..."); info!("Restarting in 5 seconds...");
Timer::after(Duration::from_secs(5)).await; Timer::after(Duration::from_secs(5)).await;
cortex_m::peripheral::SCB::sys_reset(); cortex_m::peripheral::SCB::sys_reset()
} }
} }