当前位置: 首页 > article >正文

Rust从入门到精通之精通篇:25.过程宏高级应用

过程宏高级应用

在 Rust 精通篇中,我们将深入探索 Rust 的过程宏系统。过程宏是 Rust 元编程的强大工具,允许你在编译时生成代码。在本章中,我们将学习如何创建各种类型的过程宏,包括派生宏、属性宏和函数宏,并探索它们的高级应用场景。

过程宏基础回顾

在深入高级主题之前,让我们简要回顾 Rust 的过程宏系统:

// 在 Cargo.toml 中声明过程宏 crate
// [lib]
// proc-macro = true

use proc_macro::TokenStream;

#[proc_macro_derive(MyDerive)]
pub fn my_derive(input: TokenStream) -> TokenStream {
    // 解析输入的 TokenStream
    // 生成新的代码
    // 返回生成的代码作为 TokenStream
    "fn generated_function() { println!(\"Hello from generated function!\"); }".parse().unwrap()
}

Rust 支持三种类型的过程宏:

  1. 派生宏(Derive Macros):使用 #[derive(MacroName)] 语法,为结构体或枚举自动实现特征
  2. 属性宏(Attribute Macros):使用 #[macro_name] 语法,修改或扩展带注解的项
  3. 函数宏(Function-like Macros):使用 macro_name!(...) 语法,类似于声明宏但功能更强大

过程宏开发工具

syn 和 quote 库

开发过程宏通常需要使用两个关键库:

  • syn:用于解析 Rust 代码为语法树
  • quote:用于将语法树转换回 Rust 代码
// Cargo.toml
// [dependencies]
// syn = { version = "1.0", features = ["full"] }
// quote = "1.0"
// proc-macro2 = "1.0"

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(HelloWorld)]
pub fn hello_world_derive(input: TokenStream) -> TokenStream {
    // 解析输入为语法树
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;
    
    // 使用 quote! 生成代码
    let expanded = quote! {
        impl HelloWorld for #name {
            fn hello_world() {
                println!("Hello, World! My name is {}", stringify!(#name));
            }
        }
    };
    
    // 将生成的代码转换为 TokenStream
    TokenStream::from(expanded)
}

proc-macro2 库

proc-macro2 提供了与标准库 proc_macro 兼容的类型,但可以在过程宏 crate 之外使用,便于测试:

use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::Ident;

// 可以在非过程宏 crate 中测试
fn generate_impl(name: &str) -> TokenStream2 {
    let ident = Ident::new(name, Span::call_site());
    quote! {
        impl #ident {
            fn new() -> Self {
                Self {}
            }
        }
    }
}

高级派生宏

自定义派生宏实现序列化

下面是一个实现自定义序列化的派生宏示例:

use proc_macro::TokenStream;
use quote::{quote, format_ident};
use syn::{parse_macro_input, Data, DeriveInput, Fields};

#[proc_macro_derive(Serialize)]
pub fn serialize_derive(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;
    
    // 获取结构体字段
    let fields = match input.data {
        Data::Struct(data) => match data.fields {
            Fields::Named(fields) => fields.named,
            _ => panic!("Serialize only supports structs with named fields"),
        },
        _ => panic!("Serialize only supports structs"),
    };
    
    // 为每个字段生成序列化代码
    let field_serializations = fields.iter().map(|field| {
        let field_name = field.ident.as_ref().unwrap();
        let field_name_str = field_name.to_string();
        quote! {
            serialized.push_str(&format!("\"{}\": {}, ", #field_name_str, self.#field_name.serialize()));
        }
    });
    
    // 生成实现代码
    let expanded = quote! {
        impl Serialize for #name {
            fn serialize(&self) -> String {
                let mut serialized = String::from("{");
                #(#field_serializations)*
                // 移除最后的逗号和空格
                if serialized.len() > 1 {
                    serialized.truncate(serialized.len() - 2);
                }
                serialized.push_str("}");
                serialized
            }
        }
    };
    
    TokenStream::from(expanded)
}

带参数的派生宏

我们可以使用属性参数扩展派生宏的功能:

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Meta, NestedMeta, Lit};

#[proc_macro_derive(Builder, attributes(builder))]
pub fn builder_derive(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;
    let builder_name = format_ident!("{}Builder", name);
    
    // 处理结构体字段
    let fields = match input.data {
        syn::Data::Struct(data) => match data.fields {
            syn::Fields::Named(fields) => fields.named,
            _ => panic!("Builder only supports structs with named fields"),
        },
        _ => panic!("Builder only supports structs"),
    };
    
    // 提取字段信息和属性
    let field_defs = fields.iter().map(|field| {
        let field_name = field.ident.as_ref().unwrap();
        let field_type = &field.ty;
        
        // 检查字段是否有 #[builder(default = "...")] 属性
        let default_value = field.attrs.iter()
            .filter(|attr| attr.path.is_ident("builder"))
            .filter_map(|attr| attr.parse_meta().ok())
            .filter_map(|meta| match meta {
                Meta::List(list) => Some(list.nested),
                _ => None,
            })
            .flat_map(|nested| nested.into_iter())
            .filter_map(|nested| match nested {
                NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("default") => {
                    match nv.lit {
                        Lit::Str(lit) => Some(lit.value()),
                        _ => None,
                    }
                },
                _ => None,
            })
            .next();
        
        // 根据是否有默认值生成不同的字段定义
        if let Some(default) = default_value {
            quote! {
                #field_name: Option<#field_type>,
            }
        } else {
            quote! {
                #field_name: Option<#field_type>,
            }
        }
    });
    
    // 生成 setter 方法
    let setters = fields.iter().map(|field| {
        let field_name = field.ident.as_ref().unwrap();
        let field_type = &field.ty;
        
        quote! {
            pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
                self.#field_name = Some(value);
                self
            }
        }
    });
    
    // 生成 build 方法
    let build_fields = fields.iter().map(|field| {
        let field_name = field.ident.as_ref().unwrap();
        
        // 检查字段是否有默认值
        let default_value = field.attrs.iter()
            .filter(|attr| attr.path.is_ident("builder"))
            .filter_map(|attr| attr.parse_meta().ok())
            .filter_map(|meta| match meta {
                Meta::List(list) => Some(list.nested),
                _ => None,
            })
            .flat_map(|nested| nested.into_iter())
            .filter_map(|nested| match nested {
                NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("default") => {
                    match nv.lit {
                        Lit::Str(lit) => Some(lit.value()),
                        _ => None,
                    }
                },
                _ => None,
            })
            .next();
        
        if let Some(default) = default_value {
            quote! {
                #field_name: self.#field_name.clone().unwrap_or_else(|| #default),
            }
        } else {
            quote! {
                #field_name: self.#field_name.clone().ok_or(format!("Field {} is required", stringify!(#field_name)))?,
            }
        }
    });
    
    // 生成完整的 Builder 实现
    let expanded = quote! {
        #[derive(Clone, Default)]
        pub struct #builder_name {
            #(#field_defs)*
        }
        
        impl #builder_name {
            pub fn new() -> Self {
                Default::default()
            }
            
            #(#setters)*
            
            pub fn build(&self) -> Result<#name, String> {
                Ok(#name {
                    #(#build_fields)*
                })
            }
        }
        
        impl #name {
            pub fn builder() -> #builder_name {
                #builder_name::new()
            }
        }
    };
    
    TokenStream::from(expanded)
}

高级属性宏

自定义路由属性宏

下面是一个用于 Web 框架的路由属性宏示例:

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, AttributeArgs, NestedMeta, Lit, LitStr};

#[proc_macro_attribute]
pub fn route(args: TokenStream, input: TokenStream) -> TokenStream {
    let args = parse_macro_input!(args as AttributeArgs);
    let input_fn = parse_macro_input!(input as ItemFn);
    
    // 提取函数信息
    let fn_name = &input_fn.sig.ident;
    let fn_block = &input_fn.block;
    
    // 解析路由参数
    let mut method = String::from("GET");
    let mut path = String::new();
    
    for arg in args {
        match arg {
            NestedMeta::Meta(meta) => {
                // 处理 method = "POST" 形式
                if let syn::Meta::NameValue(nv) = meta {
                    if nv.path.is_ident("method") {
                        if let Lit::Str(lit) = nv.lit {
                            method = lit.value();
                        }
                    }
                }
            },
            NestedMeta::Lit(lit) => {
                // 处理 "/users" 形式
                if let Lit::Str(lit) = lit {
                    path = lit.value();
                }
            },
        }
    }
    
    // 生成路由注册代码
    let expanded = quote! {
        #[allow(non_camel_case_types)]
        pub struct #fn_name;
        
        impl Route for #fn_name {
            fn method(&self) -> &'static str {
                #method
            }
            
            fn path(&self) -> &'static str {
                #path
            }
            
            fn handler(&self, req: Request) -> Response {
                let handler = || #fn_block;
                handler()
            }
        }
        
        // 注册路由
        inventory::submit! {
            RouteItem {
                route: Box::new(#fn_name)
            }
        }
    };
    
    TokenStream::from(expanded)
}

条件编译属性宏

创建一个用于条件编译的属性宏:

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, AttributeArgs, NestedMeta, Lit};

#[proc_macro_attribute]
pub fn platform(args: TokenStream, input: TokenStream) -> TokenStream {
    let args = parse_macro_input!(args as AttributeArgs);
    let input_fn = parse_macro_input!(input as ItemFn);
    
    // 提取目标平台
    let mut target_platforms = Vec::new();
    
    for arg in args {
        if let NestedMeta::Lit(Lit::Str(lit)) = arg {
            target_platforms.push(lit.value());
        }
    }
    
    // 获取当前平台
    let current_platform = if cfg!(target_os = "windows") {
        "windows"
    } else if cfg!(target_os = "macos") {
        "macos"
    } else if cfg!(target_os = "linux") {
        "linux"
    } else {
        "unknown"
    };
    
    // 检查当前平台是否在目标平台列表中
    let should_include = target_platforms.iter().any(|p| p == current_platform);
    
    // 根据条件生成代码
    let output = if should_include {
        // 包含原始函数
        quote! { #input_fn }
    } else {
        // 生成一个空的存根函数
        let fn_name = &input_fn.sig.ident;
        let fn_args = &input_fn.sig.inputs;
        let fn_output = &input_fn.sig.output;
        
        quote! {
            #[allow(unused_variables)]
            fn #fn_name(#fn_args) #fn_output {
                panic!("Function not available on this platform");
            }
        }
    };
    
    TokenStream::from(output)
}

高级函数宏

SQL 查询构建宏

创建一个用于构建类型安全 SQL 查询的函数宏:

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, LitStr, parse::Parse, parse::ParseStream, Token, Ident, Result as SynResult};

// 定义查询参数解析器
struct SqlQuery {
    query: LitStr,
    params: Vec<Ident>,
}

impl Parse for SqlQuery {
    fn parse(input: ParseStream) -> SynResult<Self> {
        let query = input.parse::<LitStr>()?;
        let mut params = Vec::new();
        
        // 解析参数列表
        if input.peek(Token![,]) {
            input.parse::<Token![,]>()?;
            
            while !input.is_empty() {
                params.push(input.parse::<Ident>()?);
                
                if input.peek(Token![,]) {
                    input.parse::<Token![,]>()?;
                } else {
                    break;
                }
            }
        }
        
        Ok(SqlQuery { query, params })
    }
}

#[proc_macro]
pub fn sql(input: TokenStream) -> TokenStream {
    let SqlQuery { query, params } = parse_macro_input!(input as SqlQuery);
    let query_string = query.value();
    
    // 解析查询字符串,查找参数占位符
    let mut param_positions = Vec::new();
    let mut modified_query = String::new();
    let mut current_pos = 0;
    
    for (i, c) in query_string.char_indices() {
        if c == '?' && i + 1 < query_string.len() {
            if let Some(param_index) = query_string[i+1..].chars().next().and_then(|c| c.to_digit(10)) {
                param_positions.push((current_pos, param_index as usize - 1));
                modified_query.push('?');
                current_pos += 1;
                // 跳过数字
                continue;
            }
        }
        modified_query.push(c);
    }
    
    // 生成参数绑定代码
    let param_bindings = param_positions.iter().map(|(pos, idx)| {
        if *idx < params.len() {
            let param = &params[*idx];
            quote! {
                query.bind_param(#pos, &#param);
            }
        } else {
            quote! {
                compile_error!("Parameter index out of bounds");
            }
        }
    });
    
    // 生成最终代码
    let expanded = quote! {
        {
            let mut query = Query::new(#modified_query);
            #(#param_bindings)*
            query
        }
    };
    
    TokenStream::from(expanded)
}

测试生成宏

创建一个自动生成测试用例的函数宏:

use proc_macro::TokenStream;
use quote::{quote, format_ident};
use syn::{parse_macro_input, LitStr, parse::Parse, parse::ParseStream, Token, Ident, Result as SynResult, Expr};

// 定义测试用例结构
struct TestCase {
    name: Ident,
    inputs: Vec<Expr>,
    expected: Expr,
}

struct TestCases {
    function_name: Ident,
    cases: Vec<TestCase>,
}

impl Parse for TestCases {
    fn parse(input: ParseStream) -> SynResult<Self> {
        let function_name = input.parse::<Ident>()?;
        input.parse::<Token![,]>()?;
        
        let mut cases = Vec::new();
        
        while !input.is_empty() {
            // 解析测试名称
            let name = input.parse::<Ident>()?;
            input.parse::<Token![:]>()?;
            
            // 解析输入参数
            let content;
            syn::parenthesized!(content in input);
            let mut inputs = Vec::new();
            
            while !content.is_empty() {
                inputs.push(content.parse::<Expr>()?);
                
                if content.peek(Token![,]) {
                    content.parse::<Token![,]>()?;
                } else {
                    break;
                }
            }
            
            // 解析期望输出
            input.parse::<Token![=>]>()?;
            let expected = input.parse::<Expr>()?;
            
            cases.push(TestCase { name, inputs, expected });
            
            if input.peek(Token![,]) {
                input.parse::<Token![,]>()?;
            } else {
                break;
            }
        }
        
        Ok(TestCases { function_name, cases })
    }
}

#[proc_macro]
pub fn test_cases(input: TokenStream) -> TokenStream {
    let TestCases { function_name, cases } = parse_macro_input!(input as TestCases);
    
    // 为每个测试用例生成测试函数
    let test_functions = cases.iter().map(|case| {
        let test_name = format_ident!("test_{}_{}"

http://www.kler.cn/a/611998.html

相关文章:

  • MySQL 设置允许远程连接完整指南:安全与效率并重
  • 《Python实战进阶》No37: 强化学习入门:Q-Learning 与 DQN-加餐版1 Q-Learning算法可视化
  • 【前端vue】理解VUE前端框架中src下的api文件夹与views文件夹
  • 蓝桥杯(电子类)嵌入式第十一届设计与开发科目模拟试题
  • AI辅助下基于ArcGIS Pro的SWAT模型全流程高效建模实践与深度进阶应用
  • 面试题:RocketMQ 如何保证消息的顺序性
  • 04 单目标定实战示例
  • HarmonyOS之深入解析如何根据url下载pdf文件并且在本地显示和预览
  • ubuntu24 部署vnc server 使用VNC Viewer连接
  • Scala基础语法和简介
  • Cent OS7+Docker+Dify
  • SpringBoot实战——详解JdbcTemplate操作存储过程
  • 第十六届蓝桥杯模拟二(串口通信)
  • 数据结构每日一题day3(顺序表)★★★★★
  • 国际机构Gartner发布2025年网络安全趋势
  • 微软KBLaM:当语言模型学会“查字典”的下一代AI革命
  • 信息系统安全保护等级详解
  • 一文读懂Python之json模块(33)
  • Axure RP设计软件中的各种函数:包括数字、数学、字符串、时间及中继器函数,详细解释了各函数的用途、参数及其应用场景。
  • SpringMVC请求与响应深度解析:从核心原理到高级实践