diff --git a/Cargo.lock b/Cargo.lock index 2a2c450..bb09b53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,6 +100,7 @@ dependencies = [ "bytes", "console-subscriber", "dotenvy", + "dyn-clone", "enum_dispatch", "eui48", "futures", @@ -428,6 +429,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "either" version = "1.9.0" diff --git a/Cargo.toml b/Cargo.toml index fa901b1..033babd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ once_cell = "1.19.0" hostname = "0.4.0" tokio-util = { version = "0.7.11", features = ["full"] } uuid = "1.8.0" +dyn-clone = "1.0.17" [patch.crates-io] wakey = { git = "https://git.huizinga.dev/Dreaded_X/wakey" } diff --git a/google_home/google_home/src/fulfillment.rs b/google_home/google_home/src/fulfillment.rs index e9746db..5953057 100644 --- a/google_home/google_home/src/fulfillment.rs +++ b/google_home/google_home/src/fulfillment.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use automation_cast::Cast; use futures::future::{join_all, OptionFuture}; use thiserror::Error; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use crate::errors::{DeviceError, ErrorCode}; use crate::request::{self, Intent, Request}; @@ -33,7 +33,7 @@ impl GoogleHome { pub async fn handle_request + ?Sized + 'static>( &self, request: Request, - devices: &HashMap>>>, + devices: &HashMap>, ) -> Result { // TODO: What do we do if we actually get more then one thing in the input array, right now // we only respond to the first thing @@ -61,11 +61,11 @@ impl GoogleHome { async fn sync + ?Sized + 'static>( &self, - devices: &HashMap>>>, + devices: &HashMap>, ) -> sync::Payload { let mut resp_payload = sync::Payload::new(&self.user_id); let f = devices.iter().map(|(_, device)| async move { - if let Some(device) = device.read().await.as_ref().cast() { + if let Some(device) = device.as_ref().cast() { Some(Device::sync(device).await) } else { None @@ -79,7 +79,7 @@ impl GoogleHome { async fn query + ?Sized + 'static>( &self, payload: request::query::Payload, - devices: &HashMap>>>, + devices: &HashMap>, ) -> query::Payload { let mut resp_payload = query::Payload::new(); let f = payload @@ -89,7 +89,7 @@ impl GoogleHome { .map(|id| async move { // NOTE: Requires let_chains feature let device = if let Some(device) = devices.get(id.as_str()) - && let Some(device) = device.read().await.as_ref().cast() + && let Some(device) = device.as_ref().cast() { Device::query(device).await } else { @@ -111,7 +111,7 @@ impl GoogleHome { async fn execute + ?Sized + 'static>( &self, payload: request::execute::Payload, - devices: &HashMap>>>, + devices: &HashMap>, ) -> execute::Payload { let resp_payload = Arc::new(Mutex::new(response::execute::Payload::new())); @@ -138,7 +138,7 @@ impl GoogleHome { let execution = command.execution.clone(); async move { if let Some(device) = devices.get(id.as_str()) - && let Some(device) = device.write().await.as_ref().cast() + && let Some(device) = device.as_ref().cast() { if !device.is_online() { return (id, Ok(false)); diff --git a/src/device_manager.rs b/src/device_manager.rs index 603e6ee..adec2fc 100644 --- a/src/device_manager.rs +++ b/src/device_manager.rs @@ -17,16 +17,16 @@ use crate::event::{Event, EventChannel, OnDarkness, OnMqtt, OnNotification, OnPr use crate::LUA; #[derive(Debug, FromLua, Clone)] -pub struct WrappedDevice(Arc>>); +pub struct WrappedDevice(Box); impl WrappedDevice { - pub fn new(device: Box) -> Self { - Self(Arc::new(RwLock::new(device))) + pub fn new(device: impl Device + 'static) -> Self { + Self(Box::new(device)) } } impl Deref for WrappedDevice { - type Target = Arc>>; + type Target = Box; fn deref(&self) -> &Self::Target { &self.0 @@ -38,17 +38,13 @@ impl DerefMut for WrappedDevice { &mut self.0 } } + impl mlua::UserData for WrappedDevice { fn add_methods<'lua, M: mlua::prelude::LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("get_id", |_lua, this, _: ()| async { - Ok(crate::devices::Device::get_id(this.0.read().await.as_ref())) - }); + methods.add_async_method("get_id", |_lua, this, _: ()| async { Ok(this.get_id()) }); methods.add_async_method("set_on", |_lua, this, on: bool| async move { - let device = this.0.write().await; - let device = device.as_ref(); - - if let Some(device) = device.cast() as Option<&dyn OnOff> { + if let Some(device) = this.cast() as Option<&dyn OnOff> { device.set_on(on).await.unwrap() }; @@ -57,7 +53,7 @@ impl mlua::UserData for WrappedDevice { } } -pub type DeviceMap = HashMap>>>; +pub type DeviceMap = HashMap>; #[derive(Clone)] pub struct DeviceManager { @@ -94,25 +90,20 @@ impl DeviceManager { device_manager } - pub async fn add(&self, device: &WrappedDevice) { - let id = device.read().await.get_id().to_owned(); + pub async fn add(&self, device: Box) { + let id = device.get_id(); debug!(id, "Adding device"); - self.devices.write().await.insert(id, device.0.clone()); + self.devices.write().await.insert(id, device); } pub fn event_channel(&self) -> EventChannel { self.event_channel.clone() } - pub async fn get(&self, name: &str) -> Option { - self.devices - .read() - .await - .get(name) - .cloned() - .map(WrappedDevice) + pub async fn get(&self, name: &str) -> Option> { + self.devices.read().await.get(name).cloned() } pub async fn devices(&self) -> RwLockReadGuard { @@ -127,8 +118,7 @@ impl DeviceManager { let iter = devices.iter().map(|(id, device)| { let message = message.clone(); async move { - let device = device.write().await; - let device: Option<&dyn OnMqtt> = device.as_ref().cast(); + let device: Option<&dyn OnMqtt> = device.cast(); if let Some(device) = device { // let subscribed = device // .topics() @@ -149,8 +139,7 @@ impl DeviceManager { Event::Darkness(dark) => { let devices = self.devices.read().await; let iter = devices.iter().map(|(id, device)| async move { - let device = device.write().await; - let device: Option<&dyn OnDarkness> = device.as_ref().cast(); + let device: Option<&dyn OnDarkness> = device.cast(); if let Some(device) = device { trace!(id, "Handling"); device.on_darkness(dark).await; @@ -163,8 +152,7 @@ impl DeviceManager { Event::Presence(presence) => { let devices = self.devices.read().await; let iter = devices.iter().map(|(id, device)| async move { - let device = device.write().await; - let device: Option<&dyn OnPresence> = device.as_ref().cast(); + let device: Option<&dyn OnPresence> = device.cast(); if let Some(device) = device { trace!(id, "Handling"); device.on_presence(presence).await; @@ -179,8 +167,7 @@ impl DeviceManager { let iter = devices.iter().map(|(id, device)| { let notification = notification.clone(); async move { - let device = device.write().await; - let device: Option<&dyn OnNotification> = device.as_ref().cast(); + let device: Option<&dyn OnNotification> = device.cast(); if let Some(device) = device { trace!(id, "Handling"); device.on_notification(notification).await; @@ -215,7 +202,7 @@ fn run_schedule( impl mlua::UserData for DeviceManager { fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_async_method("add", |_lua, this, device: WrappedDevice| async move { - this.add(&device).await; + this.add(device.0).await; Ok(()) }); diff --git a/src/devices/audio_setup.rs b/src/devices/audio_setup.rs index 6f3385d..4caebfe 100644 --- a/src/devices/audio_setup.rs +++ b/src/devices/audio_setup.rs @@ -38,15 +38,13 @@ impl LuaDeviceCreate for AudioSetup { trace!(id = config.identifier, "Setting up AudioSetup"); { - let mixer = config.mixer.read().await; - let mixer_id = mixer.get_id().to_owned(); - if (mixer.as_ref().cast() as Option<&dyn OnOff>).is_none() { + let mixer_id = config.mixer.get_id().to_owned(); + if (config.mixer.cast() as Option<&dyn OnOff>).is_none() { return Err(DeviceConfigError::MissingTrait(mixer_id, "OnOff".into())); } - let speakers = config.speakers.read().await; - let speakers_id = speakers.get_id().to_owned(); - if (speakers.as_ref().cast() as Option<&dyn OnOff>).is_none() { + let speakers_id = config.speakers.get_id().to_owned(); + if (config.speakers.cast() as Option<&dyn OnOff>).is_none() { return Err(DeviceConfigError::MissingTrait(speakers_id, "OnOff".into())); } } @@ -81,11 +79,9 @@ impl OnMqtt for AudioSetup { } }; - let mixer = self.config.mixer.write().await; - let speakers = self.config.speakers.write().await; if let (Some(mixer), Some(speakers)) = ( - mixer.as_ref().cast() as Option<&dyn OnOff>, - speakers.as_ref().cast() as Option<&dyn OnOff>, + self.config.mixer.cast() as Option<&dyn OnOff>, + self.config.speakers.cast() as Option<&dyn OnOff>, ) { match action { RemoteAction::On => { @@ -116,12 +112,9 @@ impl OnMqtt for AudioSetup { #[async_trait] impl OnPresence for AudioSetup { async fn on_presence(&self, presence: bool) { - let mixer = self.config.mixer.write().await; - let speakers = self.config.speakers.write().await; - if let (Some(mixer), Some(speakers)) = ( - mixer.as_ref().cast() as Option<&dyn OnOff>, - speakers.as_ref().cast() as Option<&dyn OnOff>, + self.config.mixer.cast() as Option<&dyn OnOff>, + self.config.speakers.cast() as Option<&dyn OnOff>, ) { // Turn off the audio setup when we leave the house if !presence { diff --git a/src/devices/contact_sensor.rs b/src/devices/contact_sensor.rs index 591b100..c0e82c3 100644 --- a/src/devices/contact_sensor.rs +++ b/src/devices/contact_sensor.rs @@ -85,14 +85,13 @@ impl LuaDeviceCreate for ContactSensor { if let Some(trigger) = &config.trigger { for device in &trigger.devices { { - let device = device.read().await; let id = device.get_id().to_owned(); - if (device.as_ref().cast() as Option<&dyn OnOff>).is_none() { + if (device.cast() as Option<&dyn OnOff>).is_none() { return Err(DeviceConfigError::MissingTrait(id, "OnOff".into())); } if trigger.timeout.is_none() - && (device.as_ref().cast() as Option<&dyn Timeout>).is_none() + && (device.cast() as Option<&dyn Timeout>).is_none() { return Err(DeviceConfigError::MissingTrait(id, "Timeout".into())); } @@ -160,8 +159,7 @@ impl OnMqtt for ContactSensor { .iter() .zip(self.state_mut().await.previous.iter_mut()) { - let light = light.write().await; - if let Some(light) = light.as_ref().cast() as Option<&dyn OnOff> { + if let Some(light) = light.cast() as Option<&dyn OnOff> { *previous = light.on().await.unwrap(); light.set_on(true).await.ok(); } @@ -172,15 +170,14 @@ impl OnMqtt for ContactSensor { .iter() .zip(self.state_mut().await.previous.iter()) { - let light = light.write().await; if !previous { // If the timeout is zero just turn the light off directly if trigger.timeout.is_none() - && let Some(light) = light.as_ref().cast() as Option<&dyn OnOff> + && let Some(light) = light.cast() as Option<&dyn OnOff> { light.set_on(false).await.ok(); } else if let Some(timeout) = trigger.timeout - && let Some(light) = light.as_ref().cast() as Option<&dyn Timeout> + && let Some(light) = light.cast() as Option<&dyn Timeout> { light.start_timeout(timeout).await.unwrap(); } diff --git a/src/devices/mod.rs b/src/devices/mod.rs index 6024dcb..d2f1e98 100644 --- a/src/devices/mod.rs +++ b/src/devices/mod.rs @@ -16,6 +16,7 @@ use std::fmt::Debug; use async_trait::async_trait; use automation_cast::Cast; +use dyn_clone::DynClone; use google_home::traits::OnOff; pub use self::air_filter::AirFilter; @@ -63,7 +64,7 @@ macro_rules! impl_device { .await .map_err(mlua::ExternalError::into_lua_err)?; - Ok(crate::device_manager::WrappedDevice::new(Box::new(device))) + Ok(crate::device_manager::WrappedDevice::new(device)) }); } } @@ -104,6 +105,7 @@ pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> { pub trait Device: Debug + + DynClone + Sync + Send + Cast @@ -117,3 +119,5 @@ pub trait Device: { fn get_id(&self) -> String; } + +dyn_clone::clone_trait_object!(Device);