proc-macro-workshop:debug-7

godme發表於2022-07-04
// This test case covers one more heuristic that is often worth incorporating
// into derive macros that infer trait bounds. Here we look for the use of an
// associated type of a type parameter.
//
// The generated impl will need to look like:
//
//     impl<T: Trait> Debug for Field<T>
//     where
//         T::Value: Debug,
//     {...}
//
// You can identify associated types as any syn::TypePath in which the first
// path segment is one of the type parameters and there is more than one
// segment.
//
//
// Resources:
//
//   - The relevant types in the input will be represented in this syntax tree
//     node: https://docs.rs/syn/1.0/syn/struct.TypePath.html

use derive_debug::CustomDebug;
use std::fmt::Debug;

pub trait Trait {
    type Value;
}

#[derive(CustomDebug)]
pub struct Field<T: Trait> {
    values: Vec<T::Value>,
}

fn assert_debug<F: Debug>() {}

fn main() {
    // Does not implement Debug, but its associated type does.
    struct Id;

    impl Trait for Id {
        type Value = u8;
    }

    assert_debug::<Field<Id>>();
}

這裡出了一個奇葩的情況,那就是A:: B::C

之前我們對於泛型的限定中,如果有沒有被PhantomData修飾,就會限定T::Debug
但是這裡出現了一個奇葩的情況T::Value,也就是說,我們涉及到了T,但是卻沒有直接使用T,實際使用的是T::Value這種關聯型別。

回顧一下我們對於過濾後的泛型的處理方式

    for generic in generics.params.iter_mut() {
        if let syn::GenericParam::Type(t) = generic {
            let type_param_name = t.ident.to_string();
            if phantom_generic_type_names.contains(&type_param_name)
                && !fields_type_names.contains(&type_param_name)
            {
                continue;
            }
            t.bounds.push(syn::parse_quote!(std::fmt::Debug));
        }
    }

很明顯的出現一個問題:我們會將泛型全部限定為T::Debug,即使它只在宣告中出現,即使它並未參與欄位宣告。

因此,我們當前的任務是:更細粒度的控制泛型約束。
這道題中,我們約束T::Value卻忽略T,因為T並沒有直接參與泛型欄位宣告。

其中涉及兩點

  • 關聯型別的提取
  • where_clause修改

自動遍歷

// common.rs
struct TypePathVisitor {
    interst_generic_type_names: Vec<String>,
    associated_type_names: std::collections::HashMap<String, Vec<syn::TypePath>>,
}
// 需要啟動visit: syn = { version =  "1.0.84", features = ["visit"]}
impl<'ast> syn::visit::Visit<'ast> for TypePathVisitor {
    fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
        // 路徑必定大於1 A::B
        if node.path.segments.len() > 1 {
            // 外部泛型宣告
            let generic_type_name = node.path.segments[0].ident.to_string();
            // 以外部泛型宣告開頭的泛型全路徑
            if self.interst_generic_type_names.contains(&generic_type_name) {
                self.associated_type_names
                    .entry(generic_type_name)
                    .or_insert(vec![])
                    .push(node.clone());
            }
        }
        syn::visit::visit_type_path(self, node);
    }
}

自動遍歷並且收集我們感興趣的資料,可以參考官方文件

// common.rs
pub(crate) fn parse_generic_associated_types(
    ast: &syn::DeriveInput,
) -> std::collections::HashMap<String, Vec<syn::TypePath>> {
    let origin_generic_type_names: Vec<String> = ast
        .generics
        .params
        .iter()
        .filter_map(|f| {
            if let syn::GenericParam::Type(t) = f {
                return Some(t.ident.to_string());
            }
            return None;
        })
        .collect();
    let mut visitor = TypePathVisitor {
        interst_generic_type_names: origin_generic_type_names,
        associated_type_names: std::collections::HashMap::new(),
    };
    visitor.visit_derive_input(ast);
    return visitor.associated_type_names;
}

到這裡,我們已經收集到了以最頂級泛型宣告開頭的其他泛型,後續就是定製where_clause

// solution7.rs
pub(super) fn soution(
    fields: &crate::common::FieldsType,
    origin_ident: &syn::Ident,
    ast: &syn::DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
    let mut origin_field_type_names = vec![];
    let mut phantom_generic_type_names = vec![];
    for field in fields.iter() {
        if let Some(origin_field_type_name) = crate::common::parse_field_type_name(field)? {
            origin_field_type_names.push(origin_field_type_name);
        }
        if let Some(phantom_generic_type_name) =
            crate::common::parse_phantom_generic_type_name(field)?
        {
            phantom_generic_type_names.push(phantom_generic_type_name);
        }
    }
    let associated_generics_type_map = crate::common::parse_generic_associated_types(ast);
    let mut generics = crate::common::parse_generic_type(ast);
    // 限定非PhatomData和非T::Value的限定
    for generic in generics.params.iter_mut() {
        if let syn::GenericParam::Type(t) = generic {
            let type_name = t.ident.to_string();
            if phantom_generic_type_names.contains(&type_name)
                && !origin_field_type_names.contains(&type_name)
            {
                continue;
            }
            if associated_generics_type_map.contains_key(&type_name)
                && !origin_field_type_names.contains(&type_name)
            {
                continue;
            }
            t.bounds.push(syn::parse_quote!(std::fmt::Debug));
        }
    }
    // 自定義where_clause
    generics.make_where_clause();
    for (_, associated_types) in associated_generics_type_map {
        for associated_type in associated_types {
            generics
                .where_clause
                .as_mut()
                .unwrap()
                .predicates
                // 限定關聯的泛型型別
                .push(syn::parse_quote!(#associated_type:std::fmt::Debug));
        }
    }

    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
    let fields_strea_vec = generate_field_stream_vec(fields)?;
    let origin_ident_string = origin_ident.to_string();
    // 照舊
    syn::Result::Ok(quote::quote! {
        impl #impl_generics std::fmt::Debug for #origin_ident #type_generics #where_clause {
            fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
                fmt.debug_struct(#origin_ident_string)
                #(
                    #fields_strea_vec
                )*
                .finish()
            }
        }
    })
}
// 無關的debug欄位設定
fn generate_field_stream_vec(
    fields: &crate::common::FieldsType,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
    fields
        .iter()
        .map(|f| {
            let ident = &f.ident;
            let ident_string = ident.as_ref().unwrap().to_string();
            let mut format = "{:?}".to_string();
            if let Some(customer_format) = crate::common::parse_format(f)? {
                format = customer_format;
            }
            syn::Result::Ok(quote::quote! {
                .field(#ident_string, &format_args!(#format, &self.#ident))
            })
        })
        .collect()
}
mod common;
mod solution2;
mod solution3;
mod solution4;
mod solution56;
mod solution7;
mod solution8;

#[proc_macro_derive(CustomDebug, attributes(debug))]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
    match solution1(&ast) {
        syn::Result::Ok(token_stream) => {
            return proc_macro::TokenStream::from(token_stream);
        }
        syn::Result::Err(e) => {
            return e.into_compile_error().into();
        }
    }
}

fn solution1(ast: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let origin_ident = &ast.ident;
    let fields = crate::common::parse_fields(&ast)?;
    // soluton2
    let _ = solution2::solution(fields, origin_ident)?;

    let _ = solution3::solution(fields, origin_ident)?;

    let _ = solution4::solution(fields, origin_ident, ast)?;

    let _ = solution56::solution(fields, origin_ident, ast)?;

    let token_stream = solution7::soution(fields, origin_ident, ast)?;

    syn::Result::Ok(token_stream)
}

可以cargo expand觀察一下結果

#![feature(prelude_import)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
    type Value;
}
pub struct Field<T: Trait> {
    values: Vec<T::Value>,
}
impl<T: Trait> std::fmt::Debug for Field<T>
where
    T::Value: std::fmt::Debug,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        fmt.debug_struct("Field")
            .field(
                "values",
                &::core::fmt::Arguments::new_v1(
                    &[""],
                    &match (&&self.values,) {
                        _args => [::core::fmt::ArgumentV1::new(
                            _args.0,
                            ::core::fmt::Debug::fmt,
                        )],
                    },
                ),
            )
            .finish()
    }
}
fn assert_debug<F: Debug>() {}
fn main() {
    struct Id;
    impl Trait for Id {
        type Value = u8;
    }
    assert_debug::<Field<Id>>();
}

可以看到where T:: Value: std:: fmt::Debug單獨的限定了T::Value

本作品採用《CC 協議》,轉載必須註明作者和本文連結