@@ -430,27 +430,19 @@ mod test {
430430 mut decoder : Decoder ,
431431 mut input : S ,
432432 ) -> impl Stream < Item = Result < RecordBatch , ArrowError > > {
433- let mut buffered = Bytes :: new ( ) ;
434- futures:: stream:: poll_fn ( move |cx| {
435- loop {
436- if buffered. is_empty ( ) {
437- buffered = match ready ! ( input. poll_next_unpin( cx) ) {
438- Some ( b) => b,
439- None => break ,
440- } ;
441- }
442- let decoded = match decoder. decode ( buffered. as_ref ( ) ) {
443- Ok ( decoded) => decoded,
444- Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
445- } ;
446- let read = buffered. len ( ) ;
447- buffered. advance ( decoded) ;
448- if decoded != read {
449- break ;
433+ async_stream:: try_stream! {
434+ if let Some ( data) = input. next( ) . await {
435+ let consumed = decoder. decode( & data) ?;
436+ if consumed < data. len( ) {
437+ Err ( ArrowError :: ParseError (
438+ "did not consume all bytes" . to_string( ) ,
439+ ) ) ?;
450440 }
451441 }
452- Poll :: Ready ( decoder. flush ( ) . transpose ( ) )
453- } )
442+ if let Some ( batch) = decoder. flush( ) ? {
443+ yield batch
444+ }
445+ }
454446 }
455447
456448 #[ test]
@@ -595,29 +587,91 @@ mod test {
595587
596588 #[ test]
597589 fn test_decode_stream_with_schema ( ) {
598- const PROVIDED_SCHEMA : & str =
599- r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"# ;
600- let schema_s2: crate :: schema:: Schema = serde_json:: from_str ( PROVIDED_SCHEMA ) . unwrap ( ) ;
601- let record_val = "some_string" ;
602- let mut body = vec ! [ ] ;
603- body. push ( ( record_val. len ( ) as u8 ) << 1 ) ;
604- body. extend_from_slice ( record_val. as_bytes ( ) ) ;
605- let mut reader_placeholder = Cursor :: new ( & [ ] as & [ u8 ] ) ;
606- let decoder = ReaderBuilder :: new ( )
607- . with_batch_size ( 1 )
608- . with_schema ( schema_s2)
609- . build_decoder ( & mut reader_placeholder)
610- . unwrap ( ) ;
611- let stream = Box :: pin ( stream:: once ( async { Bytes :: from ( body) } ) ) ;
612- let decoded_stream = decode_stream ( decoder, stream) ;
613- let batches: Vec < RecordBatch > = block_on ( decoded_stream. try_collect ( ) ) . unwrap ( ) ;
614- let batch = arrow:: compute:: concat_batches ( & batches[ 0 ] . schema ( ) , & batches) . unwrap ( ) ;
615- let expected_field = Field :: new ( "f2" , DataType :: Utf8 , false ) ;
616- let expected_schema = Arc :: new ( Schema :: new ( vec ! [ expected_field] ) ) ;
617- let expected_array = Arc :: new ( StringArray :: from ( vec ! [ record_val] ) ) ;
618- let expected_batch = RecordBatch :: try_new ( expected_schema, vec ! [ expected_array] ) . unwrap ( ) ;
619- assert_eq ! ( batch, expected_batch) ;
620- assert_eq ! ( batch. schema( ) . field( 0 ) . name( ) , "f2" ) ;
590+ struct TestCase < ' a > {
591+ name : & ' a str ,
592+ schema : & ' a str ,
593+ expected_error : Option < & ' a str > ,
594+ }
595+ let tests = vec ! [
596+ TestCase {
597+ name: "success" ,
598+ schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"# ,
599+ expected_error: None ,
600+ } ,
601+ TestCase {
602+ name: "valid schema invalid data" ,
603+ schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"# ,
604+ expected_error: Some ( "did not consume all bytes" ) ,
605+ } ,
606+ ] ;
607+ for test in tests {
608+ let schema_s2: crate :: schema:: Schema = serde_json:: from_str ( test. schema ) . unwrap ( ) ;
609+ let record_val = "some_string" ;
610+ let mut body = vec ! [ ] ;
611+ body. push ( ( record_val. len ( ) as u8 ) << 1 ) ;
612+ body. extend_from_slice ( record_val. as_bytes ( ) ) ;
613+ let mut reader_placeholder = Cursor :: new ( & [ ] as & [ u8 ] ) ;
614+ let builder = ReaderBuilder :: new ( )
615+ . with_batch_size ( 1 )
616+ . with_schema ( schema_s2) ;
617+ let decoder_result = builder. build_decoder ( & mut reader_placeholder) ;
618+ let decoder = match decoder_result {
619+ Ok ( decoder) => decoder,
620+ Err ( e) => {
621+ if let Some ( expected) = test. expected_error {
622+ assert ! (
623+ e. to_string( ) . contains( expected) ,
624+ "Test '{}' failed: unexpected error message at build.\n Expected to contain: '{expected}'\n Actual: '{e}'" ,
625+ test. name,
626+ ) ;
627+ continue ;
628+ } else {
629+ panic ! ( "Test '{}' failed at decoder build: {e}" , test. name) ;
630+ }
631+ }
632+ } ;
633+ let stream = Box :: pin ( stream:: once ( async { Bytes :: from ( body) } ) ) ;
634+ let decoded_stream = decode_stream ( decoder, stream) ;
635+ let batches_result: Result < Vec < RecordBatch > , ArrowError > =
636+ block_on ( decoded_stream. try_collect ( ) ) ;
637+ match ( batches_result, test. expected_error ) {
638+ ( Ok ( batches) , None ) => {
639+ let batch =
640+ arrow:: compute:: concat_batches ( & batches[ 0 ] . schema ( ) , & batches) . unwrap ( ) ;
641+ let expected_field = Field :: new ( "f2" , DataType :: Utf8 , false ) ;
642+ let expected_schema = Arc :: new ( Schema :: new ( vec ! [ expected_field] ) ) ;
643+ let expected_array = Arc :: new ( StringArray :: from ( vec ! [ record_val] ) ) ;
644+ let expected_batch =
645+ RecordBatch :: try_new ( expected_schema, vec ! [ expected_array] ) . unwrap ( ) ;
646+ assert_eq ! ( batch, expected_batch, "Test '{}' failed" , test. name) ;
647+ assert_eq ! (
648+ batch. schema( ) . field( 0 ) . name( ) ,
649+ "f2" ,
650+ "Test '{}' failed" ,
651+ test. name
652+ ) ;
653+ }
654+ ( Err ( e) , Some ( expected) ) => {
655+ assert ! (
656+ e. to_string( ) . contains( expected) ,
657+ "Test '{}' failed: unexpected error message at decode.\n Expected to contain: '{expected}'\n Actual: '{e}'" ,
658+ test. name,
659+ ) ;
660+ }
661+ ( Ok ( batches) , Some ( expected) ) => {
662+ panic ! (
663+ "Test '{}' was expected to fail with '{expected}', but it succeeded with: {:?}" ,
664+ test. name, batches
665+ ) ;
666+ }
667+ ( Err ( e) , None ) => {
668+ panic ! (
669+ "Test '{}' was not expected to fail, but it did with '{e}'" ,
670+ test. name
671+ ) ;
672+ }
673+ }
674+ }
621675 }
622676
623677 #[ test]
0 commit comments