#![no_std] #![no_main] #![feature(type_alias_impl_trait)] #![feature(never_type)] use defmt::*; use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater}; use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack}; use embassy_time::{Duration, Timer}; use embedded_io_async::{Read, Write}; use embedded_storage::nor_flash::NorFlash; use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; use heapless::Vec; use nourl::{Url, UrlScheme}; use rand_core::{CryptoRng, RngCore}; use reqwless::{ request::{Method, Request, RequestBuilder}, response::Response, }; use rust_mqtt::{ client::{client::MqttClient, client_config::ClientConfig}, packet::v5::publish_packet::QualityOfService, }; use serde::Serialize; use static_cell::make_static; mod error; pub use crate::error::Error; #[derive(Serialize)] #[serde(rename_all = "snake_case", tag = "status")] enum Status<'a> { Connected { version: &'a str }, Disconnected, PreparingUpdate, Erasing, Writing { progress: u32 }, Verifying, UpdateFailed { error: &'a str }, UpdateComplete, } impl Status<'_> { fn json(&self) -> Vec { serde_json_core::to_vec(self) .expect("This buffers size should be large enough to contain the serialized status") } } // TODO: Make this the owner of the blocking firmware updater // TODO: When fixed, use the async firmware updater pub struct Updater<'a, DFU, STATE> where DFU: NorFlash, STATE: NorFlash, DFU::Error: Format, { updater: BlockingFirmwareUpdater<'a, DFU, STATE>, topic_status: &'static str, topic_update: &'static str, version: &'static str, public_key: &'static [u8], } impl<'a, DFU, STATE> Updater<'a, DFU, STATE> where DFU: NorFlash, STATE: NorFlash, DFU::Error: Format, { pub fn new( updater: BlockingFirmwareUpdater<'a, DFU, STATE>, topic_status: &'static str, topic_update: &'static str, version: &'static str, public_key: &'static [u8], ) -> Self { Self { updater, topic_status, topic_update, version, public_key, } } pub fn add_will( &self, config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>, ) { let msg = make_static!(Status::Disconnected.json()); config.add_will(self.topic_status, msg, true); } pub async fn ready( &mut self, client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, ) -> Result<(), Error> { let status = Status::Connected { version: self.version, } .json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, true) .await?; client.subscribe_to_topic(self.topic_update).await?; self.updater.mark_booted()?; Ok(()) } pub async fn update( &mut self, stack: &'static Stack, rng: &mut (impl RngCore + CryptoRng), client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, url: Url<'_>, ) -> Result> { let result = self._update(stack, rng, client, url).await; if let Err(err) = &result { let status = Status::UpdateFailed { error: &err.string(), } .json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; } result } async fn _update( &mut self, stack: &'static Stack, rng: &mut (impl RngCore + CryptoRng), client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>, url: Url<'_>, ) -> Result> { info!("Preparing for OTA..."); let status = Status::PreparingUpdate.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; debug!("Making sure url is HTTPS"); if url.scheme() != UrlScheme::HTTPS { return Err(Error::InvalidScheme); } // TODO: Clear out retained update message, currently gives implementation specific error let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0]; let mut rx_buffer = [0; 1024]; let mut tx_buffer = [0; 1024]; let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); let addr = (ip, url.port_or_default()); debug!("Addr: {}", addr); socket.connect(addr).await?; let mut read_record_buffer = [0; 16384 * 2]; let mut write_record_buffer = [0; 16384]; let mut tls: TlsConnection = TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng)) .await?; debug!("Path: {}", url.path()); Request::get(url.path()) .host(url.host()) .build() .write(&mut tls) .await?; let mut headers = [0; 1024]; let resp = Response::read(&mut tls, Method::GET, &mut headers).await?; let mut body = resp.body().reader(); debug!("Erasing flash..."); let status = Status::Erasing.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; let writer = self.updater.prepare_update()?; debug!("Writing..."); let status = Status::Writing { progress: 0 }.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; // The first 64 bytes of the file contain the signature let mut signature = [0; 64]; body.read_exact(&mut signature).await?; trace!("Signature: {:?}", signature); let mut buffer = AlignedBuffer([0; 4096]); let mut size = 0; while let Ok(read) = body.read(&mut buffer.0).await { if read == 0 { break; } debug!("Writing chunk: {}", read); writer .write(size, &buffer.0[..read]) .map_err(Error::FlashError)?; size += read as u32; let status = Status::Writing { progress: size }.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; } debug!("Total size: {}", size); let status = Status::Verifying.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; self.updater .verify_and_mark_updated(self.public_key, &signature, size)?; let status = Status::UpdateComplete.json(); client .send_message(self.topic_status, &status, QualityOfService::QoS1, false) .await?; client.disconnect().await?; info!("Restarting in 5 seconds..."); Timer::after(Duration::from_secs(5)).await; cortex_m::peripheral::SCB::sys_reset() } }