Merge patch series "refactor Rust proc macros with syn"

Gary writes:

  "This series converts Rust proc macros that we have to use `syn`,
   and replace the custom `quote!` macro that we have with the vendored
   `quote!` macro. The `pin-init` macros are not converted yet; Benno
   has a work in progress in converting them. They're however converted
   to use `quote` and `proc-macro2` crate so our custom `quote!` macro
   can be removed.

   Overall this improves the robustness of the macros as we have precise
   parsing of the AST rather than relying on heuristics to extract needed
   information from there. This is also a quality-of-life improvement
   to those using language servers (e.g. Rust analyzer) as the span
   information of the proc macros are now preserved which allows the
   "jump-to-definition" feature to work, even when used on completely
   custom macros such as `module!`.

   Miguel gave a very good explanation on why `syn` is a good idea in the
   patch series that introduced it [1], which I shall not repeat here."

The `pin-init` rewrite was merged just before this one.

Link: https://lore.kernel.org/rust-for-linux/20251124151837.2184382-1-ojeda@kernel.org/ [1]
Link: https://patch.msgid.link/20260112170919.1888584-1-gary@kernel.org
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
This commit is contained in:
Miguel Ojeda 2026-01-28 13:44:42 +01:00
commit a7c013f779
11 changed files with 796 additions and 962 deletions

View file

@ -189,9 +189,6 @@ pub fn is_test_result_ok(t: impl TestResult) -> bool {
}
/// Represents an individual test case.
///
/// The [`kunit_unsafe_test_suite!`] macro expects a `NULL`-terminated list of valid test cases.
/// Use [`kunit_case_null`] to generate such a delimiter.
#[doc(hidden)]
pub const fn kunit_case(
name: &'static kernel::str::CStr,
@ -212,27 +209,6 @@ pub const fn kunit_case(
}
}
/// Represents the `NULL` test case delimiter.
///
/// The [`kunit_unsafe_test_suite!`] macro expects a `NULL`-terminated list of test cases. This
/// function returns such a delimiter.
#[doc(hidden)]
pub const fn kunit_case_null() -> kernel::bindings::kunit_case {
kernel::bindings::kunit_case {
run_case: None,
name: core::ptr::null_mut(),
generate_params: None,
attr: kernel::bindings::kunit_attributes {
speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL,
},
status: kernel::bindings::kunit_status_KUNIT_SUCCESS,
module_name: core::ptr::null_mut(),
log: core::ptr::null_mut(),
param_init: None,
param_exit: None,
}
}
/// Registers a KUnit test suite.
///
/// # Safety
@ -251,7 +227,7 @@ pub const fn kunit_case_null() -> kernel::bindings::kunit_case {
///
/// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [
/// kernel::kunit::kunit_case(c"name", test_fn),
/// kernel::kunit::kunit_case_null(),
/// pin_init::zeroed(),
/// ];
/// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES);
/// ```

View file

@ -1,23 +1,36 @@
// SPDX-License-Identifier: GPL-2.0
use proc_macro::{token_stream, Ident, TokenStream, TokenTree};
use proc_macro2::{
Ident,
TokenStream,
TokenTree, //
};
use syn::{
parse::{
Parse,
ParseStream, //
},
Result,
Token, //
};
use crate::helpers::expect_punct;
pub(crate) struct Input {
a: Ident,
_comma: Token![,],
b: Ident,
}
fn expect_ident(it: &mut token_stream::IntoIter) -> Ident {
if let Some(TokenTree::Ident(ident)) = it.next() {
ident
} else {
panic!("Expected Ident")
impl Parse for Input {
fn parse(input: ParseStream<'_>) -> Result<Self> {
Ok(Self {
a: input.parse()?,
_comma: input.parse()?,
b: input.parse()?,
})
}
}
pub(crate) fn concat_idents(ts: TokenStream) -> TokenStream {
let mut it = ts.into_iter();
let a = expect_ident(&mut it);
assert_eq!(expect_punct(&mut it), ',');
let b = expect_ident(&mut it);
assert!(it.next().is_none(), "only two idents can be concatenated");
pub(crate) fn concat_idents(Input { a, b, .. }: Input) -> TokenStream {
let res = Ident::new(&format!("{a}{b}"), b.span());
TokenStream::from_iter([TokenTree::Ident(res)])
}

View file

@ -1,19 +1,16 @@
// SPDX-License-Identifier: GPL-2.0
use crate::helpers::function_name;
use proc_macro::TokenStream;
use proc_macro2::TokenStream;
use quote::quote;
/// Please see [`crate::export`] for documentation.
pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream {
let Some(name) = function_name(ts.clone()) else {
return "::core::compile_error!(\"The #[export] attribute must be used on a function.\");"
.parse::<TokenStream>()
.unwrap();
};
pub(crate) fn export(f: syn::ItemFn) -> TokenStream {
let name = &f.sig.ident;
// This verifies that the function has the same signature as the declaration generated by
// bindgen. It makes use of the fact that all branches of an if/else must have the same type.
let signature_check = quote!(
quote! {
// This verifies that the function has the same signature as the declaration generated by
// bindgen. It makes use of the fact that all branches of an if/else must have the same
// type.
const _: () = {
if true {
::kernel::bindings::#name
@ -21,9 +18,8 @@ pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream {
#name
};
};
);
let no_mangle = quote!(#[no_mangle]);
TokenStream::from_iter([signature_check, no_mangle, ts])
#[no_mangle]
#f
}
}

View file

@ -1,8 +1,10 @@
// SPDX-License-Identifier: GPL-2.0
use proc_macro::{Ident, TokenStream, TokenTree};
use std::collections::BTreeSet;
use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::quote_spanned;
/// Please see [`crate::fmt`] for documentation.
pub(crate) fn fmt(input: TokenStream) -> TokenStream {
let mut input = input.into_iter();

View file

@ -1,103 +1,43 @@
// SPDX-License-Identifier: GPL-2.0
use proc_macro::{token_stream, Group, Ident, TokenStream, TokenTree};
use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{
parse::{
Parse,
ParseStream, //
},
Attribute,
Error,
LitStr,
Result, //
};
pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> {
if let Some(TokenTree::Ident(ident)) = it.next() {
Some(ident.to_string())
} else {
None
}
}
/// A string literal that is required to have ASCII value only.
pub(crate) struct AsciiLitStr(LitStr);
pub(crate) fn try_sign(it: &mut token_stream::IntoIter) -> Option<char> {
let peek = it.clone().next();
match peek {
Some(TokenTree::Punct(punct)) if punct.as_char() == '-' => {
let _ = it.next();
Some(punct.as_char())
impl Parse for AsciiLitStr {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let s: LitStr = input.parse()?;
if !s.value().is_ascii() {
return Err(Error::new_spanned(s, "expected ASCII-only string literal"));
}
_ => None,
Ok(Self(s))
}
}
pub(crate) fn try_literal(it: &mut token_stream::IntoIter) -> Option<String> {
if let Some(TokenTree::Literal(literal)) = it.next() {
Some(literal.to_string())
} else {
None
impl ToTokens for AsciiLitStr {
fn to_tokens(&self, ts: &mut TokenStream) {
self.0.to_tokens(ts);
}
}
pub(crate) fn try_string(it: &mut token_stream::IntoIter) -> Option<String> {
try_literal(it).and_then(|string| {
if string.starts_with('\"') && string.ends_with('\"') {
let content = &string[1..string.len() - 1];
if content.contains('\\') {
panic!("Escape sequences in string literals not yet handled");
}
Some(content.to_string())
} else if string.starts_with("r\"") {
panic!("Raw string literals are not yet handled");
} else {
None
}
})
}
pub(crate) fn expect_ident(it: &mut token_stream::IntoIter) -> String {
try_ident(it).expect("Expected Ident")
}
pub(crate) fn expect_punct(it: &mut token_stream::IntoIter) -> char {
if let TokenTree::Punct(punct) = it.next().expect("Reached end of token stream for Punct") {
punct.as_char()
} else {
panic!("Expected Punct");
impl AsciiLitStr {
pub(crate) fn value(&self) -> String {
self.0.value()
}
}
pub(crate) fn expect_string(it: &mut token_stream::IntoIter) -> String {
try_string(it).expect("Expected string")
}
pub(crate) fn expect_string_ascii(it: &mut token_stream::IntoIter) -> String {
let string = try_string(it).expect("Expected string");
assert!(string.is_ascii(), "Expected ASCII string");
string
}
pub(crate) fn expect_group(it: &mut token_stream::IntoIter) -> Group {
if let TokenTree::Group(group) = it.next().expect("Reached end of token stream for Group") {
group
} else {
panic!("Expected Group");
}
}
pub(crate) fn expect_end(it: &mut token_stream::IntoIter) {
if it.next().is_some() {
panic!("Expected end");
}
}
/// Given a function declaration, finds the name of the function.
pub(crate) fn function_name(input: TokenStream) -> Option<Ident> {
let mut input = input.into_iter();
while let Some(token) = input.next() {
match token {
TokenTree::Ident(i) if i.to_string() == "fn" => {
if let Some(TokenTree::Ident(i)) = input.next() {
return Some(i);
}
return None;
}
_ => continue,
}
}
None
}
pub(crate) fn file() -> String {
#[cfg(not(CONFIG_RUSTC_HAS_SPAN_FILE))]
{
@ -115,16 +55,7 @@ pub(crate) fn file() -> String {
}
}
/// Parse a token stream of the form `expected_name: "value",` and return the
/// string in the position of "value".
///
/// # Panics
///
/// - On parse error.
pub(crate) fn expect_string_field(it: &mut token_stream::IntoIter, expected_name: &str) -> String {
assert_eq!(expect_ident(it), expected_name);
assert_eq!(expect_punct(it), ':');
let string = expect_string(it);
assert_eq!(expect_punct(it), ',');
string
/// Obtain all `#[cfg]` attributes.
pub(crate) fn gather_cfg_attrs(attr: &[Attribute]) -> impl Iterator<Item = &Attribute> + '_ {
attr.iter().filter(|a| a.path().is_ident("cfg"))
}

View file

@ -4,80 +4,50 @@
//!
//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
use std::collections::HashMap;
use std::fmt::Write;
use std::ffi::CString;
pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
let attr = attr.to_string();
use proc_macro2::TokenStream;
use quote::{
format_ident,
quote,
ToTokens, //
};
use syn::{
parse_quote,
Error,
Ident,
Item,
ItemMod,
LitCStr,
Result, //
};
if attr.is_empty() {
panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
if test_suite.to_string().len() > 255 {
return Err(Error::new_spanned(
test_suite,
"test suite names cannot exceed the maximum length of 255 bytes",
));
}
if attr.len() > 255 {
panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
}
let mut tokens: Vec<_> = ts.into_iter().collect();
// Scan for the `mod` keyword.
tokens
.iter()
.find_map(|token| match token {
TokenTree::Ident(ident) => match ident.to_string().as_str() {
"mod" => Some(true),
_ => None,
},
_ => None,
})
.expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
// Retrieve the main body. The main body should be the last token tree.
let body = match tokens.pop() {
Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
_ => panic!("Cannot locate main body of module"),
// We cannot handle modules that defer to another file (e.g. `mod foo;`).
let Some((module_brace, module_items)) = module.content.take() else {
Err(Error::new_spanned(
module,
"`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
))?
};
// Get the functions set as tests. Search for `[test]` -> `fn`.
let mut body_it = body.stream().into_iter();
let mut tests = Vec::new();
let mut attributes: HashMap<String, TokenStream> = HashMap::new();
while let Some(token) = body_it.next() {
match token {
TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
// Collect attributes because we need to find which are tests. We also
// need to copy `cfg` attributes so tests can be conditionally enabled.
attributes
.entry(name.to_string())
.or_default()
.extend([token, TokenTree::Group(g)]);
}
continue;
}
_ => (),
},
TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
if let Some(TokenTree::Ident(test_name)) = body_it.next() {
tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
}
}
// Make the entire module gated behind `CONFIG_KUNIT`.
module
.attrs
.insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
_ => (),
}
attributes.clear();
}
// Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
tokens.insert(
0,
TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
);
let mut processed_items = Vec::new();
let mut test_cases = Vec::new();
// Generate the test KUnit test suite and a test case for each `#[test]`.
//
// The code generated for the following test module:
//
// ```
@ -104,103 +74,98 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
// static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
// ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
// ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
// ::kernel::kunit::kunit_case_null(),
// ::pin_init::zeroed(),
// ];
//
// ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
// ```
let mut kunit_macros = "".to_owned();
let mut test_cases = "".to_owned();
let mut assert_macros = "".to_owned();
let path = crate::helpers::file();
let num_tests = tests.len();
for (test, cfg_attr) in tests {
let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
// Append any `cfg` attributes the user might have written on their tests so we don't
// attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
// the length of the assert message.
let kunit_wrapper = format!(
r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
{{
//
// Non-function items (e.g. imports) are preserved.
for item in module_items {
let Item::Fn(mut f) = item else {
processed_items.push(item);
continue;
};
// TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.
let before_len = f.attrs.len();
f.attrs.retain(|attr| !attr.path().is_ident("test"));
if f.attrs.len() == before_len {
processed_items.push(Item::Fn(f));
continue;
}
let test = f.sig.ident.clone();
// Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
let cfg_attrs: Vec<_> = f
.attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect();
// Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
// KUnit instead.
let test_str = test.to_string();
let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
processed_items.push(parse_quote! {
#[allow(unused)]
macro_rules! assert {
($cond:expr $(,)?) => {{
kernel::kunit_assert!(#test_str, #path, 0, $cond);
}}
}
});
processed_items.push(parse_quote! {
#[allow(unused)]
macro_rules! assert_eq {
($left:expr, $right:expr $(,)?) => {{
kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
}}
}
});
// Add back the test item.
processed_items.push(Item::Fn(f));
let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
let test_cstr = LitCStr::new(
&CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
test.span(),
);
processed_items.push(parse_quote! {
unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
{cfg_attr} {{
// Append any `cfg` attributes the user might have written on their tests so we
// don't attempt to call them when they are `cfg`'d out. An extra `use` is used
// here to reduce the length of the assert message.
#(#cfg_attrs)*
{
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
use ::kernel::kunit::is_test_result_ok;
assert!(is_test_result_ok({test}()));
}}
}}"#,
);
writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
writeln!(
test_cases,
" ::kernel::kunit::kunit_case(c\"{test}\", {kunit_wrapper_fn_name}),"
)
.unwrap();
writeln!(
assert_macros,
r#"
/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
#[allow(unused)]
macro_rules! assert {{
($cond:expr $(,)?) => {{{{
kernel::kunit_assert!("{test}", c"{path}", 0, $cond);
}}}}
}}
/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
#[allow(unused)]
macro_rules! assert_eq {{
($left:expr, $right:expr $(,)?) => {{{{
kernel::kunit_assert_eq!("{test}", c"{path}", 0, $left, $right);
}}}}
}}
"#
)
.unwrap();
}
writeln!(kunit_macros).unwrap();
writeln!(
kunit_macros,
"static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",
num_tests + 1
)
.unwrap();
writeln!(
kunit_macros,
"::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
)
.unwrap();
// Remove the `#[test]` macros.
// We do this at a token level, in order to preserve span information.
let mut new_body = vec![];
let mut body_it = body.stream().into_iter();
while let Some(token) = body_it.next() {
match token {
TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
Some(next) => {
new_body.extend([token, next]);
assert!(is_test_result_ok(#test()));
}
_ => {
new_body.push(token);
}
},
_ => {
new_body.push(token);
}
}
});
test_cases.push(quote!(
::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
));
}
let mut final_body = TokenStream::new();
final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
final_body.extend(new_body);
final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
let num_tests_plus_1 = test_cases.len() + 1;
processed_items.push(parse_quote! {
static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
#(#test_cases,)*
::pin_init::zeroed(),
];
});
processed_items.push(parse_quote! {
::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
});
tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
tokens.into_iter().collect()
module.content = Some((module_brace, processed_items));
Ok(module.to_token_stream())
}

View file

@ -11,8 +11,6 @@
// to avoid depending on the full `proc_macro_span` on Rust >= 1.88.0.
#![cfg_attr(not(CONFIG_RUSTC_HAS_SPAN_FILE), feature(proc_macro_span))]
#[macro_use]
mod quote;
mod concat_idents;
mod export;
mod fmt;
@ -24,6 +22,8 @@ mod vtable;
use proc_macro::TokenStream;
use syn::parse_macro_input;
/// Declares a kernel module.
///
/// The `type` argument should be a type which implements the [`Module`]
@ -131,8 +131,10 @@ use proc_macro::TokenStream;
/// - `firmware`: array of ASCII string literals of the firmware files of
/// the kernel module.
#[proc_macro]
pub fn module(ts: TokenStream) -> TokenStream {
module::module(ts)
pub fn module(input: TokenStream) -> TokenStream {
module::module(parse_macro_input!(input))
.unwrap_or_else(|e| e.into_compile_error())
.into()
}
/// Declares or implements a vtable trait.
@ -206,8 +208,11 @@ pub fn module(ts: TokenStream) -> TokenStream {
///
/// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html
#[proc_macro_attribute]
pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream {
vtable::vtable(attr, ts)
pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(attr as syn::parse::Nothing);
vtable::vtable(parse_macro_input!(input))
.unwrap_or_else(|e| e.into_compile_error())
.into()
}
/// Export a function so that C code can call it via a header file.
@ -229,8 +234,9 @@ pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream {
/// This macro is *not* the same as the C macros `EXPORT_SYMBOL_*`. All Rust symbols are currently
/// automatically exported with `EXPORT_SYMBOL_GPL`.
#[proc_macro_attribute]
pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream {
export::export(attr, ts)
pub fn export(attr: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(attr as syn::parse::Nothing);
export::export(parse_macro_input!(input)).into()
}
/// Like [`core::format_args!`], but automatically wraps arguments in [`kernel::fmt::Adapter`].
@ -248,7 +254,7 @@ pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream {
/// [`pr_info!`]: ../kernel/macro.pr_info.html
#[proc_macro]
pub fn fmt(input: TokenStream) -> TokenStream {
fmt::fmt(input)
fmt::fmt(input.into()).into()
}
/// Concatenate two identifiers.
@ -305,8 +311,8 @@ pub fn fmt(input: TokenStream) -> TokenStream {
/// assert_eq!(BR_OK, binder_driver_return_protocol_BR_OK);
/// ```
#[proc_macro]
pub fn concat_idents(ts: TokenStream) -> TokenStream {
concat_idents::concat_idents(ts)
pub fn concat_idents(input: TokenStream) -> TokenStream {
concat_idents::concat_idents(parse_macro_input!(input)).into()
}
/// Paste identifiers together.
@ -444,9 +450,12 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream {
/// [`paste`]: https://docs.rs/paste/
#[proc_macro]
pub fn paste(input: TokenStream) -> TokenStream {
let mut tokens = input.into_iter().collect();
let mut tokens = proc_macro2::TokenStream::from(input).into_iter().collect();
paste::expand(&mut tokens);
tokens.into_iter().collect()
tokens
.into_iter()
.collect::<proc_macro2::TokenStream>()
.into()
}
/// Registers a KUnit test suite and its test cases using a user-space like syntax.
@ -472,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream {
/// }
/// ```
#[proc_macro_attribute]
pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
kunit::kunit_tests(attr, ts)
pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream {
kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input))
.unwrap_or_else(|e| e.into_compile_error())
.into()
}

View file

@ -1,32 +1,42 @@
// SPDX-License-Identifier: GPL-2.0
use std::ffi::CString;
use proc_macro2::{
Literal,
TokenStream, //
};
use quote::{
format_ident,
quote, //
};
use syn::{
braced,
bracketed,
ext::IdentExt,
parse::{
Parse,
ParseStream, //
},
parse_quote,
punctuated::Punctuated,
Error,
Expr,
Ident,
LitStr,
Path,
Result,
Token,
Type, //
};
use crate::helpers::*;
use proc_macro::{token_stream, Delimiter, Literal, TokenStream, TokenTree};
use std::fmt::Write;
fn expect_string_array(it: &mut token_stream::IntoIter) -> Vec<String> {
let group = expect_group(it);
assert_eq!(group.delimiter(), Delimiter::Bracket);
let mut values = Vec::new();
let mut it = group.stream().into_iter();
while let Some(val) = try_string(&mut it) {
assert!(val.is_ascii(), "Expected ASCII string");
values.push(val);
match it.next() {
Some(TokenTree::Punct(punct)) => assert_eq!(punct.as_char(), ','),
None => break,
_ => panic!("Expected ',' or end of array"),
}
}
values
}
struct ModInfoBuilder<'a> {
module: &'a str,
counter: usize,
buffer: String,
param_buffer: String,
ts: TokenStream,
param_ts: TokenStream,
}
impl<'a> ModInfoBuilder<'a> {
@ -34,8 +44,8 @@ impl<'a> ModInfoBuilder<'a> {
ModInfoBuilder {
module,
counter: 0,
buffer: String::new(),
param_buffer: String::new(),
ts: TokenStream::new(),
param_ts: TokenStream::new(),
}
}
@ -52,33 +62,31 @@ impl<'a> ModInfoBuilder<'a> {
// Loadable modules' modinfo strings go as-is.
format!("{field}={content}\0")
};
let buffer = if param {
&mut self.param_buffer
let length = string.len();
let string = Literal::byte_string(string.as_bytes());
let cfg = if builtin {
quote!(#[cfg(not(MODULE))])
} else {
&mut self.buffer
quote!(#[cfg(MODULE)])
};
write!(
buffer,
"
{cfg}
#[doc(hidden)]
#[cfg_attr(not(target_os = \"macos\"), link_section = \".modinfo\")]
#[used(compiler)]
pub static __{module}_{counter}: [u8; {length}] = *{string};
",
cfg = if builtin {
"#[cfg(not(MODULE))]"
} else {
"#[cfg(MODULE)]"
},
let counter = format_ident!(
"__{module}_{counter}",
module = self.module.to_uppercase(),
counter = self.counter,
length = string.len(),
string = Literal::byte_string(string.as_bytes()),
)
.unwrap();
counter = self.counter
);
let item = quote! {
#cfg
#[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
#[used(compiler)]
pub static #counter: [u8; #length] = *#string;
};
if param {
self.param_ts.extend(item);
} else {
self.ts.extend(item);
}
self.counter += 1;
}
@ -111,201 +119,160 @@ impl<'a> ModInfoBuilder<'a> {
};
for param in params {
let ops = param_ops_path(&param.ptype);
let param_name_str = param.name.to_string();
let param_type_str = param.ptype.to_string();
let ops = param_ops_path(&param_type_str);
// Note: The spelling of these fields is dictated by the user space
// tool `modinfo`.
self.emit_param("parmtype", &param.name, &param.ptype);
self.emit_param("parm", &param.name, &param.description);
self.emit_param("parmtype", &param_name_str, &param_type_str);
self.emit_param("parm", &param_name_str, &param.description.value());
write!(
self.param_buffer,
"
pub(crate) static {param_name}:
::kernel::module_param::ModuleParamAccess<{param_type}> =
::kernel::module_param::ModuleParamAccess::new({param_default});
let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
let param_name_cstr =
CString::new(param_name_str).expect("name contains NUL-terminator");
let param_name_cstr_with_module =
CString::new(format!("{}.{}", self.module, param.name))
.expect("name contains NUL-terminator");
const _: () = {{
#[link_section = \"__param\"]
#[used]
static __{module_name}_{param_name}_struct:
let param_name = &param.name;
let param_type = &param.ptype;
let param_default = &param.default;
self.param_ts.extend(quote! {
#[allow(non_upper_case_globals)]
pub(crate) static #param_name:
::kernel::module_param::ModuleParamAccess<#param_type> =
::kernel::module_param::ModuleParamAccess::new(#param_default);
const _: () = {
#[allow(non_upper_case_globals)]
#[link_section = "__param"]
#[used(compiler)]
static #static_name:
::kernel::module_param::KernelParam =
::kernel::module_param::KernelParam::new(
::kernel::bindings::kernel_param {{
name: if ::core::cfg!(MODULE) {{
::kernel::c_str!(\"{param_name}\").to_bytes_with_nul()
}} else {{
::kernel::c_str!(\"{module_name}.{param_name}\")
.to_bytes_with_nul()
}}.as_ptr(),
::kernel::bindings::kernel_param {
name: kernel::str::as_char_ptr_in_const_context(
if ::core::cfg!(MODULE) {
#param_name_cstr
} else {
#param_name_cstr_with_module
}
),
// SAFETY: `__this_module` is constructed by the kernel at load
// time and will not be freed until the module is unloaded.
#[cfg(MODULE)]
mod_: unsafe {{
mod_: unsafe {
core::ptr::from_ref(&::kernel::bindings::__this_module)
.cast_mut()
}},
},
#[cfg(not(MODULE))]
mod_: ::core::ptr::null_mut(),
ops: core::ptr::from_ref(&{ops}),
ops: core::ptr::from_ref(&#ops),
perm: 0, // Will not appear in sysfs
level: -1,
flags: 0,
__bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {{
arg: {param_name}.as_void_ptr()
}},
}}
__bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
arg: #param_name.as_void_ptr()
},
}
);
}};
",
module_name = info.name,
param_type = param.ptype,
param_default = param.default,
param_name = param.name,
ops = ops,
)
.unwrap();
};
});
}
}
}
fn param_ops_path(param_type: &str) -> &'static str {
fn param_ops_path(param_type: &str) -> Path {
match param_type {
"i8" => "::kernel::module_param::PARAM_OPS_I8",
"u8" => "::kernel::module_param::PARAM_OPS_U8",
"i16" => "::kernel::module_param::PARAM_OPS_I16",
"u16" => "::kernel::module_param::PARAM_OPS_U16",
"i32" => "::kernel::module_param::PARAM_OPS_I32",
"u32" => "::kernel::module_param::PARAM_OPS_U32",
"i64" => "::kernel::module_param::PARAM_OPS_I64",
"u64" => "::kernel::module_param::PARAM_OPS_U64",
"isize" => "::kernel::module_param::PARAM_OPS_ISIZE",
"usize" => "::kernel::module_param::PARAM_OPS_USIZE",
"i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
"u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
"i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
"u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
"i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
"u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
"i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
"u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
"isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
"usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
t => panic!("Unsupported parameter type {}", t),
}
}
fn expect_param_default(param_it: &mut token_stream::IntoIter) -> String {
assert_eq!(expect_ident(param_it), "default");
assert_eq!(expect_punct(param_it), ':');
let sign = try_sign(param_it);
let default = try_literal(param_it).expect("Expected default param value");
assert_eq!(expect_punct(param_it), ',');
let mut value = sign.map(String::from).unwrap_or_default();
value.push_str(&default);
value
}
/// Parse fields that are required to use a specific order.
///
/// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
/// However the error message generated when implementing that way is not very friendly.
///
/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
/// and if the wrong order is used, the proper order is communicated to the user with error message.
///
/// Usage looks like this:
/// ```ignore
/// parse_ordered_fields! {
/// from input;
///
/// // This will extract "foo: <field>" into a variable named "foo".
/// // The variable will have type `Option<_>`.
/// foo => <expression that parses the field>,
///
/// // If you need the variable name to be different than the key name.
/// // This extracts "baz: <field>" into a variable named "bar".
/// // You might want this if "baz" is a keyword.
/// baz as bar => <expression that parse the field>,
///
/// // You can mark a key as required, and the variable will no longer be `Option`.
/// // foobar will be of type `Expr` instead of `Option<Expr>`.
/// foobar [required] => input.parse::<Expr>()?,
/// }
/// ```
macro_rules! parse_ordered_fields {
(@gen
[$input:expr]
[$([$name:ident; $key:ident; $parser:expr])*]
[$([$req_name:ident; $req_key:ident])*]
) => {
$(let mut $name = None;)*
#[derive(Debug, Default)]
struct ModuleInfo {
type_: String,
license: String,
name: String,
authors: Option<Vec<String>>,
description: Option<String>,
alias: Option<Vec<String>>,
firmware: Option<Vec<String>>,
imports_ns: Option<Vec<String>>,
params: Option<Vec<Parameter>>,
}
const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
#[derive(Debug)]
struct Parameter {
name: String,
ptype: String,
default: String,
description: String,
}
fn expect_params(it: &mut token_stream::IntoIter) -> Vec<Parameter> {
let params = expect_group(it);
assert_eq!(params.delimiter(), Delimiter::Brace);
let mut it = params.stream().into_iter();
let mut parsed = Vec::new();
loop {
let param_name = match it.next() {
Some(TokenTree::Ident(ident)) => ident.to_string(),
Some(_) => panic!("Expected Ident or end"),
None => break,
};
assert_eq!(expect_punct(&mut it), ':');
let param_type = expect_ident(&mut it);
let group = expect_group(&mut it);
assert_eq!(group.delimiter(), Delimiter::Brace);
assert_eq!(expect_punct(&mut it), ',');
let mut param_it = group.stream().into_iter();
let param_default = expect_param_default(&mut param_it);
let param_description = expect_string_field(&mut param_it, "description");
expect_end(&mut param_it);
parsed.push(Parameter {
name: param_name,
ptype: param_type,
default: param_default,
description: param_description,
})
}
parsed
}
impl ModuleInfo {
fn parse(it: &mut token_stream::IntoIter) -> Self {
let mut info = ModuleInfo::default();
const EXPECTED_KEYS: &[&str] = &[
"type",
"name",
"authors",
"description",
"license",
"alias",
"firmware",
"imports_ns",
"params",
];
const REQUIRED_KEYS: &[&str] = &["type", "name", "license"];
let span = $input.span();
let mut seen_keys = Vec::new();
loop {
let key = match it.next() {
Some(TokenTree::Ident(ident)) => ident.to_string(),
Some(_) => panic!("Expected Ident or end"),
None => break,
};
while !$input.is_empty() {
let key = $input.call(Ident::parse_any)?;
if seen_keys.contains(&key) {
panic!("Duplicated key \"{key}\". Keys can only be specified once.");
Err(Error::new_spanned(
&key,
format!(r#"duplicated key "{key}". Keys can only be specified once."#),
))?
}
assert_eq!(expect_punct(it), ':');
$input.parse::<Token![:]>()?;
match key.as_str() {
"type" => info.type_ = expect_ident(it),
"name" => info.name = expect_string_ascii(it),
"authors" => info.authors = Some(expect_string_array(it)),
"description" => info.description = Some(expect_string(it)),
"license" => info.license = expect_string_ascii(it),
"alias" => info.alias = Some(expect_string_array(it)),
"firmware" => info.firmware = Some(expect_string_array(it)),
"imports_ns" => info.imports_ns = Some(expect_string_array(it)),
"params" => info.params = Some(expect_params(it)),
_ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."),
match &*key.to_string() {
$(
stringify!($key) => $name = Some($parser),
)*
_ => {
Err(Error::new_spanned(
&key,
format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
))?
}
}
assert_eq!(expect_punct(it), ',');
$input.parse::<Token![,]>()?;
seen_keys.push(key);
}
expect_end(it);
for key in REQUIRED_KEYS {
if !seen_keys.iter().any(|e| e == key) {
panic!("Missing required key \"{key}\".");
Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
}
}
@ -317,43 +284,190 @@ impl ModuleInfo {
}
if seen_keys != ordered_keys {
panic!("Keys are not ordered as expected. Order them like: {ordered_keys:?}.");
Err(Error::new(
span,
format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
))?
}
info
$(let $req_name = $req_name.expect("required field");)*
};
// Handle required fields.
(@gen
[$input:expr] [$($tok:tt)*] [$($req:tt)*]
$key:ident as $name:ident [required] => $parser:expr,
$($rest:tt)*
) => {
parse_ordered_fields!(
@gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
)
};
(@gen
[$input:expr] [$($tok:tt)*] [$($req:tt)*]
$name:ident [required] => $parser:expr,
$($rest:tt)*
) => {
parse_ordered_fields!(
@gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
)
};
// Handle optional fields.
(@gen
[$input:expr] [$($tok:tt)*] [$($req:tt)*]
$key:ident as $name:ident => $parser:expr,
$($rest:tt)*
) => {
parse_ordered_fields!(
@gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
)
};
(@gen
[$input:expr] [$($tok:tt)*] [$($req:tt)*]
$name:ident => $parser:expr,
$($rest:tt)*
) => {
parse_ordered_fields!(
@gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
)
};
(from $input:expr; $($tok:tt)*) => {
parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
}
}
pub(crate) fn module(ts: TokenStream) -> TokenStream {
let mut it = ts.into_iter();
struct Parameter {
name: Ident,
ptype: Ident,
default: Expr,
description: LitStr,
}
let info = ModuleInfo::parse(&mut it);
impl Parse for Parameter {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let name = input.parse()?;
input.parse::<Token![:]>()?;
let ptype = input.parse()?;
let fields;
braced!(fields in input);
parse_ordered_fields! {
from fields;
default [required] => fields.parse()?,
description [required] => fields.parse()?,
}
Ok(Self {
name,
ptype,
default,
description,
})
}
}
pub(crate) struct ModuleInfo {
type_: Type,
license: AsciiLitStr,
name: AsciiLitStr,
authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
description: Option<LitStr>,
alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
params: Option<Punctuated<Parameter, Token![,]>>,
}
impl Parse for ModuleInfo {
fn parse(input: ParseStream<'_>) -> Result<Self> {
parse_ordered_fields!(
from input;
type as type_ [required] => input.parse()?,
name [required] => input.parse()?,
authors => {
let list;
bracketed!(list in input);
Punctuated::parse_terminated(&list)?
},
description => input.parse()?,
license [required] => input.parse()?,
alias => {
let list;
bracketed!(list in input);
Punctuated::parse_terminated(&list)?
},
firmware => {
let list;
bracketed!(list in input);
Punctuated::parse_terminated(&list)?
},
imports_ns => {
let list;
bracketed!(list in input);
Punctuated::parse_terminated(&list)?
},
params => {
let list;
braced!(list in input);
Punctuated::parse_terminated(&list)?
},
);
Ok(ModuleInfo {
type_,
license,
name,
authors,
description,
alias,
firmware,
imports_ns,
params,
})
}
}
pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
let ModuleInfo {
type_,
license,
name,
authors,
description,
alias,
firmware,
imports_ns,
params: _,
} = &info;
// Rust does not allow hyphens in identifiers, use underscore instead.
let ident = info.name.replace('-', "_");
let ident = name.value().replace('-', "_");
let mut modinfo = ModInfoBuilder::new(ident.as_ref());
if let Some(authors) = &info.authors {
if let Some(authors) = authors {
for author in authors {
modinfo.emit("author", author);
modinfo.emit("author", &author.value());
}
}
if let Some(description) = &info.description {
modinfo.emit("description", description);
if let Some(description) = description {
modinfo.emit("description", &description.value());
}
modinfo.emit("license", &info.license);
if let Some(aliases) = &info.alias {
modinfo.emit("license", &license.value());
if let Some(aliases) = alias {
for alias in aliases {
modinfo.emit("alias", alias);
modinfo.emit("alias", &alias.value());
}
}
if let Some(firmware) = &info.firmware {
if let Some(firmware) = firmware {
for fw in firmware {
modinfo.emit("firmware", fw);
modinfo.emit("firmware", &fw.value());
}
}
if let Some(imports) = &info.imports_ns {
if let Some(imports) = imports_ns {
for ns in imports {
modinfo.emit("import_ns", ns);
modinfo.emit("import_ns", &ns.value());
}
}
@ -364,182 +478,181 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream {
modinfo.emit_params(&info);
format!(
"
/// The module name.
///
/// Used by the printing macros, e.g. [`info!`].
const __LOG_PREFIX: &[u8] = b\"{name}\\0\";
let modinfo_ts = modinfo.ts;
let params_ts = modinfo.param_ts;
// SAFETY: `__this_module` is constructed by the kernel at load time and will not be
// freed until the module is unloaded.
#[cfg(MODULE)]
static THIS_MODULE: ::kernel::ThisModule = unsafe {{
extern \"C\" {{
static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
}}
let ident_init = format_ident!("__{ident}_init");
let ident_exit = format_ident!("__{ident}_exit");
let ident_initcall = format_ident!("__{ident}_initcall");
let initcall_section = ".initcall6.init";
::kernel::ThisModule::from_ptr(__this_module.get())
}};
#[cfg(not(MODULE))]
static THIS_MODULE: ::kernel::ThisModule = unsafe {{
::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
}};
let global_asm = format!(
r#".section "{initcall_section}", "a"
__{ident}_initcall:
.long __{ident}_init - .
.previous
"#
);
/// The `LocalModule` type is the type of the module created by `module!`,
/// `module_pci_driver!`, `module_platform_driver!`, etc.
type LocalModule = {type_};
let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
impl ::kernel::ModuleMetadata for {type_} {{
const NAME: &'static ::kernel::str::CStr = c\"{name}\";
}}
Ok(quote! {
/// The module name.
///
/// Used by the printing macros, e.g. [`info!`].
const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
// Double nested modules, since then nobody can access the public items inside.
mod __module_init {{
mod __module_init {{
use super::super::{type_};
use pin_init::PinInit;
// SAFETY: `__this_module` is constructed by the kernel at load time and will not be
// freed until the module is unloaded.
#[cfg(MODULE)]
static THIS_MODULE: ::kernel::ThisModule = unsafe {
extern "C" {
static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
};
/// The \"Rust loadable module\" mark.
//
// This may be best done another way later on, e.g. as a new modinfo
// key or a new section. For the moment, keep it simple.
#[cfg(MODULE)]
#[doc(hidden)]
#[used(compiler)]
static __IS_RUST_MODULE: () = ();
::kernel::ThisModule::from_ptr(__this_module.get())
};
static mut __MOD: ::core::mem::MaybeUninit<{type_}> =
::core::mem::MaybeUninit::uninit();
#[cfg(not(MODULE))]
static THIS_MODULE: ::kernel::ThisModule = unsafe {
::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
};
// Loadable modules need to export the `{{init,cleanup}}_module` identifiers.
/// # Safety
///
/// This function must not be called after module initialization, because it may be
/// freed after that completes.
#[cfg(MODULE)]
#[doc(hidden)]
#[no_mangle]
#[link_section = \".init.text\"]
pub unsafe extern \"C\" fn init_module() -> ::kernel::ffi::c_int {{
// SAFETY: This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name.
unsafe {{ __init() }}
}}
/// The `LocalModule` type is the type of the module created by `module!`,
/// `module_pci_driver!`, `module_platform_driver!`, etc.
type LocalModule = #type_;
#[cfg(MODULE)]
#[doc(hidden)]
#[used(compiler)]
#[link_section = \".init.data\"]
static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module;
impl ::kernel::ModuleMetadata for #type_ {
const NAME: &'static ::kernel::str::CStr = #name_cstr;
}
#[cfg(MODULE)]
#[doc(hidden)]
#[no_mangle]
#[link_section = \".exit.text\"]
pub extern \"C\" fn cleanup_module() {{
// SAFETY:
// - This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name,
// - furthermore it is only called after `init_module` has returned `0`
// (which delegates to `__init`).
unsafe {{ __exit() }}
}}
// Double nested modules, since then nobody can access the public items inside.
#[doc(hidden)]
mod __module_init {
mod __module_init {
use pin_init::PinInit;
#[cfg(MODULE)]
#[doc(hidden)]
#[used(compiler)]
#[link_section = \".exit.data\"]
static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module;
/// The "Rust loadable module" mark.
//
// This may be best done another way later on, e.g. as a new modinfo
// key or a new section. For the moment, keep it simple.
#[cfg(MODULE)]
#[used(compiler)]
static __IS_RUST_MODULE: () = ();
// Built-in modules are initialized through an initcall pointer
// and the identifiers need to be unique.
#[cfg(not(MODULE))]
#[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
#[doc(hidden)]
#[link_section = \"{initcall_section}\"]
#[used(compiler)]
pub static __{ident}_initcall: extern \"C\" fn() ->
::kernel::ffi::c_int = __{ident}_init;
static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
::core::mem::MaybeUninit::uninit();
#[cfg(not(MODULE))]
#[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
::core::arch::global_asm!(
r#\".section \"{initcall_section}\", \"a\"
__{ident}_initcall:
.long __{ident}_init - .
.previous
\"#
// Loadable modules need to export the `{init,cleanup}_module` identifiers.
/// # Safety
///
/// This function must not be called after module initialization, because it may be
/// freed after that completes.
#[cfg(MODULE)]
#[no_mangle]
#[link_section = ".init.text"]
pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
// SAFETY: This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name.
unsafe { __init() }
}
#[cfg(MODULE)]
#[used(compiler)]
#[link_section = ".init.data"]
static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
init_module;
#[cfg(MODULE)]
#[no_mangle]
#[link_section = ".exit.text"]
pub extern "C" fn cleanup_module() {
// SAFETY:
// - This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name,
// - furthermore it is only called after `init_module` has returned `0`
// (which delegates to `__init`).
unsafe { __exit() }
}
#[cfg(MODULE)]
#[used(compiler)]
#[link_section = ".exit.data"]
static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
// Built-in modules are initialized through an initcall pointer
// and the identifiers need to be unique.
#[cfg(not(MODULE))]
#[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
#[link_section = #initcall_section]
#[used(compiler)]
pub static #ident_initcall: extern "C" fn() ->
::kernel::ffi::c_int = #ident_init;
#[cfg(not(MODULE))]
#[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
::core::arch::global_asm!(#global_asm);
#[cfg(not(MODULE))]
#[no_mangle]
pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
// SAFETY: This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// placement above in the initcall section.
unsafe { __init() }
}
#[cfg(not(MODULE))]
#[no_mangle]
pub extern "C" fn #ident_exit() {
// SAFETY:
// - This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name,
// - furthermore it is only called after `#ident_init` has
// returned `0` (which delegates to `__init`).
unsafe { __exit() }
}
/// # Safety
///
/// This function must only be called once.
unsafe fn __init() -> ::kernel::ffi::c_int {
let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
&super::super::THIS_MODULE
);
// SAFETY: No data race, since `__MOD` can only be accessed by this module
// and there only `__init` and `__exit` access it. These functions are only
// called once and `__exit` cannot be called before or during `__init`.
match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
Ok(m) => 0,
Err(e) => e.to_errno(),
}
}
#[cfg(not(MODULE))]
#[doc(hidden)]
#[no_mangle]
pub extern \"C\" fn __{ident}_init() -> ::kernel::ffi::c_int {{
// SAFETY: This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// placement above in the initcall section.
unsafe {{ __init() }}
}}
/// # Safety
///
/// This function must
/// - only be called once,
/// - be called after `__init` has been called and returned `0`.
unsafe fn __exit() {
// SAFETY: No data race, since `__MOD` can only be accessed by this module
// and there only `__init` and `__exit` access it. These functions are only
// called once and `__init` was already called.
unsafe {
// Invokes `drop()` on `__MOD`, which should be used for cleanup.
__MOD.assume_init_drop();
}
}
#[cfg(not(MODULE))]
#[doc(hidden)]
#[no_mangle]
pub extern \"C\" fn __{ident}_exit() {{
// SAFETY:
// - This function is inaccessible to the outside due to the double
// module wrapping it. It is called exactly once by the C side via its
// unique name,
// - furthermore it is only called after `__{ident}_init` has
// returned `0` (which delegates to `__init`).
unsafe {{ __exit() }}
}}
#modinfo_ts
}
}
/// # Safety
///
/// This function must only be called once.
unsafe fn __init() -> ::kernel::ffi::c_int {{
let initer =
<{type_} as ::kernel::InPlaceModule>::init(&super::super::THIS_MODULE);
// SAFETY: No data race, since `__MOD` can only be accessed by this module
// and there only `__init` and `__exit` access it. These functions are only
// called once and `__exit` cannot be called before or during `__init`.
match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{
Ok(m) => 0,
Err(e) => e.to_errno(),
}}
}}
/// # Safety
///
/// This function must
/// - only be called once,
/// - be called after `__init` has been called and returned `0`.
unsafe fn __exit() {{
// SAFETY: No data race, since `__MOD` can only be accessed by this module
// and there only `__init` and `__exit` access it. These functions are only
// called once and `__init` was already called.
unsafe {{
// Invokes `drop()` on `__MOD`, which should be used for cleanup.
__MOD.assume_init_drop();
}}
}}
{modinfo}
}}
}}
mod module_parameters {{
{params}
}}
",
type_ = info.type_,
name = info.name,
ident = ident,
modinfo = modinfo.buffer,
params = modinfo.param_buffer,
initcall_section = ".initcall6.init"
)
.parse()
.expect("Error parsing formatted string into token stream.")
mod module_parameters {
#params_ts
}
})
}

View file

@ -1,6 +1,6 @@
// SPDX-License-Identifier: GPL-2.0
use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree};
use proc_macro2::{Delimiter, Group, Ident, Spacing, Span, TokenTree};
fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> {
let mut tokens = tokens.iter();

View file

@ -1,182 +0,0 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT
use proc_macro::{TokenStream, TokenTree};
pub(crate) trait ToTokens {
fn to_tokens(&self, tokens: &mut TokenStream);
}
impl<T: ToTokens> ToTokens for Option<T> {
fn to_tokens(&self, tokens: &mut TokenStream) {
if let Some(v) = self {
v.to_tokens(tokens);
}
}
}
impl ToTokens for proc_macro::Group {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend([TokenTree::from(self.clone())]);
}
}
impl ToTokens for proc_macro::Ident {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend([TokenTree::from(self.clone())]);
}
}
impl ToTokens for TokenTree {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend([self.clone()]);
}
}
impl ToTokens for TokenStream {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend(self.clone());
}
}
/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with
/// the given span.
///
/// This is a similar to the
/// [`quote_spanned!`](https://docs.rs/quote/latest/quote/macro.quote_spanned.html) macro from the
/// `quote` crate but provides only just enough functionality needed by the current `macros` crate.
macro_rules! quote_spanned {
($span:expr => $($tt:tt)*) => {{
let mut tokens = ::proc_macro::TokenStream::new();
{
#[allow(unused_variables)]
let span = $span;
quote_spanned!(@proc tokens span $($tt)*);
}
tokens
}};
(@proc $v:ident $span:ident) => {};
(@proc $v:ident $span:ident #$id:ident $($tt:tt)*) => {
$crate::quote::ToTokens::to_tokens(&$id, &mut $v);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident #(#$id:ident)* $($tt:tt)*) => {
for token in $id {
$crate::quote::ToTokens::to_tokens(&token, &mut $v);
}
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident ( $($inner:tt)* ) $($tt:tt)*) => {
#[allow(unused_mut)]
let mut tokens = ::proc_macro::TokenStream::new();
quote_spanned!(@proc tokens $span $($inner)*);
$v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new(
::proc_macro::Delimiter::Parenthesis,
tokens,
))]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident [ $($inner:tt)* ] $($tt:tt)*) => {
let mut tokens = ::proc_macro::TokenStream::new();
quote_spanned!(@proc tokens $span $($inner)*);
$v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new(
::proc_macro::Delimiter::Bracket,
tokens,
))]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident { $($inner:tt)* } $($tt:tt)*) => {
let mut tokens = ::proc_macro::TokenStream::new();
quote_spanned!(@proc tokens $span $($inner)*);
$v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new(
::proc_macro::Delimiter::Brace,
tokens,
))]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident :: $($tt:tt)*) => {
$v.extend([::proc_macro::Spacing::Joint, ::proc_macro::Spacing::Alone].map(|spacing| {
::proc_macro::TokenTree::Punct(::proc_macro::Punct::new(':', spacing))
}));
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident : $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new(':', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident , $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new(',', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident @ $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('@', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident ! $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('!', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident ; $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident + $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident = $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('=', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident # $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('#', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident & $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Punct(
::proc_macro::Punct::new('&', ::proc_macro::Spacing::Alone),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident _ $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Ident(
::proc_macro::Ident::new("_", $span),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
(@proc $v:ident $span:ident $id:ident $($tt:tt)*) => {
$v.extend([::proc_macro::TokenTree::Ident(
::proc_macro::Ident::new(stringify!($id), $span),
)]);
quote_spanned!(@proc $v $span $($tt)*);
};
}
/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with
/// mixed site span ([`Span::mixed_site()`]).
///
/// This is a similar to the [`quote!`](https://docs.rs/quote/latest/quote/macro.quote.html) macro
/// from the `quote` crate but provides only just enough functionality needed by the current
/// `macros` crate.
///
/// [`Span::mixed_site()`]: https://doc.rust-lang.org/proc_macro/struct.Span.html#method.mixed_site
macro_rules! quote {
($($tt:tt)*) => {
quote_spanned!(::proc_macro::Span::mixed_site() => $($tt)*)
}
}

View file

@ -1,96 +1,105 @@
// SPDX-License-Identifier: GPL-2.0
use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
use std::collections::HashSet;
use std::fmt::Write;
use std::{
collections::HashSet,
iter::Extend, //
};
pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream {
let mut tokens: Vec<_> = ts.into_iter().collect();
use proc_macro2::{
Ident,
TokenStream, //
};
use quote::ToTokens;
use syn::{
parse_quote,
Error,
ImplItem,
Item,
ItemImpl,
ItemTrait,
Result,
TraitItem, //
};
// Scan for the `trait` or `impl` keyword.
let is_trait = tokens
.iter()
.find_map(|token| match token {
TokenTree::Ident(ident) => match ident.to_string().as_str() {
"trait" => Some(true),
"impl" => Some(false),
_ => None,
},
_ => None,
})
.expect("#[vtable] attribute should only be applied to trait or impl block");
fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> {
let mut gen_items = Vec::new();
// Retrieve the main body. The main body should be the last token tree.
let body = match tokens.pop() {
Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
_ => panic!("cannot locate main body of trait or impl block"),
};
gen_items.push(parse_quote! {
/// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
/// attribute when implementing this trait.
const USE_VTABLE_ATTR: ();
});
let mut body_it = body.stream().into_iter();
let mut functions = Vec::new();
let mut consts = HashSet::new();
while let Some(token) = body_it.next() {
match token {
TokenTree::Ident(ident) if ident.to_string() == "fn" => {
let fn_name = match body_it.next() {
Some(TokenTree::Ident(ident)) => ident.to_string(),
// Possibly we've encountered a fn pointer type instead.
_ => continue,
};
functions.push(fn_name);
}
TokenTree::Ident(ident) if ident.to_string() == "const" => {
let const_name = match body_it.next() {
Some(TokenTree::Ident(ident)) => ident.to_string(),
// Possibly we've encountered an inline const block instead.
_ => continue,
};
consts.insert(const_name);
}
_ => (),
}
}
for item in &item.items {
if let TraitItem::Fn(fn_item) = item {
let name = &fn_item.sig.ident;
let gen_const_name = Ident::new(
&format!("HAS_{}", name.to_string().to_uppercase()),
name.span(),
);
let mut const_items;
if is_trait {
const_items = "
/// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
/// attribute when implementing this trait.
const USE_VTABLE_ATTR: ();
"
.to_owned();
for f in functions {
let gen_const_name = format!("HAS_{}", f.to_uppercase());
// Skip if it's declared already -- this allows user override.
if consts.contains(&gen_const_name) {
continue;
}
// We don't know on the implementation-site whether a method is required or provided
// so we have to generate a const for all methods.
write!(
const_items,
"/// Indicates if the `{f}` method is overridden by the implementor.
const {gen_const_name}: bool = false;",
)
.unwrap();
consts.insert(gen_const_name);
}
} else {
const_items = "const USE_VTABLE_ATTR: () = ();".to_owned();
for f in functions {
let gen_const_name = format!("HAS_{}", f.to_uppercase());
if consts.contains(&gen_const_name) {
continue;
}
write!(const_items, "const {gen_const_name}: bool = true;").unwrap();
let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs);
let comment =
format!("Indicates if the `{name}` method is overridden by the implementor.");
gen_items.push(parse_quote! {
#(#cfg_attrs)*
#[doc = #comment]
const #gen_const_name: bool = false;
});
}
}
let new_body = vec![const_items.parse().unwrap(), body.stream()]
.into_iter()
.collect();
tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body)));
tokens.into_iter().collect()
item.items.extend(gen_items);
Ok(item)
}
fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> {
let mut gen_items = Vec::new();
let mut defined_consts = HashSet::new();
// Iterate over all user-defined constants to gather any possible explicit overrides.
for item in &item.items {
if let ImplItem::Const(const_item) = item {
defined_consts.insert(const_item.ident.clone());
}
}
gen_items.push(parse_quote! {
const USE_VTABLE_ATTR: () = ();
});
for item in &item.items {
if let ImplItem::Fn(fn_item) = item {
let name = &fn_item.sig.ident;
let gen_const_name = Ident::new(
&format!("HAS_{}", name.to_string().to_uppercase()),
name.span(),
);
// Skip if it's declared already -- this allows user override.
if defined_consts.contains(&gen_const_name) {
continue;
}
let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs);
gen_items.push(parse_quote! {
#(#cfg_attrs)*
const #gen_const_name: bool = true;
});
}
}
item.items.extend(gen_items);
Ok(item)
}
pub(crate) fn vtable(input: Item) -> Result<TokenStream> {
match input {
Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()),
Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()),
_ => Err(Error::new_spanned(
input,
"`#[vtable]` attribute should only be applied to trait or impl block",
))?,
}
}