diff --git a/Cargo.lock b/Cargo.lock index 0153116..c8b3419 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -144,6 +144,7 @@ version = "0.1.0" dependencies = [ "async-trait", "automation_cast", + "automation_macro", "bytes", "dyn-clone", "futures", diff --git a/automation_lib/Cargo.toml b/automation_lib/Cargo.toml index 04dea30..b05bee5 100644 --- a/automation_lib/Cargo.toml +++ b/automation_lib/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +automation_macro = { workspace = true } async-trait = { workspace = true } automation_cast = { workspace = true } bytes = { workspace = true } diff --git a/automation_lib/src/config.rs b/automation_lib/src/config.rs index 0848ee0..995ef52 100644 --- a/automation_lib/src/config.rs +++ b/automation_lib/src/config.rs @@ -1,34 +1,6 @@ -use std::time::Duration; - use lua_typed::Typed; -use rumqttc::{MqttOptions, Transport}; use serde::Deserialize; -#[derive(Debug, Clone, Deserialize, Typed)] -pub struct MqttConfig { - pub host: String, - pub port: u16, - pub client_name: String, - pub username: String, - pub password: String, - #[serde(default)] - pub tls: bool, -} - -impl From for MqttOptions { - fn from(value: MqttConfig) -> Self { - let mut mqtt_options = MqttOptions::new(value.client_name, value.host, value.port); - mqtt_options.set_credentials(value.username, value.password); - mqtt_options.set_keep_alive(Duration::from_secs(5)); - - if value.tls { - mqtt_options.set_transport(Transport::tls_with_default_config()); - } - - mqtt_options - } -} - #[derive(Debug, Clone, Deserialize, Typed)] pub struct InfoConfig { pub name: String, diff --git a/automation_lib/src/mqtt.rs b/automation_lib/src/mqtt.rs index 038251c..bd1b48f 100644 --- a/automation_lib/src/mqtt.rs +++ b/automation_lib/src/mqtt.rs @@ -1,15 +1,41 @@ use std::ops::{Deref, DerefMut}; +use std::time::Duration; +use automation_macro::LuaDeviceConfig; use lua_typed::Typed; -use mlua::{FromLua, LuaSerdeExt}; -use rumqttc::{AsyncClient, Event, EventLoop, Incoming}; +use mlua::FromLua; +use rumqttc::{AsyncClient, Event, Incoming, MqttOptions, Transport}; +use serde::Deserialize; use tracing::{debug, warn}; -use crate::Module; -use crate::config::MqttConfig; -use crate::device_manager::DeviceManager; use crate::event::{self, EventChannel}; +#[derive(Debug, Clone, LuaDeviceConfig, Deserialize, Typed)] +pub struct MqttConfig { + pub host: String, + pub port: u16, + pub client_name: String, + pub username: String, + pub password: String, + #[serde(default)] + #[typed(default)] + pub tls: bool, +} + +impl From for MqttOptions { + fn from(value: MqttConfig) -> Self { + let mut mqtt_options = MqttOptions::new(value.client_name, value.host, value.port); + mqtt_options.set_credentials(value.username, value.password); + mqtt_options.set_keep_alive(Duration::from_secs(5)); + + if value.tls { + mqtt_options.set_transport(Transport::tls_with_default_config()); + } + + mqtt_options + } +} + #[derive(Debug, Clone, FromLua)] pub struct WrappedAsyncClient(pub AsyncClient); @@ -34,20 +60,6 @@ impl Typed for WrappedAsyncClient { Some(output) } - - fn generate_footer() -> Option { - let mut output = String::new(); - - let type_name = Self::type_name(); - - output += &format!("mqtt.{type_name} = {{}}\n"); - output += &format!("---@param device_manager {}\n", DeviceManager::type_name()); - output += &format!("---@param config {}\n", MqttConfig::type_name()); - output += &format!("---@return {type_name}\n"); - output += "function mqtt.new(device_manager, config) end\n"; - - Some(output) - } } impl Deref for WrappedAsyncClient { @@ -90,8 +102,9 @@ impl mlua::UserData for WrappedAsyncClient { } } -pub fn start(mut eventloop: EventLoop, event_channel: &EventChannel) { +pub fn start(config: MqttConfig, event_channel: &EventChannel) -> WrappedAsyncClient { let tx = event_channel.get_tx(); + let (client, mut eventloop) = AsyncClient::new(config.into(), 100); tokio::spawn(async move { debug!("Listening for MQTT events"); @@ -110,42 +123,6 @@ pub fn start(mut eventloop: EventLoop, event_channel: &EventChannel) { } } }); + + WrappedAsyncClient(client) } - -fn create_module(lua: &mlua::Lua) -> mlua::Result { - let mqtt = lua.create_table()?; - let mqtt_new = lua.create_function( - move |lua, (device_manager, config): (DeviceManager, mlua::Value)| { - let event_channel = device_manager.event_channel(); - let config: MqttConfig = lua.from_value(config)?; - - // Create a mqtt client - // TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync - let (client, eventloop) = AsyncClient::new(config.into(), 100); - start(eventloop, &event_channel); - - Ok(WrappedAsyncClient(client)) - }, - )?; - mqtt.set("new", mqtt_new)?; - - Ok(mqtt) -} - -fn generate_definitions() -> String { - let mut output = String::new(); - - output += "---@meta\n\nlocal mqtt\n\n"; - - output += &MqttConfig::generate_full().expect("WrappedAsyncClient should have generate_full"); - output += "\n"; - output += - &WrappedAsyncClient::generate_full().expect("WrappedAsyncClient should have generate_full"); - output += "\n"; - - output += "return mqtt"; - - output -} - -inventory::submit! {Module::new("automation:mqtt", create_module, Some(generate_definitions))} diff --git a/config.lua b/config.lua index 24b9566..5f916ef 100644 --- a/config.lua +++ b/config.lua @@ -734,22 +734,21 @@ local function create_devs(mqtt_client) return devs end --- TODO: Pass the mqtt config to the output config, instead of constructing the client here -local mqtt_client = require("automation:mqtt").new(device_manager, { +local mqtt_config = { host = ((host == "zeus" or host == "hephaestus") and "olympus.lan.huizinga.dev") or "mosquitto", port = 8883, client_name = "automation-" .. host, username = "mqtt", password = secrets.mqtt_password, tls = host == "zeus" or host == "hephaestus", -}) +} ---@type Config return { fulfillment = { openid_url = "https://login.huizinga.dev/api/oidc", }, - mqtt = mqtt_client, + mqtt = mqtt_config, devices = { create_devs, ntfy, diff --git a/definitions/automation:mqtt.lua b/definitions/automation:mqtt.lua deleted file mode 100644 index 58c83ae..0000000 --- a/definitions/automation:mqtt.lua +++ /dev/null @@ -1,27 +0,0 @@ --- DO NOT MODIFY, FILE IS AUTOMATICALLY GENERATED ----@meta - -local mqtt - ----@class MqttConfig ----@field host string ----@field port integer ----@field client_name string ----@field username string ----@field password string ----@field tls boolean -local MqttConfig - ----@class AsyncClient -local AsyncClient ----@async ----@param topic string ----@param message table? -function AsyncClient:send_message(topic, message) end -mqtt.AsyncClient = {} ----@param device_manager DeviceManager ----@param config MqttConfig ----@return AsyncClient -function mqtt.new(device_manager, config) end - -return mqtt diff --git a/definitions/config.lua b/definitions/config.lua index c271e12..ad89353 100644 --- a/definitions/config.lua +++ b/definitions/config.lua @@ -10,8 +10,24 @@ local FulfillmentConfig ---@class Config ---@field fulfillment FulfillmentConfig ---@field devices Devices? ----@field mqtt AsyncClient +---@field mqtt MqttConfig ---@field schedule table? local Config ---@alias Devices (DeviceInterface | fun(client: AsyncClient): Devices)[] + +---@class MqttConfig +---@field host string +---@field port integer +---@field client_name string +---@field username string +---@field password string +---@field tls boolean? +local MqttConfig + +---@class AsyncClient +local AsyncClient +---@async +---@param topic string +---@param message table? +function AsyncClient:send_message(topic, message) end diff --git a/src/bin/automation.rs b/src/bin/automation.rs index 5baca23..1965a09 100644 --- a/src/bin/automation.rs +++ b/src/bin/automation.rs @@ -11,6 +11,7 @@ use automation::secret::EnvironmentSecretFile; use automation::version::VERSION; use automation::web::{ApiError, User}; use automation_lib::device_manager::DeviceManager; +use automation_lib::mqtt; use axum::extract::{FromRef, State}; use axum::http::StatusCode; use axum::routing::post; @@ -139,8 +140,10 @@ async fn app() -> anyhow::Result<()> { let entrypoint = Path::new(&setup.entrypoint); let config: Config = lua.load(entrypoint).eval_async().await?; + let mqtt_client = mqtt::start(config.mqtt, &device_manager.event_channel()); + if let Some(devices) = config.devices { - for device in devices.get(&lua, &config.mqtt).await? { + for device in devices.get(&lua, &mqtt_client).await? { device_manager.add(device).await; } } diff --git a/src/bin/generate_definitions.rs b/src/bin/generate_definitions.rs index 41262f5..ac6a3de 100644 --- a/src/bin/generate_definitions.rs +++ b/src/bin/generate_definitions.rs @@ -3,6 +3,7 @@ use std::io::Write; use automation::config::{Config, Devices, FulfillmentConfig}; use automation_lib::Module; +use automation_lib::mqtt::{MqttConfig, WrappedAsyncClient}; use lua_typed::Typed; use tracing::{info, warn}; @@ -35,6 +36,11 @@ fn config_definitions() -> String { output += &Config::generate_full().expect("Config should have a definition"); output += "\n"; output += &Devices::generate_full().expect("Devices should have a definition"); + output += "\n"; + output += &MqttConfig::generate_full().expect("MqttConfig should have a definition"); + output += "\n"; + output += + &WrappedAsyncClient::generate_full().expect("WrappedAsyncClient should have a definition"); output } diff --git a/src/config.rs b/src/config.rs index daac095..9d692e2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use automation_lib::action_callback::ActionCallback; use automation_lib::device::Device; -use automation_lib::mqtt::WrappedAsyncClient; +use automation_lib::mqtt::{MqttConfig, WrappedAsyncClient}; use automation_macro::LuaDeviceConfig; use lua_typed::Typed; use mlua::FromLua; @@ -105,7 +105,7 @@ pub struct Config { #[device_config(from_lua, default)] pub devices: Option, #[device_config(from_lua)] - pub mqtt: WrappedAsyncClient, + pub mqtt: MqttConfig, #[device_config(from_lua, default)] #[typed(default)] pub schedule: HashMap>,