From 1b8566e593669d934d8a76b1e1c83dfc441cfcd5 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Mon, 1 Sep 2025 03:18:56 +0200 Subject: [PATCH] refactor: Switch to async closures --- automation_devices/src/hue_bridge.rs | 2 +- automation_devices/src/ntfy.rs | 2 +- automation_lib/src/device_manager.rs | 33 ++--- automation_lib/src/helpers/timeout.rs | 6 +- automation_lib/src/lua/traits.rs | 25 ++-- automation_lib/src/mqtt.rs | 2 +- automation_macro/src/impl_device.rs | 4 +- google_home/google_home/src/fulfillment.rs | 159 ++++++++++----------- 8 files changed, 109 insertions(+), 124 deletions(-) diff --git a/automation_devices/src/hue_bridge.rs b/automation_devices/src/hue_bridge.rs index e981259..c0e6187 100644 --- a/automation_devices/src/hue_bridge.rs +++ b/automation_devices/src/hue_bridge.rs @@ -99,7 +99,7 @@ impl AddAdditionalMethods for HueBridge { { methods.add_async_method( "set_flag", - |lua, this, (flag, value): (mlua::Value, bool)| async move { + async |lua, this, (flag, value): (mlua::Value, bool)| { let flag: Flag = lua.from_value(flag)?; this.set_flag(flag, value).await; diff --git a/automation_devices/src/ntfy.rs b/automation_devices/src/ntfy.rs index 4206bd8..565969b 100644 --- a/automation_devices/src/ntfy.rs +++ b/automation_devices/src/ntfy.rs @@ -170,7 +170,7 @@ impl AddAdditionalMethods for Ntfy { { methods.add_async_method( "send_notification", - |lua, this, notification: mlua::Value| async move { + async |lua, this, notification: mlua::Value| { let notification: Notification = lua.from_value(notification)?; this.send(notification).await; diff --git a/automation_lib/src/device_manager.rs b/automation_lib/src/device_manager.rs index bbad8a6..405d382 100644 --- a/automation_lib/src/device_manager.rs +++ b/automation_lib/src/device_manager.rs @@ -73,22 +73,19 @@ impl DeviceManager { match event { Event::MqttMessage(message) => { let devices = self.devices.read().await; - let iter = devices.iter().map(|(id, device)| { - let message = message.clone(); - async move { - let device: Option<&dyn OnMqtt> = device.cast(); - if let Some(device) = device { - // let subscribed = device - // .topics() - // .iter() - // .any(|topic| matches(&message.topic, topic)); - // - // if subscribed { - trace!(id, "Handling"); - device.on_mqtt(message).await; - trace!(id, "Done"); - // } - } + let iter = devices.iter().map(async |(id, device)| { + let device: Option<&dyn OnMqtt> = device.cast(); + if let Some(device) = device { + // let subscribed = device + // .topics() + // .iter() + // .any(|topic| matches(&message.topic, topic)); + // + // if subscribed { + trace!(id, "Handling"); + device.on_mqtt(message.clone()).await; + trace!(id, "Done"); + // } } }); @@ -100,7 +97,7 @@ impl DeviceManager { impl mlua::UserData for DeviceManager { fn add_methods>(methods: &mut M) { - methods.add_async_method("add", |_lua, this, device: Box| async move { + methods.add_async_method("add", async |_lua, this, device: Box| { this.add(device).await; Ok(()) @@ -108,7 +105,7 @@ impl mlua::UserData for DeviceManager { methods.add_async_method( "schedule", - |lua, this, (schedule, f): (String, mlua::Function)| async move { + async |lua, this, (schedule, f): (String, mlua::Function)| { debug!("schedule = {schedule}"); // This creates a function, that returns the actual job we want to run let create_job = { diff --git a/automation_lib/src/helpers/timeout.rs b/automation_lib/src/helpers/timeout.rs index bb3b456..4b5ae7e 100644 --- a/automation_lib/src/helpers/timeout.rs +++ b/automation_lib/src/helpers/timeout.rs @@ -29,7 +29,7 @@ impl mlua::UserData for Timeout { methods.add_async_method( "start", - |_lua, this, (timeout, callback): (f32, ActionCallback)| async move { + async |_lua, this, (timeout, callback): (f32, ActionCallback)| { if let Some(handle) = this.state.write().await.handle.take() { handle.abort(); } @@ -50,7 +50,7 @@ impl mlua::UserData for Timeout { }, ); - methods.add_async_method("cancel", |_lua, this, ()| async move { + methods.add_async_method("cancel", async |_lua, this, ()| { debug!("Canceling timeout callback"); if let Some(handle) = this.state.write().await.handle.take() { @@ -60,7 +60,7 @@ impl mlua::UserData for Timeout { Ok(()) }); - methods.add_async_method("is_waiting", |_lua, this, ()| async move { + methods.add_async_method("is_waiting", async |_lua, this, ()| { debug!("Canceling timeout callback"); if let Some(handle) = this.state.read().await.handle.as_ref() { diff --git a/automation_lib/src/lua/traits.rs b/automation_lib/src/lua/traits.rs index 69abcc3..55c2cab 100644 --- a/automation_lib/src/lua/traits.rs +++ b/automation_lib/src/lua/traits.rs @@ -7,13 +7,13 @@ pub trait OnOff { where Self: Sized + google_home::traits::OnOff + 'static, { - methods.add_async_method("set_on", |_lua, this, on: bool| async move { + methods.add_async_method("set_on", async |_lua, this, on: bool| { this.deref().set_on(on).await.unwrap(); Ok(()) }); - methods.add_async_method("on", |_lua, this, ()| async move { + methods.add_async_method("on", async |_lua, this, ()| { Ok(this.deref().on().await.unwrap()) }); } @@ -25,13 +25,13 @@ pub trait Brightness { where Self: Sized + google_home::traits::Brightness + 'static, { - methods.add_async_method("set_brightness", |_lua, this, brightness: u8| async move { + methods.add_async_method("set_brightness", async |_lua, this, brightness: u8| { this.set_brightness(brightness).await.unwrap(); Ok(()) }); - methods.add_async_method("brightness", |_lua, this, _: ()| async move { + methods.add_async_method("brightness", async |_lua, this, _: ()| { Ok(this.brightness().await.unwrap()) }); } @@ -45,7 +45,7 @@ pub trait ColorSetting { { methods.add_async_method( "set_color_temperature", - |_lua, this, temperature: u32| async move { + async |_lua, this, temperature: u32| { this.set_color(google_home::traits::Color { temperature }) .await .unwrap(); @@ -54,7 +54,7 @@ pub trait ColorSetting { }, ); - methods.add_async_method("color_temperature", |_lua, this, ()| async move { + methods.add_async_method("color_temperature", async |_lua, this, ()| { Ok(this.color().await.temperature) }); } @@ -66,16 +66,13 @@ pub trait OpenClose { where Self: Sized + google_home::traits::OpenClose + 'static, { - methods.add_async_method( - "set_open_percent", - |_lua, this, open_percent: u8| async move { - this.set_open_percent(open_percent).await.unwrap(); + methods.add_async_method("set_open_percent", async |_lua, this, open_percent: u8| { + this.set_open_percent(open_percent).await.unwrap(); - Ok(()) - }, - ); + Ok(()) + }); - methods.add_async_method("open_percent", |_lua, this, _: ()| async move { + methods.add_async_method("open_percent", async |_lua, this, _: ()| { Ok(this.open_percent().await.unwrap()) }); } diff --git a/automation_lib/src/mqtt.rs b/automation_lib/src/mqtt.rs index 8830ab2..7e71d43 100644 --- a/automation_lib/src/mqtt.rs +++ b/automation_lib/src/mqtt.rs @@ -27,7 +27,7 @@ impl mlua::UserData for WrappedAsyncClient { fn add_methods>(methods: &mut M) { methods.add_async_method( "send_message", - |_lua, this, (topic, message): (String, mlua::Value)| async move { + async |_lua, this, (topic, message): (String, mlua::Value)| { let message = serde_json::to_string(&message).unwrap(); debug!("message = {message}"); diff --git a/automation_macro/src/impl_device.rs b/automation_macro/src/impl_device.rs index f3c3d67..3c001b2 100644 --- a/automation_macro/src/impl_device.rs +++ b/automation_macro/src/impl_device.rs @@ -45,7 +45,7 @@ impl Impl { quote! { impl mlua::UserData for #name #generics { fn add_methods>(methods: &mut M) { - methods.add_async_function("new", |_lua, config| async { + methods.add_async_function("new", async |_lua, config| { let device: Self = LuaDeviceCreate::create(config) .await .map_err(mlua::ExternalError::into_lua_err)?; @@ -58,7 +58,7 @@ impl Impl { Ok(b) }); - methods.add_async_method("get_id", |_lua, this, _: ()| async move { Ok(this.get_id()) }); + methods.add_async_method("get_id", async |_lua, this, _: ()| { Ok(this.get_id()) }); #( #traits::add_methods(methods); diff --git a/google_home/google_home/src/fulfillment.rs b/google_home/google_home/src/fulfillment.rs index 89828ff..ebcb7cd 100644 --- a/google_home/google_home/src/fulfillment.rs +++ b/google_home/google_home/src/fulfillment.rs @@ -40,15 +40,13 @@ impl GoogleHome { let intent = request.inputs.into_iter().next(); let payload: OptionFuture<_> = intent - .map(|intent| async move { - match intent { - Intent::Sync => ResponsePayload::Sync(self.sync(devices).await), - Intent::Query(payload) => { - ResponsePayload::Query(self.query(payload, devices).await) - } - Intent::Execute(payload) => { - ResponsePayload::Execute(self.execute(payload, devices).await) - } + .map(async |intent| match intent { + Intent::Sync => ResponsePayload::Sync(self.sync(devices).await), + Intent::Query(payload) => { + ResponsePayload::Query(self.query(payload, devices).await) + } + Intent::Execute(payload) => { + ResponsePayload::Execute(self.execute(payload, devices).await) } }) .into(); @@ -64,7 +62,7 @@ impl GoogleHome { devices: &HashMap>, ) -> sync::Payload { let mut resp_payload = sync::Payload::new(&self.user_id); - let f = devices.values().map(|device| async move { + let f = devices.values().map(async |device| { if let Some(device) = device.as_ref().cast() { Some(Device::sync(device).await) } else { @@ -86,7 +84,7 @@ impl GoogleHome { .devices .into_iter() .map(|device| device.id) - .map(|id| async move { + .map(async |id| { // NOTE: Requires let_chains feature let device = if let Some(device) = devices.get(id.as_str()) && let Some(device) = device.as_ref().cast() @@ -115,84 +113,77 @@ impl GoogleHome { ) -> execute::Payload { let resp_payload = Arc::new(Mutex::new(response::execute::Payload::new())); - let f = payload.commands.into_iter().map(|command| { - let resp_payload = resp_payload.clone(); - async move { - let mut success = response::execute::Command::new(execute::Status::Success); - success.states = Some(execute::States { - online: true, - state: Default::default(), - }); - let mut offline = response::execute::Command::new(execute::Status::Offline); - offline.states = Some(execute::States { - online: false, - state: Default::default(), - }); - let mut errors: HashMap = HashMap::new(); + let f = payload.commands.into_iter().map(async |command| { + let mut success = response::execute::Command::new(execute::Status::Success); + success.states = Some(execute::States { + online: true, + state: Default::default(), + }); + let mut offline = response::execute::Command::new(execute::Status::Offline); + offline.states = Some(execute::States { + online: false, + state: Default::default(), + }); + let mut errors: HashMap = HashMap::new(); - let f = command - .devices - .into_iter() - .map(|device| device.id) - .map(|id| { - let execution = command.execution.clone(); - async move { - if let Some(device) = devices.get(id.as_str()) - && let Some(device) = device.as_ref().cast() - { - if !device.is_online().await { - return (id, Ok(false)); - } - - // NOTE: We can not use .map here because async =( - let mut results = Vec::new(); - for cmd in &execution { - results.push(Device::execute(device, cmd.clone()).await); - } - - // Convert vec of results to a result with a vec and the first - // encountered error - let results = - results.into_iter().collect::, ErrorCode>>(); - - // TODO: We only get one error not all errors - if let Err(err) = results { - (id, Err(err)) - } else { - (id, Ok(true)) - } - } else { - (id.clone(), Err(DeviceError::DeviceNotFound.into())) - } + let f = command + .devices + .into_iter() + .map(|device| device.id) + .map(async |id| { + if let Some(device) = devices.get(id.as_str()) + && let Some(device) = device.as_ref().cast() + { + if !device.is_online().await { + return (id, Ok(false)); } - }); - let a = join_all(f).await; - a.into_iter().for_each(|(id, state)| { - match state { - Ok(true) => success.add_id(&id), - Ok(false) => offline.add_id(&id), - Err(err) => errors - .entry(err) - .or_insert_with(|| match &err { - ErrorCode::DeviceError(_) => { - response::execute::Command::new(execute::Status::Error) - } - ErrorCode::DeviceException(_) => { - response::execute::Command::new(execute::Status::Exceptions) - } - }) - .add_id(&id), - }; + // NOTE: We can not use .map here because async =( + let mut results = Vec::new(); + for cmd in &command.execution { + results.push(Device::execute(device, cmd.clone()).await); + } + + // Convert vec of results to a result with a vec and the first + // encountered error + let results = results.into_iter().collect::, ErrorCode>>(); + + // TODO: We only get one error not all errors + if let Err(err) = results { + (id, Err(err)) + } else { + (id, Ok(true)) + } + } else { + (id.clone(), Err(DeviceError::DeviceNotFound.into())) + } }); - let mut resp_payload = resp_payload.lock().await; - resp_payload.add_command(success); - resp_payload.add_command(offline); - for (error, mut cmd) in errors { - cmd.error_code = Some(error); - resp_payload.add_command(cmd); - } + let a = join_all(f).await; + a.into_iter().for_each(|(id, state)| { + match state { + Ok(true) => success.add_id(&id), + Ok(false) => offline.add_id(&id), + Err(err) => errors + .entry(err) + .or_insert_with(|| match &err { + ErrorCode::DeviceError(_) => { + response::execute::Command::new(execute::Status::Error) + } + ErrorCode::DeviceException(_) => { + response::execute::Command::new(execute::Status::Exceptions) + } + }) + .add_id(&id), + }; + }); + + let mut resp_payload = resp_payload.lock().await; + resp_payload.add_command(success); + resp_payload.add_command(offline); + for (error, mut cmd) in errors { + cmd.error_code = Some(error); + resp_payload.add_command(cmd); } });