diff --git a/Cargo.lock b/Cargo.lock index 06f896a..4beef55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1778,6 +1778,7 @@ dependencies = [ "serde-json-core", "smoltcp", "static_cell", + "updater", ] [[package]] @@ -2131,6 +2132,30 @@ dependencies = [ "subtle", ] +[[package]] +name = "updater" +version = "0.1.0" +dependencies = [ + "cortex-m", + "cortex-m-rt", + "defmt", + "defmt-rtt", + "embassy-boot", + "embassy-net", + "embassy-time", + "embedded-io-async", + "embedded-storage", + "embedded-tls", + "heapless 0.7.16", + "nourl", + "rand_core", + "reqwless", + "rust-mqtt", + "serde", + "serde-json-core", + "static_cell", +] + [[package]] name = "vcell" version = "0.1.3" diff --git a/Cargo.toml b/Cargo.toml index d307c25..86ddcdc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,8 @@ embedded-tls = { version = "0.15.0", default-features = false, features = [ "defmt", ] } +updater = { path = "../iot_tools/updater" } + [patch.crates-io] embassy-executor = { git = "https://github.com/embassy-rs/embassy" } embassy-rp = { git = "https://github.com/embassy-rs/embassy" } diff --git a/src/main.rs b/src/main.rs index f51ebdc..1d164be 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,6 @@ use core::cell::RefCell; use embassy_boot_rp::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterConfig}; -use embedded_storage::nor_flash::NorFlash; -use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; use heapless::{String, Vec}; use rand::{ rngs::{SmallRng, StdRng}, @@ -35,14 +33,9 @@ use embassy_sync::{ channel::{Channel, Sender}, }; use embassy_time::{Duration, Ticker, Timer}; -use embedded_io_async::Read; use dsmr5::Readout; use nourl::Url; -use reqwless::{ - request::{Method, Request, RequestBuilder}, - response::Response, -}; use rust_mqtt::{ client::{ client::MqttClient, @@ -50,11 +43,11 @@ use rust_mqtt::{ }, packet::v5::publish_packet::QualityOfService, }; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use static_cell::make_static; use const_format::formatcp; -use defmt::{debug, error, info, trace, warn, Debug2Format}; +use defmt::*; use {defmt_rtt as _, panic_probe as _}; @@ -70,6 +63,8 @@ const TOPIC_UPDATE: &str = formatcp!("{}/update", TOPIC_BASE); const VERSION: &str = git_version::git_version!(); const PUBLIC_SIGNING_KEY: &[u8] = include_bytes!("../key.pub"); +const FLASH_SIZE: usize = 2 * 1024 * 1024; + #[derive(Deserialize)] struct UpdateMessage<'a> { url: &'a str, @@ -81,25 +76,6 @@ impl UpdateMessage<'_> { } } -#[derive(Serialize)] -#[serde(rename_all = "snake_case", tag = "status")] -enum Status<'a> { - Connected { version: &'a str }, - Disconnected, - PreparingUpdate, - Erasing, - Writing { progress: u32 }, - Verifying, - UpdateComplete, -} - -impl Status<'_> { - fn vec(&self) -> Vec { - serde_json_core::to_vec(self) - .expect("The buffer should be large enough to contain all the data") - } -} - #[embassy_executor::task] async fn wifi_task( runner: cyw43::Runner< @@ -300,6 +276,8 @@ async fn main(spawner: Spawner) { } info!("TCP Connected!"); + let up = updater::Updater::new(TOPIC_STATUS, TOPIC_UPDATE, VERSION, PUBLIC_SIGNING_KEY); + let mut config = ClientConfig::new( MqttVersion::MQTTv5, // Use fast and simple PRNG to generate packet identifiers, there is no need for this to be @@ -311,9 +289,7 @@ async fn main(spawner: Spawner) { config.add_password(env!("MQTT_PASSWORD")); config.add_max_subscribe_qos(QualityOfService::QoS1); config.add_client_id(ID); - // Leads to InsufficientBufferSize error - let msg: &Vec<_, 1024> = make_static!(Status::Disconnected.vec()); - config.add_will(TOPIC_STATUS, &msg, true); + up.add_will(&mut config); let mut recv_buffer = [0; 1024]; let mut write_buffer = [0; 1024]; @@ -337,13 +313,7 @@ async fn main(spawner: Spawner) { // We wait with marking as booted until everything is connected updater.mark_booted().unwrap(); - let status = Status::Connected { version: VERSION }.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, true) - .await - .unwrap(); - - client.subscribe_to_topic(TOPIC_UPDATE).await.unwrap(); + up.ready(&mut client).await; // Turn LED off when connected control.gpio_set(0, false).await; @@ -391,7 +361,7 @@ async fn main(spawner: Spawner) { let url = message.get_url(); let url = Url::parse(url.as_str()).unwrap(); - attempt_update(stack, &mut updater, &mut rng, &mut client, url).await; + up.update(stack, &mut updater, &mut rng, &mut client, url).await; } } } @@ -410,126 +380,3 @@ async fn wait_for_config( } } -const FLASH_SIZE: usize = 2 * 1024 * 1024; - -async fn attempt_update( - stack: &'static Stack>, - updater: &mut BlockingFirmwareUpdater<'_, F, F>, - rng: &mut StdRng, - client: &mut MqttClient<'_, T, MAX_PROPERTIES, R>, - url: Url<'_>, -) where - T: embedded_io_async::Write + embedded_io_async::Read, - R: rand::RngCore, - F: NorFlash, -{ - info!("Preparing for OTA..."); - let status = Status::PreparingUpdate.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - - let ip = stack.dns_query(url.host(), DnsQueryType::A).await.unwrap()[0]; - - let mut rx_buffer = [0; 1024]; - let mut tx_buffer = [0; 1024]; - - let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); - - let addr = (ip, url.port_or_default()); - debug!("Addr: {}", addr); - socket.connect(addr).await.unwrap(); - - let mut read_record_buffer = [0; 16384 * 2]; - let mut write_record_buffer = [0; 16384]; - let mut tls: TlsConnection = - TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); - tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng)) - .await - .unwrap(); - - debug!("Path: {}", url.path()); - Request::get(url.path()) - .host(url.host()) - .build() - .write(&mut tls) - .await - .unwrap(); - - let mut headers = [0; 1024]; - let resp = Response::read(&mut tls, Method::GET, &mut headers) - .await - .unwrap(); - - let mut body = resp.body().reader(); - - debug!("Erasing flash..."); - let status = Status::Erasing.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - - let writer = updater - .prepare_update() - .map_err(|e| warn!("E: {:?}", Debug2Format(&e))) - .unwrap(); - - debug!("Writing..."); - let status = Status::Writing { progress: 0 }.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - - // The first 64 bytes of the file contain the signature - let mut signature = [0; 64]; - body.read_exact(&mut signature).await.unwrap(); - - trace!("Signature: {:?}", signature); - - let mut buffer = AlignedBuffer([0; 4096]); - let mut size = 0; - while let Ok(read) = body.read(&mut buffer.0).await { - if read == 0 { - break; - } - debug!("Writing chunk: {}", read); - writer.write(size, &buffer.0[..read]).unwrap(); - size += read as u32; - - let status = Status::Writing { progress: size }.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - } - debug!("Total size: {}", size); - - let status = Status::Verifying.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - - updater - .verify_and_mark_updated(PUBLIC_SIGNING_KEY, &signature, size) - .unwrap(); - - // Update mqtt message should be send using retain - // TODO: Clear the message - - let status = Status::UpdateComplete.vec(); - client - .send_message(TOPIC_STATUS, &status, QualityOfService::QoS1, false) - .await - .unwrap(); - - client.disconnect().await.unwrap(); - - info!("Restarting in 5 seconds..."); - Timer::after(Duration::from_secs(5)).await; - - cortex_m::peripheral::SCB::sys_reset(); -}