From 6d2dbd37f147ae24daa84aaad7974b6912cdaf84 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Fri, 26 Apr 2024 05:18:46 +0200 Subject: [PATCH] Slight macro cleanup --- automation_macro/src/lib.rs | 320 +--------------------- automation_macro/src/lua_device.rs | 46 ++++ automation_macro/src/lua_device_config.rs | 280 +++++++++++++++++++ 3 files changed, 332 insertions(+), 314 deletions(-) create mode 100644 automation_macro/src/lua_device.rs create mode 100644 automation_macro/src/lua_device_config.rs diff --git a/automation_macro/src/lib.rs b/automation_macro/src/lib.rs index 2caa40b..6e1654a 100644 --- a/automation_macro/src/lib.rs +++ b/automation_macro/src/lib.rs @@ -1,11 +1,9 @@ -use itertools::Itertools; -use proc_macro2::TokenStream; -use quote::{quote, quote_spanned}; -use syn::parse::{Parse, ParseStream}; -use syn::punctuated::Punctuated; -use syn::spanned::Spanned; -use syn::token::Paren; -use syn::{parenthesized, parse_macro_input, DeriveInput, Expr, LitStr, Result, Token}; +mod lua_device; +mod lua_device_config; + +use lua_device::impl_lua_device_macro; +use lua_device_config::impl_lua_device_config_macro; +use syn::{parse_macro_input, DeriveInput}; #[proc_macro_derive(LuaDevice, attributes(config))] pub fn lua_device_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -14,315 +12,9 @@ pub fn lua_device_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStr impl_lua_device_macro(&ast).into() } -fn impl_lua_device_macro(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - // TODO: Handle errors properly - // This includes making sure one, and only one config is specified - let config = if let syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }), - .. - }) = ast.data - { - named - .iter() - .find(|&field| { - field - .attrs - .iter() - .any(|attr| attr.path().is_ident("config")) - }) - .map(|field| field.ty.clone()) - .unwrap() - } else { - unimplemented!() - }; - - let gen = quote! { - impl #name { - pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> { - lua.globals().set(stringify!(#name), lua.create_proxy::<#name>()?) - } - } - impl mlua::UserData for #name { - fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_function("new", |lua, config: mlua::Value| { - let config: #config = mlua::FromLua::from_lua(config, lua)?; - let config: Box = Box::new(config); - Ok(config) - }); - } - } - }; - - gen -} - -mod kw { - syn::custom_keyword!(device_config); - syn::custom_keyword!(flatten); - syn::custom_keyword!(from_lua); - syn::custom_keyword!(rename); - syn::custom_keyword!(with); - syn::custom_keyword!(from); - syn::custom_keyword!(default); -} - -#[derive(Debug)] -enum Argument { - Flatten { - _keyword: kw::flatten, - }, - FromLua { - _keyword: kw::from_lua, - }, - Rename { - _keyword: kw::rename, - _paren: Paren, - ident: LitStr, - }, - With { - _keyword: kw::with, - _paren: Paren, - // TODO: Ideally we capture this better - expr: Expr, - }, - From { - _keyword: kw::from, - _paren: Paren, - ty: syn::Type, - }, - Default { - _keyword: kw::default, - }, - DefaultExpr { - _keyword: kw::default, - _paren: Paren, - expr: Expr, - }, -} - -impl Parse for Argument { - fn parse(input: ParseStream) -> Result { - let lookahead = input.lookahead1(); - if lookahead.peek(kw::flatten) { - Ok(Self::Flatten { - _keyword: input.parse()?, - }) - } else if lookahead.peek(kw::from_lua) { - Ok(Self::FromLua { - _keyword: input.parse()?, - }) - } else if lookahead.peek(kw::rename) { - let content; - Ok(Self::Rename { - _keyword: input.parse()?, - _paren: parenthesized!(content in input), - ident: content.parse()?, - }) - } else if lookahead.peek(kw::with) { - let content; - Ok(Self::With { - _keyword: input.parse()?, - _paren: parenthesized!(content in input), - expr: content.parse()?, - }) - } else if lookahead.peek(kw::from) { - let content; - Ok(Self::From { - _keyword: input.parse()?, - _paren: parenthesized!(content in input), - ty: content.parse()?, - }) - } else if lookahead.peek(kw::default) { - let keyword = input.parse()?; - if input.peek(Paren) { - let content; - Ok(Self::DefaultExpr { - _keyword: keyword, - _paren: parenthesized!(content in input), - expr: content.parse()?, - }) - } else { - Ok(Self::Default { _keyword: keyword }) - } - } else { - Err(lookahead.error()) - } - } -} - -#[derive(Debug)] -struct Args { - args: Punctuated, -} - -impl Parse for Args { - fn parse(input: ParseStream) -> Result { - Ok(Self { - args: input.parse_terminated(Argument::parse, Token![,])?, - }) - } -} - #[proc_macro_derive(LuaDeviceConfig, attributes(device_config))] pub fn lua_device_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = parse_macro_input!(input as DeriveInput); impl_lua_device_config_macro(&ast).into() } - -fn impl_lua_device_config_macro(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - let fields = if let syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }), - .. - }) = ast.data - { - named - } else { - return quote_spanned! {ast.span() => compile_error!("This macro only works on named structs")}; - }; - - let field_names: Vec<_> = fields - .iter() - .map(|field| field.ident.clone().unwrap()) - .collect(); - - let fields: Vec<_> = fields - .iter() - .map(|field| { - let field_name = field.ident.clone().unwrap(); - let (args, errors): (Vec<_>, Vec<_>) = field - .attrs - .iter() - .filter_map(|attr| { - if attr.path().is_ident("device_config") { - Some(attr.parse_args::().map(|args| args.args)) - } else { - None - } - }) - .partition_result(); - - let errors: Vec<_> = errors - .iter() - .map(|error| error.to_compile_error()) - .collect(); - - if !errors.is_empty() { - return quote! { #(#errors)* }; - } - - let args: Vec<_> = args.into_iter().flatten().collect(); - - let table_name = match args - .iter() - .filter_map(|arg| match arg { - Argument::Rename { ident, .. } => Some(ident.value()), - _ => None, - }) - .collect::>() - .as_slice() - { - [] => field_name.to_string(), - [rename] => rename.to_owned(), - _ => return quote_spanned! {field.span() => compile_error!("Field contains duplicate 'rename'")}, - }; - - // TODO: Detect Option<_> properly and use Default::default() as fallback automatically - let missing = format!("Missing field '{table_name}'"); - let default = match args - .iter() - .filter_map(|arg| match arg { - Argument::Default { .. } => Some(quote! { Default::default() }), - Argument::DefaultExpr { expr, .. } => Some(quote! { (#expr) }), - _ => None, - }) - .collect::>() - .as_slice() - { - [] => quote! {panic!(#missing)}, - [default] => default.to_owned(), - _ => return quote_spanned! {field.span() => compile_error!("Field contains duplicate 'default'")}, - }; - - - let value = match args - .iter() - .filter_map(|arg| match arg { - Argument::Flatten { .. } => Some(quote! { - mlua::LuaSerdeExt::from_value_with(lua, value.clone(), mlua::DeserializeOptions::new().deny_unsupported_types(false))? - }), - Argument::FromLua { .. } => Some(quote! { - if table.contains_key(#table_name)? { - table.get(#table_name)? - } else { - #default - } - }), - _ => None, - }) - .collect::>() - .as_slice() { - [] => quote! { - { - let #field_name: mlua::Value = table.get(#table_name)?; - if !#field_name.is_nil() { - mlua::LuaSerdeExt::from_value(lua, #field_name)? - } else { - #default - } - } - }, - [value] => value.to_owned(), - _ => return quote_spanned! {field.span() => compile_error!("Only one of either 'flatten' or 'from_lua' is allowed")}, - }; - - let value = match args - .iter() - .filter_map(|arg| match arg { - Argument::From { ty, .. } => Some(quote! { - { - let temp: #ty = #value; - temp.into() - } - }), - Argument::With { expr, .. } => Some(quote! { - { - let temp = #value; - (#expr)(temp) - } - }), - _ => None, - }) - .collect::>() - .as_slice() { - [] => value, - [value] => value.to_owned(), - _ => return quote_spanned! {field.span() => compile_error!("Field contains duplicate 'as'")}, - }; - - quote! { #value } - }) - .zip(field_names) - .map(|(value, name)| quote! { #name: #value }) - .collect(); - - let gen = quote! { - impl<'lua> mlua::FromLua<'lua> for #name { - fn from_lua(value: mlua::Value<'lua>, lua: &'lua mlua::Lua) -> mlua::Result { - if !value.is_table() { - panic!("Expected table"); - } - let table = value.as_table().unwrap(); - - Ok(#name { - #(#fields,)* - }) - - } - } - }; - - gen -} diff --git a/automation_macro/src/lua_device.rs b/automation_macro/src/lua_device.rs new file mode 100644 index 0000000..6d9ce59 --- /dev/null +++ b/automation_macro/src/lua_device.rs @@ -0,0 +1,46 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed}; + +pub fn impl_lua_device_macro(ast: &DeriveInput) -> TokenStream { + let name = &ast.ident; + // TODO: Handle errors properly + // This includes making sure one, and only one config is specified + let config = if let Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { ref named, .. }), + .. + }) = ast.data + { + named + .iter() + .find(|&field| { + field + .attrs + .iter() + .any(|attr| attr.path().is_ident("config")) + }) + .map(|field| field.ty.clone()) + .unwrap() + } else { + unimplemented!() + }; + + let gen = quote! { + impl #name { + pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> { + lua.globals().set(stringify!(#name), lua.create_proxy::<#name>()?) + } + } + impl mlua::UserData for #name { + fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_function("new", |lua, config: mlua::Value| { + let config: #config = mlua::FromLua::from_lua(config, lua)?; + let config: Box = Box::new(config); + Ok(config) + }); + } + } + }; + + gen +} diff --git a/automation_macro/src/lua_device_config.rs b/automation_macro/src/lua_device_config.rs new file mode 100644 index 0000000..96f83c9 --- /dev/null +++ b/automation_macro/src/lua_device_config.rs @@ -0,0 +1,280 @@ +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::token::Paren; +use syn::{ + parenthesized, Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, LitStr, Result, + Token, Type, +}; + +mod kw { + use syn::custom_keyword; + + custom_keyword!(device_config); + custom_keyword!(flatten); + custom_keyword!(from_lua); + custom_keyword!(rename); + custom_keyword!(with); + custom_keyword!(from); + custom_keyword!(default); +} + +#[derive(Debug)] +enum Argument { + Flatten { + _keyword: kw::flatten, + }, + FromLua { + _keyword: kw::from_lua, + }, + Rename { + _keyword: kw::rename, + _paren: Paren, + ident: LitStr, + }, + With { + _keyword: kw::with, + _paren: Paren, + // TODO: Ideally we capture this better + expr: Expr, + }, + From { + _keyword: kw::from, + _paren: Paren, + ty: Type, + }, + Default { + _keyword: kw::default, + }, + DefaultExpr { + _keyword: kw::default, + _paren: Paren, + expr: Expr, + }, +} + +impl Parse for Argument { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::flatten) { + Ok(Self::Flatten { + _keyword: input.parse()?, + }) + } else if lookahead.peek(kw::from_lua) { + Ok(Self::FromLua { + _keyword: input.parse()?, + }) + } else if lookahead.peek(kw::rename) { + let content; + Ok(Self::Rename { + _keyword: input.parse()?, + _paren: parenthesized!(content in input), + ident: content.parse()?, + }) + } else if lookahead.peek(kw::with) { + let content; + Ok(Self::With { + _keyword: input.parse()?, + _paren: parenthesized!(content in input), + expr: content.parse()?, + }) + } else if lookahead.peek(kw::from) { + let content; + Ok(Self::From { + _keyword: input.parse()?, + _paren: parenthesized!(content in input), + ty: content.parse()?, + }) + } else if lookahead.peek(kw::default) { + let keyword = input.parse()?; + if input.peek(Paren) { + let content; + Ok(Self::DefaultExpr { + _keyword: keyword, + _paren: parenthesized!(content in input), + expr: content.parse()?, + }) + } else { + Ok(Self::Default { _keyword: keyword }) + } + } else { + Err(lookahead.error()) + } + } +} + +#[derive(Debug)] +struct Args { + args: Punctuated, +} + +impl Parse for Args { + fn parse(input: ParseStream) -> Result { + Ok(Self { + args: input.parse_terminated(Argument::parse, Token![,])?, + }) + } +} + +fn field_from_lua(field: &Field) -> TokenStream { + let (args, errors): (Vec<_>, Vec<_>) = field + .attrs + .iter() + .filter_map(|attr| { + if attr.path().is_ident("device_config") { + Some(attr.parse_args::().map(|args| args.args)) + } else { + None + } + }) + .partition_result(); + + let errors: Vec<_> = errors + .iter() + .map(|error| error.to_compile_error()) + .collect(); + + if !errors.is_empty() { + return quote! { #(#errors)* }; + } + + let args: Vec<_> = args.into_iter().flatten().collect(); + + let table_name = match args + .iter() + .filter_map(|arg| match arg { + Argument::Rename { ident, .. } => Some(ident.value()), + _ => None, + }) + .collect::>() + .as_slice() + { + [] => field.ident.clone().unwrap().to_string(), + [rename] => rename.to_owned(), + _ => { + return quote_spanned! {field.span() => compile_error!("Field contains duplicate 'rename'")} + } + }; + + // TODO: Detect Option<_> properly and use Default::default() as fallback automatically + let missing = format!("Missing field '{table_name}'"); + let default = match args + .iter() + .filter_map(|arg| match arg { + Argument::Default { .. } => Some(quote! { Default::default() }), + Argument::DefaultExpr { expr, .. } => Some(quote! { (#expr) }), + _ => None, + }) + .collect::>() + .as_slice() + { + [] => quote! {panic!(#missing)}, + [default] => default.to_owned(), + _ => { + return quote_spanned! {field.span() => compile_error!("Field contains duplicate 'default'")} + } + }; + + let value = match args + .iter() + .filter_map(|arg| match arg { + Argument::Flatten { .. } => Some(quote! { + mlua::LuaSerdeExt::from_value_with(lua, value.clone(), mlua::DeserializeOptions::new().deny_unsupported_types(false))? + }), + Argument::FromLua { .. } => Some(quote! { + if table.contains_key(#table_name)? { + table.get(#table_name)? + } else { + #default + } + }), + _ => None, + }) + .collect::>() + .as_slice() { + [] => quote! { + { + let value: mlua::Value = table.get(#table_name)?; + if !value.is_nil() { + mlua::LuaSerdeExt::from_value(lua, value)? + } else { + #default + } + } + }, + [value] => value.to_owned(), + _ => return quote_spanned! {field.span() => compile_error!("Only one of either 'flatten' or 'from_lua' is allowed")}, + }; + + let value = match args + .iter() + .filter_map(|arg| match arg { + Argument::From { ty, .. } => Some(quote! { + { + let temp: #ty = #value; + temp.into() + } + }), + Argument::With { expr, .. } => Some(quote! { + { + let temp = #value; + (#expr)(temp) + } + }), + _ => None, + }) + .collect::>() + .as_slice() + { + [] => value, + [value] => value.to_owned(), + _ => { + return quote_spanned! {field.span() => compile_error!("Only one of either 'from' or 'with' is allowed")} + } + }; + + quote! { #value } +} + +pub fn impl_lua_device_config_macro(ast: &DeriveInput) -> TokenStream { + let name = &ast.ident; + let fields = if let Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { ref named, .. }), + .. + }) = ast.data + { + named + } else { + return quote_spanned! {ast.span() => compile_error!("This macro only works on named structs")}; + }; + + let lua_fields: Vec<_> = fields + .iter() + .map(|field| { + let name = field.ident.clone().unwrap(); + let value = field_from_lua(field); + quote! { #name: #value } + }) + .collect(); + + let impl_from_lua = quote! { + impl<'lua> mlua::FromLua<'lua> for #name { + fn from_lua(value: mlua::Value<'lua>, lua: &'lua mlua::Lua) -> mlua::Result { + if !value.is_table() { + panic!("Expected table"); + } + let table = value.as_table().unwrap(); + + Ok(#name { + #(#lua_fields,)* + }) + + } + } + }; + + impl_from_lua +}