Files
iot_tools/updater/src/lib.rs
2025-01-11 05:25:31 +01:00

146 lines
4.3 KiB
Rust

#![no_std]
#![no_main]
#![feature(never_type)]
use defmt::*;
use embassy_boot::{AlignedBuffer, BlockingFirmwareUpdater};
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
use embassy_time::{Duration, Timer};
use embedded_io_async::Read;
use embedded_storage::nor_flash::NorFlash;
use picoserve::{
response::{self, IntoResponse, StatusCode},
routing::{get, put_service, PathRouter},
Router,
};
#[derive(Clone, Copy)]
struct UpdaterService<DFU, STATE>
where
DFU: NorFlash + 'static,
STATE: NorFlash + 'static,
DFU::Error: Format,
{
updater: &'static Mutex<CriticalSectionRawMutex, BlockingFirmwareUpdater<'static, DFU, STATE>>,
public_key: &'static [u8; 32],
}
impl<DFU, STATE> UpdaterService<DFU, STATE>
where
DFU: NorFlash,
STATE: NorFlash,
DFU::Error: Format,
{
fn new(
updater: &'static Mutex<
CriticalSectionRawMutex,
BlockingFirmwareUpdater<'static, DFU, STATE>,
>,
public_key: &'static [u8; 32],
) -> Self {
Self {
updater,
public_key,
}
}
}
impl<S, DFU, STATE> picoserve::routing::RequestHandlerService<S> for UpdaterService<DFU, STATE>
where
DFU: NorFlash + 'static,
STATE: NorFlash + 'static,
DFU::Error: Format,
{
async fn call_request_handler_service<
R: Read,
W: picoserve::response::ResponseWriter<Error = R::Error>,
>(
&self,
_state: &S,
_path_parameters: (),
mut request: picoserve::request::Request<'_, R>,
response_writer: W,
) -> Result<picoserve::ResponseSent, W::Error> {
let mut updater = self.updater.lock().await;
let writer = match updater.prepare_update() {
Ok(writer) => writer,
Err(err) => {
return response::Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
format_args!("{err:?}"),
)
.write_to(request.body_connection.finalize().await?, response_writer)
.await;
}
};
let mut reader = request.body_connection.body().reader();
let mut signature = [0; 64];
if let Err(err) = reader.read_exact(&mut signature).await {
return response::Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
format_args!("{err:?}"),
)
.write_to(request.body_connection.finalize().await?, response_writer)
.await;
}
trace!("Signature: {:?}", signature);
let mut buffer = AlignedBuffer([0; 4096]);
let mut size = 0;
while let Ok(read) = reader.read(&mut buffer.0).await {
if read == 0 {
break;
}
debug!("Writing chunk: {}", read);
if let Err(err) = writer.write(size, &buffer.0[..read]) {
return response::Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
format_args!("{err:?}"),
)
.write_to(request.body_connection.finalize().await?, response_writer)
.await;
}
size += read as u32;
}
let public_key = self.public_key;
if let Err(err) = updater.verify_and_mark_updated(public_key, &signature, size) {
return response::Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
format_args!("{err:?}"),
)
.write_to(request.body_connection.finalize().await?, response_writer)
.await;
}
"Update complete"
.write_to(request.body_connection.finalize().await?, response_writer)
.await?;
Timer::after(Duration::from_secs(1)).await;
cortex_m::peripheral::SCB::sys_reset();
}
}
pub fn firmware_router<S, DFU, STATE>(
version: &'static str,
updater: &'static Mutex<CriticalSectionRawMutex, BlockingFirmwareUpdater<'static, DFU, STATE>>,
public_key: &'static [u8; 32],
) -> Router<impl PathRouter<S>, S>
where
DFU: NorFlash + 'static,
STATE: NorFlash + 'static,
DFU::Error: Format,
{
let updater_service = UpdaterService::new(updater, public_key);
Router::new()
.route("/update", put_service(updater_service))
.route("/version", get(move || async move { version }))
}