@@ -29,14 +29,15 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
2929 let ( lifetime_without_bounds, lifetime_with_bounds) =
3030 build_arbitrary_lifetime ( input. generics . clone ( ) ) ;
3131
32+ // This won't be used if `needs_recursive_count` ends up false.
3233 let recursive_count = syn:: Ident :: new (
3334 & format ! ( "RECURSIVE_COUNT_{}" , input. ident) ,
3435 Span :: call_site ( ) ,
3536 ) ;
3637
37- let arbitrary_method =
38+ let ( arbitrary_method, needs_recursive_count ) =
3839 gen_arbitrary_method ( & input, lifetime_without_bounds. clone ( ) , & recursive_count) ?;
39- let size_hint_method = gen_size_hint_method ( & input) ?;
40+ let size_hint_method = gen_size_hint_method ( & input, needs_recursive_count ) ?;
4041 let name = input. ident ;
4142
4243 // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
@@ -56,17 +57,25 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
5657 // Build TypeGenerics and WhereClause without a lifetime
5758 let ( _, ty_generics, where_clause) = generics. split_for_impl ( ) ;
5859
59- Ok ( quote ! {
60- const _ : ( ) = {
60+ let recursive_count = needs_recursive_count . then ( || {
61+ Some ( quote ! {
6162 :: std:: thread_local! {
6263 #[ allow( non_upper_case_globals) ]
6364 static #recursive_count: :: core:: cell:: Cell <u32 > = const {
6465 :: core:: cell:: Cell :: new( 0 )
6566 } ;
6667 }
68+ } )
69+ } ) ;
70+
71+ Ok ( quote ! {
72+ const _: ( ) = {
73+ #recursive_count
6774
6875 #[ automatically_derived]
69- impl #impl_generics arbitrary:: Arbitrary <#lifetime_without_bounds> for #name #ty_generics #where_clause {
76+ impl #impl_generics arbitrary:: Arbitrary <#lifetime_without_bounds>
77+ for #name #ty_generics #where_clause
78+ {
7079 #arbitrary_method
7180 #size_hint_method
7281 }
@@ -149,10 +158,7 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
149158 generics
150159}
151160
152- fn with_recursive_count_guard (
153- recursive_count : & syn:: Ident ,
154- expr : impl quote:: ToTokens ,
155- ) -> impl quote:: ToTokens {
161+ fn with_recursive_count_guard ( recursive_count : & syn:: Ident , expr : TokenStream ) -> TokenStream {
156162 quote ! {
157163 let guard_against_recursion = u. is_empty( ) ;
158164 if guard_against_recursion {
@@ -181,7 +187,7 @@ fn gen_arbitrary_method(
181187 input : & DeriveInput ,
182188 lifetime : LifetimeParam ,
183189 recursive_count : & syn:: Ident ,
184- ) -> Result < TokenStream > {
190+ ) -> Result < ( TokenStream , bool ) > {
185191 fn arbitrary_structlike (
186192 fields : & Fields ,
187193 ident : & syn:: Ident ,
@@ -219,28 +225,36 @@ fn gen_arbitrary_method(
219225 recursive_count : & syn:: Ident ,
220226 unstructured : TokenStream ,
221227 variants : & [ TokenStream ] ,
222- ) -> impl quote:: ToTokens {
228+ needs_recursive_count : bool ,
229+ ) -> TokenStream {
223230 let count = variants. len ( ) as u64 ;
224- with_recursive_count_guard (
225- recursive_count,
226- quote ! {
227- // Use a multiply + shift to generate a ranged random number
228- // with slight bias. For details, see:
229- // https://lemire.me/blog/2016/06/30/fast-random-shuffling
230- Ok ( match ( u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( #unstructured) ?) * #count) >> 32 {
231- #( #variants, ) *
232- _ => unreachable!( )
233- } )
234- } ,
235- )
231+
232+ let do_variants = quote ! {
233+ // Use a multiply + shift to generate a ranged random number
234+ // with slight bias. For details, see:
235+ // https://lemire.me/blog/2016/06/30/fast-random-shuffling
236+ Ok ( match (
237+ u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( #unstructured) ?) * #count
238+ ) >> 32
239+ {
240+ #( #variants, ) *
241+ _ => unreachable!( )
242+ } )
243+ } ;
244+
245+ if needs_recursive_count {
246+ with_recursive_count_guard ( recursive_count, do_variants)
247+ } else {
248+ do_variants
249+ }
236250 }
237251
238252 fn arbitrary_enum (
239253 DataEnum { variants, .. } : & DataEnum ,
240254 enum_name : & Ident ,
241255 lifetime : LifetimeParam ,
242256 recursive_count : & syn:: Ident ,
243- ) -> Result < TokenStream > {
257+ ) -> Result < ( TokenStream , bool ) > {
244258 let filtered_variants = variants. iter ( ) . filter ( not_skipped) ;
245259
246260 // Check attributes of all variants:
@@ -254,11 +268,16 @@ fn gen_arbitrary_method(
254268 . map ( |( index, variant) | ( index as u64 , variant) ) ;
255269
256270 // Construct `match`-arms for the `arbitrary` method.
271+ let mut needs_recursive_count = false ;
257272 let variants = enumerated_variants
258273 . clone ( )
259274 . map ( |( index, Variant { fields, ident, .. } ) | {
260- construct ( fields, |_, field| gen_constructor_for_field ( field) )
261- . map ( |ctor| arbitrary_variant ( index, enum_name, ident, ctor) )
275+ construct ( fields, |_, field| gen_constructor_for_field ( field) ) . map ( |ctor| {
276+ if !ctor. is_empty ( ) {
277+ needs_recursive_count = true ;
278+ }
279+ arbitrary_variant ( index, enum_name, ident, ctor)
280+ } )
262281 } )
263282 . collect :: < Result < Vec < TokenStream > > > ( ) ?;
264283
@@ -277,34 +296,56 @@ fn gen_arbitrary_method(
277296 ( !variants. is_empty ( ) )
278297 . then ( || {
279298 // TODO: Improve dealing with `u` vs. `&mut u`.
280- let arbitrary = arbitrary_enum_method ( recursive_count, quote ! { u } , & variants) ;
281- let arbitrary_take_rest = arbitrary_enum_method ( recursive_count, quote ! { & mut u } , & variants_take_rest) ;
282-
283- quote ! {
284- fn arbitrary( u: & mut arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
285- #arbitrary
286- }
299+ let arbitrary = arbitrary_enum_method (
300+ recursive_count,
301+ quote ! { u } ,
302+ & variants,
303+ needs_recursive_count,
304+ ) ;
305+ let arbitrary_take_rest = arbitrary_enum_method (
306+ recursive_count,
307+ quote ! { & mut u } ,
308+ & variants_take_rest,
309+ needs_recursive_count,
310+ ) ;
311+
312+ (
313+ quote ! {
314+ fn arbitrary( u: & mut arbitrary:: Unstructured <#lifetime>)
315+ -> arbitrary:: Result <Self >
316+ {
317+ #arbitrary
318+ }
287319
288- fn arbitrary_take_rest( mut u: arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
289- #arbitrary_take_rest
290- }
291- }
320+ fn arbitrary_take_rest( mut u: arbitrary:: Unstructured <#lifetime>)
321+ -> arbitrary:: Result <Self >
322+ {
323+ #arbitrary_take_rest
324+ }
325+ } ,
326+ needs_recursive_count,
327+ )
328+ } )
329+ . ok_or_else ( || {
330+ Error :: new_spanned (
331+ enum_name,
332+ "Enum must have at least one variant, that is not skipped" ,
333+ )
292334 } )
293- . ok_or_else ( || Error :: new_spanned (
294- enum_name,
295- "Enum must have at least one variant, that is not skipped"
296- ) )
297335 }
298336
299337 let ident = & input. ident ;
338+ let needs_recursive_count = true ;
300339 match & input. data {
301- Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count) ,
340+ Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count)
341+ . map ( |ts| ( ts, needs_recursive_count) ) ,
302342 Data :: Union ( data) => arbitrary_structlike (
303343 & Fields :: Named ( data. fields . clone ( ) ) ,
304344 ident,
305345 lifetime,
306346 recursive_count,
307- ) ,
347+ )
348+ . map ( |ts| ( ts, needs_recursive_count) ) ,
308349 Data :: Enum ( data) => arbitrary_enum ( data, ident, lifetime, recursive_count) ,
309350 }
310351}
@@ -357,7 +398,7 @@ fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
357398 } )
358399}
359400
360- fn gen_size_hint_method ( input : & DeriveInput ) -> Result < TokenStream > {
401+ fn gen_size_hint_method ( input : & DeriveInput , needs_recursive_count : bool ) -> Result < TokenStream > {
361402 let size_hint_fields = |fields : & Fields | {
362403 fields
363404 . iter ( )
@@ -372,9 +413,9 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
372413 quote ! { <#ty as arbitrary:: Arbitrary >:: try_size_hint( depth) }
373414 }
374415
375- // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
376- // just an educated guess, although it's gonna be inaccurate for dynamically
377- // allocated types (Vec, HashMap, etc.).
416+ // Note that in this case it's hard to determine what size_hint must be, so
417+ // size_of::<T>() is just an educated guess, although it's gonna be
418+ // inaccurate for dynamically allocated types (Vec, HashMap, etc.).
378419 FieldConstructor :: With ( _) => {
379420 quote ! { Ok ( ( :: core:: mem:: size_of:: <#ty>( ) , None ) ) }
380421 }
@@ -391,6 +432,7 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
391432 } )
392433 } ;
393434 let size_hint_structlike = |fields : & Fields | {
435+ assert ! ( needs_recursive_count) ;
394436 size_hint_fields ( fields) . map ( |hint| {
395437 quote ! {
396438 #[ inline]
@@ -399,7 +441,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
399441 }
400442
401443 #[ inline]
402- fn try_size_hint( depth: usize ) -> :: core:: result:: Result <( usize , :: core:: option:: Option <usize >) , arbitrary:: MaxRecursionReached > {
444+ fn try_size_hint( depth: usize )
445+ -> :: core:: result:: Result <
446+ ( usize , :: core:: option:: Option <usize >) ,
447+ arbitrary:: MaxRecursionReached ,
448+ >
449+ {
403450 arbitrary:: size_hint:: try_recursion_guard( depth, |depth| #hint)
404451 }
405452 }
@@ -413,24 +460,44 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
413460 . iter ( )
414461 . filter ( not_skipped)
415462 . map ( |Variant { fields, .. } | {
463+ if !needs_recursive_count {
464+ assert ! ( fields. is_empty( ) ) ;
465+ }
416466 // The attributes of all variants are checked in `gen_arbitrary_method` above
417- // and can therefore assume that they are valid.
467+ // and can therefore assume that they are valid.
418468 size_hint_fields ( fields)
419469 } )
420470 . collect :: < Result < Vec < TokenStream > > > ( )
421471 . map ( |variants| {
422- quote ! {
423- fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
424- Self :: try_size_hint( depth) . unwrap_or_default( )
472+ if needs_recursive_count {
473+ // The enum might be recursive: `try_size_hint` is the primary one, and
474+ // `size_hint` is defined in terms of it.
475+ quote ! {
476+ fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
477+ Self :: try_size_hint( depth) . unwrap_or_default( )
478+ }
479+ #[ inline]
480+ fn try_size_hint( depth: usize )
481+ -> :: core:: result:: Result <
482+ ( usize , :: core:: option:: Option <usize >) ,
483+ arbitrary:: MaxRecursionReached ,
484+ >
485+ {
486+ Ok ( arbitrary:: size_hint:: and(
487+ <u32 as arbitrary:: Arbitrary >:: size_hint( depth) ,
488+ arbitrary:: size_hint:: try_recursion_guard( depth, |depth| {
489+ Ok ( arbitrary:: size_hint:: or_all( & [ #( #variants? ) , * ] ) )
490+ } ) ?,
491+ ) )
492+ }
425493 }
426- #[ inline]
427- fn try_size_hint( depth: usize ) -> :: core:: result:: Result <( usize , :: core:: option:: Option <usize >) , arbitrary:: MaxRecursionReached > {
428- Ok ( arbitrary:: size_hint:: and(
429- <u32 as arbitrary:: Arbitrary >:: try_size_hint( depth) ?,
430- arbitrary:: size_hint:: try_recursion_guard( depth, |depth| {
431- Ok ( arbitrary:: size_hint:: or_all( & [ #( #variants? ) , * ] ) )
432- } ) ?,
433- ) )
494+ } else {
495+ // The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the
496+ // primary one, and the default `try_size_hint` is good enough.
497+ quote ! {
498+ fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
499+ <u32 as arbitrary:: Arbitrary >:: size_hint( depth)
500+ }
434501 }
435502 }
436503 } ) ,
0 commit comments