diff --git a/config.lua b/config.lua index 90493b9..2920bf6 100644 --- a/config.lua +++ b/config.lua @@ -742,11 +742,15 @@ devs:add(devices.ContactSensor.new({ battery_callback = check_battery, })) +-- HACK: If the devices config contains a function it will call it so we have to remove it +devs.add = nil + ---@type Config return { fulfillment = { openid_url = "https://login.huizinga.dev/api/oidc", }, + mqtt = mqtt_client, devices = devs, schedule = { ["0 0 19 * * *"] = function() diff --git a/definitions/config.lua b/definitions/config.lua index c6cc30f..c271e12 100644 --- a/definitions/config.lua +++ b/definitions/config.lua @@ -9,6 +9,9 @@ local FulfillmentConfig ---@class Config ---@field fulfillment FulfillmentConfig ----@field devices DeviceInterface[]? +---@field devices Devices? +---@field mqtt AsyncClient ---@field schedule table? local Config + +---@alias Devices (DeviceInterface | fun(client: AsyncClient): Devices)[] diff --git a/src/bin/automation.rs b/src/bin/automation.rs index 644e94a..5baca23 100644 --- a/src/bin/automation.rs +++ b/src/bin/automation.rs @@ -139,8 +139,10 @@ async fn app() -> anyhow::Result<()> { let entrypoint = Path::new(&setup.entrypoint); let config: Config = lua.load(entrypoint).eval_async().await?; - for device in config.devices { - device_manager.add(device).await; + if let Some(devices) = config.devices { + for device in devices.get(&lua, &config.mqtt).await? { + device_manager.add(device).await; + } } start_scheduler(config.schedule).await?; diff --git a/src/bin/generate_definitions.rs b/src/bin/generate_definitions.rs index aab0e20..41262f5 100644 --- a/src/bin/generate_definitions.rs +++ b/src/bin/generate_definitions.rs @@ -1,7 +1,7 @@ use std::fs::{self, File}; use std::io::Write; -use automation::config::{Config, FulfillmentConfig}; +use automation::config::{Config, Devices, FulfillmentConfig}; use automation_lib::Module; use lua_typed::Typed; use tracing::{info, warn}; @@ -33,6 +33,8 @@ fn config_definitions() -> String { &FulfillmentConfig::generate_full().expect("FulfillmentConfig should have a definition"); output += "\n"; output += &Config::generate_full().expect("Config should have a definition"); + output += "\n"; + output += &Devices::generate_full().expect("Devices should have a definition"); output } diff --git a/src/config.rs b/src/config.rs index 38d732c..daac095 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,10 +1,12 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::net::{Ipv4Addr, SocketAddr}; use automation_lib::action_callback::ActionCallback; use automation_lib::device::Device; +use automation_lib::mqtt::WrappedAsyncClient; use automation_macro::LuaDeviceConfig; use lua_typed::Typed; +use mlua::FromLua; use serde::Deserialize; #[derive(Debug, Deserialize)] @@ -32,12 +34,78 @@ pub struct FulfillmentConfig { pub port: u16, } +#[derive(Debug, Default)] +pub struct Devices(mlua::Value); + +impl Devices { + pub async fn get( + self, + lua: &mlua::Lua, + client: &WrappedAsyncClient, + ) -> mlua::Result>> { + let mut devices = Vec::new(); + let initial_table = match self.0 { + mlua::Value::Table(table) => table, + mlua::Value::Function(f) => f.call_async(client.clone()).await?, + _ => Err(mlua::Error::runtime(format!( + "Expected table or function, instead found: {}", + self.0.type_name() + )))?, + }; + + let mut queue: VecDeque = [initial_table].into(); + loop { + let Some(table) = queue.pop_front() else { + break; + }; + + for pair in table.pairs() { + let (_, value): (mlua::Value, _) = pair?; + + match value { + mlua::Value::UserData(_) => devices.push(Box::from_lua(value, lua)?), + mlua::Value::Function(f) => { + queue.push_back(f.call_async(client.clone()).await?); + } + _ => Err(mlua::Error::runtime(format!( + "Expected a device, table, or function, instead found: {}", + value.type_name() + )))?, + } + } + } + + Ok(devices) + } +} + +impl FromLua for Devices { + fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result { + Ok(Devices(value)) + } +} + +impl Typed for Devices { + fn type_name() -> String { + "Devices".into() + } + + fn generate_header() -> Option { + Some(format!( + "---@alias {} (DeviceInterface | fun(client: {}): Devices)[]\n", + ::type_name(), + ::type_name() + )) + } +} + #[derive(Debug, LuaDeviceConfig, Typed)] pub struct Config { pub fulfillment: FulfillmentConfig, #[device_config(from_lua, default)] - #[typed(default)] - pub devices: Vec>, + pub devices: Option, + #[device_config(from_lua)] + pub mqtt: WrappedAsyncClient, #[device_config(from_lua, default)] #[typed(default)] pub schedule: HashMap>,