Skip to content

Commit 93efc40

Browse files
committed
Added a negative test case to test_decode_stream_with_schema
1 parent b9ec9a5 commit 93efc40

File tree

2 files changed

+97
-42
lines changed

2 files changed

+97
-42
lines changed

arrow-avro/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ tempfile = "3.3"
5858
arrow = { workspace = true }
5959
futures = "0.3.31"
6060
bytes = "1.10.1"
61+
async-stream = "0.3.6"
6162

6263
[[bench]]
6364
name = "avro_reader"

arrow-avro/src/reader/mod.rs

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -430,27 +430,19 @@ mod test {
430430
mut decoder: Decoder,
431431
mut input: S,
432432
) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
433-
let mut buffered = Bytes::new();
434-
futures::stream::poll_fn(move |cx| {
435-
loop {
436-
if buffered.is_empty() {
437-
buffered = match ready!(input.poll_next_unpin(cx)) {
438-
Some(b) => b,
439-
None => break,
440-
};
441-
}
442-
let decoded = match decoder.decode(buffered.as_ref()) {
443-
Ok(decoded) => decoded,
444-
Err(e) => return Poll::Ready(Some(Err(e))),
445-
};
446-
let read = buffered.len();
447-
buffered.advance(decoded);
448-
if decoded != read {
449-
break;
433+
async_stream::try_stream! {
434+
if let Some(data) = input.next().await {
435+
let consumed = decoder.decode(&data)?;
436+
if consumed < data.len() {
437+
Err(ArrowError::ParseError(
438+
"did not consume all bytes".to_string(),
439+
))?;
450440
}
451441
}
452-
Poll::Ready(decoder.flush().transpose())
453-
})
442+
if let Some(batch) = decoder.flush()? {
443+
yield batch
444+
}
445+
}
454446
}
455447

456448
#[test]
@@ -595,29 +587,91 @@ mod test {
595587

596588
#[test]
597589
fn test_decode_stream_with_schema() {
598-
const PROVIDED_SCHEMA: &str =
599-
r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#;
600-
let schema_s2: crate::schema::Schema = serde_json::from_str(PROVIDED_SCHEMA).unwrap();
601-
let record_val = "some_string";
602-
let mut body = vec![];
603-
body.push((record_val.len() as u8) << 1);
604-
body.extend_from_slice(record_val.as_bytes());
605-
let mut reader_placeholder = Cursor::new(&[] as &[u8]);
606-
let decoder = ReaderBuilder::new()
607-
.with_batch_size(1)
608-
.with_schema(schema_s2)
609-
.build_decoder(&mut reader_placeholder)
610-
.unwrap();
611-
let stream = Box::pin(stream::once(async { Bytes::from(body) }));
612-
let decoded_stream = decode_stream(decoder, stream);
613-
let batches: Vec<RecordBatch> = block_on(decoded_stream.try_collect()).unwrap();
614-
let batch = arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap();
615-
let expected_field = Field::new("f2", DataType::Utf8, false);
616-
let expected_schema = Arc::new(Schema::new(vec![expected_field]));
617-
let expected_array = Arc::new(StringArray::from(vec![record_val]));
618-
let expected_batch = RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap();
619-
assert_eq!(batch, expected_batch);
620-
assert_eq!(batch.schema().field(0).name(), "f2");
590+
struct TestCase<'a> {
591+
name: &'a str,
592+
schema: &'a str,
593+
expected_error: Option<&'a str>,
594+
}
595+
let tests = vec![
596+
TestCase {
597+
name: "success",
598+
schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#,
599+
expected_error: None,
600+
},
601+
TestCase {
602+
name: "valid schema invalid data",
603+
schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#,
604+
expected_error: Some("did not consume all bytes"),
605+
},
606+
];
607+
for test in tests {
608+
let schema_s2: crate::schema::Schema = serde_json::from_str(test.schema).unwrap();
609+
let record_val = "some_string";
610+
let mut body = vec![];
611+
body.push((record_val.len() as u8) << 1);
612+
body.extend_from_slice(record_val.as_bytes());
613+
let mut reader_placeholder = Cursor::new(&[] as &[u8]);
614+
let builder = ReaderBuilder::new()
615+
.with_batch_size(1)
616+
.with_schema(schema_s2);
617+
let decoder_result = builder.build_decoder(&mut reader_placeholder);
618+
let decoder = match decoder_result {
619+
Ok(decoder) => decoder,
620+
Err(e) => {
621+
if let Some(expected) = test.expected_error {
622+
assert!(
623+
e.to_string().contains(expected),
624+
"Test '{}' failed: unexpected error message at build.\nExpected to contain: '{expected}'\nActual: '{e}'",
625+
test.name,
626+
);
627+
continue;
628+
} else {
629+
panic!("Test '{}' failed at decoder build: {e}", test.name);
630+
}
631+
}
632+
};
633+
let stream = Box::pin(stream::once(async { Bytes::from(body) }));
634+
let decoded_stream = decode_stream(decoder, stream);
635+
let batches_result: Result<Vec<RecordBatch>, ArrowError> =
636+
block_on(decoded_stream.try_collect());
637+
match (batches_result, test.expected_error) {
638+
(Ok(batches), None) => {
639+
let batch =
640+
arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap();
641+
let expected_field = Field::new("f2", DataType::Utf8, false);
642+
let expected_schema = Arc::new(Schema::new(vec![expected_field]));
643+
let expected_array = Arc::new(StringArray::from(vec![record_val]));
644+
let expected_batch =
645+
RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap();
646+
assert_eq!(batch, expected_batch, "Test '{}' failed", test.name);
647+
assert_eq!(
648+
batch.schema().field(0).name(),
649+
"f2",
650+
"Test '{}' failed",
651+
test.name
652+
);
653+
}
654+
(Err(e), Some(expected)) => {
655+
assert!(
656+
e.to_string().contains(expected),
657+
"Test '{}' failed: unexpected error message at decode.\nExpected to contain: '{expected}'\nActual: '{e}'",
658+
test.name,
659+
);
660+
}
661+
(Ok(batches), Some(expected)) => {
662+
panic!(
663+
"Test '{}' was expected to fail with '{expected}', but it succeeded with: {:?}",
664+
test.name, batches
665+
);
666+
}
667+
(Err(e), None) => {
668+
panic!(
669+
"Test '{}' was not expected to fail, but it did with '{e}'",
670+
test.name
671+
);
672+
}
673+
}
674+
}
621675
}
622676

623677
#[test]

0 commit comments

Comments
 (0)