diff --git a/config/zeus.dev.toml b/config/zeus.dev.toml index adaff86..965f8ed 100644 --- a/config/zeus.dev.toml +++ b/config/zeus.dev.toml @@ -7,6 +7,9 @@ username="mqtt" port=7878 username="Dreaded_X" +[presence] +topic = "automation/presence" + [devices.kitchen_kettle] type = "IkeaOutlet" info = { name = "Kettle", room = "Kitchen" } diff --git a/src/config.rs b/src/config.rs index 1711604..179075d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,6 +10,7 @@ use crate::devices::{DeviceBox, IkeaOutlet, WakeOnLAN}; pub struct Config { pub mqtt: MQTTConfig, pub fullfillment: FullfillmentConfig, + pub presence: MqttDeviceConfig, #[serde(default)] pub devices: HashMap } diff --git a/src/devices.rs b/src/devices.rs index d09f660..1a34879 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -7,14 +7,16 @@ pub use self::wake_on_lan::WakeOnLAN; use std::collections::HashMap; use google_home::{GoogleHomeDevice, traits::OnOff}; +use log::trace; -use crate::mqtt::Listener; +use crate::{mqtt::Listener, presence::OnPresence}; impl_cast::impl_cast!(Device, Listener); +impl_cast::impl_cast!(Device, OnPresence); impl_cast::impl_cast!(Device, GoogleHomeDevice); impl_cast::impl_cast!(Device, OnOff); -pub trait Device: AsGoogleHomeDevice + AsListener + AsOnOff { +pub trait Device: AsGoogleHomeDevice + AsListener + AsOnPresence + AsOnOff { fn get_id(&self) -> String; } @@ -53,6 +55,7 @@ impl Devices { } get_cast!(Listener); + get_cast!(OnPresence); get_cast!(GoogleHomeDevice); get_cast!(OnOff); @@ -71,3 +74,13 @@ impl Listener for Devices { }) } } + +impl OnPresence for Devices { + fn on_presence(&mut self, presence: bool) { + trace!("OnPresence for devices"); + self.as_on_presences().iter_mut().for_each(|(name, device)| { + trace!("OnPresence: {name}"); + device.on_presence(presence); + }) + } +} diff --git a/src/devices/ikea_outlet.rs b/src/devices/ikea_outlet.rs index 3fc9cee..edef75e 100644 --- a/src/devices/ikea_outlet.rs +++ b/src/devices/ikea_outlet.rs @@ -10,6 +10,7 @@ use tokio::task::JoinHandle; use crate::config::{KettleConfig, InfoConfig, MqttDeviceConfig}; use crate::devices::Device; use crate::mqtt::Listener; +use crate::presence::OnPresence; pub struct IkeaOutlet { identifier: String, @@ -63,12 +64,8 @@ impl TryFrom<&Publish> for StateMessage { type Error = anyhow::Error; fn try_from(message: &Publish) -> Result { - match serde_json::from_slice(&message.payload) { - Ok(message) => Ok(message), - Err(..) => { - Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload)) - } - } + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) } } @@ -134,6 +131,19 @@ impl Listener for IkeaOutlet { } } +impl OnPresence for IkeaOutlet { + fn on_presence(&mut self, presence: bool) { + // Turn off the outlet when we leave the house + if !presence { + let client = self.client.clone(); + let topic = self.mqtt.topic.clone(); + tokio::spawn(async move { + set_on(client, topic, false).await; + }); + } + } +} + impl GoogleHomeDevice for IkeaOutlet { fn get_device_type(&self) -> Type { if self.kettle.is_some() { diff --git a/src/devices/wake_on_lan.rs b/src/devices/wake_on_lan.rs index 813a18a..7889bb4 100644 --- a/src/devices/wake_on_lan.rs +++ b/src/devices/wake_on_lan.rs @@ -16,11 +16,10 @@ pub struct WakeOnLAN { impl WakeOnLAN { pub fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, mac_address: String, client: AsyncClient) -> Self { - let c = client.clone(); let t = mqtt.topic.clone(); // @TODO Handle potential errors here tokio::spawn(async move { - c.subscribe(t, rumqttc::QoS::AtLeastOnce).await.unwrap(); + client.subscribe(t, rumqttc::QoS::AtLeastOnce).await.unwrap(); }); Self { identifier, info, mqtt, mac_address } @@ -42,12 +41,8 @@ impl TryFrom<&Publish> for StateMessage { type Error = anyhow::Error; fn try_from(message: &Publish) -> Result { - match serde_json::from_slice(&message.payload) { - Ok(message) => Ok(message), - Err(..) => { - Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload)) - } - } + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) } } diff --git a/src/lib.rs b/src/lib.rs index b97b8bd..089d821 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,3 +2,4 @@ pub mod devices; pub mod mqtt; pub mod config; +pub mod presence; diff --git a/src/main.rs b/src/main.rs index 3a77041..64a0da7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use std::{time::Duration, sync::{Arc, RwLock}, process, net::SocketAddr}; use axum::{Router, Json, routing::post, http::StatusCode}; -use automation::config::Config; +use automation::{config::Config, presence::Presence}; use dotenv::dotenv; use rumqttc::{MqttOptions, Transport, AsyncClient}; use env_logger::Builder; @@ -18,7 +18,7 @@ async fn main() { // Setup logger Builder::new() - .filter_module("automation", LevelFilter::Info) + .filter_module("automation", LevelFilter::Trace) .parse_default_env() .init(); @@ -43,7 +43,14 @@ async fn main() { // Create a notifier and start it in a seperate task let (client, eventloop) = AsyncClient::new(mqttoptions, 10); let mut notifier = Notifier::new(eventloop); + notifier.add_listener(Arc::downgrade(&devices)); + + let mut presence = Presence::new(config.presence, client.clone()); + presence.add_listener(Arc::downgrade(&devices)); + let presence = Arc::new(RwLock::new(presence)); + notifier.add_listener(Arc::downgrade(&presence)); + notifier.start(); // Create devices based on config diff --git a/src/mqtt.rs b/src/mqtt.rs index 83301d7..7bafecd 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -21,10 +21,14 @@ impl Notifier { } fn notify(&mut self, message: Publish) { + trace!("Listener count: {}", self.listeners.len()); + self.listeners.retain(|listener| { if let Some(listener) = listener.upgrade() { listener.write().unwrap().notify(&message); return true; + } else { + trace!("Removing listener..."); } return false; diff --git a/src/presence.rs b/src/presence.rs new file mode 100644 index 0000000..a5a723c --- /dev/null +++ b/src/presence.rs @@ -0,0 +1,93 @@ +use std::{sync::{Weak, RwLock}, collections::HashMap}; + +use log::{debug, warn, trace}; +use rumqttc::{AsyncClient, Publish}; +use serde::{Serialize, Deserialize}; + +use crate::{mqtt::Listener, config::MqttDeviceConfig}; + +pub trait OnPresence { + fn on_presence(&mut self, presence: bool); +} + +pub struct Presence { + listeners: Vec>>, + devices: HashMap, + overall_presence: bool, + mqtt: MqttDeviceConfig, +} + +impl Presence { + pub fn new(mqtt: MqttDeviceConfig, client: AsyncClient) -> Self { + // @TODO Handle potential errors here + let topic = mqtt.topic.clone() + "/+"; + tokio::spawn(async move { + client.subscribe(topic, rumqttc::QoS::AtLeastOnce).await.unwrap(); + }); + + Self { listeners: Vec::new(), devices: HashMap::new(), overall_presence: false, mqtt } + } + + pub fn add_listener(&mut self, listener: Weak>) { + self.listeners.push(listener); + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct StateMessage { + state: bool +} + +impl TryFrom<&Publish> for StateMessage { + type Error = anyhow::Error; + + fn try_from(message: &Publish) -> Result { + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) + } +} + +impl Listener for Presence { + fn notify(&mut self, message: &rumqttc::Publish) { + if message.topic.starts_with(&(self.mqtt.topic.clone() + "/")) { + let device_name = message.topic.rsplit_once("/").unwrap().1; + + if message.payload.len() == 0 { + // Remove the device from the map + debug!("State of device [{device_name}] has been removed"); + self.devices.remove(device_name); + return; + } else { + let state = match StateMessage::try_from(message) { + Ok(state) => state, + Err(err) => { + warn!("Failed to parse message: {err}"); + return; + } + }; + + debug!("State of device [{device_name}] has changed: {}", state.state); + self.devices.insert(device_name.to_owned(), state.state); + } + + let overall_presence = self.devices.iter().any(|(_, v)| *v); + if overall_presence != self.overall_presence { + debug!("Overall presence updated: {overall_presence}"); + self.overall_presence = overall_presence; + + trace!("Listener count: {}", self.listeners.len()); + + self.listeners.retain(|listener| { + if let Some(listener) = listener.upgrade() { + listener.write().unwrap().on_presence(overall_presence); + return true; + } else { + trace!("Removing listener..."); + } + + return false; + }) + } + } + } +}