diff --git a/updater/Cargo.lock b/updater/Cargo.lock index 0d74bf9..d30ae52 100644 --- a/updater/Cargo.lock +++ b/updater/Cargo.lock @@ -749,6 +749,30 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "impl-tools" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82c305b1081f1a99fda262883c788e50ab57d36c00830bdd7e0a82894ad965c" +dependencies = [ + "autocfg", + "impl-tools-lib", + "proc-macro-error", + "syn 2.0.33", +] + +[[package]] +name = "impl-tools-lib" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85d3946d886eaab0702fa0c6585adcced581513223fa9df7ccfabbd9fa331a88" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.33", +] + [[package]] name = "inout" version = "0.1.3" @@ -1197,6 +1221,7 @@ dependencies = [ "embedded-storage", "embedded-tls", "heapless 0.7.16", + "impl-tools", "nourl", "rand_core", "reqwless", diff --git a/updater/Cargo.toml b/updater/Cargo.toml index 68b7c24..dd4848f 100644 --- a/updater/Cargo.toml +++ b/updater/Cargo.toml @@ -45,6 +45,7 @@ embedded-tls = { version = "0.15.0", default-features = false, features = [ ] } reqwless = { version = "0.5.0", features = ["defmt"] } static_cell = { version = "1.2.0", features = ["nightly"] } +impl-tools = "0.10.0" [patch.crates-io] embassy-net = { git = "https://github.com/embassy-rs/embassy" } diff --git a/updater/src/error.rs b/updater/src/error.rs new file mode 100644 index 0000000..38178f8 --- /dev/null +++ b/updater/src/error.rs @@ -0,0 +1,107 @@ +use core::fmt::{Display, Write}; + +use heapless::String; +use defmt::{Format, Formatter}; +use embassy_boot::FirmwareUpdaterError; +use embassy_net::{dns, tcp::ConnectError}; +use embedded_io_async::ReadExactError; +use embedded_storage::nor_flash::NorFlashError; +use embedded_tls::TlsError; +use rust_mqtt::packet::v5::reason_codes::ReasonCode; + +impl_tools::impl_scope! { + #[derive(Debug)] + pub enum Error { + Mqtt(ReasonCode), + Dns(dns::Error), + Connect(ConnectError), + Tls(TlsError), + Reqwless(reqwless::Error), + FirmwareUpdater(FirmwareUpdaterError), + FlashError(FE), + UnexpectedEof, + } + + impl Self { + pub fn string(&self) -> String<256> { + let mut error = String::new(); + core::write!(error, "{}", self).expect("Formatting the error should not fail"); + error + } + } + + impl From for Self { + fn from(error: ReasonCode) -> Self { + Self::Mqtt(error) + } + } + + impl From for Self { + fn from(error: dns::Error) -> Self { + Self::Dns(error) + } + } + + impl From for Self { + fn from(error: ConnectError) -> Self { + Self::Connect(error) + } + } + + impl From for Self { + fn from(error: TlsError) -> Self { + Self::Tls(error) + } + } + + impl From for Self { + fn from(error: reqwless::Error) -> Self { + Self::Reqwless(error) + } + } + + impl From for Self { + fn from(error: FirmwareUpdaterError) -> Self { + Self::FirmwareUpdater(error) + } + } + + impl From> for Self { + fn from(error: ReadExactError) -> Self { + match error { + ReadExactError::UnexpectedEof => Self::UnexpectedEof, + ReadExactError::Other(error) => Self::Reqwless(error), + } + } + } + + impl Format for Self { + fn format(&self, f: Formatter) { + match self { + Error::Mqtt(error) => defmt::write!(f, "Mqtt: {}", error), + Error::Dns(error) => defmt::write!(f, "Dns: {}", error), + Error::Connect(error) => defmt::write!(f, "Connect: {}", error), + Error::Tls(error) => defmt::write!(f, "Tls: {}", error), + Error::Reqwless(error) => defmt::write!(f, "Reqwless: {}", error), + Error::FirmwareUpdater(error) => defmt::write!(f, "FirmwareUpdater: {}", error), + Error::FlashError(error) => defmt::write!(f, "FlashError: {:?}", error), + Error::UnexpectedEof => defmt::write!(f, "UnexpectedEof"), + } + } + } + + impl Display for Self { + fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> { + match self { + Error::Mqtt(error) => core::write!(f, "Mqtt: {}", error), + Error::Dns(error) => core::write!(f, "Dns: {:?}", error), + Error::Connect(error) => core::write!(f, "Connect: {:?}", error), + Error::Tls(error) => core::write!(f, "Tls: {:?}", error), + Error::Reqwless(error) => core::write!(f, "Reqwless: {:?}", error), + Error::FirmwareUpdater(error) => core::write!(f, "FirmwareUpdater: {:?}", error), + Error::FlashError(error) => core::write!(f, "FlashError: {:?}", error), + Error::UnexpectedEof => core::write!(f, "UnexpectedEof"), + } + } + } +} diff --git a/updater/src/lib.rs b/updater/src/lib.rs index 5b6ef1a..9476dcd 100644 --- a/updater/src/lib.rs +++ b/updater/src/lib.rs @@ -4,17 +4,12 @@ #![feature(never_type)] use defmt::*; -use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterError}; -use embassy_net::{ - dns::{self, DnsQueryType}, - driver::Driver, - tcp::{ConnectError, TcpSocket}, - Stack, -}; +use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater}; +use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack}; use embassy_time::{Duration, Timer}; -use embedded_io_async::{Read, ReadExactError, Write}; -use embedded_storage::nor_flash::{NorFlash, NorFlashError}; -use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext, TlsError}; +use embedded_io_async::{Read, Write}; +use embedded_storage::nor_flash::NorFlash; +use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; use heapless::Vec; use nourl::Url; use rand_core::{CryptoRng, RngCore}; @@ -24,11 +19,15 @@ use reqwless::{ }; use rust_mqtt::{ client::{client::MqttClient, client_config::ClientConfig}, - packet::v5::{publish_packet::QualityOfService, reason_codes::ReasonCode}, + packet::v5::publish_packet::QualityOfService, }; use serde::Serialize; use static_cell::make_static; +mod error; + +pub use crate::error::Error; + #[derive(Serialize)] #[serde(rename_all = "snake_case", tag = "status")] enum Status<'a> { @@ -38,94 +37,24 @@ enum Status<'a> { Erasing, Writing { progress: u32 }, Verifying, + UpdateFailed { error: &'a str }, UpdateComplete, } impl Status<'_> { - fn json(&self) -> Vec { + fn json(&self) -> Vec { serde_json_core::to_vec(self) .expect("This buffers size should be large enough to contain the serialized status") } } -#[derive(Debug)] -pub enum Error { - Mqtt(ReasonCode), - Dns(dns::Error), - Connect(ConnectError), - Tls(TlsError), - Reqwless(reqwless::Error), - FirmwareUpdater(FirmwareUpdaterError), - FlashError(FE), - UnexpectedEof, -} - -impl From for Error { - fn from(error: ReasonCode) -> Self { - Self::Mqtt(error) - } -} - -impl From for Error { - fn from(error: dns::Error) -> Self { - Self::Dns(error) - } -} - -impl From for Error { - fn from(error: ConnectError) -> Self { - Self::Connect(error) - } -} - -impl From for Error { - fn from(error: TlsError) -> Self { - Self::Tls(error) - } -} - -impl From for Error { - fn from(error: reqwless::Error) -> Self { - Self::Reqwless(error) - } -} - -impl From for Error { - fn from(error: FirmwareUpdaterError) -> Self { - Self::FirmwareUpdater(error) - } -} - -impl From> for Error { - fn from(error: ReadExactError) -> Self { - match error { - ReadExactError::UnexpectedEof => Self::UnexpectedEof, - ReadExactError::Other(error) => Self::Reqwless(error), - } - } -} - -impl Format for Error { - fn format(&self, f: Formatter) { - match self { - Error::Mqtt(error) => defmt::write!(f, "Mqtt: {}", error), - Error::Dns(error) => defmt::write!(f, "Dns: {}", error), - Error::Connect(error) => defmt::write!(f, "Connect: {}", error), - Error::Tls(error) => defmt::write!(f, "Tls: {}", error), - Error::Reqwless(error) => defmt::write!(f, "Reqwless: {}", error), - Error::FirmwareUpdater(error) => defmt::write!(f, "FirmwareUpdater: {}", error), - Error::FlashError(error) => defmt::write!(f, "FlashError: {:?}", error), - Error::UnexpectedEof => defmt::write!(f, "UnexpectedEof"), - } - } -} - // TODO: Make this the owner of the blocking firmware updater // TODO: When fixed, use the async firmware updater pub struct Updater<'a, DFU, STATE> where DFU: NorFlash, STATE: NorFlash, + DFU::Error: Format, { updater: BlockingFirmwareUpdater<'a, DFU, STATE>, @@ -135,7 +64,12 @@ where public_key: &'static [u8], } -impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> { +impl<'a, DFU, STATE> Updater<'a, DFU, STATE> +where + DFU: NorFlash, + STATE: NorFlash, + DFU::Error: Format, +{ pub fn new( updater: BlockingFirmwareUpdater<'a, DFU, STATE>, topic_status: &'static str, @@ -186,6 +120,29 @@ impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> { rng: &mut (impl RngCore + CryptoRng), client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, url: Url<'_>, + ) -> Result> { + let result = self._update(stack, rng, client, url).await; + + if let Err(err) = &result { + let status = Status::UpdateFailed { + error: &err.string(), + } + .json(); + + client + .send_message(self.topic_status, &status, QualityOfService::QoS1, false) + .await?; + } + + result + } + + async fn _update( + &mut self, + stack: &'static Stack, + rng: &mut (impl RngCore + CryptoRng), + client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, + url: Url<'_>, ) -> Result> { info!("Preparing for OTA..."); let status = Status::PreparingUpdate.json(); @@ -250,7 +207,9 @@ impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> { break; } debug!("Writing chunk: {}", read); - writer.write(size, &buffer.0[..read]).map_err(Error::FlashError)?; + writer + .write(size, &buffer.0[..read]) + .map_err(Error::FlashError)?; size += read as u32; let status = Status::Writing { progress: size }.json();