Reworked to serve API using picoserve
This commit is contained in:
@@ -4,24 +4,16 @@
|
||||
|
||||
use defmt::*;
|
||||
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater};
|
||||
use embassy_net::{dns::DnsQueryType, driver::Driver, tcp::TcpSocket, Stack};
|
||||
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
|
||||
use embassy_time::{Duration, Timer};
|
||||
use embedded_io_async::{Read, Write};
|
||||
use embedded_io_async::Read;
|
||||
use embedded_storage::nor_flash::NorFlash;
|
||||
use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext};
|
||||
use heapless::Vec;
|
||||
use nourl::{Url, UrlScheme};
|
||||
use rand_core::{CryptoRng, RngCore};
|
||||
use reqwless::{
|
||||
request::{Method, Request, RequestBuilder},
|
||||
response::Response,
|
||||
};
|
||||
use rust_mqtt::{
|
||||
client::{client::MqttClient, client_config::ClientConfig},
|
||||
packet::v5::publish_packet::QualityOfService,
|
||||
use picoserve::{
|
||||
response::{self, IntoResponse, StatusCode},
|
||||
routing::{get, put_service, PathRouter},
|
||||
Router,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use static_cell::StaticCell;
|
||||
|
||||
mod error;
|
||||
|
||||
@@ -40,213 +32,141 @@ pub enum Status<'a> {
|
||||
UpdateComplete,
|
||||
}
|
||||
|
||||
impl Status<'_> {
|
||||
fn json(&self) -> Vec<u8, 512> {
|
||||
serde_json_core::to_vec(self)
|
||||
.expect("This buffers size should be large enough to contain the serialized status")
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a wrapper around `BlockingFirmwareUpdater` that downloads signed updates
|
||||
/// from a HTTPS url.
|
||||
/// It also provides the current device status over MQTT
|
||||
// TODO: Make this the owner of the blocking firmware updater
|
||||
// TODO: When fixed, use the async firmware updater
|
||||
pub struct Updater<'a, DFU, STATE>
|
||||
#[derive(Clone, Copy)]
|
||||
struct UpdaterService<DFU, STATE>
|
||||
where
|
||||
DFU: NorFlash,
|
||||
STATE: NorFlash,
|
||||
DFU: NorFlash + 'static,
|
||||
STATE: NorFlash + 'static,
|
||||
DFU::Error: Format,
|
||||
{
|
||||
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
|
||||
|
||||
topic_status: &'static str,
|
||||
version: &'static str,
|
||||
updater: &'static Mutex<CriticalSectionRawMutex, BlockingFirmwareUpdater<'static, DFU, STATE>>,
|
||||
public_key: &'static [u8; 32],
|
||||
}
|
||||
|
||||
impl<'a, DFU, STATE> Updater<'a, DFU, STATE>
|
||||
impl<DFU, STATE> UpdaterService<DFU, STATE>
|
||||
where
|
||||
DFU: NorFlash,
|
||||
STATE: NorFlash,
|
||||
DFU::Error: Format,
|
||||
{
|
||||
/// Wrap the `BlockingFirmwareUpdater`
|
||||
pub fn new(
|
||||
updater: BlockingFirmwareUpdater<'a, DFU, STATE>,
|
||||
topic_status: &'static str,
|
||||
version: &'static str,
|
||||
fn new(
|
||||
updater: &'static Mutex<
|
||||
CriticalSectionRawMutex,
|
||||
BlockingFirmwareUpdater<'static, DFU, STATE>,
|
||||
>,
|
||||
public_key: &'static [u8; 32],
|
||||
) -> Self {
|
||||
Self {
|
||||
updater,
|
||||
topic_status,
|
||||
version,
|
||||
public_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set MQTT connection up to notify over MQTT when the device loses connection
|
||||
pub fn add_will<const MAX_PROPERTIES: usize>(
|
||||
impl<S, DFU, STATE> picoserve::routing::RequestHandlerService<S> for UpdaterService<DFU, STATE>
|
||||
where
|
||||
DFU: NorFlash + 'static,
|
||||
STATE: NorFlash + 'static,
|
||||
DFU::Error: Format,
|
||||
{
|
||||
async fn call_request_handler_service<
|
||||
R: Read,
|
||||
W: picoserve::response::ResponseWriter<Error = R::Error>,
|
||||
>(
|
||||
&self,
|
||||
config: &mut ClientConfig<'_, MAX_PROPERTIES, impl RngCore>,
|
||||
) {
|
||||
static MSG: StaticCell<Vec<u8, 512>> = StaticCell::new();
|
||||
let msg = MSG.init(Status::Disconnected.json());
|
||||
config.add_will(self.topic_status, msg, true);
|
||||
}
|
||||
|
||||
/// Mark the device is ready and booted, will notify over MQTT that the device is connected and the
|
||||
/// currently running firmware version
|
||||
pub async fn ready<const MAX_PROPERTIES: usize>(
|
||||
&mut self,
|
||||
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
|
||||
) -> Result<(), Error<DFU::Error>> {
|
||||
let status = Status::Connected {
|
||||
version: self.version,
|
||||
}
|
||||
.json();
|
||||
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, true)
|
||||
.await?;
|
||||
|
||||
self.updater.mark_booted()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download signed update from specified url and notify progress over MQTT
|
||||
pub async fn update<const MAX_PROPERTIES: usize>(
|
||||
&mut self,
|
||||
url: Url<'_>,
|
||||
stack: &'static Stack<impl Driver>,
|
||||
rng: &mut (impl RngCore + CryptoRng),
|
||||
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
|
||||
) -> Result<!, Error<DFU::Error>> {
|
||||
let result = self._update(url, stack, rng, client).await;
|
||||
|
||||
if let Err(err) = &result {
|
||||
let status = Status::UpdateFailed {
|
||||
error: &err.string(),
|
||||
_state: &S,
|
||||
_path_parameters: (),
|
||||
mut request: picoserve::request::Request<'_, R>,
|
||||
response_writer: W,
|
||||
) -> Result<picoserve::ResponseSent, W::Error> {
|
||||
let mut updater = self.updater.lock().await;
|
||||
let writer = match updater.prepare_update() {
|
||||
Ok(writer) => writer,
|
||||
Err(err) => {
|
||||
return response::Response::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format_args!("{err:?}"),
|
||||
)
|
||||
.write_to(request.body_connection.finalize().await?, response_writer)
|
||||
.await;
|
||||
}
|
||||
.json();
|
||||
};
|
||||
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
}
|
||||
let mut reader = request.body_connection.body().reader();
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn _update<const MAX_PROPERTIES: usize>(
|
||||
&mut self,
|
||||
url: Url<'_>,
|
||||
stack: &'static Stack<impl Driver>,
|
||||
rng: &mut (impl RngCore + CryptoRng),
|
||||
client: &mut MqttClient<'_, impl Write + Read, MAX_PROPERTIES, impl RngCore>,
|
||||
) -> Result<!, Error<DFU::Error>> {
|
||||
info!("Preparing for OTA...");
|
||||
let status = Status::PreparingUpdate.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
|
||||
debug!("Making sure url is HTTPS");
|
||||
if url.scheme() != UrlScheme::HTTPS {
|
||||
return Err(Error::InvalidScheme);
|
||||
}
|
||||
|
||||
// TODO: Clear out retained update message, currently gives implementation specific error
|
||||
|
||||
let ip = stack.dns_query(url.host(), DnsQueryType::A).await?[0];
|
||||
|
||||
let mut rx_buffer = [0; 1024];
|
||||
let mut tx_buffer = [0; 1024];
|
||||
|
||||
let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer);
|
||||
|
||||
let addr = (ip, url.port_or_default());
|
||||
debug!("Addr: {}", addr);
|
||||
socket.connect(addr).await?;
|
||||
|
||||
let mut read_record_buffer = [0; 16384 * 2];
|
||||
let mut write_record_buffer = [0; 16384];
|
||||
let mut tls: TlsConnection<TcpSocket, Aes128GcmSha256> =
|
||||
TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer);
|
||||
tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng))
|
||||
.await?;
|
||||
|
||||
debug!("Path: {}", url.path());
|
||||
Request::get(url.path())
|
||||
.host(url.host())
|
||||
.build()
|
||||
.write(&mut tls)
|
||||
.await?;
|
||||
|
||||
let mut headers = [0; 1024];
|
||||
let resp = Response::read(&mut tls, Method::GET, &mut headers).await?;
|
||||
|
||||
let mut body = resp.body().reader();
|
||||
|
||||
debug!("Erasing flash...");
|
||||
let status = Status::Erasing.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
|
||||
let writer = self.updater.prepare_update()?;
|
||||
|
||||
debug!("Writing...");
|
||||
let status = Status::Writing { progress: 0 }.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
|
||||
// The first 64 bytes of the file contain the signature
|
||||
let mut signature = [0; 64];
|
||||
body.read_exact(&mut signature).await?;
|
||||
if let Err(err) = reader.read_exact(&mut signature).await {
|
||||
return response::Response::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format_args!("{err:?}"),
|
||||
)
|
||||
.write_to(request.body_connection.finalize().await?, response_writer)
|
||||
.await;
|
||||
}
|
||||
|
||||
trace!("Signature: {:?}", signature);
|
||||
|
||||
let mut buffer = AlignedBuffer([0; 4096]);
|
||||
let mut size = 0;
|
||||
while let Ok(read) = body.read(&mut buffer.0).await {
|
||||
while let Ok(read) = reader.read(&mut buffer.0).await {
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
debug!("Writing chunk: {}", read);
|
||||
writer
|
||||
if let Err(err) = writer
|
||||
.write(size, &buffer.0[..read])
|
||||
.map_err(Error::FlashError)?;
|
||||
.map_err(Error::FlashError)
|
||||
{
|
||||
return response::Response::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format_args!("{err:?}"),
|
||||
)
|
||||
.write_to(request.body_connection.finalize().await?, response_writer)
|
||||
.await;
|
||||
}
|
||||
size += read as u32;
|
||||
|
||||
let status = Status::Writing { progress: size }.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
}
|
||||
debug!("Total size: {}", size);
|
||||
|
||||
let status = Status::Verifying.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
let public_key = self.public_key;
|
||||
if let Err(err) = updater.verify_and_mark_updated(public_key, &signature, size) {
|
||||
return response::Response::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format_args!("{err:?}"),
|
||||
)
|
||||
.write_to(request.body_connection.finalize().await?, response_writer)
|
||||
.await;
|
||||
}
|
||||
|
||||
"Update complete"
|
||||
.write_to(request.body_connection.finalize().await?, response_writer)
|
||||
.await?;
|
||||
|
||||
self.updater
|
||||
.verify_and_mark_updated(self.public_key, &signature, size)?;
|
||||
Timer::after(Duration::from_secs(1)).await;
|
||||
|
||||
let status = Status::UpdateComplete.json();
|
||||
client
|
||||
.send_message(self.topic_status, &status, QualityOfService::QoS1, false)
|
||||
.await?;
|
||||
|
||||
client.disconnect().await?;
|
||||
|
||||
info!("Restarting in 5 seconds...");
|
||||
Timer::after(Duration::from_secs(5)).await;
|
||||
|
||||
cortex_m::peripheral::SCB::sys_reset()
|
||||
cortex_m::peripheral::SCB::sys_reset();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn firmware_router<S, DFU, STATE>(
|
||||
version: &'static str,
|
||||
updater: &'static Mutex<CriticalSectionRawMutex, BlockingFirmwareUpdater<'static, DFU, STATE>>,
|
||||
public_key: &'static [u8; 32],
|
||||
) -> Router<impl PathRouter<S>, S>
|
||||
where
|
||||
DFU: NorFlash + 'static,
|
||||
STATE: NorFlash + 'static,
|
||||
DFU::Error: Format,
|
||||
{
|
||||
let updater_service = UpdaterService::new(updater, public_key);
|
||||
|
||||
Router::new()
|
||||
.route("/update", put_service(updater_service))
|
||||
.route("/version", get(move || async move { version }))
|
||||
.route(
|
||||
"/reset",
|
||||
get(|| async {
|
||||
cortex_m::peripheral::SCB::sys_reset();
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user