Improved error handling and Updater now functions as a wrapper around the FirmwareUpdater
This commit is contained in:
parent
c279e52e2d
commit
7763269b06
|
@ -1,14 +1,20 @@
|
|||
#![no_std]
|
||||
#![no_main]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(never_type)]
|
||||
|
||||
use defmt::*;
|
||||
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater};
|
||||
use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack};
|
||||
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterError};
|
||||
use embassy_net::{
|
||||
dns::{self, DnsQueryType},
|
||||
driver::Driver,
|
||||
tcp::{ConnectError, TcpSocket},
|
||||
Stack,
|
||||
};
|
||||
use embassy_time::{Duration, Timer};
|
||||
use embedded_io_async::{Read, Write};
|
||||
use embedded_storage::nor_flash::NorFlash;
|
||||
use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext};
|
||||
use embedded_io_async::{Read, ReadExactError, Write};
|
||||
use embedded_storage::nor_flash::{NorFlash, NorFlashError};
|
||||
use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext, TlsError};
|
||||
use heapless::Vec;
|
||||
use nourl::Url;
|
||||
use rand_core::{CryptoRng, RngCore};
|
||||
|
@ -18,7 +24,7 @@ use reqwless::{
|
|||
};
|
||||
use rust_mqtt::{
|
||||
client::{client::MqttClient, client_config::ClientConfig},
|
||||
packet::v5::publish_packet::QualityOfService,
|
||||
packet::v5::{publish_packet::QualityOfService, reason_codes::ReasonCode},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use static_cell::make_static;
|
||||
|
@ -36,29 +42,109 @@ enum Status<'a> {
|
|||
}
|
||||
|
||||
impl Status<'_> {
|
||||
fn vec(&self) -> Vec<u8, 1024> {
|
||||
fn json(&self) -> Vec<u8, 128> {
|
||||
serde_json_core::to_vec(self)
|
||||
.expect("The buffer should be large enough to contain all the data")
|
||||
.expect("This buffers size should be large enough to contain the serialized status")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error<FE: NorFlashError> {
|
||||
Mqtt(ReasonCode),
|
||||
Dns(dns::Error),
|
||||
Connect(ConnectError),
|
||||
Tls(TlsError),
|
||||
Reqwless(reqwless::Error),
|
||||
FirmwareUpdater(FirmwareUpdaterError),
|
||||
FlashError(FE),
|
||||
UnexpectedEof,
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<ReasonCode> for Error<FE> {
|
||||
fn from(error: ReasonCode) -> Self {
|
||||
Self::Mqtt(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<dns::Error> for Error<FE> {
|
||||
fn from(error: dns::Error) -> Self {
|
||||
Self::Dns(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<ConnectError> for Error<FE> {
|
||||
fn from(error: ConnectError) -> Self {
|
||||
Self::Connect(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<TlsError> for Error<FE> {
|
||||
fn from(error: TlsError) -> Self {
|
||||
Self::Tls(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<reqwless::Error> for Error<FE> {
|
||||
fn from(error: reqwless::Error) -> Self {
|
||||
Self::Reqwless(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<FirmwareUpdaterError> for Error<FE> {
|
||||
fn from(error: FirmwareUpdaterError) -> Self {
|
||||
Self::FirmwareUpdater(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError> From<ReadExactError<reqwless::Error>> for Error<FE> {
|
||||
fn from(error: ReadExactError<reqwless::Error>) -> Self {
|
||||
match error {
|
||||
ReadExactError::UnexpectedEof => Self::UnexpectedEof,
|
||||
ReadExactError::Other(error) => Self::Reqwless(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<FE: NorFlashError + defmt::Format> Format for Error<FE> {
|
||||
fn format(&self, f: Formatter) {
|
||||
match self {
|
||||
Error::Mqtt(error) => defmt::write!(f, "Mqtt: {}", error),
|
||||
Error::Dns(error) => defmt::write!(f, "Dns: {}", error),
|
||||
Error::Connect(error) => defmt::write!(f, "Connect: {}", error),
|
||||
Error::Tls(error) => defmt::write!(f, "Tls: {}", error),
|
||||
Error::Reqwless(error) => defmt::write!(f, "Reqwless: {}", error),
|
||||
Error::FirmwareUpdater(error) => defmt::write!(f, "FirmwareUpdater: {}", error),
|
||||
Error::FlashError(error) => defmt::write!(f, "FlashError: {:?}", error),
|
||||
Error::UnexpectedEof => defmt::write!(f, "UnexpectedEof"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Make this the owner of the blocking firmware updater
|
||||
// TODO: When fixed, use the async firmware updater
|
||||
pub struct Updater {
|
||||
pub struct Updater<'a, DFU, STATE>
|
||||
where
|
||||
DFU: NorFlash,
|
||||
STATE: NorFlash,
|
||||
{
|
||||
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
|
||||
|
||||
topic_status: &'static str,
|
||||
topic_update: &'static str,
|
||||
version: &'static str,
|
||||
public_key: &'static [u8],
|
||||
}
|
||||
|
||||
impl Updater {
|
||||
impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> {
|
||||
pub fn new(
|
||||
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
|
||||
topic_status: &'static str,
|
||||
topic_update: &'static str,
|
||||
version: &'static str,
|
||||
public_key: &'static [u8],
|
||||
) -> Self {
|
||||
Self {
|
||||
updater,
|
||||
topic_status,
|
||||
topic_update,
|
||||
version,
|
||||
|
@ -70,42 +156,44 @@ impl Updater {
|
|||
&self,
|
||||
config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>,
|
||||
) {
|
||||
let msg = make_static!(Status::Disconnected.vec());
|
||||
let msg = make_static!(Status::Disconnected.json());
|
||||
config.add_will(self.topic_status, msg, true);
|
||||
}
|
||||
|
||||
pub async fn ready<const MAX_PROPERTIES: usize>(
|
||||
&self,
|
||||
&mut self,
|
||||
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
|
||||
) {
|
||||
) -> Result<(), Error<DFU::Error>> {
|
||||
let status = Status::Connected {
|
||||
version: self.version,
|
||||
}
|
||||
.vec();
|
||||
.json();
|
||||
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, true)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
client.subscribe_to_topic(self.topic_update).await.unwrap();
|
||||
client.subscribe_to_topic(self.topic_update).await?;
|
||||
|
||||
self.updater.mark_booted()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update<const MAX_PROPERTIES: usize>(
|
||||
&self,
|
||||
&mut self,
|
||||
stack: &'static Stack<impl Driver>,
|
||||
updater: &mut BlockingFirmwareUpdater<'_, impl NorFlash, impl NorFlash>,
|
||||
rng: &mut (impl RngCore + CryptoRng),
|
||||
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
|
||||
url: Url<'_>,
|
||||
) {
|
||||
) -> Result<!, Error<DFU::Error>> {
|
||||
info!("Preparing for OTA...");
|
||||
let status = Status::PreparingUpdate.vec();
|
||||
let status = Status::PreparingUpdate.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
let ip = stack.dns_query(url.host(), DnsQueryType::A).await.unwrap()[0];
|
||||
let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0];
|
||||
|
||||
let mut rx_buffer = [0; 1024];
|
||||
let mut tx_buffer = [0; 1024];
|
||||
|
@ -114,53 +202,44 @@ impl Updater {
|
|||
|
||||
let addr = (ip, url.port_or_default());
|
||||
debug!("Addr: {}", addr);
|
||||
socket.connect(addr).await.unwrap();
|
||||
socket.connect(addr).await?;
|
||||
|
||||
let mut read_record_buffer = [0; 16384 * 2];
|
||||
let mut write_record_buffer = [0; 16384];
|
||||
let mut tls: TlsConnection<TcpSocket, Aes128GcmSha256> =
|
||||
TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer);
|
||||
tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng))
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
debug!("Path: {}", url.path());
|
||||
Request::get(url.path())
|
||||
.host(url.host())
|
||||
.build()
|
||||
.write(&mut tls)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
let mut headers = [0; 1024];
|
||||
let resp = Response::read(&mut tls, Method::GET, &mut headers)
|
||||
.await
|
||||
.unwrap();
|
||||
let resp = Response::read(&mut tls, Method::GET, &mut headers).await?;
|
||||
|
||||
let mut body = resp.body().reader();
|
||||
|
||||
debug!("Erasing flash...");
|
||||
let status = Status::Erasing.vec();
|
||||
let status = Status::Erasing.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
let writer = updater
|
||||
.prepare_update()
|
||||
.map_err(|e| warn!("E: {:?}", Debug2Format(&e)))
|
||||
.unwrap();
|
||||
let writer = self.updater.prepare_update()?;
|
||||
|
||||
debug!("Writing...");
|
||||
let status = Status::Writing { progress: 0 }.vec();
|
||||
let status = Status::Writing { progress: 0 }.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
// The first 64 bytes of the file contain the signature
|
||||
let mut signature = [0; 64];
|
||||
body.read_exact(&mut signature).await.unwrap();
|
||||
body.read_exact(&mut signature).await?;
|
||||
|
||||
trace!("Signature: {:?}", signature);
|
||||
|
||||
|
@ -171,41 +250,37 @@ impl Updater {
|
|||
break;
|
||||
}
|
||||
debug!("Writing chunk: {}", read);
|
||||
writer.write(size, &buffer.0[..read]).unwrap();
|
||||
writer.write(size, &buffer.0[..read]).map_err(Error::FlashError)?;
|
||||
size += read as u32;
|
||||
|
||||
let status = Status::Writing { progress: size }.vec();
|
||||
let status = Status::Writing { progress: size }.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
debug!("Total size: {}", size);
|
||||
|
||||
let status = Status::Verifying.vec();
|
||||
let status = Status::Verifying.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
updater
|
||||
.verify_and_mark_updated(self.public_key, &signature, size)
|
||||
.unwrap();
|
||||
self.updater
|
||||
.verify_and_mark_updated(self.public_key, &signature, size)?;
|
||||
|
||||
// Update mqtt message should be send using retain
|
||||
// TODO: Clear the message
|
||||
|
||||
let status = Status::UpdateComplete.vec();
|
||||
let status = Status::UpdateComplete.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
client.disconnect().await.unwrap();
|
||||
client.disconnect().await?;
|
||||
|
||||
info!("Restarting in 5 seconds...");
|
||||
Timer::after(Duration::from_secs(5)).await;
|
||||
|
||||
cortex_m::peripheral::SCB::sys_reset();
|
||||
cortex_m::peripheral::SCB::sys_reset()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user