borsh_derive/internals/serialize/structs/
mod.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::quote;
3use syn::{Fields, ItemStruct, Path};
4
5use crate::internals::{
6    attributes::{field, BoundType},
7    generics, serialize,
8};
9
10pub fn process(input: &ItemStruct, cratename: Path) -> syn::Result<TokenStream2> {
11    let name = &input.ident;
12    let generics = generics::without_defaults(&input.generics);
13    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
14    let mut where_clause = generics::default_where(where_clause);
15    let mut body = TokenStream2::new();
16    let mut generics_output = serialize::GenericsOutput::new(&generics);
17    match &input.fields {
18        Fields::Named(fields) => {
19            for field in &fields.named {
20                let field_id = serialize::FieldId::Struct(field.ident.clone().unwrap());
21
22                process_field(field, field_id, &cratename, &mut generics_output, &mut body)?;
23            }
24        }
25        Fields::Unnamed(fields) => {
26            for (field_idx, field) in fields.unnamed.iter().enumerate() {
27                let field_id = serialize::FieldId::new_struct_unnamed(field_idx)?;
28
29                process_field(field, field_id, &cratename, &mut generics_output, &mut body)?;
30            }
31        }
32        Fields::Unit => {}
33    }
34    generics_output.extend(&mut where_clause, &cratename);
35
36    Ok(quote! {
37        impl #impl_generics #cratename::ser::BorshSerialize for #name #ty_generics #where_clause {
38            fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> {
39                #body
40                Ok(())
41            }
42        }
43    })
44}
45
46fn process_field(
47    field: &syn::Field,
48    field_id: serialize::FieldId,
49    cratename: &Path,
50    generics: &mut serialize::GenericsOutput,
51    body: &mut TokenStream2,
52) -> syn::Result<()> {
53    let parsed = field::Attributes::parse(&field.attrs)?;
54    let needs_bounds_derive = parsed.needs_bounds_derive(BoundType::Serialize);
55
56    generics
57        .overrides
58        .extend(parsed.collect_bounds(BoundType::Serialize));
59    if !parsed.skip {
60        let delta = field_id.serialize_output(cratename, parsed.serialize_with);
61        body.extend(delta);
62
63        if needs_bounds_derive {
64            generics.serialize_visitor.visit_field(field);
65        }
66    }
67    Ok(())
68}
69
70#[cfg(test)]
71mod tests {
72    use crate::internals::test_helpers::{
73        default_cratename, local_insta_assert_debug_snapshot, local_insta_assert_snapshot,
74        pretty_print_syn_str,
75    };
76
77    use super::*;
78
79    #[test]
80    fn simple_struct() {
81        let item_struct: ItemStruct = syn::parse2(quote! {
82            struct A {
83                x: u64,
84                y: String,
85            }
86        })
87        .unwrap();
88
89        let actual = process(&item_struct, default_cratename()).unwrap();
90
91        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
92    }
93
94    #[test]
95    fn simple_struct_with_custom_crate() {
96        let item_struct: ItemStruct = syn::parse2(quote! {
97            struct A {
98                x: u64,
99                y: String,
100            }
101        })
102        .unwrap();
103
104        let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
105        let actual = process(&item_struct, crate_).unwrap();
106
107        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
108    }
109
110    #[test]
111    fn simple_generics() {
112        let item_struct: ItemStruct = syn::parse2(quote! {
113            struct A<K, V> {
114                x: HashMap<K, V>,
115                y: String,
116            }
117        })
118        .unwrap();
119
120        let actual = process(&item_struct, default_cratename()).unwrap();
121        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
122    }
123
124    #[test]
125    fn simple_generic_tuple_struct() {
126        let item_struct: ItemStruct = syn::parse2(quote! {
127            struct TupleA<T>(T, u32);
128        })
129        .unwrap();
130
131        let actual = process(&item_struct, default_cratename()).unwrap();
132        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
133    }
134
135    #[test]
136    fn bound_generics() {
137        let item_struct: ItemStruct = syn::parse2(quote! {
138            struct A<K: Key, V> where V: Value {
139                x: HashMap<K, V>,
140                y: String,
141            }
142        })
143        .unwrap();
144
145        let actual = process(&item_struct, default_cratename()).unwrap();
146        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
147    }
148
149    #[test]
150    fn recursive_struct() {
151        let item_struct: ItemStruct = syn::parse2(quote! {
152            struct CRecC {
153                a: String,
154                b: HashMap<String, CRecC>,
155            }
156        })
157        .unwrap();
158
159        let actual = process(&item_struct, default_cratename()).unwrap();
160
161        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
162    }
163
164    #[test]
165    fn generic_tuple_struct_borsh_skip1() {
166        let item_struct: ItemStruct = syn::parse2(quote! {
167            struct G<K, V, U> (
168                #[borsh(skip)]
169                HashMap<K, V>,
170                U,
171            );
172        })
173        .unwrap();
174
175        let actual = process(&item_struct, default_cratename()).unwrap();
176
177        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
178    }
179
180    #[test]
181    fn generic_tuple_struct_borsh_skip2() {
182        let item_struct: ItemStruct = syn::parse2(quote! {
183            struct G<K, V, U> (
184                HashMap<K, V>,
185                #[borsh(skip)]
186                U,
187            );
188        })
189        .unwrap();
190
191        let actual = process(&item_struct, default_cratename()).unwrap();
192
193        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
194    }
195
196    #[test]
197    fn generic_named_fields_struct_borsh_skip() {
198        let item_struct: ItemStruct = syn::parse2(quote! {
199            struct G<K, V, U> {
200                #[borsh(skip)]
201                x: HashMap<K, V>,
202                y: U,
203            }
204        })
205        .unwrap();
206
207        let actual = process(&item_struct, default_cratename()).unwrap();
208
209        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
210    }
211
212    #[test]
213    fn generic_associated_type() {
214        let item_struct: ItemStruct = syn::parse2(quote! {
215            struct Parametrized<T, V>
216            where
217                T: TraitName,
218            {
219                field: T::Associated,
220                another: V,
221            }
222        })
223        .unwrap();
224
225        let actual = process(&item_struct, default_cratename()).unwrap();
226
227        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
228    }
229
230    #[test]
231    fn generic_serialize_bound() {
232        let item_struct: ItemStruct = syn::parse2(quote! {
233            struct C<T: Debug, U> {
234                a: String,
235                #[borsh(bound(serialize =
236                    "T: borsh::ser::BorshSerialize + PartialOrd,
237                     U: borsh::ser::BorshSerialize"
238                ))]
239                b: HashMap<T, U>,
240            }
241        })
242        .unwrap();
243
244        let actual = process(&item_struct, default_cratename()).unwrap();
245
246        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
247    }
248
249    #[test]
250    fn override_generic_associated_type_wrong_derive() {
251        let item_struct: ItemStruct = syn::parse2(quote! {
252            struct Parametrized<T, V> where T: TraitName {
253                #[borsh(bound(serialize =
254                    "<T as TraitName>::Associated: borsh::ser::BorshSerialize"
255                ))]
256                field: <T as TraitName>::Associated,
257                another: V,
258            }
259        })
260        .unwrap();
261
262        let actual = process(&item_struct, default_cratename()).unwrap();
263
264        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
265    }
266
267    #[test]
268    fn check_serialize_with_attr() {
269        let item_struct: ItemStruct = syn::parse2(quote! {
270            struct A<K: Ord, V> {
271                #[borsh(serialize_with = "third_party_impl::serialize_third_party")]
272                x: ThirdParty<K, V>,
273                y: u64,
274            }
275        })
276        .unwrap();
277
278        let actual = process(&item_struct, default_cratename()).unwrap();
279
280        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
281    }
282
283    #[test]
284    fn check_serialize_with_skip_conflict() {
285        let item_struct: ItemStruct = syn::parse2(quote! {
286            struct A<K: Ord, V> {
287                #[borsh(skip,serialize_with = "third_party_impl::serialize_third_party")]
288                x: ThirdParty<K, V>,
289                y: u64,
290            }
291        })
292        .unwrap();
293
294        let actual = process(&item_struct, default_cratename());
295
296        let err = match actual {
297            Ok(..) => unreachable!("expecting error here"),
298            Err(err) => err,
299        };
300        local_insta_assert_debug_snapshot!(err);
301    }
302}