// This test case covers one more heuristic that is often worth incorporating
// into derive macros that infer trait bounds. Here we look for the use of an
// associated type of a type parameter.
//
// The generated impl will need to look like:
//
// impl<T: Trait> Debug for Field<T>
// where
// T::Value: Debug,
// {...}
//
// You can identify associated types as any syn::TypePath in which the first
// path segment is one of the type parameters and there is more than one
// segment.
//
//
// Resources:
//
// - The relevant types in the input will be represented in this syntax tree
// node: https://docs.rs/syn/1.0/syn/struct.TypePath.html
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
type Value;
}
#[derive(CustomDebug)]
pub struct Field<T: Trait> {
values: Vec<T::Value>,
}
fn assert_debug<F: Debug>() {}
fn main() {
// Does not implement Debug, but its associated type does.
struct Id;
impl Trait for Id {
type Value = u8;
}
assert_debug::<Field<Id>>();
}
這裡出了一個奇葩的情況,那就是A:: B::C
。
之前我們對於泛型的限定中,如果有沒有被PhantomData
修飾,就會限定T::Debug
。
但是這裡出現了一個奇葩的情況T::Value
,也就是說,我們涉及到了T
,但是卻沒有直接使用T
,實際使用的是T::Value
這種關聯型別。
回顧一下我們對於過濾後的泛型的處理方式
for generic in generics.params.iter_mut() {
if let syn::GenericParam::Type(t) = generic {
let type_param_name = t.ident.to_string();
if phantom_generic_type_names.contains(&type_param_name)
&& !fields_type_names.contains(&type_param_name)
{
continue;
}
t.bounds.push(syn::parse_quote!(std::fmt::Debug));
}
}
很明顯的出現一個問題:我們會將泛型全部限定為T::Debug
,即使它只在宣告中出現,即使它並未參與欄位宣告。
因此,我們當前的任務是:更細粒度的控制泛型約束。
這道題中,我們約束T::Value
卻忽略T
,因為T
並沒有直接參與泛型欄位宣告。
其中涉及兩點
- 關聯型別的提取
where_clause
修改
自動遍歷
// common.rs
struct TypePathVisitor {
interst_generic_type_names: Vec<String>,
associated_type_names: std::collections::HashMap<String, Vec<syn::TypePath>>,
}
// 需要啟動visit: syn = { version = "1.0.84", features = ["visit"]}
impl<'ast> syn::visit::Visit<'ast> for TypePathVisitor {
fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
// 路徑必定大於1 A::B
if node.path.segments.len() > 1 {
// 外部泛型宣告
let generic_type_name = node.path.segments[0].ident.to_string();
// 以外部泛型宣告開頭的泛型全路徑
if self.interst_generic_type_names.contains(&generic_type_name) {
self.associated_type_names
.entry(generic_type_name)
.or_insert(vec![])
.push(node.clone());
}
}
syn::visit::visit_type_path(self, node);
}
}
自動遍歷並且收集我們感興趣的資料,可以參考官方文件。
// common.rs
pub(crate) fn parse_generic_associated_types(
ast: &syn::DeriveInput,
) -> std::collections::HashMap<String, Vec<syn::TypePath>> {
let origin_generic_type_names: Vec<String> = ast
.generics
.params
.iter()
.filter_map(|f| {
if let syn::GenericParam::Type(t) = f {
return Some(t.ident.to_string());
}
return None;
})
.collect();
let mut visitor = TypePathVisitor {
interst_generic_type_names: origin_generic_type_names,
associated_type_names: std::collections::HashMap::new(),
};
visitor.visit_derive_input(ast);
return visitor.associated_type_names;
}
到這裡,我們已經收集到了以最頂級泛型宣告開頭的其他泛型,後續就是定製where_clause
。
// solution7.rs
pub(super) fn soution(
fields: &crate::common::FieldsType,
origin_ident: &syn::Ident,
ast: &syn::DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
let mut origin_field_type_names = vec![];
let mut phantom_generic_type_names = vec![];
for field in fields.iter() {
if let Some(origin_field_type_name) = crate::common::parse_field_type_name(field)? {
origin_field_type_names.push(origin_field_type_name);
}
if let Some(phantom_generic_type_name) =
crate::common::parse_phantom_generic_type_name(field)?
{
phantom_generic_type_names.push(phantom_generic_type_name);
}
}
let associated_generics_type_map = crate::common::parse_generic_associated_types(ast);
let mut generics = crate::common::parse_generic_type(ast);
// 限定非PhatomData和非T::Value的限定
for generic in generics.params.iter_mut() {
if let syn::GenericParam::Type(t) = generic {
let type_name = t.ident.to_string();
if phantom_generic_type_names.contains(&type_name)
&& !origin_field_type_names.contains(&type_name)
{
continue;
}
if associated_generics_type_map.contains_key(&type_name)
&& !origin_field_type_names.contains(&type_name)
{
continue;
}
t.bounds.push(syn::parse_quote!(std::fmt::Debug));
}
}
// 自定義where_clause
generics.make_where_clause();
for (_, associated_types) in associated_generics_type_map {
for associated_type in associated_types {
generics
.where_clause
.as_mut()
.unwrap()
.predicates
// 限定關聯的泛型型別
.push(syn::parse_quote!(#associated_type:std::fmt::Debug));
}
}
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
let fields_strea_vec = generate_field_stream_vec(fields)?;
let origin_ident_string = origin_ident.to_string();
// 照舊
syn::Result::Ok(quote::quote! {
impl #impl_generics std::fmt::Debug for #origin_ident #type_generics #where_clause {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#origin_ident_string)
#(
#fields_strea_vec
)*
.finish()
}
}
})
}
// 無關的debug欄位設定
fn generate_field_stream_vec(
fields: &crate::common::FieldsType,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
fields
.iter()
.map(|f| {
let ident = &f.ident;
let ident_string = ident.as_ref().unwrap().to_string();
let mut format = "{:?}".to_string();
if let Some(customer_format) = crate::common::parse_format(f)? {
format = customer_format;
}
syn::Result::Ok(quote::quote! {
.field(#ident_string, &format_args!(#format, &self.#ident))
})
})
.collect()
}
mod common;
mod solution2;
mod solution3;
mod solution4;
mod solution56;
mod solution7;
mod solution8;
#[proc_macro_derive(CustomDebug, attributes(debug))]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(input as syn::DeriveInput);
match solution1(&ast) {
syn::Result::Ok(token_stream) => {
return proc_macro::TokenStream::from(token_stream);
}
syn::Result::Err(e) => {
return e.into_compile_error().into();
}
}
}
fn solution1(ast: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let origin_ident = &ast.ident;
let fields = crate::common::parse_fields(&ast)?;
// soluton2
let _ = solution2::solution(fields, origin_ident)?;
let _ = solution3::solution(fields, origin_ident)?;
let _ = solution4::solution(fields, origin_ident, ast)?;
let _ = solution56::solution(fields, origin_ident, ast)?;
let token_stream = solution7::soution(fields, origin_ident, ast)?;
syn::Result::Ok(token_stream)
}
可以cargo expand
觀察一下結果
#![feature(prelude_import)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
type Value;
}
pub struct Field<T: Trait> {
values: Vec<T::Value>,
}
impl<T: Trait> std::fmt::Debug for Field<T>
where
T::Value: std::fmt::Debug,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("Field")
.field(
"values",
&::core::fmt::Arguments::new_v1(
&[""],
&match (&&self.values,) {
_args => [::core::fmt::ArgumentV1::new(
_args.0,
::core::fmt::Debug::fmt,
)],
},
),
)
.finish()
}
}
fn assert_debug<F: Debug>() {}
fn main() {
struct Id;
impl Trait for Id {
type Value = u8;
}
assert_debug::<Field<Id>>();
}
可以看到where T:: Value: std:: fmt::Debug
單獨的限定了T::Value
。
本作品採用《CC 協議》,轉載必須註明作者和本文連結