11use clap:: { CommandFactory , Parser , Subcommand } ;
22use rustc_apfloat:: Float as _;
33use std:: fmt;
4+ use std:: io:: Write ;
45use std:: mem:: MaybeUninit ;
6+ use std:: num:: NonZeroUsize ;
57use std:: path:: PathBuf ;
68
79// See `build.rs` and `ops.rs` for how `FuzzOp` is generated.
810include ! ( concat!( env!( "OUT_DIR" ) , "/generated_fuzz_ops.rs" ) ) ;
911
10- #[ derive( Parser , Debug ) ]
12+ #[ derive( Clone , Parser , Debug ) ]
1113struct Args {
1214 /// Disable comparison with C++ (LLVM's original) APFloat
1315 #[ arg( long) ]
@@ -37,10 +39,13 @@ struct Args {
3739 command : Option < Commands > ,
3840}
3941
40- #[ derive( Subcommand , Debug ) ]
42+ #[ derive( Clone , Subcommand , Debug ) ]
4143enum Commands {
4244 /// Decode fuzzing in/out testcases (binary serialized `FuzzOp`s)
4345 Decode { files : Vec < PathBuf > } ,
46+
47+ /// Exhaustively test all possible ops and inputs for tiny (8-bit) formats
48+ BruteforceTiny ,
4449}
4550
4651/// Trait implemented for types that describe a floating-point format supported
@@ -69,6 +74,10 @@ trait FloatRepr: Copy + Default + Eq + fmt::Display {
6974 // format with the same `BIT_WIDTH`, so it's not unambiguous on its own.
7075 const REPR_TAG : u8 = Self :: BIT_WIDTH as u8 ;
7176
77+ fn short_lowercase_name ( ) -> String {
78+ Self :: NAME . to_ascii_lowercase ( ) . replace ( "ieee" , "f" )
79+ }
80+
7281 // FIXME(eddyb) these should ideally be using `[u8; Self::BYTE_LEN]`.
7382 fn from_le_bytes ( bytes : & [ u8 ] ) -> Self ;
7483 fn write_as_le_bytes_into ( self , out_bytes : & mut Vec < u8 > ) ;
@@ -234,13 +243,11 @@ struct FuzzOpEvalOutputs<F: FloatRepr> {
234243}
235244
236245impl < F : FloatRepr > FuzzOpEvalOutputs < F > {
237- fn assert_all_match ( self ) {
238- if let Some ( cxx_apf) = self . cxx_apf {
239- assert ! ( cxx_apf == self . rs_apf) ;
240- }
241- if let Some ( hard) = self . hard {
242- assert ! ( hard == self . rs_apf) ;
243- }
246+ fn all_match ( self ) -> bool {
247+ [ self . cxx_apf , self . hard ]
248+ . into_iter ( )
249+ . flatten ( )
250+ . all ( |x| x == self . rs_apf )
244251 }
245252}
246253
@@ -430,8 +437,11 @@ where
430437 }
431438 }
432439
433- let short_float_type_name = F :: NAME . to_ascii_lowercase ( ) . replace ( "ieee" , "f" ) ;
434- println ! ( " {short_float_type_name}.{:?}" , self . map( FloatPrintHelper ) ) ;
440+ println ! (
441+ " {}.{:?}" ,
442+ F :: short_lowercase_name( ) ,
443+ self . map( FloatPrintHelper )
444+ ) ;
435445
436446 // HACK(eddyb) this lets us show all files even if some cause panics.
437447 let FuzzOpEvalOutputs {
@@ -463,6 +473,143 @@ where
463473 cxx_apf. map ( |x| print ( x, "C++ / llvm::APFloat" ) ) ;
464474 hard. map ( |x| print ( x, "native hardware floats" ) ) ;
465475 }
476+
477+ /// [`Commands::BruteforceTiny`] implementation (for a specific choice of `F`),
478+ /// returning `Err(mismatch_count)` if there were any mismatches.
479+ //
480+ // HACK(eddyb) this is a method here because of the bounds `eval` needs, which
481+ // are thankfully on the whole `impl`, so `Self::eval` is callable.
482+ fn bruteforce_tiny ( cli_args : & Args ) -> Result < ( ) , NonZeroUsize > {
483+ // Here "tiny" is "8-bit" - 16-bit floats could maybe also be bruteforced,
484+ // but the cost increases exponentially, so less useful relative to fuzzing.
485+ if F :: BIT_WIDTH > 8 {
486+ return Ok ( ( ) ) ;
487+ }
488+
489+ // HACK(eddyb) avoid reporting panics while iterating.
490+ std:: panic:: set_hook ( Box :: new ( |_| { } ) ) ;
491+
492+ let all_ops = ( 0 ..)
493+ . map ( FuzzOp :: from_tag)
494+ . take_while ( |op| op. is_some ( ) )
495+ . map ( |op| op. unwrap ( ) ) ;
496+
497+ let op_to_exhaustive_cases = |op : FuzzOp < ( ) > | {
498+ let mut total_bit_width = 0 ;
499+ op. map ( |( ) | total_bit_width += F :: BIT_WIDTH ) ;
500+ ( 0 ..usize:: checked_shl ( 1 , total_bit_width as u32 ) . unwrap ( ) ) . map ( move |i| -> Self {
501+ let mut combined_input_bits = i;
502+ let op_with_inputs = op. map ( |( ) | {
503+ let x = combined_input_bits & ( ( 1 << F :: BIT_WIDTH ) - 1 ) ;
504+ combined_input_bits >>= F :: BIT_WIDTH ;
505+ F :: from_bits_u128 ( x. try_into ( ) . unwrap ( ) )
506+ } ) ;
507+ assert_eq ! ( combined_input_bits, 0 ) ;
508+ op_with_inputs
509+ } )
510+ } ;
511+
512+ let num_total_cases = all_ops
513+ . clone ( )
514+ . map ( |op| op_to_exhaustive_cases ( op) . len ( ) )
515+ . try_fold ( 0 , usize:: checked_add)
516+ . unwrap ( ) ;
517+
518+ let float_name = F :: short_lowercase_name ( ) ;
519+ println ! ( "Exhaustively checking all {num_total_cases} cases for {float_name}:" , ) ;
520+
521+ const NUM_DOTS : usize = 80 ;
522+ let cases_per_dot = num_total_cases / NUM_DOTS ;
523+ let mut cases_in_this_dot = 0 ;
524+ let mut mismatches_in_this_dot = false ;
525+ let mut num_mismatches = 0 ;
526+ let mut select_mismatches = vec ! [ ] ;
527+ let mut all_panics = vec ! [ ] ;
528+ for op in all_ops {
529+ let mut first_mismatch = None ;
530+ for op_with_inputs in op_to_exhaustive_cases ( op) {
531+ cases_in_this_dot += 1 ;
532+ if cases_in_this_dot >= cases_per_dot {
533+ cases_in_this_dot -= cases_per_dot;
534+ if mismatches_in_this_dot {
535+ mismatches_in_this_dot = false ;
536+ print ! ( "X" ) ;
537+ } else {
538+ print ! ( "." )
539+ }
540+ // HACK(eddyb) get around `stdout` line buffering.
541+ std:: io:: stdout ( ) . flush ( ) . unwrap ( ) ;
542+ }
543+
544+ // HACK(eddyb) there are still panics we need to account for,
545+ // e.g. https://github.com/llvm/llvm-project/issues/63895, and
546+ // even if the Rust code didn't panic, LLVM asserts would trip.
547+ match std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( || {
548+ op_with_inputs. eval ( cli_args)
549+ } ) ) {
550+ Ok ( out) => {
551+ if !out. all_match ( ) {
552+ num_mismatches += 1 ;
553+ mismatches_in_this_dot = true ;
554+ if first_mismatch. is_none ( ) {
555+ first_mismatch = Some ( op_with_inputs) ;
556+ }
557+ }
558+ }
559+ Err ( _) => {
560+ mismatches_in_this_dot = true ;
561+ all_panics. push ( op_with_inputs) ;
562+ }
563+ }
564+ }
565+ select_mismatches. extend ( first_mismatch) ;
566+ }
567+ println ! ( ) ;
568+
569+ // HACK(eddyb) undo what we did at the start of this function.
570+ let _ = std:: panic:: take_hook ( ) ;
571+
572+ if num_mismatches > 0 {
573+ assert ! ( !select_mismatches. is_empty( ) ) ;
574+ println ! ( ) ;
575+ println ! (
576+ "!!! found {num_mismatches} ({:.1}%) mismatches for {float_name}, showing {} of them:" ,
577+ ( num_mismatches as f64 ) / ( num_total_cases as f64 ) * 100.0 ,
578+ select_mismatches. len( ) ,
579+ ) ;
580+ for mismatch in select_mismatches {
581+ mismatch. print_op_and_eval_outputs ( cli_args) ;
582+ }
583+ println ! ( ) ;
584+ } else {
585+ assert ! ( select_mismatches. is_empty( ) ) ;
586+ }
587+
588+ if !all_panics. is_empty ( ) {
589+ // HACK(eddyb) there is a good chance C++ will also fail, so avoid
590+ // triggering the (more fatal) C++ assertion failure.
591+ let cli_args_plus_ignore_cxx = Args {
592+ ignore_cxx : true ,
593+ ..cli_args. clone ( )
594+ } ;
595+
596+ println ! (
597+ "!!! found {} panics for {float_name}, showing them (without trying C++):" ,
598+ all_panics. len( )
599+ ) ;
600+ for & panicking_case in & all_panics {
601+ panicking_case. print_op_and_eval_outputs ( & cli_args_plus_ignore_cxx) ;
602+ }
603+ println ! ( ) ;
604+ }
605+
606+ if num_mismatches == 0 && all_panics. is_empty ( ) {
607+ println ! ( "all {num_total_cases} cases match" ) ;
608+ println ! ( ) ;
609+ }
610+
611+ NonZeroUsize :: new ( num_mismatches + all_panics. len ( ) ) . map_or ( Ok ( ( ) ) , Err )
612+ }
466613}
467614
468615fn main ( ) {
@@ -491,6 +638,20 @@ fn main() {
491638 . unwrap_or_else ( |e| println ! ( " invalid data ({e})" ) ) ;
492639 }
493640 }
641+ Commands :: BruteforceTiny => {
642+ let mut any_mismatches = false ;
643+ for repr_tag in 0 ..=u8:: MAX {
644+ dispatch_any_float_repr_by_repr_tag ! ( match repr_tag {
645+ for <F : FloatRepr > => {
646+ any_mismatches |= FuzzOp :: <F >:: bruteforce_tiny( & cli_args) . is_err( ) ;
647+ }
648+ } ) ;
649+ }
650+ if any_mismatches {
651+ // FIXME(eddyb) use `fn main() -> ExitStatus`.
652+ std:: process:: exit ( 1 ) ;
653+ }
654+ }
494655 }
495656 return ;
496657 }
@@ -500,7 +661,7 @@ fn main() {
500661 data. split_first ( ) . and_then ( |( & repr_tag, data) | {
501662 dispatch_any_float_repr_by_repr_tag ! ( match repr_tag {
502663 for <F : FloatRepr > => return Some (
503- FuzzOp :: <F >:: try_decode( data) . ok( ) ?. eval( & cli_args) . assert_all_match ( )
664+ assert! ( FuzzOp :: <F >:: try_decode( data) . ok( ) ?. eval( & cli_args) . all_match ( ) )
504665 )
505666 } ) ;
506667 None
0 commit comments