diff --git a/src/config.rs b/src/config.rs index f15b1bd..c2151fd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,6 +17,7 @@ use crate::{ error::{ConfigParseError, CreateDeviceError, MissingEnv}, hue_bridge::HueBridgeConfig, light_sensor::LightSensorConfig, + presence::PresenceConfig, }; #[derive(Debug, Deserialize)] @@ -27,7 +28,7 @@ pub struct Config { #[serde(default)] pub fullfillment: FullfillmentConfig, pub ntfy: Option, - pub presence: MqttDeviceConfig, + pub presence: PresenceConfig, pub light_sensor: LightSensorConfig, pub hue_bridge: Option, pub debug_bridge: Option, diff --git a/src/debug_bridge.rs b/src/debug_bridge.rs index 4891647..e618e59 100644 --- a/src/debug_bridge.rs +++ b/src/debug_bridge.rs @@ -1,12 +1,10 @@ -use async_trait::async_trait; use rumqttc::AsyncClient; use serde::Deserialize; use tracing::warn; use crate::{ - light_sensor::{self, OnDarkness}, + event::{Event, EventChannel}, mqtt::{DarknessMessage, PresenceMessage}, - presence::{self, OnPresence}, }; #[derive(Debug, Deserialize)] @@ -14,96 +12,53 @@ pub struct DebugBridgeConfig { pub topic: String, } -struct DebugBridge { - topic: String, - client: AsyncClient, -} - -impl DebugBridge { - pub fn new(config: DebugBridgeConfig, client: AsyncClient) -> Self { - Self { - topic: config.topic, - client, - } - } -} - -pub fn start( - mut presence_rx: presence::Receiver, - mut light_sensor_rx: light_sensor::Receiver, - config: DebugBridgeConfig, - client: AsyncClient, -) { - let mut debug_bridge = DebugBridge::new(config, client); +pub fn start(config: DebugBridgeConfig, event_channel: &EventChannel, client: AsyncClient) { + let mut rx = event_channel.get_rx(); tokio::spawn(async move { loop { - tokio::select! { - res = presence_rx.changed() => { - if res.is_err() { - break; - } - - let presence = *presence_rx.borrow(); - debug_bridge.on_presence(presence).await; + match rx.recv().await { + Ok(Event::Presence(presence)) => { + let message = PresenceMessage::new(presence); + let topic = format!("{}/presence", config.topic); + client + .publish( + topic, + rumqttc::QoS::AtLeastOnce, + true, + serde_json::to_string(&message).unwrap(), + ) + .await + .map_err(|err| { + warn!( + "Failed to update presence on {}/presence: {err}", + config.topic + ) + }) + .ok(); } - res = light_sensor_rx.changed() => { - if res.is_err() { - break; - } - - let darkness = *light_sensor_rx.borrow(); - debug_bridge.on_darkness(darkness).await; + Ok(Event::Darkness(dark)) => { + let message = DarknessMessage::new(dark); + let topic = format!("{}/darkness", config.topic); + client + .publish( + topic, + rumqttc::QoS::AtLeastOnce, + true, + serde_json::to_string(&message).unwrap(), + ) + .await + .map_err(|err| { + warn!( + "Failed to update presence on {}/presence: {err}", + config.topic + ) + }) + .ok(); } + Ok(_) => {} + Err(_) => todo!("Handle errors with the event channel properly"), } } - - unreachable!("Did not expect this"); }); } - -#[async_trait] -impl OnPresence for DebugBridge { - async fn on_presence(&mut self, presence: bool) { - let message = PresenceMessage::new(presence); - let topic = format!("{}/presence", self.topic); - self.client - .publish( - topic, - rumqttc::QoS::AtLeastOnce, - true, - serde_json::to_string(&message).unwrap(), - ) - .await - .map_err(|err| { - warn!( - "Failed to update presence on {}/presence: {err}", - self.topic - ) - }) - .ok(); - } -} - -#[async_trait] -impl OnDarkness for DebugBridge { - async fn on_darkness(&mut self, dark: bool) { - let message = DarknessMessage::new(dark); - let topic = format!("{}/darkness", self.topic); - self.client - .publish( - topic, - rumqttc::QoS::AtLeastOnce, - true, - serde_json::to_string(&message).unwrap(), - ) - .await - .map_err(|err| { - warn!( - "Failed to update presence on {}/presence: {err}", - self.topic - ) - }) - .ok(); - } -} diff --git a/src/devices.rs b/src/devices.rs index 311c345..e3501b7 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -12,7 +12,6 @@ pub use self::wake_on_lan::WakeOnLAN; use std::collections::HashMap; -use async_trait::async_trait; use google_home::{traits::OnOff, FullfillmentError, GoogleHome, GoogleHomeDevice}; use pollster::FutureExt; use rumqttc::{matches, AsyncClient, QoS}; @@ -21,9 +20,10 @@ use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, trace}; use crate::{ - light_sensor::{self, OnDarkness}, - mqtt::{self, OnMqtt}, - presence::{self, OnPresence}, + event::{Event, EventChannel}, + light_sensor::OnDarkness, + mqtt::OnMqtt, + presence::OnPresence, }; #[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + GoogleHomeDevice + OnOff)] @@ -92,37 +92,33 @@ impl DevicesHandle { } } -pub fn start( - mut mqtt_rx: mqtt::Receiver, - mut presence_rx: presence::Receiver, - mut light_sensor_rx: light_sensor::Receiver, - client: AsyncClient, -) -> DevicesHandle { +pub fn start(event_channel: &EventChannel, client: AsyncClient) -> DevicesHandle { let mut devices = Devices { devices: HashMap::new(), client, }; let (tx, mut rx) = mpsc::channel(100); + let mut event_rx = event_channel.get_rx(); tokio::spawn(async move { // TODO: Handle error better loop { tokio::select! { - Ok(message) = mqtt_rx.recv() => { - devices.on_mqtt(&message).await; - } - Ok(_) = presence_rx.changed() => { - let presence = *presence_rx.borrow(); - devices.on_presence(presence).await; - } - Ok(_) = light_sensor_rx.changed() => { - let darkness = *light_sensor_rx.borrow(); - devices.on_darkness(darkness).await; + event = event_rx.recv() => { + if event.is_err() { + todo!("Handle errors with the event channel properly") + } + devices.handle_event(event.unwrap()).await; } // TODO: Handle receiving None better, otherwise it might constantly run doing // nothing - Some(cmd) = rx.recv() => devices.handle_cmd(cmd).await + cmd = rx.recv() => { + if cmd.is_none() { + todo!("Handle errors with the cmd channel properly") + } + devices.handle_cmd(cmd.unwrap()).await; + } } } }); @@ -169,6 +165,43 @@ impl Devices { self.devices.insert(device.get_id().to_owned(), device); } + async fn handle_event(&mut self, event: Event) { + match event { + Event::MqttMessage(message) => { + self.get::() + .iter_mut() + .for_each(|(id, listener)| { + let subscribed = listener + .topics() + .iter() + .any(|topic| matches(&message.topic, topic)); + + if subscribed { + trace!(id, "Handling"); + listener.on_mqtt(&message).block_on(); + } + }) + } + Event::Darkness(dark) => { + self.get::() + .iter_mut() + .for_each(|(id, device)| { + trace!(id, "Handling"); + device.on_darkness(dark).block_on(); + }) + } + Event::Presence(presence) => { + self.get::() + .iter_mut() + .for_each(|(id, device)| { + trace!(id, "Handling"); + device.on_presence(presence).block_on(); + }) + } + Event::Ntfy(_) => {} + } + } + fn get(&mut self) -> HashMap<&str, &mut T> where T: ?Sized + 'static, @@ -180,53 +213,3 @@ impl Devices { .collect() } } - -#[async_trait] -impl OnMqtt for Devices { - fn topics(&self) -> Vec<&str> { - Vec::new() - } - - #[tracing::instrument(skip_all)] - async fn on_mqtt(&mut self, message: &rumqttc::Publish) { - self.get::() - .iter_mut() - .for_each(|(id, listener)| { - let subscribed = listener - .topics() - .iter() - .any(|topic| matches(&message.topic, topic)); - - if subscribed { - trace!(id, "Handling"); - listener.on_mqtt(message).block_on(); - } - }) - } -} - -#[async_trait] -impl OnPresence for Devices { - #[tracing::instrument(skip(self))] - async fn on_presence(&mut self, presence: bool) { - self.get::() - .iter_mut() - .for_each(|(id, device)| { - trace!(id, "Handling"); - device.on_presence(presence).block_on(); - }) - } -} - -#[async_trait] -impl OnDarkness for Devices { - #[tracing::instrument(skip(self))] - async fn on_darkness(&mut self, dark: bool) { - self.get::() - .iter_mut() - .for_each(|(id, device)| { - trace!(id, "Handling"); - device.on_darkness(dark).block_on(); - }) - } -} diff --git a/src/event.rs b/src/event.rs new file mode 100644 index 0000000..37b24de --- /dev/null +++ b/src/event.rs @@ -0,0 +1,39 @@ +use rumqttc::Publish; +use tokio::sync::broadcast; + +use crate::ntfy; + +#[derive(Clone)] +pub enum Event { + MqttMessage(Publish), + Darkness(bool), + Presence(bool), + Ntfy(ntfy::Notification), +} + +pub type Sender = broadcast::Sender; +pub type Receiver = broadcast::Receiver; + +pub struct EventChannel(Sender); + +impl EventChannel { + pub fn new() -> Self { + let (tx, _) = broadcast::channel(100); + + Self(tx) + } + + pub fn get_rx(&self) -> Receiver { + self.0.subscribe() + } + + pub fn get_tx(&self) -> Sender { + self.0.clone() + } +} + +impl Default for EventChannel { + fn default() -> Self { + Self::new() + } +} diff --git a/src/hue_bridge.rs b/src/hue_bridge.rs index 37e836e..39e5a0a 100644 --- a/src/hue_bridge.rs +++ b/src/hue_bridge.rs @@ -1,13 +1,9 @@ use std::net::{Ipv4Addr, SocketAddr}; -use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::{error, trace, warn}; -use crate::{ - light_sensor::{self, OnDarkness}, - presence::{self, OnPresence}, -}; +use crate::event::{Event, EventChannel}; pub enum Flag { Presence, @@ -26,10 +22,11 @@ pub struct HueBridgeConfig { pub login: String, pub flags: FlagIDs, } + struct HueBridge { addr: SocketAddr, login: String, - flags: FlagIDs, + flag_ids: FlagIDs, } #[derive(Debug, Serialize)] @@ -42,18 +39,18 @@ impl HueBridge { Self { addr: (config.ip, 80).into(), login: config.login, - flags: config.flags, + flag_ids: config.flags, } } pub async fn set_flag(&self, flag: Flag, value: bool) { - let flag = match flag { - Flag::Presence => self.flags.presence, - Flag::Darkness => self.flags.darkness, + let flag_id = match flag { + Flag::Presence => self.flag_ids.presence, + Flag::Darkness => self.flag_ids.darkness, }; let url = format!( - "http://{}/api/{}/sensors/{flag}/state", + "http://{}/api/{}/sensors/{flag_id}/state", self.addr, self.login ); let res = reqwest::Client::new() @@ -66,61 +63,35 @@ impl HueBridge { Ok(res) => { let status = res.status(); if !status.is_success() { - warn!(flag, "Status code is not success: {status}"); + warn!(flag_id, "Status code is not success: {status}"); } } Err(err) => { - error!(flag, "Error: {err}"); + error!(flag_id, "Error: {err}"); } } } } -pub fn start( - mut presence_rx: presence::Receiver, - mut light_sensor_rx: light_sensor::Receiver, - config: HueBridgeConfig, -) { - let mut hue_bridge = HueBridge::new(config); +pub fn start(config: HueBridgeConfig, event_channel: &EventChannel) { + let hue_bridge = HueBridge::new(config); + + let mut rx = event_channel.get_rx(); tokio::spawn(async move { loop { - tokio::select! { - res = presence_rx.changed() => { - if res.is_err() { - break; - } - - let presence = *presence_rx.borrow(); - hue_bridge.on_presence(presence).await; + match rx.recv().await { + Ok(Event::Presence(presence)) => { + trace!("Bridging presence to hue"); + hue_bridge.set_flag(Flag::Presence, presence).await; } - res = light_sensor_rx.changed() => { - if res.is_err() { - break; - } - - let darkness = *light_sensor_rx.borrow(); - hue_bridge.on_darkness(darkness).await; + Ok(Event::Darkness(dark)) => { + trace!("Bridging darkness to hue"); + hue_bridge.set_flag(Flag::Darkness, dark).await; } + Ok(_) => {} + Err(_) => todo!("Handle errors with the event channel properly"), } } - - unreachable!("Did not expect this"); }); } - -#[async_trait] -impl OnPresence for HueBridge { - async fn on_presence(&mut self, presence: bool) { - trace!("Bridging presence to hue"); - self.set_flag(Flag::Presence, presence).await; - } -} - -#[async_trait] -impl OnDarkness for HueBridge { - async fn on_darkness(&mut self, dark: bool) { - trace!("Bridging darkness to hue"); - self.set_flag(Flag::Darkness, dark).await; - } -} diff --git a/src/lib.rs b/src/lib.rs index 6f49045..0aec4df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod config; pub mod debug_bridge; pub mod devices; pub mod error; +pub mod event; pub mod hue_bridge; pub mod light_sensor; pub mod mqtt; diff --git a/src/light_sensor.rs b/src/light_sensor.rs index d581199..a02e568 100644 --- a/src/light_sensor.rs +++ b/src/light_sensor.rs @@ -1,13 +1,13 @@ use async_trait::async_trait; use rumqttc::{matches, AsyncClient}; use serde::Deserialize; -use tokio::sync::watch; -use tracing::{debug, error, trace}; +use tracing::{debug, error, trace, warn}; use crate::{ config::MqttDeviceConfig, error::LightSensorError, - mqtt::{self, BrightnessMessage, OnMqtt}, + event::{Event, EventChannel}, + mqtt::BrightnessMessage, }; #[async_trait] @@ -15,9 +15,6 @@ pub trait OnDarkness: Sync + Send + 'static { async fn on_darkness(&mut self, dark: bool); } -pub type Receiver = watch::Receiver; -type Sender = watch::Sender; - #[derive(Debug, Clone, Deserialize)] pub struct LightSensorConfig { #[serde(flatten)] @@ -26,91 +23,72 @@ pub struct LightSensorConfig { pub max: isize, } -#[derive(Debug)] -struct LightSensor { - mqtt: MqttDeviceConfig, - min: isize, - max: isize, - tx: Sender, - is_dark: Receiver, -} - -impl LightSensor { - fn new(mqtt: MqttDeviceConfig, min: isize, max: isize) -> Self { - let (tx, is_dark) = watch::channel(false); - Self { - is_dark, - mqtt, - min, - max, - tx, - } - } -} +const DEFAULT: bool = false; pub async fn start( - mut mqtt_rx: mqtt::Receiver, config: LightSensorConfig, + event_channel: &EventChannel, client: AsyncClient, -) -> Result { +) -> Result<(), LightSensorError> { + // Subscrive to the mqtt topic client .subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce) .await?; - let mut light_sensor = LightSensor::new(config.mqtt, config.min, config.max); - let is_dark = light_sensor.is_dark.clone(); + // Create the channels + let mut rx = event_channel.get_rx(); + let tx = event_channel.get_tx(); + + // Setup default value, this is needed for hysteresis + let mut current_is_dark = DEFAULT; tokio::spawn(async move { loop { - // TODO: Handle errors, warn if lagging - if let Ok(message) = mqtt_rx.recv().await { - light_sensor.on_mqtt(&message).await; + match rx.recv().await { + Ok(Event::MqttMessage(message)) => { + if !matches(&message.topic, &config.mqtt.topic) { + continue; + } + + let illuminance = match BrightnessMessage::try_from(message) { + Ok(state) => state.illuminance(), + Err(err) => { + error!("Failed to parse message: {err}"); + continue; + } + }; + + debug!("Illuminance: {illuminance}"); + let is_dark = if illuminance <= config.min { + trace!("It is dark"); + true + } else if illuminance >= config.max { + trace!("It is light"); + false + } else { + trace!( + "In between min ({}) and max ({}) value, keeping current state: {}", + config.min, + config.max, + current_is_dark + ); + current_is_dark + }; + + if is_dark != current_is_dark { + debug!("Dark state has changed: {is_dark}"); + current_is_dark = is_dark; + + if tx.send(Event::Darkness(is_dark)).is_err() { + warn!("There are no receivers on the event channel"); + } + } + } + Ok(_) => {} + Err(_) => todo!("Handle errors with the event channel properly"), } } }); - Ok(is_dark) -} - -#[async_trait] -impl OnMqtt for LightSensor { - fn topics(&self) -> Vec<&str> { - vec![&self.mqtt.topic] - } - - async fn on_mqtt(&mut self, message: &rumqttc::Publish) { - if !matches(&message.topic, &self.mqtt.topic) { - return; - } - - let illuminance = match BrightnessMessage::try_from(message) { - Ok(state) => state.illuminance(), - Err(err) => { - error!("Failed to parse message: {err}"); - return; - } - }; - - debug!("Illuminance: {illuminance}"); - let is_dark = if illuminance <= self.min { - trace!("It is dark"); - true - } else if illuminance >= self.max { - trace!("It is light"); - false - } else { - trace!( - "In between min ({}) and max ({}) value, keeping current state: {}", - self.min, - self.max, - *self.is_dark.borrow() - ); - *self.is_dark.borrow() - }; - - if is_dark != *self.is_dark.borrow() { - debug!("Dark state has changed: {is_dark}"); - self.tx.send(is_dark).ok(); - } - } + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 0203db6..d4fad56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,9 +10,8 @@ use automation::{ config::Config, debug_bridge, devices, error::ApiError, - hue_bridge, light_sensor, - mqtt::Mqtt, - ntfy, presence, + event::EventChannel, + hue_bridge, light_sensor, mqtt, ntfy, presence, }; use dotenvy::dotenv; use futures::future::join_all; @@ -56,52 +55,39 @@ async fn app() -> anyhow::Result<()> { std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.toml".to_owned()); let config = Config::parse_file(&config_filename)?; - // Create a mqtt client and wrap the eventloop + let event_channel = EventChannel::new(); + + // Create a mqtt client let (client, eventloop) = AsyncClient::new(config.mqtt.clone(), 10); - let mqtt = Mqtt::new(eventloop); - let presence = - presence::start(config.presence.clone(), mqtt.subscribe(), client.clone()).await?; - let light_sensor = light_sensor::start( - mqtt.subscribe(), - config.light_sensor.clone(), - client.clone(), - ) - .await?; + + let presence_topic = config.presence.mqtt.topic.to_owned(); + presence::start(config.presence, &event_channel, client.clone()).await?; + light_sensor::start(config.light_sensor, &event_channel, client.clone()).await?; // Start the ntfy service if it is configured - if let Some(config) = &config.ntfy { - ntfy::start(presence.clone(), config); + if let Some(config) = config.ntfy { + ntfy::start(config, &event_channel); } // Start the hue bridge if it is configured if let Some(config) = config.hue_bridge { - hue_bridge::start(presence.clone(), light_sensor.clone(), config); + hue_bridge::start(config, &event_channel); } // Start the debug bridge if it is configured if let Some(config) = config.debug_bridge { - debug_bridge::start( - presence.clone(), - light_sensor.clone(), - config, - client.clone(), - ); + debug_bridge::start(config, &event_channel, client.clone()); } // Setup the device handler - let device_handler = devices::start( - mqtt.subscribe(), - presence.clone(), - light_sensor.clone(), - client.clone(), - ); + let device_handler = devices::start(&event_channel, client.clone()); // Create all the devices specified in the config let devices = config .devices .into_iter() .map(|(identifier, device_config)| { - device_config.create(&identifier, client.clone(), &config.presence.topic) + device_config.create(&identifier, client.clone(), &presence_topic) }) .collect::, _>>()?; @@ -118,9 +104,9 @@ async fn app() -> anyhow::Result<()> { .into_iter() .collect::>()?; - // Actually start listening for mqtt message, - // we wait until all the setup is done, as otherwise we might miss some messages - mqtt.start(); + // Wrap the mqtt eventloop and start listening for message + // NOTE: We wait until all the setup is done, as otherwise we might miss some messages + mqtt::start(eventloop, &event_channel); // Create google home fullfillment route let fullfillment = Router::new().route( diff --git a/src/mqtt.rs b/src/mqtt.rs index d9f50b0..aaf0db6 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -7,7 +7,8 @@ use thiserror::Error; use tracing::{debug, warn}; use rumqttc::{Event, EventLoop, Incoming, Publish}; -use tokio::sync::broadcast; + +use crate::event::{self, EventChannel}; #[async_trait] #[impl_cast::device_trait] @@ -16,49 +17,32 @@ pub trait OnMqtt { async fn on_mqtt(&mut self, message: &Publish); } -pub type Receiver = broadcast::Receiver; -type Sender = broadcast::Sender; - -pub struct Mqtt { - tx: Sender, - eventloop: EventLoop, -} - #[derive(Debug, Error)] pub enum ParseError { #[error("Invalid message payload received: {0:?}")] InvalidPayload(Bytes), } -impl Mqtt { - pub fn new(eventloop: EventLoop) -> Self { - let (tx, _rx) = broadcast::channel(100); - Self { tx, eventloop } - } +pub fn start(mut eventloop: EventLoop, event_channel: &EventChannel) { + let tx = event_channel.get_tx(); - pub fn subscribe(&self) -> Receiver { - self.tx.subscribe() - } - - pub fn start(mut self) { - tokio::spawn(async move { - debug!("Listening for MQTT events"); - loop { - let notification = self.eventloop.poll().await; - match notification { - Ok(Event::Incoming(Incoming::Publish(p))) => { - self.tx.send(p).ok(); - } - Ok(..) => continue, - Err(err) => { - // Something has gone wrong - // We stay in the loop as that will attempt to reconnect - warn!("{}", err); - } + tokio::spawn(async move { + debug!("Listening for MQTT events"); + loop { + let notification = eventloop.poll().await; + match notification { + Ok(Event::Incoming(Incoming::Publish(p))) => { + tx.send(event::Event::MqttMessage(p)).ok(); + } + Ok(..) => continue, + Err(err) => { + // Something has gone wrong + // We stay in the loop as that will attempt to reconnect + warn!("{}", err); } } - }); - } + } + }); } #[derive(Debug, Serialize, Deserialize)] @@ -161,10 +145,10 @@ impl PresenceMessage { } } -impl TryFrom<&Publish> for PresenceMessage { +impl TryFrom for PresenceMessage { type Error = ParseError; - fn try_from(message: &Publish) -> Result { + fn try_from(message: Publish) -> Result { serde_json::from_slice(&message.payload) .or(Err(ParseError::InvalidPayload(message.payload.clone()))) } @@ -181,10 +165,10 @@ impl BrightnessMessage { } } -impl TryFrom<&Publish> for BrightnessMessage { +impl TryFrom for BrightnessMessage { type Error = ParseError; - fn try_from(message: &Publish) -> Result { + fn try_from(message: Publish) -> Result { serde_json::from_slice(&message.payload) .or(Err(ParseError::InvalidPayload(message.payload.clone()))) } diff --git a/src/ntfy.rs b/src/ntfy.rs index f74154f..87ff7c0 100644 --- a/src/ntfy.rs +++ b/src/ntfy.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use async_trait::async_trait; use serde::Serialize; use serde_repr::*; use tokio::sync::mpsc; @@ -8,7 +7,7 @@ use tracing::{debug, error, warn}; use crate::{ config::NtfyConfig, - presence::{self, OnPresence}, + event::{Event, EventChannel}, }; pub type Sender = mpsc::Sender; @@ -17,10 +16,9 @@ pub type Receiver = mpsc::Receiver; struct Ntfy { base_url: String, topic: String, - tx: Sender, } -#[derive(Serialize_repr)] +#[derive(Serialize_repr, Clone, Copy)] #[repr(u8)] pub enum Priority { Min = 1, @@ -30,7 +28,7 @@ pub enum Priority { Max, } -#[derive(Serialize)] +#[derive(Serialize, Clone)] #[serde(rename_all = "snake_case", tag = "action")] pub enum ActionType { Broadcast { @@ -41,7 +39,7 @@ pub enum ActionType { // Http } -#[derive(Serialize)] +#[derive(Serialize, Clone)] pub struct Action { #[serde(flatten)] action: ActionType, @@ -56,7 +54,7 @@ struct NotificationFinal { inner: Notification, } -#[derive(Serialize)] +#[derive(Serialize, Clone)] pub struct Notification { #[serde(skip_serializing_if = "Option::is_none")] title: Option, @@ -121,11 +119,10 @@ impl Default for Notification { } impl Ntfy { - fn new(base_url: &str, topic: &str, tx: Sender) -> Self { + fn new(base_url: &str, topic: &str) -> Self { Self { base_url: base_url.to_owned(), topic: topic.to_owned(), - tx, } } @@ -151,52 +148,45 @@ impl Ntfy { } } -pub fn start(mut presence_rx: presence::Receiver, config: &NtfyConfig) -> Sender { - let (tx, mut rx) = mpsc::channel(10); +pub fn start(config: NtfyConfig, event_channel: &EventChannel) { + let mut rx = event_channel.get_rx(); + let tx = event_channel.get_tx(); - let mut ntfy = Ntfy::new(&config.url, &config.topic, tx.clone()); + let ntfy = Ntfy::new(&config.url, &config.topic); tokio::spawn(async move { loop { - tokio::select! { - Ok(_) = presence_rx.changed() => { - let presence = *presence_rx.borrow(); - ntfy.on_presence(presence).await; - }, - Some(notifcation) = rx.recv() => { - ntfy.send(notifcation).await; + match rx.recv().await { + Ok(Event::Presence(presence)) => { + // Setup extras for the broadcast + let extras = HashMap::from([ + ("cmd".into(), "presence".into()), + ("state".into(), if presence { "0" } else { "1" }.into()), + ]); + + // Create broadcast action + let action = Action { + action: ActionType::Broadcast { extras }, + label: if presence { "Set away" } else { "Set home" }.to_owned(), + clear: Some(true), + }; + + // Create the notification + let notification = Notification::new() + .set_title("Presence") + .set_message(if presence { "Home" } else { "Away" }) + .add_tag("house") + .add_action(action) + .set_priority(Priority::Low); + + if tx.send(Event::Ntfy(notification)).is_err() { + warn!("There are no receivers on the event channel"); + } } + Ok(Event::Ntfy(notification)) => ntfy.send(notification).await, + Ok(_) => {} + Err(_) => todo!("Handle errors with the event channel properly"), } } }); - - tx -} - -#[async_trait] -impl OnPresence for Ntfy { - async fn on_presence(&mut self, presence: bool) { - // Setup extras for the broadcast - let extras = HashMap::from([ - ("cmd".into(), "presence".into()), - ("state".into(), if presence { "0" } else { "1" }.into()), - ]); - - // Create broadcast action - let action = Action { - action: ActionType::Broadcast { extras }, - label: if presence { "Set away" } else { "Set home" }.to_owned(), - clear: Some(true), - }; - - // Create the notification - let notification = Notification::new() - .set_title("Presence") - .set_message(if presence { "Home" } else { "Away" }) - .add_tag("house") - .add_action(action) - .set_priority(Priority::Low); - - self.tx.send(notification).await.ok(); - } } diff --git a/src/presence.rs b/src/presence.rs index 64f45c8..a988415 100644 --- a/src/presence.rs +++ b/src/presence.rs @@ -2,13 +2,17 @@ use std::collections::HashMap; use async_trait::async_trait; use rumqttc::{has_wildcards, matches, AsyncClient}; -use tokio::sync::watch; -use tracing::{debug, error}; +use serde::Deserialize; +use tracing::{debug, warn}; use crate::{ config::MqttDeviceConfig, error::{MissingWildcard, PresenceError}, - mqtt::{self, OnMqtt, PresenceMessage}, + event::{ + Event::{self, MqttMessage}, + EventChannel, + }, + mqtt::PresenceMessage, }; #[async_trait] @@ -16,98 +20,79 @@ pub trait OnPresence: Sync + Send + 'static { async fn on_presence(&mut self, presence: bool); } -pub type Receiver = watch::Receiver; -type Sender = watch::Sender; - -#[derive(Debug)] -struct Presence { - devices: HashMap, - mqtt: MqttDeviceConfig, - tx: Sender, - overall_presence: Receiver, +#[derive(Debug, Deserialize)] +pub struct PresenceConfig { + #[serde(flatten)] + pub mqtt: MqttDeviceConfig, } -impl Presence { - fn build(mqtt: MqttDeviceConfig) -> Result { - if !has_wildcards(&mqtt.topic) { - return Err(MissingWildcard::new(&mqtt.topic)); - } - - let (tx, overall_presence) = watch::channel(false); - Ok(Self { - devices: HashMap::new(), - overall_presence, - mqtt, - tx, - }) - } -} +const DEFAULT: bool = false; pub async fn start( - mqtt: MqttDeviceConfig, - mut mqtt_rx: mqtt::Receiver, + config: PresenceConfig, + event_channel: &EventChannel, client: AsyncClient, -) -> Result { +) -> Result<(), PresenceError> { + if !has_wildcards(&config.mqtt.topic) { + return Err(MissingWildcard::new(&config.mqtt.topic).into()); + } + // Subscribe to the relevant topics on mqtt client - .subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce) + .subscribe(config.mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce) .await?; - let mut presence = Presence::build(mqtt)?; - let overall_presence = presence.overall_presence.clone(); + let mut rx = event_channel.get_rx(); + let tx = event_channel.get_tx(); + + let mut devices = HashMap::::new(); + let mut current_overall_presence = DEFAULT; tokio::spawn(async move { loop { // TODO: Handle errors, warn if lagging - if let Ok(message) = mqtt_rx.recv().await { - presence.on_mqtt(&message).await; + if let Ok(MqttMessage(message)) = rx.recv().await { + if !matches(&message.topic, &config.mqtt.topic) { + continue; + } + + let offset = config + .mqtt + .topic + .find('+') + .or(config.mqtt.topic.find('#')) + .expect("Presence::new fails if it does not contain wildcards"); + let device_name = message.topic[offset..].to_owned(); + + if message.payload.is_empty() { + // Remove the device from the map + debug!("State of device [{device_name}] has been removed"); + devices.remove(&device_name); + } else { + let present = match PresenceMessage::try_from(message) { + Ok(state) => state.present(), + Err(err) => { + warn!("Failed to parse message: {err}"); + continue; + } + }; + + debug!("State of device [{device_name}] has changed: {}", present); + devices.insert(device_name, present); + } + + let overall_presence = devices.iter().any(|(_, v)| *v); + if overall_presence != current_overall_presence { + debug!("Overall presence updated: {overall_presence}"); + current_overall_presence = overall_presence; + + if tx.send(Event::Presence(overall_presence)).is_err() { + warn!("There are no receivers on the event channel"); + } + } } } }); - Ok(overall_presence) -} - -#[async_trait] -impl OnMqtt for Presence { - fn topics(&self) -> Vec<&str> { - vec![&self.mqtt.topic] - } - - async fn on_mqtt(&mut self, message: &rumqttc::Publish) { - if !matches(&message.topic, &self.mqtt.topic) { - return; - } - - let offset = self - .mqtt - .topic - .find('+') - .or(self.mqtt.topic.find('#')) - .expect("Presence::new fails if it does not contain wildcards"); - let device_name = &message.topic[offset..]; - - if message.payload.is_empty() { - // Remove the device from the map - debug!("State of device [{device_name}] has been removed"); - self.devices.remove(device_name); - } else { - let present = match PresenceMessage::try_from(message) { - Ok(state) => state.present(), - Err(err) => { - error!("Failed to parse message: {err}"); - return; - } - }; - - debug!("State of device [{device_name}] has changed: {}", present); - self.devices.insert(device_name.to_owned(), present); - } - - let overall_presence = self.devices.iter().any(|(_, v)| *v); - if overall_presence != *self.overall_presence.borrow() { - debug!("Overall presence updated: {overall_presence}"); - self.tx.send(overall_presence).ok(); - } - } + Ok(()) }