feat!: Expanded add_methods to extra_user_data

Instead of being a function it now expects a struct with the
PartialUserData trait implemented. This in part ensures the correct
function signature.

It also adds another optional function to PartialUserData that returns
definitions for the added methods.
This commit is contained in:
2025-10-15 00:44:06 +02:00
parent 4b76bde2a6
commit cd470cadaf
5 changed files with 86 additions and 47 deletions

View File

@@ -3,18 +3,21 @@ use std::net::SocketAddr;
use async_trait::async_trait;
use automation_lib::device::{Device, LuaDeviceCreate};
use automation_lib::lua::traits::PartialUserData;
use automation_macro::{Device, LuaDeviceConfig};
use lua_typed::Typed;
use mlua::LuaSerdeExt;
use serde::{Deserialize, Serialize};
use tracing::{error, trace, warn};
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Typed)]
#[serde(rename_all = "snake_case")]
#[typed(rename_all = "snake_case")]
pub enum Flag {
Presence,
Darkness,
}
crate::register_type!(Flag);
#[derive(Debug, Clone, Deserialize, Typed)]
pub struct FlagIDs {
@@ -36,12 +39,36 @@ pub struct Config {
crate::register_type!(Config);
#[derive(Debug, Clone, Device)]
#[device(add_methods = Self::add_methods)]
#[device(extra_user_data = SetFlag)]
pub struct HueBridge {
config: Config,
}
crate::register_device!(HueBridge);
struct SetFlag;
impl PartialUserData<HueBridge> for SetFlag {
fn add_methods<M: mlua::UserDataMethods<HueBridge>>(methods: &mut M) {
methods.add_async_method(
"set_flag",
async |lua, this, (flag, value): (mlua::Value, bool)| {
let flag: Flag = lua.from_value(flag)?;
this.set_flag(flag, value).await;
Ok(())
},
);
}
fn definitions() -> Option<String> {
Some(format!(
"---@async\n---@param flag {}\n---@param value boolean\nfunction {}:set_flag(flag, value) end\n",
<Flag as Typed>::type_name(),
<HueBridge as Typed>::type_name(),
))
}
}
#[derive(Debug, Serialize)]
struct FlagMessage {
flag: bool,
@@ -89,19 +116,6 @@ impl HueBridge {
}
}
}
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_async_method(
"set_flag",
async |lua, this, (flag, value): (mlua::Value, bool)| {
let flag: Flag = lua.from_value(flag)?;
this.set_flag(flag, value).await;
Ok(())
},
);
}
}
impl Device for HueBridge {

View File

@@ -3,6 +3,7 @@ use std::convert::Infallible;
use async_trait::async_trait;
use automation_lib::device::{Device, LuaDeviceCreate};
use automation_lib::lua::traits::PartialUserData;
use automation_macro::{Device, LuaDeviceConfig};
use lua_typed::Typed;
use mlua::LuaSerdeExt;
@@ -90,14 +91,15 @@ pub struct Config {
crate::register_type!(Config);
#[derive(Debug, Clone, Device)]
#[device(add_methods = Self::add_methods)]
#[device(extra_user_data = SendNotification)]
pub struct Ntfy {
config: Config,
}
crate::register_device!(Ntfy);
impl Ntfy {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
struct SendNotification;
impl PartialUserData<Ntfy> for SendNotification {
fn add_methods<M: mlua::UserDataMethods<Ntfy>>(methods: &mut M) {
methods.add_async_method(
"send_notification",
async |lua, this, notification: mlua::Value| {
@@ -109,6 +111,14 @@ impl Ntfy {
},
);
}
fn definitions() -> Option<String> {
Some(format!(
"---@async\n---@param notification {}\nfunction {}:send_notification(notification) end\n",
<Notification as Typed>::type_name(),
<Ntfy as Typed>::type_name(),
))
}
}
#[async_trait]

View File

@@ -6,6 +6,7 @@ use automation_lib::action_callback::ActionCallback;
use automation_lib::config::MqttDeviceConfig;
use automation_lib::device::{Device, LuaDeviceCreate};
use automation_lib::event::OnMqtt;
use automation_lib::lua::traits::PartialUserData;
use automation_lib::messages::PresenceMessage;
use automation_lib::mqtt::WrappedAsyncClient;
use automation_macro::{Device, LuaDeviceConfig};
@@ -39,13 +40,29 @@ pub struct State {
}
#[derive(Debug, Clone, Device)]
#[device(add_methods = Self::add_methods)]
#[device(extra_user_data = OverallPresence)]
pub struct Presence {
config: Config,
state: Arc<RwLock<State>>,
}
crate::register_device!(Presence);
struct OverallPresence;
impl PartialUserData<Presence> for OverallPresence {
fn add_methods<M: mlua::UserDataMethods<Presence>>(methods: &mut M) {
methods.add_async_method("overall_presence", async |_lua, this, ()| {
Ok(this.state().await.current_overall_presence)
});
}
fn definitions() -> Option<String> {
Some(format!(
"---@async\n---@return boolean\nfunction {}:overall_presence() end\n",
<Presence as Typed>::type_name(),
))
}
}
impl Presence {
async fn state(&self) -> RwLockReadGuard<'_, State> {
self.state.read().await
@@ -54,12 +71,6 @@ impl Presence {
async fn state_mut(&self) -> RwLockWriteGuard<'_, State> {
self.state.write().await
}
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_async_method("overall_presence", async |_lua, this, ()| {
Ok(this.state().await.current_overall_presence)
});
}
}
#[async_trait]

View File

@@ -8,6 +8,10 @@ pub trait PartialUserData<T> {
fn interface_name() -> Option<&'static str> {
None
}
fn definitions() -> Option<String> {
None
}
}
pub struct Device;

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use proc_macro2::TokenStream as TokenStream2;
use quote::{ToTokens, quote};
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
@@ -9,7 +9,7 @@ use syn::{Attribute, DeriveInput, Token, parenthesized};
enum Attr {
Trait(TraitAttr),
AddMethods(AddMethodsAttr),
ExtraUserData(ExtraUserDataAttr),
}
impl Attr {
@@ -20,9 +20,9 @@ impl Attr {
let input;
_ = parenthesized!(input in meta.input);
parsed = Some(Attr::Trait(input.parse()?));
} else if meta.path.is_ident("add_methods") {
} else if meta.path.is_ident("extra_user_data") {
let value = meta.value()?;
parsed = Some(Attr::AddMethods(value.parse()?));
parsed = Some(Attr::ExtraUserData(value.parse()?));
} else {
return Err(syn::Error::new(meta.path.span(), "Unknown attribute"));
}
@@ -95,28 +95,18 @@ impl Parse for Aliases {
}
#[derive(Clone)]
struct AddMethodsAttr(syn::Path);
struct ExtraUserDataAttr(syn::Ident);
impl Parse for AddMethodsAttr {
impl Parse for ExtraUserDataAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Self(input.parse()?))
}
}
impl ToTokens for AddMethodsAttr {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let Self(path) = self;
tokens.extend(quote! {
#path
});
}
}
struct Implementation {
name: syn::Ident,
traits: Traits,
add_methods: Vec<AddMethodsAttr>,
extra_user_data: Vec<ExtraUserDataAttr>,
}
impl quote::ToTokens for Implementation {
@@ -124,9 +114,10 @@ impl quote::ToTokens for Implementation {
let Self {
name,
traits,
add_methods,
extra_user_data,
} = &self;
let Traits(traits) = traits;
let extra_user_data: Vec<_> = extra_user_data.iter().map(|tr| tr.0.clone()).collect();
tokens.extend(quote! {
impl mlua::UserData for #name {
@@ -151,7 +142,7 @@ impl quote::ToTokens for Implementation {
)*
#(
#add_methods(methods);
<#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::add_methods(methods);
)*
}
}
@@ -178,7 +169,7 @@ impl quote::ToTokens for Implementation {
format!(": {interfaces}")
};
Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}"))
Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}\n"))
}
fn generate_members() -> Option<String> {
@@ -191,6 +182,15 @@ impl quote::ToTokens for Implementation {
output += &format!("---@return {type_name}\n");
output += &format!("function devices.{type_name}.new(config) end\n");
output += &<::automation_lib::lua::traits::Device as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
#(
output += &<::automation_lib::lua::traits::#traits as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
)*
#(
output += &<#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into());
)*
Some(output)
}
@@ -220,7 +220,7 @@ impl Implementations {
all.extend(&attribute.traits);
}
}
Attr::AddMethods(attribute) => add_methods.push(attribute),
Attr::ExtraUserData(attribute) => add_methods.push(attribute),
}
}
@@ -238,7 +238,7 @@ impl Implementations {
.map(|(alias, traits)| Implementation {
name: alias.unwrap_or(name.clone()),
traits,
add_methods: add_methods.clone(),
extra_user_data: add_methods.clone(),
})
.collect(),
)