diff --git a/updater/src/lib.rs b/updater/src/lib.rs index ee95e05..5b6ef1a 100644 --- a/updater/src/lib.rs +++ b/updater/src/lib.rs @@ -1,14 +1,20 @@ #![no_std] #![no_main] #![feature(type_alias_impl_trait)] +#![feature(never_type)] use defmt::*; -use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater}; -use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack}; +use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterError}; +use embassy_net::{ + dns::{self, DnsQueryType}, + driver::Driver, + tcp::{ConnectError, TcpSocket}, + Stack, +}; use embassy_time::{Duration, Timer}; -use embedded_io_async::{Read, Write}; -use embedded_storage::nor_flash::NorFlash; -use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; +use embedded_io_async::{Read, ReadExactError, Write}; +use embedded_storage::nor_flash::{NorFlash, NorFlashError}; +use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext, TlsError}; use heapless::Vec; use nourl::Url; use rand_core::{CryptoRng, RngCore}; @@ -18,7 +24,7 @@ use reqwless::{ }; use rust_mqtt::{ client::{client::MqttClient, client_config::ClientConfig}, - packet::v5::publish_packet::QualityOfService, + packet::v5::{publish_packet::QualityOfService, reason_codes::ReasonCode}, }; use serde::Serialize; use static_cell::make_static; @@ -36,29 +42,109 @@ enum Status<'a> { } impl Status<'_> { - fn vec(&self) -> Vec { + fn json(&self) -> Vec { serde_json_core::to_vec(self) - .expect("The buffer should be large enough to contain all the data") + .expect("This buffers size should be large enough to contain the serialized status") + } +} + +#[derive(Debug)] +pub enum Error { + Mqtt(ReasonCode), + Dns(dns::Error), + Connect(ConnectError), + Tls(TlsError), + Reqwless(reqwless::Error), + FirmwareUpdater(FirmwareUpdaterError), + FlashError(FE), + UnexpectedEof, +} + +impl From for Error { + fn from(error: ReasonCode) -> Self { + Self::Mqtt(error) + } +} + +impl From for Error { + fn from(error: dns::Error) -> Self { + Self::Dns(error) + } +} + +impl From for Error { + fn from(error: ConnectError) -> Self { + Self::Connect(error) + } +} + +impl From for Error { + fn from(error: TlsError) -> Self { + Self::Tls(error) + } +} + +impl From for Error { + fn from(error: reqwless::Error) -> Self { + Self::Reqwless(error) + } +} + +impl From for Error { + fn from(error: FirmwareUpdaterError) -> Self { + Self::FirmwareUpdater(error) + } +} + +impl From> for Error { + fn from(error: ReadExactError) -> Self { + match error { + ReadExactError::UnexpectedEof => Self::UnexpectedEof, + ReadExactError::Other(error) => Self::Reqwless(error), + } + } +} + +impl Format for Error { + fn format(&self, f: Formatter) { + match self { + Error::Mqtt(error) => defmt::write!(f, "Mqtt: {}", error), + Error::Dns(error) => defmt::write!(f, "Dns: {}", error), + Error::Connect(error) => defmt::write!(f, "Connect: {}", error), + Error::Tls(error) => defmt::write!(f, "Tls: {}", error), + Error::Reqwless(error) => defmt::write!(f, "Reqwless: {}", error), + Error::FirmwareUpdater(error) => defmt::write!(f, "FirmwareUpdater: {}", error), + Error::FlashError(error) => defmt::write!(f, "FlashError: {:?}", error), + Error::UnexpectedEof => defmt::write!(f, "UnexpectedEof"), + } } } // TODO: Make this the owner of the blocking firmware updater // TODO: When fixed, use the async firmware updater -pub struct Updater { +pub struct Updater<'a, DFU, STATE> +where + DFU: NorFlash, + STATE: NorFlash, +{ + updater: BlockingFirmwareUpdater<'a, DFU, STATE>, + topic_status: &'static str, topic_update: &'static str, version: &'static str, public_key: &'static [u8], } -impl Updater { +impl<'a, DFU: NorFlash, STATE: NorFlash> Updater<'a, DFU, STATE> { pub fn new( + updater: BlockingFirmwareUpdater<'a, DFU, STATE>, topic_status: &'static str, topic_update: &'static str, version: &'static str, public_key: &'static [u8], ) -> Self { Self { + updater, topic_status, topic_update, version, @@ -70,42 +156,44 @@ impl Updater { &self, config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>, ) { - let msg = make_static!(Status::Disconnected.vec()); + let msg = make_static!(Status::Disconnected.json()); config.add_will(self.topic_status, msg, true); } pub async fn ready( - &self, + &mut self, client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, - ) { + ) -> Result<(), Error> { let status = Status::Connected { version: self.version, } - .vec(); + .json(); + client .send_message(self.topic_status, &status, QualityOfService::QoS1, true) - .await - .unwrap(); + .await?; - client.subscribe_to_topic(self.topic_update).await.unwrap(); + client.subscribe_to_topic(self.topic_update).await?; + + self.updater.mark_booted()?; + + Ok(()) } pub async fn update( - &self, + &mut self, stack: &'static Stack, - updater: &mut BlockingFirmwareUpdater<'_, impl NorFlash, impl NorFlash>, rng: &mut (impl RngCore + CryptoRng), client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, url: Url<'_>, - ) { + ) -> Result> { info!("Preparing for OTA..."); - let status = Status::PreparingUpdate.vec(); + let status = Status::PreparingUpdate.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; - let ip = stack.dns_query(url.host(), DnsQueryType::A).await.unwrap()[0]; + let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0]; let mut rx_buffer = [0; 1024]; let mut tx_buffer = [0; 1024]; @@ -114,53 +202,44 @@ impl Updater { let addr = (ip, url.port_or_default()); debug!("Addr: {}", addr); - socket.connect(addr).await.unwrap(); + socket.connect(addr).await?; let mut read_record_buffer = [0; 16384 * 2]; let mut write_record_buffer = [0; 16384]; let mut tls: TlsConnection = TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng)) - .await - .unwrap(); + .await?; debug!("Path: {}", url.path()); Request::get(url.path()) .host(url.host()) .build() .write(&mut tls) - .await - .unwrap(); + .await?; let mut headers = [0; 1024]; - let resp = Response::read(&mut tls, Method::GET, &mut headers) - .await - .unwrap(); + let resp = Response::read(&mut tls, Method::GET, &mut headers).await?; let mut body = resp.body().reader(); debug!("Erasing flash..."); - let status = Status::Erasing.vec(); + let status = Status::Erasing.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; - let writer = updater - .prepare_update() - .map_err(|e| warn!("E: {:?}", Debug2Format(&e))) - .unwrap(); + let writer = self.updater.prepare_update()?; debug!("Writing..."); - let status = Status::Writing { progress: 0 }.vec(); + let status = Status::Writing { progress: 0 }.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; // The first 64 bytes of the file contain the signature let mut signature = [0; 64]; - body.read_exact(&mut signature).await.unwrap(); + body.read_exact(&mut signature).await?; trace!("Signature: {:?}", signature); @@ -171,41 +250,37 @@ impl Updater { break; } debug!("Writing chunk: {}", read); - writer.write(size, &buffer.0[..read]).unwrap(); + writer.write(size, &buffer.0[..read]).map_err(Error::FlashError)?; size += read as u32; - let status = Status::Writing { progress: size }.vec(); + let status = Status::Writing { progress: size }.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; } debug!("Total size: {}", size); - let status = Status::Verifying.vec(); + let status = Status::Verifying.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; - updater - .verify_and_mark_updated(self.public_key, &signature, size) - .unwrap(); + self.updater + .verify_and_mark_updated(self.public_key, &signature, size)?; // Update mqtt message should be send using retain // TODO: Clear the message - let status = Status::UpdateComplete.vec(); + let status = Status::UpdateComplete.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) - .await - .unwrap(); + .await?; - client.disconnect().await.unwrap(); + client.disconnect().await?; info!("Restarting in 5 seconds..."); Timer::after(Duration::from_secs(5)).await; - cortex_m::peripheral::SCB::sys_reset(); + cortex_m::peripheral::SCB::sys_reset() } }