diff --git a/automation_lib/src/action_callback.rs b/automation_lib/src/action_callback.rs index 1dc08e9..bed3c99 100644 --- a/automation_lib/src/action_callback.rs +++ b/automation_lib/src/action_callback.rs @@ -1,11 +1,12 @@ use std::marker::PhantomData; +use futures::future::try_join_all; use mlua::{FromLua, IntoLua, LuaSerdeExt}; use serde::Serialize; #[derive(Debug, Clone)] struct Internal { - value: mlua::Value, + callbacks: Vec, lua: mlua::Lua, } @@ -28,9 +29,28 @@ impl Default for ActionCallback { impl FromLua for ActionCallback { fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result { + let callbacks = match value { + mlua::Value::Function(f) => vec![f], + mlua::Value::Table(table) => table + .pairs::() + .map(|pair| { + let (_, f) = pair?; + + Ok::<_, mlua::Error>(f) + }) + .try_collect()?, + _ => { + return Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "ActionCallback".into(), + message: Some("expected function or table of functions".into()), + }); + } + }; + Ok(ActionCallback { internal: Some(Internal { - value, + callbacks, lua: lua.clone(), }), _this: PhantomData::, @@ -52,10 +72,14 @@ where let state = internal.lua.to_value(state).unwrap(); - match &internal.value { - mlua::Value::Function(f) => f.call_async::<()>((this.clone(), state)).await.unwrap(), - _ => todo!("Only functions are currently supported"), - } + try_join_all( + internal + .callbacks + .iter() + .map(async |f| f.call_async::<()>((this.clone(), state.clone())).await), + ) + .await + .unwrap(); } pub fn is_set(&self) -> bool { diff --git a/automation_lib/src/lib.rs b/automation_lib/src/lib.rs index 4634704..771da64 100644 --- a/automation_lib/src/lib.rs +++ b/automation_lib/src/lib.rs @@ -1,4 +1,5 @@ #![allow(incomplete_features)] +#![feature(iterator_try_collect)] pub mod action_callback; pub mod config;