From 3be11b0c6a3b5352f26d3bc3da05012c2c93f67b Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Mon, 8 Sep 2025 02:11:05 +0200 Subject: [PATCH] feat: Allow for multiple callbacks inside of an ActionCallback This also results in the conversion being performed when the ActionCallback is instantiated instead of when it is called, this should make it easier to catch errors. --- automation_lib/src/action_callback.rs | 36 ++++++++++++++++++++++----- automation_lib/src/lib.rs | 1 + 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/automation_lib/src/action_callback.rs b/automation_lib/src/action_callback.rs index 1dc08e9..bed3c99 100644 --- a/automation_lib/src/action_callback.rs +++ b/automation_lib/src/action_callback.rs @@ -1,11 +1,12 @@ use std::marker::PhantomData; +use futures::future::try_join_all; use mlua::{FromLua, IntoLua, LuaSerdeExt}; use serde::Serialize; #[derive(Debug, Clone)] struct Internal { - value: mlua::Value, + callbacks: Vec, lua: mlua::Lua, } @@ -28,9 +29,28 @@ impl Default for ActionCallback { impl FromLua for ActionCallback { fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result { + let callbacks = match value { + mlua::Value::Function(f) => vec![f], + mlua::Value::Table(table) => table + .pairs::() + .map(|pair| { + let (_, f) = pair?; + + Ok::<_, mlua::Error>(f) + }) + .try_collect()?, + _ => { + return Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "ActionCallback".into(), + message: Some("expected function or table of functions".into()), + }); + } + }; + Ok(ActionCallback { internal: Some(Internal { - value, + callbacks, lua: lua.clone(), }), _this: PhantomData::, @@ -52,10 +72,14 @@ where let state = internal.lua.to_value(state).unwrap(); - match &internal.value { - mlua::Value::Function(f) => f.call_async::<()>((this.clone(), state)).await.unwrap(), - _ => todo!("Only functions are currently supported"), - } + try_join_all( + internal + .callbacks + .iter() + .map(async |f| f.call_async::<()>((this.clone(), state.clone())).await), + ) + .await + .unwrap(); } pub fn is_set(&self) -> bool { diff --git a/automation_lib/src/lib.rs b/automation_lib/src/lib.rs index 4634704..771da64 100644 --- a/automation_lib/src/lib.rs +++ b/automation_lib/src/lib.rs @@ -1,4 +1,5 @@ #![allow(incomplete_features)] +#![feature(iterator_try_collect)] pub mod action_callback; pub mod config;