Compare commits

..

3 Commits

Author SHA1 Message Date
f7ba602762 chore: Removed dotenvy
All checks were successful
Build and deploy / build (push) Successful in 10m21s
Build and deploy / Deploy container (push) Has been skipped
Since secrets can now be set from automation.toml the .env file was no
longer used, so dotenvy can be removed.
2025-10-15 01:18:09 +02:00
929007d8c2 feat: Use Typed type_name for registering proxy 2025-10-15 01:01:10 +02:00
b191d67ae2 feat: Expanded add_methods to extra_user_data
All checks were successful
Build and deploy / build (push) Successful in 10m29s
Build and deploy / Deploy container (push) Has been skipped
Instead of being a function it now expects a struct with the
PartialUserData trait implemented. This in part ensures the correct
function signature.

It also adds another optional function to PartialUserData that returns
definitions for the added methods.
diff --git a/automation_devices/src/hue_bridge.rs b/automation_devices/src/hue_bridge.rs
index b08ab51..0c548c8 100644
--- a/automation_devices/src/hue_bridge.rs
+++ b/automation_devices/src/hue_bridge.rs
@@ -3,18 +3,21 @@ use std::net::SocketAddr;

 use async_trait::async_trait;
 use automation_lib::device::{Device, LuaDeviceCreate};
+use automation_lib::lua::traits::PartialUserData;
 use automation_macro::{Device, LuaDeviceConfig};
 use lua_typed::Typed;
 use mlua::LuaSerdeExt;
 use serde::{Deserialize, Serialize};
 use tracing::{error, trace, warn};

-#[derive(Debug, Deserialize)]
+#[derive(Debug, Deserialize, Typed)]
 #[serde(rename_all = "snake_case")]
+#[typed(rename_all = "snake_case")]
 pub enum Flag {
     Presence,
     Darkness,
 }
+crate::register_type!(Flag);

 #[derive(Debug, Clone, Deserialize, Typed)]
 pub struct FlagIDs {
@@ -36,12 +39,36 @@ pub struct Config {
 crate::register_type!(Config);

 #[derive(Debug, Clone, Device)]
-#[device(add_methods = Self::add_methods)]
+#[device(extra_user_data = SetFlag)]
 pub struct HueBridge {
     config: Config,
 }
 crate::register_device!(HueBridge);

+struct SetFlag;
+impl PartialUserData<HueBridge> for SetFlag {
+    fn add_methods<M: mlua::UserDataMethods<HueBridge>>(methods: &mut M) {
+        methods.add_async_method(
+            "set_flag",
+            async |lua, this, (flag, value): (mlua::Value, bool)| {
+                let flag: Flag = lua.from_value(flag)?;
+
+                this.set_flag(flag, value).await;
+
+                Ok(())
+            },
+        );
+    }
+
+    fn definitions() -> Option<String> {
+        Some(format!(
+            "---@async\n---@param flag {}\n---@param value boolean\nfunction {}:set_flag(flag, value) end\n",
+            <Flag as Typed>::type_name(),
+            <HueBridge as Typed>::type_name(),
+        ))
+    }
+}
+
 #[derive(Debug, Serialize)]
 struct FlagMessage {
     flag: bool,
@@ -89,19 +116,6 @@ impl HueBridge {
             }
         }
     }
-
-    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
-        methods.add_async_method(
-            "set_flag",
-            async |lua, this, (flag, value): (mlua::Value, bool)| {
-                let flag: Flag = lua.from_value(flag)?;
-
-                this.set_flag(flag, value).await;
-
-                Ok(())
-            },
-        );
-    }
 }

 impl Device for HueBridge {
diff --git a/automation_devices/src/ntfy.rs b/automation_devices/src/ntfy.rs
index 8060ced..1be2874 100644
--- a/automation_devices/src/ntfy.rs
+++ b/automation_devices/src/ntfy.rs
@@ -3,6 +3,7 @@ use std::convert::Infallible;

 use async_trait::async_trait;
 use automation_lib::device::{Device, LuaDeviceCreate};
+use automation_lib::lua::traits::PartialUserData;
 use automation_macro::{Device, LuaDeviceConfig};
 use lua_typed::Typed;
 use mlua::LuaSerdeExt;
@@ -90,14 +91,15 @@ pub struct Config {
 crate::register_type!(Config);

 #[derive(Debug, Clone, Device)]
-#[device(add_methods = Self::add_methods)]
+#[device(extra_user_data = SendNotification)]
 pub struct Ntfy {
     config: Config,
 }
 crate::register_device!(Ntfy);

-impl Ntfy {
-    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
+struct SendNotification;
+impl PartialUserData<Ntfy> for SendNotification {
+    fn add_methods<M: mlua::UserDataMethods<Ntfy>>(methods: &mut M) {
         methods.add_async_method(
             "send_notification",
             async |lua, this, notification: mlua::Value| {
@@ -109,6 +111,14 @@ impl Ntfy {
             },
         );
     }
+
+    fn definitions() -> Option<String> {
+        Some(format!(
+            "---@async\n---@param notification {}\nfunction {}:send_notification(notification) end\n",
+            <Notification as Typed>::type_name(),
+            <Ntfy as Typed>::type_name(),
+        ))
+    }
 }

 #[async_trait]
diff --git a/automation_devices/src/presence.rs b/automation_devices/src/presence.rs
index 72391ab..a77327c 100644
--- a/automation_devices/src/presence.rs
+++ b/automation_devices/src/presence.rs
@@ -6,6 +6,7 @@ use automation_lib::action_callback::ActionCallback;
 use automation_lib::config::MqttDeviceConfig;
 use automation_lib::device::{Device, LuaDeviceCreate};
 use automation_lib::event::OnMqtt;
+use automation_lib::lua::traits::PartialUserData;
 use automation_lib::messages::PresenceMessage;
 use automation_lib::mqtt::WrappedAsyncClient;
 use automation_macro::{Device, LuaDeviceConfig};
@@ -39,13 +40,29 @@ pub struct State {
 }

 #[derive(Debug, Clone, Device)]
-#[device(add_methods = Self::add_methods)]
+#[device(extra_user_data = OverallPresence)]
 pub struct Presence {
     config: Config,
     state: Arc<RwLock<State>>,
 }
 crate::register_device!(Presence);

+struct OverallPresence;
+impl PartialUserData<Presence> for OverallPresence {
+    fn add_methods<M: mlua::UserDataMethods<Presence>>(methods: &mut M) {
+        methods.add_async_method("overall_presence", async |_lua, this, ()| {
+            Ok(this.state().await.current_overall_presence)
+        });
+    }
+
+    fn definitions() -> Option<String> {
+        Some(format!(
+            "---@async\n---@return boolean\nfunction {}:overall_presence() end\n",
+            <Presence as Typed>::type_name(),
+        ))
+    }
+}
+
 impl Presence {
     async fn state(&self) -> RwLockReadGuard<'_, State> {
         self.state.read().await
@@ -54,12 +71,6 @@ impl Presence {
     async fn state_mut(&self) -> RwLockWriteGuard<'_, State> {
         self.state.write().await
     }
-
-    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
-        methods.add_async_method("overall_presence", async |_lua, this, ()| {
-            Ok(this.state().await.current_overall_presence)
-        });
-    }
 }

 #[async_trait]
diff --git a/automation_lib/src/lua/traits.rs b/automation_lib/src/lua/traits.rs
index 6f61841..adae7df 100644
--- a/automation_lib/src/lua/traits.rs
+++ b/automation_lib/src/lua/traits.rs
@@ -8,6 +8,10 @@ pub trait PartialUserData<T> {
     fn interface_name() -> Option<&'static str> {
         None
     }
+
+    fn definitions() -> Option<String> {
+        None
+    }
 }

 pub struct Device;
diff --git a/automation_macro/src/device.rs b/automation_macro/src/device.rs
index 874765f..d66e0bd 100644
--- a/automation_macro/src/device.rs
+++ b/automation_macro/src/device.rs
@@ -1,7 +1,7 @@
 use std::collections::HashMap;

 use proc_macro2::TokenStream as TokenStream2;
-use quote::{ToTokens, quote};
+use quote::quote;
 use syn::parse::{Parse, ParseStream};
 use syn::punctuated::Punctuated;
 use syn::spanned::Spanned;
@@ -9,7 +9,7 @@ use syn::{Attribute, DeriveInput, Token, parenthesized};

 enum Attr {
     Trait(TraitAttr),
-    AddMethods(AddMethodsAttr),
+    ExtraUserData(ExtraUserDataAttr),
 }

 impl Attr {
@@ -20,9 +20,9 @@ impl Attr {
                 let input;
                 _ = parenthesized!(input in meta.input);
                 parsed = Some(Attr::Trait(input.parse()?));
-            } else if meta.path.is_ident("add_methods") {
+            } else if meta.path.is_ident("extra_user_data") {
                 let value = meta.value()?;
-                parsed = Some(Attr::AddMethods(value.parse()?));
+                parsed = Some(Attr::ExtraUserData(value.parse()?));
             } else {
                 return Err(syn::Error::new(meta.path.span(), "Unknown attribute"));
             }
@@ -95,28 +95,18 @@ impl Parse for Aliases {
 }

 #[derive(Clone)]
-struct AddMethodsAttr(syn::Path);
+struct ExtraUserDataAttr(syn::Ident);

-impl Parse for AddMethodsAttr {
+impl Parse for ExtraUserDataAttr {
     fn parse(input: ParseStream) -> syn::Result<Self> {
         Ok(Self(input.parse()?))
     }
 }

-impl ToTokens for AddMethodsAttr {
-    fn to_tokens(&self, tokens: &mut TokenStream2) {
-        let Self(path) = self;
-
-        tokens.extend(quote! {
-            #path
-        });
-    }
-}
-
 struct Implementation {
     name: syn::Ident,
     traits: Traits,
-    add_methods: Vec<AddMethodsAttr>,
+    extra_user_data: Vec<ExtraUserDataAttr>,
 }

 impl quote::ToTokens for Implementation {
@@ -124,9 +114,10 @@ impl quote::ToTokens for Implementation {
         let Self {
             name,
             traits,
-            add_methods,
+            extra_user_data,
         } = &self;
         let Traits(traits) = traits;
+        let extra_user_data: Vec<_> = extra_user_data.iter().map(|tr| tr.0.clone()).collect();

         tokens.extend(quote! {
             impl mlua::UserData for #name {
@@ -151,7 +142,7 @@ impl quote::ToTokens for Implementation {
                     )*

                     #(
-                        #add_methods(methods);
+                        <#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::add_methods(methods);
                     )*
                 }
             }
@@ -178,7 +169,7 @@ impl quote::ToTokens for Implementation {
                         format!(": {interfaces}")
                     };

-                    Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}"))
+                    Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}\n"))
                 }

                 fn generate_members() -> Option<String> {
@@ -191,6 +182,15 @@ impl quote::ToTokens for Implementation {
                     output += &format!("---@return {type_name}\n");
                     output += &format!("function devices.{type_name}.new(config) end\n");

+                    output += &<::automation_lib::lua::traits::Device as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
+
+                    #(
+                        output += &<::automation_lib::lua::traits::#traits as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
+                    )*
+                    #(
+                        output += &<#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
+                    )*
+

                     Some(output)
                 }
@@ -220,7 +220,7 @@ impl Implementations {
                         all.extend(&attribute.traits);
                     }
                 }
-                Attr::AddMethods(attribute) => add_methods.push(attribute),
+                Attr::ExtraUserData(attribute) => add_methods.push(attribute),
             }
         }

@@ -238,7 +238,7 @@ impl Implementations {
                 .map(|(alias, traits)| Implementation {
                     name: alias.unwrap_or(name.clone()),
                     traits,
-                    add_methods: add_methods.clone(),
+                    extra_user_data: add_methods.clone(),
                 })
                 .collect(),
         )
2025-10-15 00:45:38 +02:00
11 changed files with 27 additions and 122 deletions

4
Cargo.lock generated
View File

@@ -1101,7 +1101,7 @@ dependencies = [
[[package]] [[package]]
name = "lua_typed" name = "lua_typed"
version = "0.1.0" version = "0.1.0"
source = "git+https://git.huizinga.dev/Dreaded_X/lua_typed#08f5c4533a93131e8eda6702c062fb841d14d4e1" source = "git+https://git.huizinga.dev/Dreaded_X/lua_typed#d5d6fc1638bd108514899a792ee64335af50fc8b"
dependencies = [ dependencies = [
"eui48", "eui48",
"lua_typed_macro", "lua_typed_macro",
@@ -1110,7 +1110,7 @@ dependencies = [
[[package]] [[package]]
name = "lua_typed_macro" name = "lua_typed_macro"
version = "0.1.0" version = "0.1.0"
source = "git+https://git.huizinga.dev/Dreaded_X/lua_typed#08f5c4533a93131e8eda6702c062fb841d14d4e1" source = "git+https://git.huizinga.dev/Dreaded_X/lua_typed#d5d6fc1638bd108514899a792ee64335af50fc8b"
dependencies = [ dependencies = [
"convert_case", "convert_case",
"itertools", "itertools",

View File

@@ -71,7 +71,7 @@ pub fn create_module(lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
Ok(devices) Ok(devices)
} }
inventory::submit! {Module::new("automation:devices", create_module)} inventory::submit! {Module::new("devices", create_module)}
macro_rules! register_type { macro_rules! register_type {
($ty:ty) => { ($ty:ty) => {

View File

@@ -10,12 +10,6 @@ pub struct ActionCallback<P> {
_parameters: PhantomData<P>, _parameters: PhantomData<P>,
} }
impl Typed for ActionCallback<()> {
fn type_name() -> String {
"fun() | fun()[]".into()
}
}
impl<A: Typed> Typed for ActionCallback<A> { impl<A: Typed> Typed for ActionCallback<A> {
fn type_name() -> String { fn type_name() -> String {
let type_name = A::type_name(); let type_name = A::type_name();

View File

@@ -5,7 +5,7 @@ use lua_typed::Typed;
use rumqttc::{MqttOptions, Transport}; use rumqttc::{MqttOptions, Transport};
use serde::Deserialize; use serde::Deserialize;
#[derive(Debug, Clone, Deserialize, Typed)] #[derive(Debug, Clone, Deserialize)]
pub struct MqttConfig { pub struct MqttConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,

View File

@@ -4,8 +4,6 @@ use std::sync::Arc;
use futures::Future; use futures::Future;
use futures::future::join_all; use futures::future::join_all;
use lua_typed::Typed;
use mlua::FromLua;
use tokio::sync::{RwLock, RwLockReadGuard}; use tokio::sync::{RwLock, RwLockReadGuard};
use tokio_cron_scheduler::{Job, JobScheduler}; use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{debug, instrument, trace}; use tracing::{debug, instrument, trace};
@@ -15,7 +13,7 @@ use crate::event::{Event, EventChannel, OnMqtt};
pub type DeviceMap = HashMap<String, Box<dyn Device>>; pub type DeviceMap = HashMap<String, Box<dyn Device>>;
#[derive(Clone, FromLua)] #[derive(Clone)]
pub struct DeviceManager { pub struct DeviceManager {
devices: Arc<RwLock<DeviceMap>>, devices: Arc<RwLock<DeviceMap>>,
event_channel: EventChannel, event_channel: EventChannel,
@@ -144,9 +142,3 @@ impl mlua::UserData for DeviceManager {
methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel())) methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel()))
} }
} }
impl Typed for DeviceManager {
fn type_name() -> String {
"DeviceManager".into()
}
}

View File

@@ -41,7 +41,7 @@ pub fn load_modules(lua: &mlua::Lua) -> mlua::Result<()> {
for module in inventory::iter::<Module> { for module in inventory::iter::<Module> {
debug!(name = module.get_name(), "Loading module"); debug!(name = module.get_name(), "Loading module");
let table = module.register(lua)?; let table = module.register(lua)?;
lua.register_module(module.get_name(), table)?; lua.register_module(&format!("automation:{}", module.get_name()), table)?;
} }
Ok(()) Ok(())

View File

@@ -28,4 +28,4 @@ fn create_module(lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
Ok(utils) Ok(utils)
} }
inventory::submit! {Module::new("automation:utils", create_module)} inventory::submit! {Module::new("utils", create_module)}

View File

@@ -1,7 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use lua_typed::Typed;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::debug; use tracing::debug;
@@ -75,44 +74,3 @@ impl mlua::UserData for Timeout {
}); });
} }
} }
impl Typed for Timeout {
fn type_name() -> String {
"Timeout".into()
}
fn generate_header() -> Option<String> {
let type_name = Self::type_name();
Some(format!("---@class {type_name}\nlocal {type_name}\n"))
}
fn generate_members() -> Option<String> {
let mut output = String::new();
let type_name = Self::type_name();
output += &format!(
"---@async\n---@param timeout number\n---@param callback {}\nfunction {type_name}:start(timeout, callback) end\n",
ActionCallback::<()>::type_name()
);
output += &format!("---@async\nfunction {type_name}:cancel() end\n",);
output +=
&format!("---@async\n---@return boolean\nfunction {type_name}:is_waiting() end\n",);
Some(output)
}
fn generate_footer() -> Option<String> {
let mut output = String::new();
let type_name = Self::type_name();
output += &format!("utils.{type_name} = {{}}\n");
output += &format!("---@return {type_name}\n");
output += &format!("function utils.{type_name}.new() end\n");
Some(output)
}
}

View File

@@ -1,13 +1,10 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use lua_typed::Typed; use lua_typed::Typed;
use mlua::{FromLua, LuaSerdeExt}; use mlua::FromLua;
use rumqttc::{AsyncClient, Event, EventLoop, Incoming}; use rumqttc::{AsyncClient, Event, EventLoop, Incoming};
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::Module;
use crate::config::MqttConfig;
use crate::device_manager::DeviceManager;
use crate::event::{self, EventChannel}; use crate::event::{self, EventChannel};
#[derive(Debug, Clone, FromLua)] #[derive(Debug, Clone, FromLua)]
@@ -17,37 +14,6 @@ impl Typed for WrappedAsyncClient {
fn type_name() -> String { fn type_name() -> String {
"AsyncClient".into() "AsyncClient".into()
} }
fn generate_header() -> Option<String> {
let type_name = Self::type_name();
Some(format!("---@class {type_name}\nlocal {type_name}\n"))
}
fn generate_members() -> Option<String> {
let mut output = String::new();
let type_name = Self::type_name();
output += &format!(
"---@async\n---@param topic string\n---@param message table?\nfunction {type_name}:send_message(topic, message) end\n"
);
Some(output)
}
fn generate_footer() -> Option<String> {
let mut output = String::new();
let type_name = Self::type_name();
output += &format!("mqtt.{type_name} = {{}}\n");
output += &format!("---@param device_manager {}\n", DeviceManager::type_name());
output += &format!("---@param config {}\n", MqttConfig::type_name());
output += &format!("---@return {type_name}\n");
output += "function mqtt.new(device_manager, config) end\n";
Some(output)
}
} }
impl Deref for WrappedAsyncClient { impl Deref for WrappedAsyncClient {
@@ -111,25 +77,3 @@ pub fn start(mut eventloop: EventLoop, event_channel: &EventChannel) {
} }
}); });
} }
fn create_module(lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
let mqtt = lua.create_table()?;
let mqtt_new = lua.create_function(
move |lua, (device_manager, config): (DeviceManager, mlua::Value)| {
let event_channel = device_manager.event_channel();
let config: MqttConfig = lua.from_value(config)?;
// Create a mqtt client
// TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync
let (client, eventloop) = AsyncClient::new(config.into(), 100);
start(eventloop, &event_channel);
Ok(WrappedAsyncClient(client))
},
)?;
mqtt.set("new", mqtt_new)?;
Ok(mqtt)
}
inventory::submit! {Module::new("automation:mqtt", create_module)}

View File

@@ -21,7 +21,7 @@ local fulfillment = {
openid_url = "https://login.huizinga.dev/api/oidc", openid_url = "https://login.huizinga.dev/api/oidc",
} }
local mqtt_client = require("automation:mqtt").new(device_manager, { local mqtt_client = require("automation:mqtt").new({
host = ((host == "zeus" or host == "hephaestus") and "olympus.lan.huizinga.dev") or "mosquitto", host = ((host == "zeus" or host == "hephaestus") and "olympus.lan.huizinga.dev") or "mosquitto",
port = 8883, port = 8883,
client_name = "automation-" .. host, client_name = "automation-" .. host,

View File

@@ -9,8 +9,9 @@ use std::path::Path;
use std::process; use std::process;
use ::config::{Environment, File}; use ::config::{Environment, File};
use automation_lib::config::FulfillmentConfig; use automation_lib::config::{FulfillmentConfig, MqttConfig};
use automation_lib::device_manager::DeviceManager; use automation_lib::device_manager::DeviceManager;
use automation_lib::mqtt::{self, WrappedAsyncClient};
use axum::extract::{FromRef, State}; use axum::extract::{FromRef, State};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::post; use axum::routing::post;
@@ -18,6 +19,7 @@ use axum::{Json, Router};
use config::Config; use config::Config;
use google_home::{GoogleHome, Request, Response}; use google_home::{GoogleHome, Request, Response};
use mlua::LuaSerdeExt; use mlua::LuaSerdeExt;
use rumqttc::AsyncClient;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use web::{ApiError, User}; use web::{ApiError, User};
@@ -136,6 +138,21 @@ async fn app() -> anyhow::Result<()> {
automation_lib::load_modules(&lua)?; automation_lib::load_modules(&lua)?;
let mqtt = lua.create_table()?;
let event_channel = device_manager.event_channel();
let mqtt_new = lua.create_function(move |lua, config: mlua::Value| {
let config: MqttConfig = lua.from_value(config)?;
// Create a mqtt client
// TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync
let (client, eventloop) = AsyncClient::new(config.into(), 100);
mqtt::start(eventloop, &event_channel);
Ok(WrappedAsyncClient(client))
})?;
mqtt.set("new", mqtt_new)?;
lua.register_module("automation:mqtt", mqtt)?;
lua.register_module("automation:device_manager", device_manager.clone())?; lua.register_module("automation:device_manager", device_manager.clone())?;
lua.register_module("automation:variables", lua.to_value(&config.variables)?)?; lua.register_module("automation:variables", lua.to_value(&config.variables)?)?;