diff --git a/src/device_manager.rs b/src/device_manager.rs new file mode 100644 index 0000000..e0db3a7 --- /dev/null +++ b/src/device_manager.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use futures::future::join_all; +use rumqttc::{matches, AsyncClient, QoS}; +use tokio::sync::{RwLock, RwLockReadGuard}; +use tracing::{debug, error, instrument, trace}; + +use crate::{ + devices::{As, Device}, + event::OnDarkness, + event::OnNotification, + event::OnPresence, + event::{Event, EventChannel, OnMqtt}, +}; + +pub type DeviceMap = HashMap>>>; + +#[derive(Debug, Clone)] +pub struct DeviceManager { + devices: Arc>, + client: AsyncClient, +} + +impl DeviceManager { + pub fn new(client: AsyncClient) -> Self { + Self { + devices: Arc::new(RwLock::new(HashMap::new())), + client, + } + } + + pub fn start(&self) -> EventChannel { + let (event_channel, mut event_rx) = EventChannel::new(); + + let devices = self.clone(); + tokio::spawn(async move { + loop { + if let Some(event) = event_rx.recv().await { + devices.handle_event(event).await; + } else { + todo!("Handle errors with the event channel properly") + } + } + }); + + event_channel + } + + pub async fn add(&self, device: Box) { + let id = device.get_id().to_owned(); + + debug!(id, "Adding device"); + + // If the device listens to mqtt, subscribe to the topics + if let Some(device) = As::::cast(device.as_ref()) { + for topic in device.topics() { + trace!(id, topic, "Subscribing to topic"); + if let Err(err) = self.client.subscribe(topic, QoS::AtLeastOnce).await { + // NOTE: Pretty sure that this can only happen if the mqtt client if no longer + // running + error!(id, topic, "Failed to subscribe to topic: {err}"); + } + } + } + + // Wrap the device + let device = Arc::new(RwLock::new(device)); + + self.devices.write().await.insert(id, device); + } + + pub async fn devices(&self) -> RwLockReadGuard { + self.devices.read().await + } + + #[instrument(skip(self))] + async fn handle_event(&self, event: Event) { + match event { + Event::MqttMessage(message) => { + let devices = self.devices.read().await; + let iter = devices.iter().map(|(id, device)| { + let message = message.clone(); + async move { + let mut device = device.write().await; + let device = device.as_mut(); + if let Some(device) = As::::cast_mut(device) { + let subscribed = device + .topics() + .iter() + .any(|topic| matches(&message.topic, topic)); + + if subscribed { + trace!(id, "Handling"); + device.on_mqtt(message).await; + } + } + } + }); + + join_all(iter).await; + } + Event::Darkness(dark) => { + let devices = self.devices.read().await; + let iter = devices.iter().map(|(id, device)| async move { + let mut device = device.write().await; + let device = device.as_mut(); + if let Some(device) = As::::cast_mut(device) { + trace!(id, "Handling"); + device.on_darkness(dark).await; + } + }); + + join_all(iter).await; + } + Event::Presence(presence) => { + let devices = self.devices.read().await; + let iter = devices.iter().map(|(id, device)| async move { + let mut device = device.write().await; + let device = device.as_mut(); + if let Some(device) = As::::cast_mut(device) { + trace!(id, "Handling"); + device.on_presence(presence).await; + } + }); + + join_all(iter).await; + } + Event::Ntfy(notification) => { + let devices = self.devices.read().await; + let iter = devices.iter().map(|(id, device)| { + let notification = notification.clone(); + async move { + let mut device = device.write().await; + let device = device.as_mut(); + if let Some(device) = As::::cast_mut(device) { + trace!(id, "Handling"); + device.on_notification(notification).await; + } + } + }); + + join_all(iter).await; + } + } + } +} diff --git a/src/devices.rs b/src/devices.rs deleted file mode 100644 index 961c6ee..0000000 --- a/src/devices.rs +++ /dev/null @@ -1,237 +0,0 @@ -mod audio_setup; -mod contact_sensor; -mod debug_bridge; -mod hue_bridge; -mod ikea_outlet; -mod kasa_outlet; -mod light_sensor; -mod ntfy; -mod presence; -mod wake_on_lan; - -pub use self::audio_setup::AudioSetup; -pub use self::contact_sensor::ContactSensor; -pub use self::debug_bridge::{DebugBridge, DebugBridgeConfig}; -pub use self::hue_bridge::{HueBridge, HueBridgeConfig}; -pub use self::ikea_outlet::IkeaOutlet; -pub use self::kasa_outlet::KasaOutlet; -pub use self::light_sensor::{LightSensor, LightSensorConfig}; -pub use self::ntfy::{Notification, Ntfy}; -pub use self::presence::{Presence, PresenceConfig, DEFAULT_PRESENCE}; -pub use self::wake_on_lan::WakeOnLAN; - -use std::collections::HashMap; -use std::sync::Arc; - -use futures::future::join_all; -use google_home::device::AsGoogleHomeDevice; -use google_home::{traits::OnOff, FullfillmentError}; -use rumqttc::{matches, AsyncClient, QoS}; -use thiserror::Error; -use tokio::sync::{mpsc, oneshot, RwLock}; -use tracing::{debug, error, instrument, trace}; - -use crate::{ - event::OnDarkness, - event::OnMqtt, - event::OnNotification, - event::OnPresence, - event::{Event, EventChannel}, -}; - -#[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + OnNotification + OnOff)] -pub trait Device: AsGoogleHomeDevice + std::fmt::Debug + Sync + Send { - fn get_id(&self) -> &str; -} - -pub type DeviceMap = HashMap>>>; - -// TODO: Add an inner type that we can wrap with Arc> to make this type a little bit nicer -// to work with -#[derive(Debug)] -struct Devices { - devices: DeviceMap, - client: AsyncClient, -} - -#[derive(Debug)] -pub enum Command { - Fullfillment { - tx: oneshot::Sender, - }, - AddDevice { - device: Box, - tx: oneshot::Sender<()>, - }, -} - -#[derive(Clone)] -pub struct DevicesHandle { - tx: mpsc::Sender, -} - -#[derive(Debug, Error)] -pub enum DevicesError { - #[error(transparent)] - FullfillmentError(#[from] FullfillmentError), - #[error(transparent)] - SendError(#[from] tokio::sync::mpsc::error::SendError), - #[error(transparent)] - RecvError(#[from] tokio::sync::oneshot::error::RecvError), -} - -impl DevicesHandle { - // TODO: Improve error type - pub async fn fullfillment(&self) -> Result { - let (tx, rx) = oneshot::channel(); - self.tx.send(Command::Fullfillment { tx }).await?; - Ok(rx.await?) - } - - pub async fn add_device(&self, device: Box) -> Result<(), DevicesError> { - let (tx, rx) = oneshot::channel(); - self.tx.send(Command::AddDevice { device, tx }).await?; - Ok(rx.await?) - } -} - -pub fn start(client: AsyncClient) -> (DevicesHandle, EventChannel) { - let mut devices = Devices { - devices: HashMap::new(), - client, - }; - - let (event_channel, mut event_rx) = EventChannel::new(); - let (tx, mut rx) = mpsc::channel(100); - - tokio::spawn(async move { - // TODO: Handle error better - loop { - tokio::select! { - event = event_rx.recv() => { - if event.is_none() { - todo!("Handle errors with the event channel properly") - } - devices.handle_event(event.unwrap()).await; - } - // TODO: Handle receiving None better, otherwise it might constantly run doing - // nothing - cmd = rx.recv() => { - if cmd.is_none() { - todo!("Handle errors with the cmd channel properly") - } - devices.handle_cmd(cmd.unwrap()).await; - } - } - } - }); - - (DevicesHandle { tx }, event_channel) -} - -impl Devices { - async fn handle_cmd(&mut self, cmd: Command) { - match cmd { - Command::Fullfillment { tx } => { - tx.send(self.devices.clone()).ok(); - } - Command::AddDevice { device, tx } => { - self.add_device(device).await; - - tx.send(()).ok(); - } - } - } - - async fn add_device(&mut self, device: Box) { - let id = device.get_id().to_owned(); - - let device = Arc::new(RwLock::new(device)); - { - let device = device.read().await; - - debug!(id, "Adding device"); - - // If the device listens to mqtt, subscribe to the topics - if let Some(device) = As::::cast(device.as_ref()) { - for topic in device.topics() { - trace!(id, topic, "Subscribing to topic"); - if let Err(err) = self.client.subscribe(topic, QoS::AtLeastOnce).await { - // NOTE: Pretty sure that this can only happen if the mqtt client if no longer - // running - error!(id, topic, "Failed to subscribe to topic: {err}"); - } - } - } - } - - self.devices.insert(id, device); - } - - #[instrument(skip(self))] - async fn handle_event(&mut self, event: Event) { - match event { - Event::MqttMessage(message) => { - let iter = self.devices.iter().map(|(id, device)| { - let message = message.clone(); - async move { - let mut device = device.write().await; - let device = device.as_mut(); - if let Some(device) = As::::cast_mut(device) { - let subscribed = device - .topics() - .iter() - .any(|topic| matches(&message.topic, topic)); - - if subscribed { - trace!(id, "Handling"); - device.on_mqtt(message).await; - } - } - } - }); - - join_all(iter).await; - } - Event::Darkness(dark) => { - let iter = self.devices.iter().map(|(id, device)| async move { - let mut device = device.write().await; - let device = device.as_mut(); - if let Some(device) = As::::cast_mut(device) { - trace!(id, "Handling"); - device.on_darkness(dark).await; - } - }); - - join_all(iter).await; - } - Event::Presence(presence) => { - let iter = self.devices.iter().map(|(id, device)| async move { - let mut device = device.write().await; - let device = device.as_mut(); - if let Some(device) = As::::cast_mut(device) { - trace!(id, "Handling"); - device.on_presence(presence).await; - } - }); - - join_all(iter).await; - } - Event::Ntfy(notification) => { - let iter = self.devices.iter().map(|(id, device)| { - let notification = notification.clone(); - async move { - let mut device = device.write().await; - let device = device.as_mut(); - if let Some(device) = As::::cast_mut(device) { - trace!(id, "Handling"); - device.on_notification(notification).await; - } - } - }); - - join_all(iter).await; - } - } - } -} diff --git a/src/devices/mod.rs b/src/devices/mod.rs new file mode 100644 index 0000000..e146106 --- /dev/null +++ b/src/devices/mod.rs @@ -0,0 +1,30 @@ +mod audio_setup; +mod contact_sensor; +mod debug_bridge; +mod hue_bridge; +mod ikea_outlet; +mod kasa_outlet; +mod light_sensor; +mod ntfy; +mod presence; +mod wake_on_lan; + +pub use self::audio_setup::AudioSetup; +pub use self::contact_sensor::ContactSensor; +pub use self::debug_bridge::{DebugBridge, DebugBridgeConfig}; +pub use self::hue_bridge::{HueBridge, HueBridgeConfig}; +pub use self::ikea_outlet::IkeaOutlet; +pub use self::kasa_outlet::KasaOutlet; +pub use self::light_sensor::{LightSensor, LightSensorConfig}; +pub use self::ntfy::{Notification, Ntfy}; +pub use self::presence::{Presence, PresenceConfig, DEFAULT_PRESENCE}; +pub use self::wake_on_lan::WakeOnLAN; + +use google_home::{device::AsGoogleHomeDevice, traits::OnOff}; + +use crate::{event::OnDarkness, event::OnMqtt, event::OnNotification, event::OnPresence}; + +#[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + OnNotification + OnOff)] +pub trait Device: AsGoogleHomeDevice + std::fmt::Debug + Sync + Send { + fn get_id(&self) -> &str; +} diff --git a/src/lib.rs b/src/lib.rs index e4a5239..d4a3720 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![feature(specialization)] pub mod auth; pub mod config; +pub mod device_manager; pub mod devices; pub mod error; pub mod event; diff --git a/src/main.rs b/src/main.rs index f6f9266..e948ea3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,20 +4,18 @@ use std::process; use axum::{ extract::FromRef, http::StatusCode, response::IntoResponse, routing::post, Json, Router, }; +use dotenvy::dotenv; +use rumqttc::AsyncClient; +use tracing::{debug, error, info}; use automation::{ auth::{OpenIDConfig, User}, config::Config, - devices, + device_manager::DeviceManager, devices::{DebugBridge, HueBridge, LightSensor, Ntfy, Presence}, error::ApiError, mqtt, }; -use dotenvy::dotenv; -use futures::future::join_all; -use rumqttc::AsyncClient; -use tracing::{debug, error, info}; - use google_home::{GoogleHome, Request}; #[derive(Clone)] @@ -59,65 +57,47 @@ async fn app() -> anyhow::Result<()> { let (client, eventloop) = AsyncClient::new(config.mqtt.clone(), 10); // Setup the device handler - let (device_handler, event_channel) = devices::start(client.clone()); + let device_manager = DeviceManager::new(client.clone()); + let event_channel = device_manager.start(); // Create all the devices specified in the config - let mut devices = config - .devices - .into_iter() - .map(|(identifier, device_config)| { - device_config.create( - &identifier, - &event_channel, - &client, - &config.presence.mqtt.topic, - ) - }) - .collect::, _>>()?; + for (id, device_config) in config.devices { + let device = + device_config.create(&id, &event_channel, &client, &config.presence.mqtt.topic)?; + + device_manager.add(device).await; + } // Create and add the light sensor { let light_sensor = LightSensor::new(config.light_sensor, &event_channel); - devices.push(Box::new(light_sensor)); + device_manager.add(Box::new(light_sensor)).await; } // Create and add the presence system { let presence = Presence::new(config.presence, &event_channel); - devices.push(Box::new(presence)); + device_manager.add(Box::new(presence)).await; } // If configured, create and add the hue bridge if let Some(config) = config.hue_bridge { let hue_bridge = HueBridge::new(config); - devices.push(Box::new(hue_bridge)); + device_manager.add(Box::new(hue_bridge)).await; } // Start the debug bridge if it is configured if let Some(config) = config.debug_bridge { let debug_bridge = DebugBridge::new(config, &client)?; - devices.push(Box::new(debug_bridge)); + device_manager.add(Box::new(debug_bridge)).await; } // Start the ntfy service if it is configured if let Some(config) = config.ntfy { let ntfy = Ntfy::new(config, &event_channel); - devices.push(Box::new(ntfy)); + device_manager.add(Box::new(ntfy)).await; } - // Can even add some more devices here - // devices.push(device) - - // Register all the devices to the device_handler - join_all( - devices - .into_iter() - .map(|device| async { device_handler.add_device(device).await }), - ) - .await - .into_iter() - .collect::>()?; - // Wrap the mqtt eventloop and start listening for message // NOTE: We wait until all the setup is done, as otherwise we might miss some messages mqtt::start(eventloop, &event_channel); @@ -128,14 +108,9 @@ async fn app() -> anyhow::Result<()> { post(async move |user: User, Json(payload): Json| { debug!(username = user.preferred_username, "{payload:#?}"); let gc = GoogleHome::new(&user.preferred_username); - let result = match device_handler.fullfillment().await { - Ok(devices) => match gc.handle_request(payload, &devices).await { - Ok(result) => result, - Err(err) => { - return ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, err.into()) - .into_response() - } - }, + let devices = device_manager.devices().await; + let result = match gc.handle_request(payload, &devices).await { + Ok(result) => result, Err(err) => { return ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, err.into()) .into_response()