enum_ordinalize/
lib.rs

1/*!
2# Enum Ordinalize
3
4This crates provides a procedural macro to let enums not only get its variants' ordinal but also be constructed from an ordinal.
5
6## Ordinalize
7
8Use `#[derive(Ordinalize)]` to make an enum (which must only has unit variants) have `from_ordinal_unsafe`, `from_ordinal`, `variants`, and `variant_count` associated functions and a `ordinal` method.
9
10```rust
11use enum_ordinalize::Ordinalize;
12
13#[derive(Debug, PartialEq, Eq, Ordinalize)]
14enum MyEnum {
15    Zero,
16    One,
17    Two,
18}
19
20assert_eq!(0i8, MyEnum::Zero.ordinal());
21assert_eq!(1i8, MyEnum::One.ordinal());
22assert_eq!(2i8, MyEnum::Two.ordinal());
23
24assert_eq!(Some(MyEnum::Zero), MyEnum::from_ordinal(0i8));
25assert_eq!(Some(MyEnum::One), MyEnum::from_ordinal(1i8));
26assert_eq!(Some(MyEnum::Two), MyEnum::from_ordinal(2i8));
27
28assert_eq!(MyEnum::Zero, unsafe { MyEnum::from_ordinal_unsafe(0i8) });
29assert_eq!(MyEnum::One, unsafe { MyEnum::from_ordinal_unsafe(1i8) });
30assert_eq!(MyEnum::Two, unsafe { MyEnum::from_ordinal_unsafe(2i8) });
31```
32
33### Get Variants
34
35```rust
36use enum_ordinalize::Ordinalize;
37
38#[derive(Debug, PartialEq, Eq, Ordinalize)]
39enum MyEnum {
40    Zero,
41    One,
42    Two,
43}
44
45assert_eq!([MyEnum::Zero, MyEnum::One, MyEnum::Two], MyEnum::variants());
46assert_eq!(3, MyEnum::variant_count());
47```
48
49`variants` and `variant_count` are constant functions.
50
51## The (Ordinal) Size of an Enum
52
53The ordinal value is an integer whose size is determined by the enum itself. The larger (or the smaller if it's negative) the variants' values are, the bigger the enum size is.
54
55For example,
56
57```rust
58use enum_ordinalize::Ordinalize;
59
60#[derive(Debug, PartialEq, Eq, Ordinalize)]
61enum MyEnum {
62    Zero,
63    One,
64    Two,
65    Thousand = 1000,
66}
67
68assert_eq!(0i16, MyEnum::Zero.ordinal());
69assert_eq!(1i16, MyEnum::One.ordinal());
70assert_eq!(2i16, MyEnum::Two.ordinal());
71
72assert_eq!(Some(MyEnum::Zero), MyEnum::from_ordinal(0i16));
73assert_eq!(Some(MyEnum::One), MyEnum::from_ordinal(1i16));
74assert_eq!(Some(MyEnum::Two), MyEnum::from_ordinal(2i16));
75
76assert_eq!(MyEnum::Zero, unsafe { MyEnum::from_ordinal_unsafe(0i16) });
77assert_eq!(MyEnum::One, unsafe { MyEnum::from_ordinal_unsafe(1i16) });
78assert_eq!(MyEnum::Two, unsafe { MyEnum::from_ordinal_unsafe(2i16) });
79```
80
81In order to store `1000`, the size of `MyEnum` grows. Thus, the ordinal is in `i16` instead of `i8`.
82
83You can use the `#[repr(type)]` attribute to control the size explicitly. For instance,
84
85```rust
86use enum_ordinalize::Ordinalize;
87
88#[derive(Debug, PartialEq, Eq, Ordinalize)]
89#[repr(usize)]
90enum MyEnum {
91    Zero,
92    One,
93    Two,
94    Thousand = 1000,
95}
96
97assert_eq!(0usize, MyEnum::Zero.ordinal());
98assert_eq!(1usize, MyEnum::One.ordinal());
99assert_eq!(2usize, MyEnum::Two.ordinal());
100
101assert_eq!(Some(MyEnum::Zero), MyEnum::from_ordinal(0usize));
102assert_eq!(Some(MyEnum::One), MyEnum::from_ordinal(1usize));
103assert_eq!(Some(MyEnum::Two), MyEnum::from_ordinal(2usize));
104
105assert_eq!(MyEnum::Zero, unsafe { MyEnum::from_ordinal_unsafe(0usize) });
106assert_eq!(MyEnum::One, unsafe { MyEnum::from_ordinal_unsafe(1usize) });
107assert_eq!(MyEnum::Two, unsafe { MyEnum::from_ordinal_unsafe(2usize) });
108```
109
110## Useful Increment
111
112The integers represented by variants are extended in successive increments and can be set explicitly from anywhere.
113
114```rust
115use enum_ordinalize::Ordinalize;
116
117#[derive(Debug, PartialEq, Eq, Ordinalize)]
118enum MyEnum {
119    Two   = 2,
120    Three,
121    Four,
122    Eight = 8,
123    Nine,
124    NegativeTen = -10,
125    NegativeNine,
126}
127
128assert_eq!(4i8, MyEnum::Four.ordinal());
129assert_eq!(9i8, MyEnum::Nine.ordinal());
130assert_eq!(-9i8, MyEnum::NegativeNine.ordinal());
131
132assert_eq!(Some(MyEnum::Four), MyEnum::from_ordinal(4i8));
133assert_eq!(Some(MyEnum::Nine), MyEnum::from_ordinal(9i8));
134assert_eq!(Some(MyEnum::NegativeNine), MyEnum::from_ordinal(-9i8));
135
136assert_eq!(MyEnum::Four, unsafe { MyEnum::from_ordinal_unsafe(4i8) });
137assert_eq!(MyEnum::Nine, unsafe { MyEnum::from_ordinal_unsafe(9i8) });
138assert_eq!(MyEnum::NegativeNine, unsafe { MyEnum::from_ordinal_unsafe(-9i8) });
139```
140*/
141
142#![no_std]
143
144extern crate alloc;
145
146mod big_int_wrapper;
147mod panic;
148mod variant_type;
149
150use alloc::{string::ToString, vec::Vec};
151use core::str::FromStr;
152
153use big_int_wrapper::BigIntWrapper;
154use num_bigint::BigInt;
155use proc_macro::TokenStream;
156use quote::quote;
157use syn::{Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, UnOp};
158use variant_type::VariantType;
159
160fn derive_input_handler(ast: DeriveInput) -> TokenStream {
161    let mut variant_type = VariantType::default();
162
163    for attr in ast.attrs {
164        if attr.path().is_ident("repr") {
165            // #[repr(u8)], #[repr(u16)], ..., etc.
166            if let Meta::List(list) = attr.meta {
167                let typ_name = list.tokens.to_string();
168
169                variant_type = VariantType::from_str(typ_name);
170            }
171        }
172    }
173
174    let name = &ast.ident;
175
176    match ast.data {
177        Data::Enum(data) => {
178            if data.variants.is_empty() {
179                panic::no_variant();
180            }
181
182            let mut values: Vec<BigIntWrapper> = Vec::with_capacity(data.variants.len());
183            let mut variant_idents: Vec<&Ident> = Vec::with_capacity(data.variants.len());
184            let mut use_constant_counter = false;
185
186            if VariantType::Nondetermined == variant_type {
187                let mut min = BigInt::from(u128::MAX);
188                let mut max = BigInt::from(i128::MIN);
189                let mut counter = BigInt::default();
190
191                for variant in data.variants.iter() {
192                    if let Fields::Unit = variant.fields {
193                        let value = if let Some((_, exp)) = variant.discriminant.as_ref() {
194                            match exp {
195                                Expr::Lit(lit) => {
196                                    let lit = &lit.lit;
197
198                                    let value = match lit {
199                                        Lit::Int(value) => {
200                                            let value = value.base10_digits();
201                                            BigInt::from_str(value).unwrap()
202                                        },
203                                        _ => panic::unsupported_discriminant(),
204                                    };
205
206                                    counter = value.clone();
207
208                                    value
209                                },
210                                Expr::Unary(unary) => match unary.op {
211                                    UnOp::Neg(_) => match unary.expr.as_ref() {
212                                        Expr::Lit(lit) => {
213                                            let lit = &lit.lit;
214
215                                            let value = match lit {
216                                                Lit::Int(value) => {
217                                                    let value = value.base10_digits();
218
219                                                    -BigInt::from_str(value).unwrap()
220                                                },
221                                                _ => panic::unsupported_discriminant(),
222                                            };
223
224                                            counter = value.clone();
225
226                                            value
227                                        },
228                                        Expr::Path(_)
229                                        | Expr::Cast(_)
230                                        | Expr::Binary(_)
231                                        | Expr::Call(_) => {
232                                            panic::constant_variable_on_non_determined_size_enum()
233                                        },
234                                        _ => panic::unsupported_discriminant(),
235                                    },
236                                    _ => panic::unsupported_discriminant(),
237                                },
238                                Expr::Path(_) | Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
239                                    panic::constant_variable_on_non_determined_size_enum()
240                                },
241                                _ => panic::unsupported_discriminant(),
242                            }
243                        } else {
244                            counter.clone()
245                        };
246
247                        if min > value {
248                            min = value.clone();
249                        }
250
251                        if max < value {
252                            max = value.clone();
253                        }
254
255                        variant_idents.push(&variant.ident);
256
257                        counter += 1;
258
259                        values.push(BigIntWrapper::from(value));
260                    } else {
261                        panic::not_unit_variant();
262                    }
263                }
264
265                if min >= BigInt::from(i8::MIN) && max <= BigInt::from(i8::MAX) {
266                    variant_type = VariantType::I8;
267                } else if min >= BigInt::from(i16::MIN) && max <= BigInt::from(i16::MAX) {
268                    variant_type = VariantType::I16;
269                } else if min >= BigInt::from(i32::MIN) && max <= BigInt::from(i32::MAX) {
270                    variant_type = VariantType::I32;
271                } else if min >= BigInt::from(i64::MIN) && max <= BigInt::from(i64::MAX) {
272                    variant_type = VariantType::I64;
273                } else if min >= BigInt::from(i128::MIN) && max <= BigInt::from(i128::MAX) {
274                    variant_type = VariantType::I128;
275                } else {
276                    panic::unsupported_discriminant()
277                }
278            } else {
279                let mut counter = BigInt::default();
280                let mut constant_counter = 0;
281                let mut last_exp: Option<&Expr> = None;
282
283                for variant in data.variants.iter() {
284                    if let Fields::Unit = variant.fields {
285                        if let Some((_, exp)) = variant.discriminant.as_ref() {
286                            match exp {
287                                Expr::Lit(lit) => {
288                                    let lit = &lit.lit;
289
290                                    let value = match lit {
291                                        Lit::Int(value) => {
292                                            let value = value.base10_digits();
293                                            BigInt::from_str(value).unwrap()
294                                        },
295                                        _ => panic::unsupported_discriminant(),
296                                    };
297
298                                    values.push(BigIntWrapper::from(value.clone()));
299
300                                    counter = value + 1;
301
302                                    last_exp = None;
303                                },
304                                Expr::Unary(unary) => match unary.op {
305                                    UnOp::Neg(_) => match unary.expr.as_ref() {
306                                        Expr::Lit(lit) => {
307                                            let lit = &lit.lit;
308
309                                            let value = match lit {
310                                                Lit::Int(value) => {
311                                                    let value = value.base10_digits();
312
313                                                    -BigInt::from_str(value).unwrap()
314                                                },
315                                                _ => panic::unsupported_discriminant(),
316                                            };
317
318                                            values.push(BigIntWrapper::from(value.clone()));
319
320                                            counter = value + 1;
321
322                                            last_exp = None;
323                                        },
324                                        Expr::Path(_) => {
325                                            values.push(BigIntWrapper::from((exp, 0)));
326
327                                            last_exp = Some(exp);
328                                            constant_counter = 1;
329                                        },
330                                        Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
331                                            values.push(BigIntWrapper::from((exp, 0)));
332
333                                            last_exp = Some(exp);
334                                            constant_counter = 1;
335
336                                            use_constant_counter = true;
337                                        },
338                                        _ => panic::unsupported_discriminant(),
339                                    },
340                                    _ => panic::unsupported_discriminant(),
341                                },
342                                Expr::Path(_) => {
343                                    values.push(BigIntWrapper::from((exp, 0)));
344
345                                    last_exp = Some(exp);
346                                    constant_counter = 1;
347                                },
348                                Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
349                                    values.push(BigIntWrapper::from((exp, 0)));
350
351                                    last_exp = Some(exp);
352                                    constant_counter = 1;
353
354                                    use_constant_counter = true;
355                                },
356                                _ => panic::unsupported_discriminant(),
357                            }
358                        } else if let Some(exp) = last_exp.as_ref() {
359                            values.push(BigIntWrapper::from((*exp, constant_counter)));
360
361                            constant_counter += 1;
362
363                            use_constant_counter = true;
364                        } else {
365                            values.push(BigIntWrapper::from(counter.clone()));
366
367                            counter += 1;
368                        }
369
370                        variant_idents.push(&variant.ident);
371                    } else {
372                        panic::not_unit_variant();
373                    }
374                }
375            }
376
377            let ordinal = quote! {
378                #[inline]
379                pub fn ordinal(&self) -> #variant_type {
380                    match self {
381                        #(
382                            Self::#variant_idents => #values,
383                        )*
384                    }
385                }
386            };
387
388            let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
389
390            let from_ordinal_unsafe = if data.variants.len() == 1 {
391                let variant_idents = &data.variants[0].ident;
392
393                quote! {
394                    #[inline]
395                    pub unsafe fn from_ordinal_unsafe(_number: #variant_type) -> #name #ty_generics {
396                        Self::#variant_idents
397                    }
398                }
399            } else {
400                quote! {
401                    #[inline]
402                    pub unsafe fn from_ordinal_unsafe(number: #variant_type) -> #name #ty_generics {
403                        ::core::mem::transmute(number)
404                    }
405                }
406            };
407
408            let from_ordinal = if use_constant_counter {
409                quote! {
410                    #[inline]
411                    pub fn from_ordinal(number: #variant_type) -> Option<#name #ty_generics> {
412                        if false {
413                            unreachable!()
414                        } #( else if number == #values {
415                            Some(Self::#variant_idents)
416                        } )* else {
417                            None
418                        }
419                    }
420                }
421            } else {
422                quote! {
423                    #[inline]
424                    pub fn from_ordinal(number: #variant_type) -> Option<#name #ty_generics> {
425                        match number{
426                            #(
427                                #values => Some(Self::#variant_idents),
428                            )*
429                            _ => None
430                        }
431                    }
432                }
433            };
434
435            let variant_count = variant_idents.len();
436
437            let variants = quote! {
438                #[inline]
439                pub const fn variants() -> [#name #ty_generics; #variant_count] {
440                    [#( Self::#variant_idents, )*]
441                }
442            };
443
444            let variant_count = quote! {
445                #[inline]
446                pub const fn variant_count() -> usize {
447                    #variant_count
448                }
449            };
450
451            let ordinalize_impl = quote! {
452                impl #impl_generics #name #ty_generics #where_clause {
453                    #from_ordinal_unsafe
454
455                    #from_ordinal
456
457                    #ordinal
458
459                    #variants
460
461                    #variant_count
462                }
463            };
464
465            ordinalize_impl.into()
466        },
467        _ => {
468            panic::not_enum();
469        },
470    }
471}
472
473#[proc_macro_derive(Ordinalize)]
474pub fn ordinalize_derive(input: TokenStream) -> TokenStream {
475    derive_input_handler(syn::parse(input).unwrap())
476}