From eed8db4863a4ac191a929168eb2624064301c539 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Wed, 13 Sep 2023 01:51:47 +0200 Subject: [PATCH] Firmware can now only be downloaded over TLS --- Cargo.lock | 1 + Cargo.toml | 4 ++++ src/main.rs | 37 ++++++++++++++++++++++++------------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28b48e4..06f896a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1766,6 +1766,7 @@ dependencies = [ "embassy-time", "embedded-io-async", "embedded-storage", + "embedded-tls", "git-version", "heapless 0.7.16", "nourl", diff --git a/Cargo.toml b/Cargo.toml index 9e57078..aff25b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,10 @@ reqwless = { version = "0.5.0", features = ["defmt"] } embedded-storage = "0.3.0" const_format = "0.2.31" git-version = "0.3.5" +embedded-tls = { version = "0.15.0", default-features = false, features = [ + "async", + "defmt", +] } [patch.crates-io] embassy-executor = { git = "https://github.com/embassy-rs/embassy" } diff --git a/src/main.rs b/src/main.rs index d84b857..eec83a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use core::cell::RefCell; use embassy_boot_rp::{AlignedBuffer, BlockingFirmwareUpdater, FirmwareUpdaterConfig}; use embedded_storage::nor_flash::NorFlash; +use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; use heapless::{String, Vec}; use rand::{ rngs::{SmallRng, StdRng}, @@ -251,8 +252,7 @@ async fn main(spawner: Spawner) { // Use the Ring Oscillator of the RP2040 as a source of true randomness to seed the // cryptographically secure PRNG - let mut rng_rosc = RoscRng; - let mut rng = StdRng::from_rng(&mut rng_rosc).unwrap(); + let mut rng = StdRng::from_rng(&mut RoscRng).unwrap(); let stack = make_static!(Stack::new( net_device, @@ -281,7 +281,7 @@ async fn main(spawner: Spawner) { info!("IP Address: {}", cfg.address.address()); let mut rx_buffer = [0; 1024]; - let mut tx_buffer = [0; 4096]; + let mut tx_buffer = [0; 1024]; let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); // socket.set_timeout(Some(Duration::from_secs(10))); @@ -301,7 +301,7 @@ async fn main(spawner: Spawner) { MqttVersion::MQTTv5, // Use fast and simple PRNG to generate packet identifiers, there is no need for this to be // cryptographically secure - SmallRng::from_rng(&mut rng_rosc).unwrap(), + SmallRng::from_rng(&mut RoscRng).unwrap(), ); config.add_username(env!("MQTT_USERNAME")); @@ -313,7 +313,7 @@ async fn main(spawner: Spawner) { config.add_will(TOPIC_STATUS, &msg, true); let mut recv_buffer = [0; 1024]; - let mut write_buffer = [0; 4096]; + let mut write_buffer = [0; 1024]; let mut client = MqttClient::<_, 5, _>::new(socket, &mut write_buffer, &mut recv_buffer, config); @@ -388,7 +388,7 @@ async fn main(spawner: Spawner) { let url = message.get_url(); let url = Url::parse(url.as_str()).unwrap(); - attempt_update(stack, &mut updater, &mut client, url).await; + attempt_update(stack, &mut updater, &mut rng, &mut client, url).await; } } } @@ -412,6 +412,7 @@ const FLASH_SIZE: usize = 2 * 1024 * 1024; async fn attempt_update( stack: &'static Stack>, updater: &mut BlockingFirmwareUpdater<'_, F, F>, + rng: &mut StdRng, client: &mut MqttClient<'_, T, MAX_PROPERTIES, R>, url: Url<'_>, ) where @@ -428,23 +429,33 @@ async fn attempt_update( let ip = stack.dns_query(url.host(), DnsQueryType::A).await.unwrap()[0]; - let mut rx_buffer = [0; 4096 * 2]; + let mut rx_buffer = [0; 1024]; let mut tx_buffer = [0; 1024]; let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); + let addr = (ip, url.port_or_default()); debug!("Addr: {}", addr); socket.connect(addr).await.unwrap(); - debug!("Path: {}", url.path()); - Request::get(url.path()) - .build() - .write(&mut socket) + let mut read_record_buffer = [0; 16384 * 2]; + let mut write_record_buffer = [0; 16384]; + let mut tls: TlsConnection = + TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); + tls.open::<_, NoVerify>(TlsContext::new(&TlsConfig::new(), rng)) .await .unwrap(); - let mut headers = [0; 4096]; - let resp = Response::read(&mut socket, Method::GET, &mut headers) + debug!("Path: {}", url.path()); + Request::get(url.path()) + .host(url.host()) + .build() + .write(&mut tls) + .await + .unwrap(); + + let mut headers = [0; 1024]; + let resp = Response::read(&mut tls, Method::GET, &mut headers) .await .unwrap();