diff --git a/automation_devices/src/hue_bridge.rs b/automation_devices/src/hue_bridge.rs index b08ab51..0c548c8 100644 --- a/automation_devices/src/hue_bridge.rs +++ b/automation_devices/src/hue_bridge.rs @@ -3,18 +3,21 @@ use std::net::SocketAddr; use async_trait::async_trait; use automation_lib::device::{Device, LuaDeviceCreate}; +use automation_lib::lua::traits::PartialUserData; use automation_macro::{Device, LuaDeviceConfig}; use lua_typed::Typed; use mlua::LuaSerdeExt; use serde::{Deserialize, Serialize}; use tracing::{error, trace, warn}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Typed)] #[serde(rename_all = "snake_case")] +#[typed(rename_all = "snake_case")] pub enum Flag { Presence, Darkness, } +crate::register_type!(Flag); #[derive(Debug, Clone, Deserialize, Typed)] pub struct FlagIDs { @@ -36,12 +39,36 @@ pub struct Config { crate::register_type!(Config); #[derive(Debug, Clone, Device)] -#[device(add_methods = Self::add_methods)] +#[device(extra_user_data = SetFlag)] pub struct HueBridge { config: Config, } crate::register_device!(HueBridge); +struct SetFlag; +impl PartialUserData for SetFlag { + fn add_methods>(methods: &mut M) { + methods.add_async_method( + "set_flag", + async |lua, this, (flag, value): (mlua::Value, bool)| { + let flag: Flag = lua.from_value(flag)?; + + this.set_flag(flag, value).await; + + Ok(()) + }, + ); + } + + fn definitions() -> Option { + Some(format!( + "---@async\n---@param flag {}\n---@param value boolean\nfunction {}:set_flag(flag, value) end\n", + ::type_name(), + ::type_name(), + )) + } +} + #[derive(Debug, Serialize)] struct FlagMessage { flag: bool, @@ -89,19 +116,6 @@ impl HueBridge { } } } - - fn add_methods>(methods: &mut M) { - methods.add_async_method( - "set_flag", - async |lua, this, (flag, value): (mlua::Value, bool)| { - let flag: Flag = lua.from_value(flag)?; - - this.set_flag(flag, value).await; - - Ok(()) - }, - ); - } } impl Device for HueBridge { diff --git a/automation_devices/src/ntfy.rs b/automation_devices/src/ntfy.rs index e873f2b..d41df47 100644 --- a/automation_devices/src/ntfy.rs +++ b/automation_devices/src/ntfy.rs @@ -3,6 +3,7 @@ use std::convert::Infallible; use async_trait::async_trait; use automation_lib::device::{Device, LuaDeviceCreate}; +use automation_lib::lua::traits::PartialUserData; use automation_macro::{Device, LuaDeviceConfig}; use lua_typed::Typed; use mlua::LuaSerdeExt; @@ -90,14 +91,15 @@ pub struct Config { crate::register_type!(Config); #[derive(Debug, Clone, Device)] -#[device(add_methods = Self::add_methods)] +#[device(extra_user_data = SendNotification)] pub struct Ntfy { config: Config, } crate::register_device!(Ntfy); -impl Ntfy { - fn add_methods>(methods: &mut M) { +struct SendNotification; +impl PartialUserData for SendNotification { + fn add_methods>(methods: &mut M) { methods.add_async_method( "send_notification", async |lua, this, notification: mlua::Value| { @@ -109,6 +111,14 @@ impl Ntfy { }, ); } + + fn definitions() -> Option { + Some(format!( + "---@async\n---@param notification {}\nfunction {}:send_notification(notification) end\n", + ::type_name(), + ::type_name(), + )) + } } #[async_trait] diff --git a/automation_devices/src/presence.rs b/automation_devices/src/presence.rs index 72391ab..a77327c 100644 --- a/automation_devices/src/presence.rs +++ b/automation_devices/src/presence.rs @@ -6,6 +6,7 @@ use automation_lib::action_callback::ActionCallback; use automation_lib::config::MqttDeviceConfig; use automation_lib::device::{Device, LuaDeviceCreate}; use automation_lib::event::OnMqtt; +use automation_lib::lua::traits::PartialUserData; use automation_lib::messages::PresenceMessage; use automation_lib::mqtt::WrappedAsyncClient; use automation_macro::{Device, LuaDeviceConfig}; @@ -39,13 +40,29 @@ pub struct State { } #[derive(Debug, Clone, Device)] -#[device(add_methods = Self::add_methods)] +#[device(extra_user_data = OverallPresence)] pub struct Presence { config: Config, state: Arc>, } crate::register_device!(Presence); +struct OverallPresence; +impl PartialUserData for OverallPresence { + fn add_methods>(methods: &mut M) { + methods.add_async_method("overall_presence", async |_lua, this, ()| { + Ok(this.state().await.current_overall_presence) + }); + } + + fn definitions() -> Option { + Some(format!( + "---@async\n---@return boolean\nfunction {}:overall_presence() end\n", + ::type_name(), + )) + } +} + impl Presence { async fn state(&self) -> RwLockReadGuard<'_, State> { self.state.read().await @@ -54,12 +71,6 @@ impl Presence { async fn state_mut(&self) -> RwLockWriteGuard<'_, State> { self.state.write().await } - - fn add_methods>(methods: &mut M) { - methods.add_async_method("overall_presence", async |_lua, this, ()| { - Ok(this.state().await.current_overall_presence) - }); - } } #[async_trait] diff --git a/automation_lib/src/lua/traits.rs b/automation_lib/src/lua/traits.rs index 6f61841..adae7df 100644 --- a/automation_lib/src/lua/traits.rs +++ b/automation_lib/src/lua/traits.rs @@ -8,6 +8,10 @@ pub trait PartialUserData { fn interface_name() -> Option<&'static str> { None } + + fn definitions() -> Option { + None + } } pub struct Device; diff --git a/automation_macro/src/device.rs b/automation_macro/src/device.rs index 874765f..d66e0bd 100644 --- a/automation_macro/src/device.rs +++ b/automation_macro/src/device.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use proc_macro2::TokenStream as TokenStream2; -use quote::{ToTokens, quote}; +use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; @@ -9,7 +9,7 @@ use syn::{Attribute, DeriveInput, Token, parenthesized}; enum Attr { Trait(TraitAttr), - AddMethods(AddMethodsAttr), + ExtraUserData(ExtraUserDataAttr), } impl Attr { @@ -20,9 +20,9 @@ impl Attr { let input; _ = parenthesized!(input in meta.input); parsed = Some(Attr::Trait(input.parse()?)); - } else if meta.path.is_ident("add_methods") { + } else if meta.path.is_ident("extra_user_data") { let value = meta.value()?; - parsed = Some(Attr::AddMethods(value.parse()?)); + parsed = Some(Attr::ExtraUserData(value.parse()?)); } else { return Err(syn::Error::new(meta.path.span(), "Unknown attribute")); } @@ -95,28 +95,18 @@ impl Parse for Aliases { } #[derive(Clone)] -struct AddMethodsAttr(syn::Path); +struct ExtraUserDataAttr(syn::Ident); -impl Parse for AddMethodsAttr { +impl Parse for ExtraUserDataAttr { fn parse(input: ParseStream) -> syn::Result { Ok(Self(input.parse()?)) } } -impl ToTokens for AddMethodsAttr { - fn to_tokens(&self, tokens: &mut TokenStream2) { - let Self(path) = self; - - tokens.extend(quote! { - #path - }); - } -} - struct Implementation { name: syn::Ident, traits: Traits, - add_methods: Vec, + extra_user_data: Vec, } impl quote::ToTokens for Implementation { @@ -124,9 +114,10 @@ impl quote::ToTokens for Implementation { let Self { name, traits, - add_methods, + extra_user_data, } = &self; let Traits(traits) = traits; + let extra_user_data: Vec<_> = extra_user_data.iter().map(|tr| tr.0.clone()).collect(); tokens.extend(quote! { impl mlua::UserData for #name { @@ -151,7 +142,7 @@ impl quote::ToTokens for Implementation { )* #( - #add_methods(methods); + <#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::add_methods(methods); )* } } @@ -178,7 +169,7 @@ impl quote::ToTokens for Implementation { format!(": {interfaces}") }; - Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}")) + Some(format!("---@class {type_name}{interfaces}\nlocal {type_name}\n")) } fn generate_members() -> Option { @@ -191,6 +182,15 @@ impl quote::ToTokens for Implementation { output += &format!("---@return {type_name}\n"); output += &format!("function devices.{type_name}.new(config) end\n"); + output += &<::automation_lib::lua::traits::Device as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into()); + + #( + output += &<::automation_lib::lua::traits::#traits as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into()); + )* + #( + output += &<#extra_user_data as ::automation_lib::lua::traits::PartialUserData<#name>>::definitions().unwrap_or("".into()); + )* + Some(output) } @@ -220,7 +220,7 @@ impl Implementations { all.extend(&attribute.traits); } } - Attr::AddMethods(attribute) => add_methods.push(attribute), + Attr::ExtraUserData(attribute) => add_methods.push(attribute), } } @@ -238,7 +238,7 @@ impl Implementations { .map(|(alias, traits)| Implementation { name: alias.unwrap_or(name.clone()), traits, - add_methods: add_methods.clone(), + extra_user_data: add_methods.clone(), }) .collect(), )