From 305d4115965e9c6a778e49223e21d6efe0d38a67 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Fri, 31 May 2024 23:53:38 +0200 Subject: [PATCH] WIP: Working on trait macro --- Cargo.lock | 16 +- automation_macro/Cargo.toml | 6 + automation_macro/src/lib.rs | 512 +++++++++++++++++++++++++++++++++- google-home/Cargo.toml | 1 + google-home/src/bin/expand.rs | 63 +++++ 5 files changed, 591 insertions(+), 7 deletions(-) create mode 100644 google-home/src/bin/expand.rs diff --git a/Cargo.lock b/Cargo.lock index 2fe3a6f..eac54ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,9 +73,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "async-trait" -version = "0.1.72" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -136,9 +136,12 @@ version = "0.1.0" name = "automation_macro" version = "0.1.0" dependencies = [ + "async-trait", + "automation_cast", "itertools 0.12.1", "proc-macro2", "quote", + "serde", "syn 2.0.60", ] @@ -642,6 +645,7 @@ dependencies = [ "anyhow", "async-trait", "automation_cast", + "automation_macro", "futures", "serde", "serde_json", @@ -1599,9 +1603,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.198" +version = "1.0.202" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395" dependencies = [ "serde_derive", ] @@ -1618,9 +1622,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.202" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838" dependencies = [ "proc-macro2", "quote", diff --git a/automation_macro/Cargo.toml b/automation_macro/Cargo.toml index 656ba92..cc6c9eb 100644 --- a/automation_macro/Cargo.toml +++ b/automation_macro/Cargo.toml @@ -7,7 +7,13 @@ edition = "2021" proc-macro = true [dependencies] +automation_cast = { path = "../automation_cast" } +async-trait = "0.1.80" itertools = "0.12.1" proc-macro2 = "1.0.81" quote = "1.0.36" +serde = { version = "1.0.202", features = ["derive"] } syn = { version = "2.0.60", features = ["extra-traits", "full"] } + +[dev-dependencies] +serde = { version = "1.0.202", features = ["derive"] } diff --git a/automation_macro/src/lib.rs b/automation_macro/src/lib.rs index 6e1654a..2128529 100644 --- a/automation_macro/src/lib.rs +++ b/automation_macro/src/lib.rs @@ -1,9 +1,19 @@ +#![feature(let_chains)] +#![feature(iter_intersperse)] mod lua_device; mod lua_device_config; use lua_device::impl_lua_device_macro; use lua_device_config::impl_lua_device_config_macro; -use syn::{parse_macro_input, DeriveInput}; +use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::{ + braced, parse_macro_input, DeriveInput, GenericArgument, Ident, LitStr, Path, PathArguments, + PathSegment, ReturnType, Signature, Token, Type, TypePath, +}; #[proc_macro_derive(LuaDevice, attributes(config))] pub fn lua_device_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -18,3 +28,503 @@ pub fn lua_device_config_derive(input: proc_macro::TokenStream) -> proc_macro::T impl_lua_device_config_macro(&ast).into() } + +mod kw { + use syn::custom_keyword; + + custom_keyword!(required); +} + +#[derive(Debug)] +struct FieldAttribute { + ident: Ident, + _colon_token: Token![:], + ty: Type, +} + +impl Parse for FieldAttribute { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self { + ident: input.parse()?, + _colon_token: input.parse()?, + ty: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct FieldState { + sign: Signature, + _fat_arrow_token: Token![=>], + ident: Ident, +} + +impl Parse for FieldState { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self { + sign: input.parse()?, + _fat_arrow_token: input.parse()?, + ident: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct FieldExecute { + name: LitStr, + _fat_arrow_token: Token![=>], + sign: Signature, +} + +impl Parse for FieldExecute { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self { + name: input.parse()?, + _fat_arrow_token: input.parse()?, + sign: input.parse()?, + }) + } +} + +#[derive(Debug)] +enum Field { + Attribute(FieldAttribute), + State(FieldState), + Execute(FieldExecute), +} + +impl Parse for Field { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + if input.peek(Ident) { + Ok(Field::Attribute(input.parse()?)) + } else if input.peek(LitStr) { + Ok(Field::Execute(input.parse()?)) + } else { + Ok(Field::State(input.parse()?)) + } + } +} + +#[derive(Debug)] +struct Trait { + name: LitStr, + _fat_arrow_token: Token![=>], + _trait_token: Token![trait], + ident: Ident, + _brace_token: Brace, + fields: Punctuated, +} + +impl Parse for Trait { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let content; + Ok(Self { + name: input.parse()?, + _fat_arrow_token: input.parse()?, + _trait_token: input.parse()?, + ident: input.parse()?, + _brace_token: braced!(content in input), + fields: content.parse_terminated(Field::parse, Token![,])?, + }) + } +} + +#[derive(Debug)] +struct Input { + ty: Type, + _brack_token: Brace, + traits: Punctuated, +} + +// TODO: Error on duplicate name? +impl Parse for Input { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let content; + Ok(Self { + ty: input.parse()?, + _brack_token: braced!(content in input), + traits: content.parse_terminated(Trait::parse, Token![,])?, + }) + } +} + +fn extract_type_path(ty: &syn::Type) -> Option<&Path> { + match *ty { + Type::Path(ref typepath) if typepath.qself.is_none() => Some(&typepath.path), + _ => None, + } +} + +fn extract_segment<'a>(path: &'a Path, options: &[&str]) -> Option<&'a PathSegment> { + let idents_of_path = path + .segments + .iter() + .map(|segment| segment.ident.to_string()) + .intersperse('|'.into()) + .collect::(); + + options + .iter() + .find(|s| &idents_of_path == *s) + .and_then(|_| path.segments.last()) +} + +// Based on: https://stackoverflow.com/a/56264023 +fn extract_type_from_option(ty: &syn::Type) -> Option<&syn::Type> { + extract_type_path(ty) + .and_then(|path| { + extract_segment(path, &["Option", "std|option|Option", "core|option|Option"]) + }) + .and_then(|path_seg| { + let type_params = &path_seg.arguments; + // It should have only on angle-bracketed param (""): + match *type_params { + PathArguments::AngleBracketed(ref params) => params.args.first(), + _ => None, + } + }) + .and_then(|generic_arg| match *generic_arg { + GenericArgument::Type(ref ty) => Some(ty), + _ => None, + }) +} + +fn extract_type_from_result(ty: &syn::Type) -> Option<&syn::Type> { + extract_type_path(ty) + .and_then(|path| { + extract_segment(path, &["Result", "std|result|Result", "core|result|Result"]) + }) + .and_then(|path_seg| { + let type_params = &path_seg.arguments; + // It should have only on angle-bracketed param (""): + match *type_params { + PathArguments::AngleBracketed(ref params) => params.args.first(), + _ => None, + } + }) + .and_then(|generic_arg| match *generic_arg { + GenericArgument::Type(ref ty) => Some(ty), + _ => None, + }) +} + +fn get_attributes_struct(traits: &Punctuated) -> proc_macro2::TokenStream { + let items = traits.iter().flat_map(|t| { + t.fields.iter().filter_map(|f| match f { + Field::Attribute(attr) => { + let ident = &attr.ident; + + let ty = &attr.ty; + let ty = extract_type_from_option(ty) + .map(ToTokens::into_token_stream) + .unwrap_or(ty.into_token_stream()); + + Some(quote! { + // #[serde(skip_serializing_if = "core::option::Option::is_none")] + #ident: Option<#ty> + }) + } + _ => None, + }) + }); + + quote! { + // #[derive(Debug, Default, serde::Serialize)] + // #[serde(rename_all = "camelCase")] + pub struct Attributes { + #(#items,)* + } + } +} + +fn get_state_struct(traits: &Punctuated) -> proc_macro2::TokenStream { + let items = traits.iter().flat_map(|t| { + t.fields.iter().filter_map(|f| match f { + Field::State(state) => { + let ident = &state.ident; + + let ReturnType::Type(_, ty) = &state.sign.output else { + return None; + }; + let ty = extract_type_from_result(ty) + .map(ToTokens::into_token_stream) + .unwrap_or(ty.into_token_stream()); + + Some(quote! {#ident: Option<#ty>}) + } + _ => None, + }) + }); + + quote! { + // #[derive(Debug, Default, serde::Serialize)] + // #[serde(rename_all = "camelCase")] + pub struct State { + #(#items,)* + } + } +} + +fn get_command_enum(traits: &Punctuated) -> proc_macro2::TokenStream { + let items = traits.iter().flat_map(|t| { + t.fields.iter().filter_map(|f| match f { + Field::Execute(execute) => { + let name = execute.name.value(); + let ident = Ident::new( + name.split_at(name.rfind('.').map(|v| v + 1).unwrap_or(0)).1, + execute.name.span(), + ); + + let parameters = execute.sign.inputs.iter().skip(1); + + Some(quote! { + // #[serde(rename = #name, rename_all = "camelCase")] + #ident { + #(#parameters,)* + } + }) + } + _ => None, + }) + }); + + quote! { + // #[derive(Debug, serde::Deserialize)] + // #[serde(tag = "command", content = "params", rename_all = "camelCase")] + pub enum Command { + #(#items,)* + } + } +} + +fn get_trait_enum(traits: &Punctuated) -> proc_macro2::TokenStream { + let items = traits.iter().map(|t| { + let name = &t.name; + let ident = &t.ident; + quote! { + // #[serde(rename = #name)] + #ident + } + }); + + quote! { + // #[derive(Debug, serde::Serialize)] + pub enum Trait { + #(#items,)* + } + } +} + +#[proc_macro] +pub fn google_home_traits(item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as Input); + let traits = input.traits; + + let attributes_struct = get_attributes_struct(&traits); + let state_struct = get_state_struct(&traits); + let command_enum = get_command_enum(&traits); + let trait_enum = get_trait_enum(&traits); + + let sync = traits.iter().map(|t| { + let ident = &t.ident; + + let attr = t.fields.iter().filter_map(|f| match f { + Field::Attribute(attr) => { + let name = &attr.ident; + + let rhs = if extract_type_from_option(&attr.ty).is_some() { + quote! { + t.#name(); + } + } else { + quote! { + Some(t.#name()); + } + }; + + Some(quote! { + attrs.#name = #rhs; + }) + } + _ => None, + }); + + quote! { + if let Some(t) = device.cast() as Option<&dyn #ident> { + traits.push(Trait::#ident); + #(#attr)* + } + } + }); + + let query = traits.iter().map(|t| { + let ident = &t.ident; + + let attr = t.fields.iter().filter_map(|f| match f { + Field::State(state) => { + let ident = &state.ident; + let f_ident = &state.sign.ident; + + let asyncness = if state.sign.asyncness.is_some() { + quote! {.await} + } else { + quote! {} + }; + + let errors = if let ReturnType::Type(_, ty) = &state.sign.output + && let Type::Path(TypePath { path, .. }) = ty.as_ref() + && let Some(PathSegment { ident, .. }) = path.segments.first() + && ident == "Result" + { + quote! {?} + } else { + quote! {} + }; + + Some(quote! { + state.#ident = t.#f_ident() #asyncness #errors; + }) + } + _ => None, + }); + + quote! { + if let Some(t) = device.cast() as Option<&dyn #ident> { + #(#attr)* + } + } + }); + + let execute = traits.iter().flat_map(|t| { + t.fields.iter().filter_map(|f| match f { + Field::Execute(execute) => { + let ident = &t.ident; + let name = execute.name.value(); + let command_name = Ident::new( + name.split_at(name.rfind('.').map(|v| v + 1).unwrap_or(0)).1, + execute.name.span(), + ); + let f_name = &&execute.sign.ident; + let parameters = execute + .sign + .inputs + .iter() + .filter_map(|p| { + if let syn::FnArg::Typed(p) = p { + Some(&p.pat) + } else { + None + } + }) + .collect::>(); + + let asyncness = if execute.sign.asyncness.is_some() { + quote! {.await} + } else { + quote! {} + }; + + let errors = if let ReturnType::Type(_, ty) = &execute.sign.output + && let Type::Path(TypePath { path, .. }) = ty.as_ref() + && let Some(PathSegment { ident, .. }) = path.segments.first() + && ident == "Result" + { + quote! {?} + } else { + quote! {} + }; + + Some(quote! { + Command::#command_name {#(#parameters,)*} => { + if let Some(t) = device.cast() as Option<&dyn #ident> { + t.#f_name(#(#parameters,)*) #asyncness #errors; + } else { + todo!("Device does not support action"); + } + } + }) + } + _ => None, + }) + }); + + let traits = traits.iter().map(|t| { + let fields = t.fields.iter().map(|f| match f { + Field::Attribute(attr) => { + let name = &attr.ident; + let ty = &attr.ty; + + if let Some(ty) = extract_type_from_option(ty) { + quote! { + fn #name(&self) -> #ty { + None + } + } + } else { + quote! { + fn #name(&self) -> #ty; + } + } + } + Field::State(state) => { + let sign = &state.sign; + + quote! { + #sign; + } + } + Field::Execute(execute) => { + let sign = &execute.sign; + quote! { + #sign; + } + } + }); + + let ident = &t.ident; + + quote! { + #[async_trait::async_trait] + pub trait #ident: Sync + Send { + #(#fields)* + } + } + }); + + let ty = input.ty; + + quote! { + #attributes_struct + #state_struct + #command_enum + #trait_enum + + #(#traits)* + + async fn sync(device: &dyn #ty) -> Result<(Vec, Attributes), Box> { + let mut traits = Vec::new(); + let mut attrs = Attributes::default(); + + #(#sync)* + + Ok((traits, attrs)) + } + + async fn query(device: &dyn #ty) -> Result> { + let mut state = State::default(); + + #(#query)* + + Ok(state) + } + + async fn execute(device: &dyn #ty, command: Command) -> Result<(), Box> { + match command { + #(#execute)* + } + + Ok(()) + } + } + .into() +} diff --git a/google-home/Cargo.toml b/google-home/Cargo.toml index f2385c6..32222d6 100644 --- a/google-home/Cargo.toml +++ b/google-home/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] automation_cast = { path = "../automation_cast/" } +automation_macro = { path = "../automation_macro/" } serde = { version = "1.0.149", features = ["derive"] } serde_json = "1.0.89" thiserror = "1.0.37" diff --git a/google-home/src/bin/expand.rs b/google-home/src/bin/expand.rs new file mode 100644 index 0000000..855ccf4 --- /dev/null +++ b/google-home/src/bin/expand.rs @@ -0,0 +1,63 @@ +use automation_cast::Cast; +use automation_macro::google_home_traits; +use google_home::errors::ErrorCode; +use google_home::traits::AvailableSpeeds; + +google_home_traits! { + GoogleHomeDevice { + "action.devices.traits.OnOff" => trait OnOff { + command_only_on_off: bool, + // This one is optional + query_only_on_off: Option, + + async fn is_on(&self) -> Result => on, + + "action.devices.commands.OnOff" => async fn set_on(&self, on: bool) -> Result<(), ErrorCode>, + }, + "action.devices.traits.Scene" => trait Scene { + scene_reversible: Option, + + "action.devices.commands.ActivateScene" => async fn set_active(&self, activate: bool) -> Result<(), ErrorCode>, + }, + "action.devices.traits.FanSpeed" => trait FanSpeed { + reversible: Option, + command_only_fan_speed: Option, + available_fan_speeds: AvailableSpeeds, + + fn get_fan_speed(&self) -> Result => current_fan_speed_setting, + + "action.devices.commands.SetFanSpeed" => fn set_speed(&self, fan_speed: String), + }, + "action.devices.traits.HumiditySetting" => trait HumiditySetting { + query_only_humidity_setting: Option, + + fn get_humidity(&self) -> Result => humidity_ambient_percent, + } + } +} + +trait Casts: + Cast + Cast + Cast + Cast +{ +} + +trait GoogleHomeDevice: Casts {} + +struct Device {} + +#[async_trait::async_trait] +impl OnOff for Device { + fn command_only_on_off(&self) -> bool { + false + } + + async fn is_on(&self) -> Result { + Ok(true) + } + + async fn set_on(&self, _on: bool) -> Result<(), ErrorCode> { + Ok(()) + } +} + +fn main() {}