diff --git a/Cargo.lock b/Cargo.lock index b870811..18c9a6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,12 +39,12 @@ name = "automation" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "bytes", "dotenvy", "google-home", "impl_cast", - "parking_lot", "paste", "pollster", "regex", @@ -593,29 +593,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff9f3fef3968a3ec5945535ed654cb38ff72d7495a25619e2247fb15a2ed9ba" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-sys 0.42.0", -] - [[package]] name = "paste" version = "1.0.10" @@ -684,15 +661,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags", -] - [[package]] name = "regex" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 0e16c9a..5d2c6ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } bytes = "1.3.0" pollster = "0.2.5" regex = "1.7.0" -parking_lot = "0.12.1" +async-trait = "0.1.61" [profile.release] lto=true diff --git a/google-home/src/fullfillment.rs b/google-home/src/fullfillment.rs index adaa7c2..1c109c9 100644 --- a/google-home/src/fullfillment.rs +++ b/google-home/src/fullfillment.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use crate::{request::{Request, Intent, self}, device::GoogleHomeDevice, response::{sync, ResponsePayload, query, execute, Response, self, State}, errors::{DeviceError, ErrorCode}}; +#[derive(Debug)] pub struct GoogleHome { user_id: String, // Add credentials so we can notify google home of actions diff --git a/google-home/src/lib.rs b/google-home/src/lib.rs index f22d1b1..fd218c4 100644 --- a/google-home/src/lib.rs +++ b/google-home/src/lib.rs @@ -12,4 +12,5 @@ mod attributes; pub use fullfillment::GoogleHome; pub use request::Request; +pub use response::Response; pub use device::GoogleHomeDevice; diff --git a/google-home/src/traits.rs b/google-home/src/traits.rs index df5c538..f891539 100644 --- a/google-home/src/traits.rs +++ b/google-home/src/traits.rs @@ -10,7 +10,7 @@ pub enum Trait { Scene, } -pub trait OnOff { +pub trait OnOff: std::fmt::Debug { fn is_command_only(&self) -> Option { None } @@ -25,7 +25,7 @@ pub trait OnOff { } impl_cast::impl_cast!(GoogleHomeDevice, OnOff); -pub trait Scene { +pub trait Scene: std::fmt::Debug { fn is_scene_reversible(&self) -> Option { None } diff --git a/src/config.rs b/src/config.rs index ccc8a8f..38c5506 100644 --- a/src/config.rs +++ b/src/config.rs @@ -75,7 +75,7 @@ fn default_ntfy_url() -> String { "https://ntfy.sh".into() } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct LightSensorConfig { #[serde(flatten)] pub mqtt: MqttDeviceConfig, diff --git a/src/devices.rs b/src/devices.rs index f2ed2f0..4465077 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -12,10 +12,13 @@ pub use self::contact_sensor::ContactSensor; use std::collections::HashMap; -use google_home::{GoogleHomeDevice, traits::OnOff}; +use async_trait::async_trait; +use google_home::{GoogleHomeDevice, traits::OnOff, GoogleHome}; +use pollster::FutureExt; +use tokio::sync::{oneshot, mpsc}; use tracing::{trace, debug, span, Level}; -use crate::{mqtt::OnMqtt, presence::OnPresence, light_sensor::OnDarkness}; +use crate::{mqtt::{OnMqtt, self}, presence::{OnPresence, self}, light_sensor::{OnDarkness, self}}; impl_cast::impl_cast!(Device, OnMqtt); impl_cast::impl_cast!(Device, OnPresence); @@ -23,13 +26,13 @@ impl_cast::impl_cast!(Device, OnDarkness); impl_cast::impl_cast!(Device, GoogleHomeDevice); impl_cast::impl_cast!(Device, OnOff); -pub trait Device: AsGoogleHomeDevice + AsOnMqtt + AsOnPresence + AsOnDarkness + AsOnOff { +pub trait Device: AsGoogleHomeDevice + AsOnMqtt + AsOnPresence + AsOnDarkness + AsOnOff + std::fmt::Debug { fn get_id(&self) -> String; } // @TODO Add an inner type that we can wrap with Arc> to make this type a little bit nicer // to work with -pub struct Devices { +struct Devices { devices: HashMap, } @@ -50,14 +53,93 @@ macro_rules! get_cast { }; } +#[derive(Debug)] +enum Command { + Fullfillment { + google_home: GoogleHome, + payload: google_home::Request, + tx: oneshot::Sender + }, + AddDevice { + device: DeviceBox, + } +} + pub type DeviceBox = Box; -impl Devices { - pub fn new() -> Self { - Self { devices: HashMap::new() } +#[derive(Clone)] +pub struct DeviceHandle { + tx: mpsc::Sender +} + +impl DeviceHandle { + // @TODO Improve error type + pub async fn fullfillment(&self, google_home: GoogleHome, payload: google_home::Request) -> Result { + let (tx, rx) = oneshot::channel(); + self.tx.send(Command::Fullfillment { google_home, payload, tx }).await.unwrap(); + rx.await } - pub fn add_device(&mut self, device: DeviceBox) { + pub fn add_device(&self, device: DeviceBox) { + self.tx.send(Command::AddDevice { device }).block_on().unwrap(); + } +} + +pub fn start(mut mqtt_rx: mqtt::Receiver, mut presence_rx: presence::Receiver, mut light_sensor_rx: light_sensor::Receiver) -> DeviceHandle { + + let mut devices = Devices { devices: HashMap::new() }; + + let (tx, mut rx) = mpsc::channel(100); + + tokio::spawn(async move { + loop { + tokio::select! { + res = mqtt_rx.changed() => { + if !res.is_ok() { + break; + } + + if let Some(message) = &*mqtt_rx.borrow() { + devices.on_mqtt(message); + } + } + res = presence_rx.changed() => { + if !res.is_ok() { + break; + } + + let presence = *presence_rx.borrow(); + devices.on_presence(presence).await; + } + res = light_sensor_rx.changed() => { + if !res.is_ok() { + break; + } + + devices.on_darkness(*light_sensor_rx.borrow()); + } + Some(cmd) = rx.recv() => devices.handle_cmd(cmd) + } + } + + unreachable!("Did not expect this"); + }); + + return DeviceHandle { tx }; +} + +impl Devices { + fn handle_cmd(&mut self, cmd: Command) { + match cmd { + Command::Fullfillment { google_home, payload, tx } => { + let result = google_home.handle_request(payload, &mut self.as_google_home_devices()).unwrap(); + tx.send(result).ok(); + }, + Command::AddDevice { device } => self.add_device(device), + } + } + + fn add_device(&mut self, device: DeviceBox) { debug!(id = device.get_id(), "Adding device"); self.devices.insert(device.get_id(), device); } @@ -66,14 +148,6 @@ impl Devices { get_cast!(OnPresence); get_cast!(OnDarkness); get_cast!(GoogleHomeDevice); - get_cast!(OnOff); - - pub fn get_device(&mut self, name: &str) -> Option<&mut dyn Device> { - if let Some(device) = self.devices.get_mut(name) { - return Some(device.as_mut()); - } - return None; - } } impl OnMqtt for Devices { @@ -86,12 +160,13 @@ impl OnMqtt for Devices { } } +#[async_trait] impl OnPresence for Devices { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { self.as_on_presences().iter_mut().for_each(|(id, device)| { let _span = span!(Level::TRACE, "on_presence").entered(); trace!(id, "Handling"); - device.on_presence(presence); + device.on_presence(presence).block_on(); }) } } diff --git a/src/devices/audio_setup.rs b/src/devices/audio_setup.rs index 28cf8fa..ffd0cbe 100644 --- a/src/devices/audio_setup.rs +++ b/src/devices/audio_setup.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use google_home::traits; use rumqttc::{AsyncClient, matches}; use tracing::{error, warn, debug}; @@ -11,6 +12,7 @@ use super::Device; // @TODO Ideally we store am Arc to the childern devices, // that way they hook into everything just like all other devices +#[derive(Debug)] pub struct AudioSetup { identifier: String, mqtt: MqttDeviceConfig, @@ -71,8 +73,9 @@ impl OnMqtt for AudioSetup { } } +#[async_trait] impl OnPresence for AudioSetup { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { // Turn off the audio setup when we leave the house if !presence { debug!(id = self.identifier, "Turning devices off"); diff --git a/src/devices/contact_sensor.rs b/src/devices/contact_sensor.rs index c0cfe8b..1bf02ac 100644 --- a/src/devices/contact_sensor.rs +++ b/src/devices/contact_sensor.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use async_trait::async_trait; use pollster::FutureExt; use rumqttc::{AsyncClient, matches}; use tokio::task::JoinHandle; @@ -9,6 +10,7 @@ use crate::{config::{MqttDeviceConfig, PresenceDeviceConfig}, mqtt::{OnMqtt, Con use super::Device; +#[derive(Debug)] pub struct ContactSensor { identifier: String, mqtt: MqttDeviceConfig, @@ -42,8 +44,9 @@ impl Device for ContactSensor { } } +#[async_trait] impl OnPresence for ContactSensor { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { self.overall_presence = presence; } } diff --git a/src/devices/ikea_outlet.rs b/src/devices/ikea_outlet.rs index 6d40f18..c19cb9c 100644 --- a/src/devices/ikea_outlet.rs +++ b/src/devices/ikea_outlet.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use async_trait::async_trait; use google_home::errors::ErrorCode; use google_home::{GoogleHomeDevice, device, types::Type, traits}; use rumqttc::{AsyncClient, Publish, matches}; @@ -12,6 +13,7 @@ use crate::devices::Device; use crate::mqtt::{OnMqtt, OnOffMessage}; use crate::presence::OnPresence; +#[derive(Debug)] pub struct IkeaOutlet { identifier: String, info: InfoConfig, @@ -108,12 +110,13 @@ impl OnMqtt for IkeaOutlet { } } +#[async_trait] impl OnPresence for IkeaOutlet { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { // Turn off the outlet when we leave the house if !presence { debug!(id = self.identifier, "Turning device off"); - set_on(self.client.clone(), self.mqtt.topic.clone(), false).block_on(); + set_on(self.client.clone(), self.mqtt.topic.clone(), false).await; } } } diff --git a/src/devices/kasa_outlet.rs b/src/devices/kasa_outlet.rs index 6a58649..abbe956 100644 --- a/src/devices/kasa_outlet.rs +++ b/src/devices/kasa_outlet.rs @@ -6,6 +6,7 @@ use serde::{Serialize, Deserialize}; use super::Device; +#[derive(Debug)] pub struct KasaOutlet { identifier: String, addr: SocketAddr, diff --git a/src/devices/wake_on_lan.rs b/src/devices/wake_on_lan.rs index ebb4dc5..bf391b1 100644 --- a/src/devices/wake_on_lan.rs +++ b/src/devices/wake_on_lan.rs @@ -7,6 +7,7 @@ use crate::{config::{InfoConfig, MqttDeviceConfig}, mqtt::{OnMqtt, ActivateMessa use super::Device; +#[derive(Debug)] pub struct WakeOnLAN { identifier: String, info: InfoConfig, diff --git a/src/hue_bridge.rs b/src/hue_bridge.rs index a38d29a..2649a21 100644 --- a/src/hue_bridge.rs +++ b/src/hue_bridge.rs @@ -1,10 +1,11 @@ use std::net::SocketAddr; +use async_trait::async_trait; use pollster::FutureExt; use serde::Serialize; use tracing::{warn, error, trace}; -use crate::{config::{HueBridgeConfig, Flags}, presence::OnPresence, light_sensor::OnDarkness}; +use crate::{config::{HueBridgeConfig, Flags}, presence::{OnPresence, self}, light_sensor::{OnDarkness, self}}; pub enum Flag { Presence, @@ -23,15 +24,39 @@ struct FlagMessage { } impl HueBridge { - pub fn new(config: HueBridgeConfig) -> Self { - Self { + pub fn create(mut presence_rx: presence::Receiver, mut light_sensor_rx: light_sensor::Receiver, config: HueBridgeConfig) { + let mut hue_bridge = Self { addr: (config.ip, 80).into(), login: config.login, flags: config.flags, - } + }; + + tokio::spawn(async move { + loop { + tokio::select! { + res = presence_rx.changed() => { + if !res.is_ok() { + break; + } + + let presence = *presence_rx.borrow(); + hue_bridge.on_presence(presence).await; + } + res = light_sensor_rx.changed() => { + if !res.is_ok() { + break; + } + + hue_bridge.on_darkness(*light_sensor_rx.borrow()); + } + } + } + + unreachable!("Did not expect this"); + }); } - pub fn set_flag(&self, flag: Flag, value: bool) { + pub async fn set_flag(&self, flag: Flag, value: bool) { let flag = match flag { Flag::Presence => self.flags.presence, Flag::Darkness => self.flags.darkness, @@ -42,7 +67,7 @@ impl HueBridge { .put(url) .json(&FlagMessage { flag: value }) .send() - .block_on(); + .await; match res { Ok(res) => { @@ -58,16 +83,17 @@ impl HueBridge { } } +#[async_trait] impl OnPresence for HueBridge { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { trace!("Bridging presence to hue"); - self.set_flag(Flag::Presence, presence); + self.set_flag(Flag::Presence, presence).await; } } impl OnDarkness for HueBridge { fn on_darkness(&mut self, dark: bool) { trace!("Bridging darkness to hue"); - self.set_flag(Flag::Darkness, dark); + self.set_flag(Flag::Darkness, dark).block_on(); } } diff --git a/src/light_sensor.rs b/src/light_sensor.rs index 144a67d..cf2fabb 100644 --- a/src/light_sensor.rs +++ b/src/light_sensor.rs @@ -1,44 +1,43 @@ -use std::sync::Weak; - -use parking_lot::RwLock; use pollster::FutureExt as _; -use rumqttc::{AsyncClient, matches}; -use tracing::{span, Level, error, trace, debug}; +use rumqttc::{matches, AsyncClient}; +use tokio::sync::watch; +use tracing::{error, trace, debug}; -use crate::{config::{MqttDeviceConfig, LightSensorConfig}, mqtt::{OnMqtt, BrightnessMessage}}; +use crate::{config::{MqttDeviceConfig, LightSensorConfig}, mqtt::{self, OnMqtt, BrightnessMessage}}; pub trait OnDarkness { fn on_darkness(&mut self, dark: bool); } -pub struct LightSensor { - listeners: Vec>>, - is_dark: bool, +pub type Receiver = watch::Receiver; +type Sender = watch::Sender; + +struct LightSensor { + is_dark: Receiver, mqtt: MqttDeviceConfig, min: isize, max: isize, + tx: Sender, } -impl LightSensor { - pub fn new(config: LightSensorConfig, client: AsyncClient) -> Self { - client.subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); +pub fn start(mut mqtt_rx: mqtt::Receiver, config: LightSensorConfig, client: AsyncClient) -> Receiver { + client.subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); - Self { listeners: Vec::new(), is_dark: false, mqtt: config.mqtt, min: config.min, max: config.max } - } + let (tx, is_dark) = watch::channel(false); + let mut light_sensor = LightSensor { is_dark: is_dark.clone(), mqtt: config.mqtt, min: config.min, max: config.max, tx }; - pub fn add_listener(&mut self, listener: Weak>) { - self.listeners.push(listener); - } - - pub fn notify(dark: bool, listeners: Vec>>) { - let _span = span!(Level::TRACE, "darkness_update").entered(); - listeners.into_iter().for_each(|listener| { - if let Some(listener) = listener.upgrade() { - listener.write().on_darkness(dark); + tokio::spawn(async move { + while mqtt_rx.changed().await.is_ok() { + if let Some(message) = &*mqtt_rx.borrow() { + light_sensor.on_mqtt(message); } - }) - } + } + + unreachable!("Did not expect this"); + }); + + return is_dark; } impl OnMqtt for LightSensor { @@ -63,19 +62,13 @@ impl OnMqtt for LightSensor { trace!("It is light"); false } else { - trace!("In between min ({}) and max ({}) value, keeping current state: {}", self.min, self.max, self.is_dark); - self.is_dark + trace!("In between min ({}) and max ({}) value, keeping current state: {}", self.min, self.max, *self.is_dark.borrow()); + *self.is_dark.borrow() }; - if is_dark != self.is_dark { + if is_dark != *self.is_dark.borrow() { debug!("Dark state has changed: {is_dark}"); - self.is_dark = is_dark; - self.listeners.retain(|listener| listener.strong_count() > 0); - let listeners = self.listeners.clone(); - - tokio::task::spawn_blocking(move || { - LightSensor::notify(is_dark, listeners) - }); + self.tx.send(is_dark).ok(); } } } diff --git a/src/main.rs b/src/main.rs index 2d51bc9..59f345e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,27 @@ #![feature(async_closure)] -use std::{time::Duration, sync::Arc, process}; -use parking_lot::RwLock; +use std::{process, time::Duration}; -use axum::{Router, Json, routing::post, http::StatusCode, extract::FromRef}; +use axum::{extract::FromRef, http::StatusCode, routing::post, Json, Router}; -use automation::{config::{Config, OpenIDConfig}, presence::Presence, ntfy::Ntfy, light_sensor::LightSensor, hue_bridge::HueBridge, auth::User}; +use automation::{ + auth::User, + config::{Config, OpenIDConfig}, + devices, + hue_bridge::HueBridge, + light_sensor, mqtt, + ntfy::Ntfy, + presence, +}; use dotenvy::dotenv; -use rumqttc::{MqttOptions, Transport, AsyncClient}; -use tracing::{error, info, metadata::LevelFilter}; +use rumqttc::{AsyncClient, MqttOptions, Transport}; +use tracing::{debug, error, info, metadata::LevelFilter}; -use automation::{devices::Devices, mqtt::Mqtt}; use google_home::{GoogleHome, Request}; use tracing_subscriber::EnvFilter; #[derive(Clone)] struct AppState { - pub openid: OpenIDConfig + pub openid: OpenIDConfig, } impl FromRef for automation::config::OpenIDConfig { @@ -32,9 +38,7 @@ async fn main() { .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(); - tracing_subscriber::fmt() - .with_env_filter(filter) - .init(); + tracing_subscriber::fmt().with_env_filter(filter).init(); let config = std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.toml".to_owned()); let config = Config::build(&config).unwrap_or_else(|err| { @@ -53,14 +57,15 @@ async fn main() { // Create a mqtt client and wrap the eventloop let (client, eventloop) = AsyncClient::new(mqttoptions, 10); - let mut mqtt = Mqtt::new(eventloop); + let mqtt = mqtt::start(eventloop); + let presence = presence::start(mqtt.clone(), config.presence.clone(), client.clone()); + let light_sensor = + light_sensor::start(mqtt.clone(), config.light_sensor.clone(), client.clone()); - // Create device holder and register it as listener for mqtt - let devices = Arc::new(RwLock::new(Devices::new())); - mqtt.add_listener(Arc::downgrade(&devices)); - - // Turn the config into actual devices and add them - config.devices.clone() + let devices = devices::start(mqtt, presence.clone(), light_sensor.clone()); + config + .devices + .clone() .into_iter() .map(|(identifier, device_config)| { // This can technically block, but this only happens during start-up, so should not be @@ -68,57 +73,38 @@ async fn main() { device_config.into(identifier, &config, client.clone()) }) .for_each(|device| { - devices.write().add_device(device); + devices.add_device(device); }); - // Setup presence system - let mut presence = Presence::new(config.presence, client.clone()); - // Register devices as presence listener - presence.add_listener(Arc::downgrade(&devices)); - - let mut light_sensor = LightSensor::new(config.light_sensor, client.clone()); - light_sensor.add_listener(Arc::downgrade(&devices)); - - let ntfy; + // Start the ntfy service if it is configured if let Some(ntfy_config) = config.ntfy { - ntfy = Arc::new(RwLock::new(Ntfy::new(ntfy_config))); - presence.add_listener(Arc::downgrade(&ntfy)); + Ntfy::create(presence.clone(), ntfy_config); } - let hue_bridge; + // Start he hue bridge if it is configured if let Some(hue_bridge_config) = config.hue_bridge { - hue_bridge = Arc::new(RwLock::new(HueBridge::new(hue_bridge_config))); - presence.add_listener(Arc::downgrade(&hue_bridge)); - light_sensor.add_listener(Arc::downgrade(&hue_bridge)); + HueBridge::create(presence.clone(), light_sensor.clone(), hue_bridge_config); } - // Register presence as mqtt listener - let presence = Arc::new(RwLock::new(presence)); - mqtt.add_listener(Arc::downgrade(&presence)); - - let light_sensor = Arc::new(RwLock::new(light_sensor)); - mqtt.add_listener(Arc::downgrade(&light_sensor)); - - // Start mqtt, this spawns a seperate async task - mqtt.start(); - // Create google home fullfillment route - let fullfillment = Router::new() - .route("/google_home", post(async move |user: User, Json(payload): Json| { - // Handle request might block, so we need to spawn a blocking task - tokio::task::spawn_blocking(move || { - let gc = GoogleHome::new(&user.preferred_username); - let result = gc.handle_request(payload, &mut devices.write().as_google_home_devices()).unwrap(); + let fullfillment = Router::new().route( + "/google_home", + post(async move |user: User, Json(payload): Json| { + debug!(username = user.preferred_username, "{payload:?}"); + let gc = GoogleHome::new(&user.preferred_username); + let result = devices.fullfillment(gc, payload).await.unwrap(); - return (StatusCode::OK, Json(result)); - }).await.unwrap() - })); + debug!(username = user.preferred_username, "{result:?}"); + + return (StatusCode::OK, Json(result)); + }), + ); // Combine together all the routes let app = Router::new() .nest("/fullfillment", fullfillment) .with_state(AppState { - openid: config.openid + openid: config.openid, }); // Start the web server diff --git a/src/mqtt.rs b/src/mqtt.rs index 1e99468..1b3067f 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,67 +1,37 @@ -use std::sync::Weak; -use parking_lot::RwLock; use serde::{Serialize, Deserialize}; -use tracing::{error, debug, span, Level}; +use tracing::{error, debug}; use rumqttc::{Publish, Event, Incoming, EventLoop}; -use tokio::task::JoinHandle; +use tokio::sync::watch; pub trait OnMqtt { fn on_mqtt(&mut self, message: &Publish); } -// @TODO Maybe rename this to make it clear it has to do with mqtt -pub struct Mqtt { - listeners: Vec>>, - eventloop: EventLoop, -} +pub type Receiver = watch::Receiver>; -impl Mqtt { - pub fn new(eventloop: EventLoop) -> Self { - return Self { listeners: Vec::new(), eventloop } - } - - fn notify(message: Publish, listeners: Vec>>) { - let _span = span!(Level::TRACE, "mqtt_message").entered(); - listeners.into_iter().for_each(|listener| { - if let Some(listener) = listener.upgrade() { - listener.write().on_mqtt(&message); +pub fn start(mut eventloop: EventLoop) -> Receiver { + let (tx, rx) = watch::channel(None); + tokio::spawn(async move { + debug!("Listening for MQTT events"); + loop { + let notification = eventloop.poll().await; + match notification { + Ok(Event::Incoming(Incoming::Publish(p))) => { + tx.send(Some(p)).ok(); + }, + Ok(..) => continue, + Err(err) => { + error!("{}", err); + break + }, } - }) - } + } - pub fn add_listener(&mut self, listener: Weak>) { - self.listeners.push(listener); - } + todo!("Error in MQTT (most likely lost connection to mqtt server), we need to handle these errors!"); + }); - pub fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { - debug!("Listening for MQTT events"); - loop { - let notification = self.eventloop.poll().await; - match notification { - Ok(Event::Incoming(Incoming::Publish(p))) => { - // Remove non-existing listeners - self.listeners.retain(|listener| listener.strong_count() > 0); - // Clone the listeners - let listeners = self.listeners.clone(); - - // Notify might block, so we spawn a blocking task - tokio::task::spawn_blocking(move || { - Mqtt::notify(p, listeners); - }); - }, - Ok(..) => continue, - Err(err) => { - error!("{}", err); - break - }, - } - } - - todo!("Error in MQTT (most likely lost connection to mqtt server), we need to handle these errors!"); - }) - } + return rx; } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/ntfy.rs b/src/ntfy.rs index e754caa..f77bf19 100644 --- a/src/ntfy.rs +++ b/src/ntfy.rs @@ -1,11 +1,11 @@ use std::collections::HashMap; +use async_trait::async_trait; use tracing::{warn, error, debug}; use serde::Serialize; use serde_repr::*; -use pollster::FutureExt as _; -use crate::{presence::OnPresence, config::NtfyConfig}; +use crate::{presence::{self, OnPresence}, config::NtfyConfig}; pub struct Ntfy { base_url: String, @@ -88,13 +88,22 @@ impl Notification { } impl Ntfy { - pub fn new(config: NtfyConfig) -> Self { - Self { base_url: config.url, topic: config.topic } + pub fn create(mut rx: presence::Receiver, config: NtfyConfig) { + let mut ntfy = Self { base_url: config.url, topic: config.topic }; + tokio::spawn(async move { + while rx.changed().await.is_ok() { + let presence = *rx.borrow(); + ntfy.on_presence(presence).await; + } + + unreachable!("Did not expect this"); + }); } } +#[async_trait] impl OnPresence for Ntfy { - fn on_presence(&mut self, presence: bool) { + async fn on_presence(&mut self, presence: bool) { // Setup extras for the broadcast let extras = HashMap::from([ ("cmd".into(), "presence".into()), @@ -123,7 +132,7 @@ impl OnPresence for Ntfy { .post(self.base_url.clone()) .json(¬ification) .send() - .block_on(); + .await; if let Err(err) = res { error!("Something went wrong while sending the notifcation: {err}"); diff --git a/src/presence.rs b/src/presence.rs index d6c570a..419b9c6 100644 --- a/src/presence.rs +++ b/src/presence.rs @@ -1,42 +1,46 @@ -use std::{sync::Weak, collections::HashMap}; +use std::collections::HashMap; -use parking_lot::RwLock; -use tracing::{debug, span, Level, error}; +use async_trait::async_trait; +use tokio::sync::watch; +use tracing::{debug, error}; use rumqttc::{AsyncClient, matches}; use pollster::FutureExt as _; -use crate::{mqtt::{OnMqtt, PresenceMessage}, config::MqttDeviceConfig}; +use crate::{mqtt::{OnMqtt, PresenceMessage, self}, config::MqttDeviceConfig}; +#[async_trait] pub trait OnPresence { - fn on_presence(&mut self, presence: bool); + async fn on_presence(&mut self, presence: bool); } -pub struct Presence { - listeners: Vec>>, +pub type Receiver = watch::Receiver; +type Sender = watch::Sender; + +struct Presence { devices: HashMap, - overall_presence: bool, + overall_presence: Receiver, mqtt: MqttDeviceConfig, + tx: Sender, } -impl Presence { - pub fn new(mqtt: MqttDeviceConfig, client: AsyncClient) -> Self { - client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); +pub fn start(mut mqtt_rx: mqtt::Receiver, mqtt: MqttDeviceConfig, client: AsyncClient) -> Receiver { + // Subscribe to the relevant topics on mqtt + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); - Self { listeners: Vec::new(), devices: HashMap::new(), overall_presence: false, mqtt } - } + let (tx, overall_presence) = watch::channel(false); + let mut presence = Presence { devices: HashMap::new(), overall_presence: overall_presence.clone(), mqtt, tx }; - pub fn add_listener(&mut self, listener: Weak>) { - self.listeners.push(listener); - } - - pub fn notify(presence: bool, listeners: Vec>>) { - let _span = span!(Level::TRACE, "presence_update").entered(); - listeners.into_iter().for_each(|listener| { - if let Some(listener) = listener.upgrade() { - listener.write().on_presence(presence); + tokio::spawn(async move { + while mqtt_rx.changed().await.is_ok() { + if let Some(message) = &*mqtt_rx.borrow() { + presence.on_mqtt(message); } - }) - } + } + + unreachable!("Did not expect this"); + }); + + return overall_presence; } impl OnMqtt for Presence { @@ -66,19 +70,9 @@ impl OnMqtt for Presence { } let overall_presence = self.devices.iter().any(|(_, v)| *v); - if overall_presence != self.overall_presence { + if overall_presence != *self.overall_presence.borrow() { debug!("Overall presence updated: {overall_presence}"); - self.overall_presence = overall_presence; - - // Remove non-existing listeners - self.listeners.retain(|listener| listener.strong_count() > 0); - // Clone the listeners - let listeners = self.listeners.clone(); - - // Notify might block, so we spawn a blocking task - tokio::task::spawn_blocking(move || { - Presence::notify(overall_presence, listeners); - }); + self.tx.send(overall_presence).ok(); } } }