From 220c68cd6521ff8fcd5c1654a09f315600a41cb1 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Tue, 10 Jan 2023 00:37:13 +0100 Subject: [PATCH] Converted more of the codebase to async --- Cargo.lock | 13 +++++++++++++ Cargo.toml | 2 ++ src/config.rs | 22 +++++++++++++--------- src/devices.rs | 23 ++++++++++++++--------- src/devices/audio_setup.rs | 8 ++++---- src/devices/contact_sensor.rs | 10 +++++----- src/devices/ikea_outlet.rs | 7 ++++--- src/devices/wake_on_lan.rs | 8 +++++--- src/hue_bridge.rs | 9 +++++---- src/light_sensor.rs | 18 ++++++++++-------- src/main.rs | 30 +++++++++++++++--------------- src/mqtt.rs | 4 +++- src/presence.rs | 14 ++++++++------ 13 files changed, 101 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18c9a6b..25860e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,17 @@ version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2cb2f989d18dd141ab8ae82f64d1a8cdd37e0840f73a406896cf5e99502fab61" +[[package]] +name = "async-recursion" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cda8f4bcc10624c4e85bc66b3f452cca98cfa5ca002dc83a16aad2367641bea" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.61" @@ -39,10 +50,12 @@ name = "automation" version = "0.1.0" dependencies = [ "anyhow", + "async-recursion", "async-trait", "axum", "bytes", "dotenvy", + "futures", "google-home", "impl_cast", "paste", diff --git a/Cargo.toml b/Cargo.toml index 5d2c6ea..4ef6c81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,8 @@ bytes = "1.3.0" pollster = "0.2.5" regex = "1.7.0" async-trait = "0.1.61" +async-recursion = "1.0.0" +futures = "0.3.25" [profile.release] lto=true diff --git a/src/config.rs b/src/config.rs index 38c5506..ea31135 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,6 @@ use std::{fs, error::Error, net::{Ipv4Addr, SocketAddr}, collections::HashMap}; +use async_recursion::async_recursion; use regex::{Regex, Captures}; use tracing::{debug, trace, error}; use rumqttc::{AsyncClient, has_wildcards}; @@ -207,15 +208,16 @@ impl Config { } impl Device { - pub fn into(self, identifier: String, config: &Config, client: AsyncClient) -> DeviceBox { - match self { + #[async_recursion] + pub async fn into(self, identifier: String, config: &Config, client: AsyncClient) -> DeviceBox { + let device: DeviceBox = match self { Device::IkeaOutlet { info, mqtt, kettle } => { trace!(id = identifier, "IkeaOutlet [{} in {:?}]", info.name, info.room); - Box::new(IkeaOutlet::new(identifier, info, mqtt, kettle, client)) + Box::new(IkeaOutlet::new(identifier, info, mqtt, kettle, client).await) }, Device::WakeOnLAN { info, mqtt, mac_address } => { trace!(id = identifier, "WakeOnLan [{} in {:?}]", info.name, info.room); - Box::new(WakeOnLAN::new(identifier, info, mqtt, mac_address, client)) + Box::new(WakeOnLAN::new(identifier, info, mqtt, mac_address, client).await) }, Device::KasaOutlet { ip } => { trace!(id = identifier, "KasaOutlet [{}]", identifier); @@ -224,8 +226,8 @@ impl Device { Device::AudioSetup { mqtt, mixer, speakers } => { trace!(id = identifier, "AudioSetup [{}]", identifier); // Create the child devices - let mixer = (*mixer).into(identifier.clone() + ".mixer", config, client.clone()); - let speakers = (*speakers).into(identifier.clone() + ".speakers", config, client.clone()); + let mixer = (*mixer).into(identifier.clone() + ".mixer", config, client.clone()).await; + let speakers = (*speakers).into(identifier.clone() + ".speakers", config, client.clone()).await; // The AudioSetup expects the children to be something that implements the OnOff trait // So let's convert the children and make sure OnOff is implemented @@ -238,15 +240,17 @@ impl Device { None => todo!("Handle this properly"), }; - Box::new(AudioSetup::new(identifier, mqtt, mixer, speakers, client)) + Box::new(AudioSetup::new(identifier, mqtt, mixer, speakers, client).await) }, Device::ContactSensor { mqtt, mut presence } => { trace!(id = identifier, "ContactSensor [{}]", identifier); if let Some(presence) = &mut presence { presence.generate_topic("contact", &identifier, &config); } - Box::new(ContactSensor::new(identifier, mqtt, presence, client)) + Box::new(ContactSensor::new(identifier, mqtt, presence, client).await) }, - } + }; + + return device; } } diff --git a/src/devices.rs b/src/devices.rs index 4465077..bed37ba 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -80,8 +80,8 @@ impl DeviceHandle { rx.await } - pub fn add_device(&self, device: DeviceBox) { - self.tx.send(Command::AddDevice { device }).block_on().unwrap(); + pub async fn add_device(&self, device: DeviceBox) { + self.tx.send(Command::AddDevice { device }).await.unwrap(); } } @@ -99,8 +99,10 @@ pub fn start(mut mqtt_rx: mqtt::Receiver, mut presence_rx: presence::Receiver, m break; } - if let Some(message) = &*mqtt_rx.borrow() { - devices.on_mqtt(message); + // @TODO Not ideal that we have to clone here, but not sure how to work around that + let message = mqtt_rx.borrow().clone(); + if let Some(message) = message { + devices.on_mqtt(&message).await; } } res = presence_rx.changed() => { @@ -116,7 +118,8 @@ pub fn start(mut mqtt_rx: mqtt::Receiver, mut presence_rx: presence::Receiver, m break; } - devices.on_darkness(*light_sensor_rx.borrow()); + let darkness = *light_sensor_rx.borrow(); + devices.on_darkness(darkness).await; } Some(cmd) = rx.recv() => devices.handle_cmd(cmd) } @@ -150,12 +153,13 @@ impl Devices { get_cast!(GoogleHomeDevice); } +#[async_trait] impl OnMqtt for Devices { - fn on_mqtt(&mut self, message: &rumqttc::Publish) { + async fn on_mqtt(&mut self, message: &rumqttc::Publish) { self.as_on_mqtts().iter_mut().for_each(|(id, listener)| { let _span = span!(Level::TRACE, "on_mqtt").entered(); trace!(id, "Handling"); - listener.on_mqtt(message); + listener.on_mqtt(message).block_on(); }) } } @@ -171,12 +175,13 @@ impl OnPresence for Devices { } } +#[async_trait] impl OnDarkness for Devices { - fn on_darkness(&mut self, dark: bool) { + async fn on_darkness(&mut self, dark: bool) { self.as_on_darknesss().iter_mut().for_each(|(id, device)| { let _span = span!(Level::TRACE, "on_darkness").entered(); trace!(id, "Handling"); - device.on_darkness(dark); + device.on_darkness(dark).block_on(); }) } } diff --git a/src/devices/audio_setup.rs b/src/devices/audio_setup.rs index ffd0cbe..cf739fc 100644 --- a/src/devices/audio_setup.rs +++ b/src/devices/audio_setup.rs @@ -2,7 +2,6 @@ use async_trait::async_trait; use google_home::traits; use rumqttc::{AsyncClient, matches}; use tracing::{error, warn, debug}; -use pollster::FutureExt as _; use crate::config::MqttDeviceConfig; use crate::mqtt::{OnMqtt, RemoteMessage, RemoteAction}; @@ -21,8 +20,8 @@ pub struct AudioSetup { } impl AudioSetup { - pub fn new(identifier: String, mqtt: MqttDeviceConfig, mixer: Box, speakers: Box, client: AsyncClient) -> Self { - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + pub async fn new(identifier: String, mqtt: MqttDeviceConfig, mixer: Box, speakers: Box, client: AsyncClient) -> Self { + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); Self { identifier, mqtt, mixer, speakers } } @@ -34,8 +33,9 @@ impl Device for AudioSetup { } } +#[async_trait] impl OnMqtt for AudioSetup { - fn on_mqtt(&mut self, message: &rumqttc::Publish) { + async fn on_mqtt(&mut self, message: &rumqttc::Publish) { if !matches(&message.topic, &self.mqtt.topic) { return; } diff --git a/src/devices/contact_sensor.rs b/src/devices/contact_sensor.rs index 1bf02ac..330294c 100644 --- a/src/devices/contact_sensor.rs +++ b/src/devices/contact_sensor.rs @@ -1,7 +1,6 @@ use std::time::Duration; use async_trait::async_trait; -use pollster::FutureExt; use rumqttc::{AsyncClient, matches}; use tokio::task::JoinHandle; use tracing::{error, debug, warn}; @@ -23,8 +22,8 @@ pub struct ContactSensor { } impl ContactSensor { - pub fn new(identifier: String, mqtt: MqttDeviceConfig, presence: Option, client: AsyncClient) -> Self { - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + pub async fn new(identifier: String, mqtt: MqttDeviceConfig, presence: Option, client: AsyncClient) -> Self { + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); Self { identifier, @@ -51,8 +50,9 @@ impl OnPresence for ContactSensor { } } +#[async_trait] impl OnMqtt for ContactSensor { - fn on_mqtt(&mut self, message: &rumqttc::Publish) { + async fn on_mqtt(&mut self, message: &rumqttc::Publish) { if !matches(&message.topic, &self.mqtt.topic) { return; } @@ -97,7 +97,7 @@ impl OnMqtt for ContactSensor { // This is to prevent the house from being marked as present for however long the // timeout is set when leaving the house if !self.overall_presence { - self.client.publish(topic, rumqttc::QoS::AtLeastOnce, false, serde_json::to_string(&PresenceMessage::new(true)).unwrap()).block_on().unwrap(); + self.client.publish(topic, rumqttc::QoS::AtLeastOnce, false, serde_json::to_string(&PresenceMessage::new(true)).unwrap()).await.unwrap(); } } else { // Once the door is closed again we start a timeout for removing the presence diff --git a/src/devices/ikea_outlet.rs b/src/devices/ikea_outlet.rs index c19cb9c..7d4b3c6 100644 --- a/src/devices/ikea_outlet.rs +++ b/src/devices/ikea_outlet.rs @@ -26,9 +26,9 @@ pub struct IkeaOutlet { } impl IkeaOutlet { - pub fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, kettle: Option, client: AsyncClient) -> Self { + pub async fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, kettle: Option, client: AsyncClient) -> Self { // @TODO Handle potential errors here - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); Self{ identifier, info, mqtt, kettle, client, last_known_state: false, handle: None } } @@ -47,8 +47,9 @@ impl Device for IkeaOutlet { } } +#[async_trait] impl OnMqtt for IkeaOutlet { - fn on_mqtt(&mut self, message: &Publish) { + async fn on_mqtt(&mut self, message: &Publish) { // Update the internal state based on what the device has reported if !matches(&message.topic, &self.mqtt.topic) { return; diff --git a/src/devices/wake_on_lan.rs b/src/devices/wake_on_lan.rs index bf391b1..493fdac 100644 --- a/src/devices/wake_on_lan.rs +++ b/src/devices/wake_on_lan.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use google_home::{GoogleHomeDevice, types::Type, device, traits::{self, Scene}, errors::{ErrorCode, DeviceError}}; use tracing::{debug, error}; use rumqttc::{AsyncClient, Publish, matches}; @@ -16,9 +17,9 @@ pub struct WakeOnLAN { } impl WakeOnLAN { - pub fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, mac_address: String, client: AsyncClient) -> Self { + pub async fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, mac_address: String, client: AsyncClient) -> Self { // @TODO Handle potential errors here - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); Self { identifier, info, mqtt, mac_address } } @@ -30,8 +31,9 @@ impl Device for WakeOnLAN { } } +#[async_trait] impl OnMqtt for WakeOnLAN { - fn on_mqtt(&mut self, message: &Publish) { + async fn on_mqtt(&mut self, message: &Publish) { if !matches(&message.topic, &self.mqtt.topic) { return; } diff --git a/src/hue_bridge.rs b/src/hue_bridge.rs index 2649a21..06ab8f5 100644 --- a/src/hue_bridge.rs +++ b/src/hue_bridge.rs @@ -1,7 +1,6 @@ use std::net::SocketAddr; use async_trait::async_trait; -use pollster::FutureExt; use serde::Serialize; use tracing::{warn, error, trace}; @@ -47,7 +46,8 @@ impl HueBridge { break; } - hue_bridge.on_darkness(*light_sensor_rx.borrow()); + let darkness = *light_sensor_rx.borrow(); + hue_bridge.on_darkness(darkness).await; } } } @@ -91,9 +91,10 @@ impl OnPresence for HueBridge { } } +#[async_trait] impl OnDarkness for HueBridge { - fn on_darkness(&mut self, dark: bool) { + async fn on_darkness(&mut self, dark: bool) { trace!("Bridging darkness to hue"); - self.set_flag(Flag::Darkness, dark).block_on(); + self.set_flag(Flag::Darkness, dark).await; } } diff --git a/src/light_sensor.rs b/src/light_sensor.rs index cf2fabb..90f0f36 100644 --- a/src/light_sensor.rs +++ b/src/light_sensor.rs @@ -1,13 +1,13 @@ -use pollster::FutureExt as _; +use async_trait::async_trait; use rumqttc::{matches, AsyncClient}; use tokio::sync::watch; use tracing::{error, trace, debug}; use crate::{config::{MqttDeviceConfig, LightSensorConfig}, mqtt::{self, OnMqtt, BrightnessMessage}}; - +#[async_trait] pub trait OnDarkness { - fn on_darkness(&mut self, dark: bool); + async fn on_darkness(&mut self, dark: bool); } pub type Receiver = watch::Receiver; @@ -21,16 +21,17 @@ struct LightSensor { tx: Sender, } -pub fn start(mut mqtt_rx: mqtt::Receiver, config: LightSensorConfig, client: AsyncClient) -> Receiver { - client.subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); +pub async fn start(mut mqtt_rx: mqtt::Receiver, config: LightSensorConfig, client: AsyncClient) -> Receiver { + client.subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); let (tx, is_dark) = watch::channel(false); let mut light_sensor = LightSensor { is_dark: is_dark.clone(), mqtt: config.mqtt, min: config.min, max: config.max, tx }; tokio::spawn(async move { while mqtt_rx.changed().await.is_ok() { - if let Some(message) = &*mqtt_rx.borrow() { - light_sensor.on_mqtt(message); + let message = mqtt_rx.borrow().clone(); + if let Some(message) = message { + light_sensor.on_mqtt(&message).await; } } @@ -40,8 +41,9 @@ pub fn start(mut mqtt_rx: mqtt::Receiver, config: LightSensorConfig, client: Asy return is_dark; } +#[async_trait] impl OnMqtt for LightSensor { - fn on_mqtt(&mut self, message: &rumqttc::Publish) { + async fn on_mqtt(&mut self, message: &rumqttc::Publish) { if !matches(&message.topic, &self.mqtt.topic) { return; } diff --git a/src/main.rs b/src/main.rs index 59f345e..74f8e5c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ use automation::{ use dotenvy::dotenv; use rumqttc::{AsyncClient, MqttOptions, Transport}; use tracing::{debug, error, info, metadata::LevelFilter}; +use futures::future::join_all; use google_home::{GoogleHome, Request}; use tracing_subscriber::EnvFilter; @@ -58,23 +59,22 @@ async fn main() { // Create a mqtt client and wrap the eventloop let (client, eventloop) = AsyncClient::new(mqttoptions, 10); let mqtt = mqtt::start(eventloop); - let presence = presence::start(mqtt.clone(), config.presence.clone(), client.clone()); - let light_sensor = - light_sensor::start(mqtt.clone(), config.light_sensor.clone(), client.clone()); + let presence = presence::start(mqtt.clone(), config.presence.clone(), client.clone()).await; + let light_sensor = light_sensor::start(mqtt.clone(), config.light_sensor.clone(), client.clone()).await; let devices = devices::start(mqtt, presence.clone(), light_sensor.clone()); - config - .devices - .clone() - .into_iter() - .map(|(identifier, device_config)| { - // This can technically block, but this only happens during start-up, so should not be - // a problem - device_config.into(identifier, &config, client.clone()) - }) - .for_each(|device| { - devices.add_device(device); - }); + join_all( + config + .devices + .clone() + .into_iter() + .map(|(identifier, device_config)| async { + // This can technically block, but this only happens during start-up, so should not be + // a problem + let device = device_config.into(identifier, &config, client.clone()).await; + devices.add_device(device).await; + }) + ).await; // Start the ntfy service if it is configured if let Some(ntfy_config) = config.ntfy { diff --git a/src/mqtt.rs b/src/mqtt.rs index 1b3067f..2c36ce8 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,11 +1,13 @@ +use async_trait::async_trait; use serde::{Serialize, Deserialize}; use tracing::{error, debug}; use rumqttc::{Publish, Event, Incoming, EventLoop}; use tokio::sync::watch; +#[async_trait] pub trait OnMqtt { - fn on_mqtt(&mut self, message: &Publish); + async fn on_mqtt(&mut self, message: &Publish); } pub type Receiver = watch::Receiver>; diff --git a/src/presence.rs b/src/presence.rs index 419b9c6..54a01c4 100644 --- a/src/presence.rs +++ b/src/presence.rs @@ -4,7 +4,6 @@ use async_trait::async_trait; use tokio::sync::watch; use tracing::{debug, error}; use rumqttc::{AsyncClient, matches}; -use pollster::FutureExt as _; use crate::{mqtt::{OnMqtt, PresenceMessage, self}, config::MqttDeviceConfig}; @@ -23,17 +22,19 @@ struct Presence { tx: Sender, } -pub fn start(mut mqtt_rx: mqtt::Receiver, mqtt: MqttDeviceConfig, client: AsyncClient) -> Receiver { +pub async fn start(mut mqtt_rx: mqtt::Receiver, mqtt: MqttDeviceConfig, client: AsyncClient) -> Receiver { // Subscribe to the relevant topics on mqtt - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).await.unwrap(); let (tx, overall_presence) = watch::channel(false); let mut presence = Presence { devices: HashMap::new(), overall_presence: overall_presence.clone(), mqtt, tx }; tokio::spawn(async move { while mqtt_rx.changed().await.is_ok() { - if let Some(message) = &*mqtt_rx.borrow() { - presence.on_mqtt(message); + // @TODO Not ideal that we have to clone here, but not sure how to work around that + let message = mqtt_rx.borrow().clone(); + if let Some(message) = message { + presence.on_mqtt(&message).await; } } @@ -43,8 +44,9 @@ pub fn start(mut mqtt_rx: mqtt::Receiver, mqtt: MqttDeviceConfig, client: AsyncC return overall_presence; } +#[async_trait] impl OnMqtt for Presence { - fn on_mqtt(&mut self, message: &rumqttc::Publish) { + async fn on_mqtt(&mut self, message: &rumqttc::Publish) { if !matches(&message.topic, &self.mqtt.topic) { return; }