diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index b884ea17391b..0ecbb2e16da3 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -22,6 +22,8 @@ mod vtable; use proc_macro::TokenStream; +use syn::parse_macro_input; + /// Declares a kernel module. /// /// The `type` argument should be a type which implements the [`Module`] @@ -204,8 +206,11 @@ pub fn module(ts: TokenStream) -> TokenStream { /// /// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html #[proc_macro_attribute] -pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { - vtable::vtable(attr.into(), ts.into()).into() +pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + vtable::vtable(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Export a function so that C code can call it via a header file. diff --git a/rust/macros/vtable.rs b/rust/macros/vtable.rs index a67d1cc81a2d..72ae0a1816a0 100644 --- a/rust/macros/vtable.rs +++ b/rust/macros/vtable.rs @@ -1,97 +1,110 @@ // SPDX-License-Identifier: GPL-2.0 -use std::collections::HashSet; -use std::fmt::Write; +use std::{ + collections::HashSet, + iter::Extend, // +}; -use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; +use proc_macro2::{ + Ident, + TokenStream, // +}; +use quote::ToTokens; +use syn::{ + parse_quote, + Error, + ImplItem, + Item, + ItemImpl, + ItemTrait, + Result, + TraitItem, // +}; -pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let mut tokens: Vec<_> = ts.into_iter().collect(); +fn handle_trait(mut item: ItemTrait) -> Result { + let mut gen_items = Vec::new(); + let mut gen_consts = HashSet::new(); - // Scan for the `trait` or `impl` keyword. - let is_trait = tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "trait" => Some(true), - "impl" => Some(false), - _ => None, - }, - _ => None, - }) - .expect("#[vtable] attribute should only be applied to trait or impl block"); + gen_items.push(parse_quote! { + /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) + /// attribute when implementing this trait. + const USE_VTABLE_ATTR: (); + }); - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("cannot locate main body of trait or impl block"), - }; - - let mut body_it = body.stream().into_iter(); - let mut functions = Vec::new(); - let mut consts = HashSet::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Ident(ident) if ident == "fn" => { - let fn_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered a fn pointer type instead. - _ => continue, - }; - functions.push(fn_name); - } - TokenTree::Ident(ident) if ident == "const" => { - let const_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered an inline const block instead. - _ => continue, - }; - consts.insert(const_name); - } - _ => (), - } - } - - let mut const_items; - if is_trait { - const_items = " - /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) - /// attribute when implementing this trait. - const USE_VTABLE_ATTR: (); - " - .to_owned(); - - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - // Skip if it's declared already -- this allows user override. - if consts.contains(&gen_const_name) { + for item in &item.items { + if let TraitItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + // Skip if it's declared already -- this can happen if `#[cfg]` is used to selectively + // define functions. + // FIXME: `#[cfg]` should be copied and propagated to the generated consts. + if gen_consts.contains(&gen_const_name) { continue; } + // We don't know on the implementation-site whether a method is required or provided // so we have to generate a const for all methods. - write!( - const_items, - "/// Indicates if the `{f}` method is overridden by the implementor. - const {gen_const_name}: bool = false;", - ) - .unwrap(); - consts.insert(gen_const_name); - } - } else { - const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); - - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - if consts.contains(&gen_const_name) { - continue; - } - write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); + let comment = + format!("Indicates if the `{name}` method is overridden by the implementor."); + gen_items.push(parse_quote! { + #[doc = #comment] + const #gen_const_name: bool = false; + }); + gen_consts.insert(gen_const_name); } } - let new_body = vec![const_items.parse().unwrap(), body.stream()] - .into_iter() - .collect(); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); - tokens.into_iter().collect() + item.items.extend(gen_items); + Ok(item) +} + +fn handle_impl(mut item: ItemImpl) -> Result { + let mut gen_items = Vec::new(); + let mut defined_consts = HashSet::new(); + + // Iterate over all user-defined constants to gather any possible explicit overrides. + for item in &item.items { + if let ImplItem::Const(const_item) = item { + defined_consts.insert(const_item.ident.clone()); + } + } + + gen_items.push(parse_quote! { + const USE_VTABLE_ATTR: () = (); + }); + + for item in &item.items { + if let ImplItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + // Skip if it's declared already -- this allows user override. + if defined_consts.contains(&gen_const_name) { + continue; + } + gen_items.push(parse_quote! { + const #gen_const_name: bool = true; + }); + defined_consts.insert(gen_const_name); + } + } + + item.items.extend(gen_items); + Ok(item) +} + +pub(crate) fn vtable(input: Item) -> Result { + match input { + Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()), + Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()), + _ => Err(Error::new_spanned( + input, + "`#[vtable]` attribute should only be applied to trait or impl block", + ))?, + } }