iot_tools/updater/src/lib.rs

250 lines
7.3 KiB
Rust

#![no_std]
#![no_main]
#![feature(type_alias_impl_trait)]
#![feature(never_type)]
use defmt::*;
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater};
use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack};
use embassy_time::{Duration, Timer};
use embedded_io_async::{Read, Write};
use embedded_storage::nor_flash::NorFlash;
use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext};
use heapless::Vec;
use nourl::{Url, UrlScheme};
use rand_core::{CryptoRng, RngCore};
use reqwless::{
request::{Method, Request, RequestBuilder},
response::Response,
};
use rust_mqtt::{
client::{client::MqttClient, client_config::ClientConfig},
packet::v5::publish_packet::QualityOfService,
};
use serde::Serialize;
use static_cell::make_static;
mod error;
pub use crate::error::Error;
#[derive(Serialize)]
#[serde(rename_all = "snake_case", tag = "status")]
enum Status<'a> {
Connected { version: &'a str },
Disconnected,
PreparingUpdate,
Erasing,
Writing { progress: u32 },
Verifying,
UpdateFailed { error: &'a str },
UpdateComplete,
}
impl Status<'_> {
fn json(&self) -> Vec<u8, 512> {
serde_json_core::to_vec(self)
.expect("This buffers size should be large enough to contain the serialized status")
}
}
// TODO: Make this the owner of the blocking firmware updater
// TODO: When fixed, use the async firmware updater
pub struct Updater<'a, DFU, STATE>
where
DFU: NorFlash,
STATE: NorFlash,
DFU::Error: Format,
{
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
topic_status: &'static str,
topic_update: &'static str,
version: &'static str,
public_key: &'static [u8],
}
impl<'a, DFU, STATE> Updater<'a, DFU, STATE>
where
DFU: NorFlash,
STATE: NorFlash,
DFU::Error: Format,
{
pub fn new(
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
topic_status: &'static str,
topic_update: &'static str,
version: &'static str,
public_key: &'static [u8],
) -> Self {
Self {
updater,
topic_status,
topic_update,
version,
public_key,
}
}
pub fn add_will<const MAX_PROPERTIES: usize>(
&self,
config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>,
) {
let msg = make_static!(Status::Disconnected.json());
config.add_will(self.topic_status, msg, true);
}
pub async fn ready<const MAX_PROPERTIES: usize>(
&mut self,
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
) -> Result<(), Error<DFU::Error>> {
let status = Status::Connected {
version: self.version,
}
.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, true)
.await?;
client.subscribe_to_topic(self.topic_update).await?;
self.updater.mark_booted()?;
Ok(())
}
pub async fn update<const MAX_PROPERTIES: usize>(
&mut self,
stack: &'static Stack<impl Driver>,
rng: &mut (impl RngCore + CryptoRng),
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
url: Url<'_>,
) -> Result<!, Error<DFU::Error>> {
let result = self._update(stack, rng, client, url).await;
if let Err(err) = &result {
let status = Status::UpdateFailed {
error: &err.string(),
}
.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
}
result
}
async fn _update<const MAX_PROPERTIES: usize>(
&mut self,
stack: &'static Stack<impl Driver>,
rng: &mut (impl RngCore + CryptoRng),
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
url: Url<'_>,
) -> Result<!, Error<DFU::Error>> {
info!("Preparing for OTA...");
let status = Status::PreparingUpdate.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
debug!("Making sure url is HTTPS");
if url.scheme() != UrlScheme::HTTPS {
return Err(Error::InvalidScheme);
}
// TODO: Clear out retained update message, currently gives implementation specific error
let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0];
let mut rx_buffer = [0; 1024];
let mut tx_buffer = [0; 1024];
let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer);
let addr = (ip, url.port_or_default());
debug!("Addr: {}", addr);
socket.connect(addr).await?;
let mut read_record_buffer = [0; 16384 * 2];
let mut write_record_buffer = [0; 16384];
let mut tls: TlsConnection<TcpSocket, Aes128GcmSha256> =
TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer);
tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng))
.await?;
debug!("Path: {}", url.path());
Request::get(url.path())
.host(url.host())
.build()
.write(&mut tls)
.await?;
let mut headers = [0; 1024];
let resp = Response::read(&mut tls, Method::GET, &mut headers).await?;
let mut body = resp.body().reader();
debug!("Erasing flash...");
let status = Status::Erasing.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
let writer = self.updater.prepare_update()?;
debug!("Writing...");
let status = Status::Writing { progress: 0 }.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
// The first 64 bytes of the file contain the signature
let mut signature = [0; 64];
body.read_exact(&mut signature).await?;
trace!("Signature: {:?}", signature);
let mut buffer = AlignedBuffer([0; 4096]);
let mut size = 0;
while let Ok(read) = body.read(&mut buffer.0).await {
if read == 0 {
break;
}
debug!("Writing chunk: {}", read);
writer
.write(size, &buffer.0[..read])
.map_err(Error::FlashError)?;
size += read as u32;
let status = Status::Writing { progress: size }.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
}
debug!("Total size: {}", size);
let status = Status::Verifying.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
self.updater
.verify_and_mark_updated(self.public_key, &signature, size)?;
let status = Status::UpdateComplete.json();
client
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
.await?;
client.disconnect().await?;
info!("Restarting in 5 seconds...");
Timer::after(Duration::from_secs(5)).await;
cortex_m::peripheral::SCB::sys_reset()
}
}