Skip to content

Commit c013f26

Browse files
crepererumCopilot
andcommitted
fix: test code inconsistencies
Co-authored-by: Copilot <[email protected]>
1 parent 1bce88d commit c013f26

File tree

5 files changed

+102
-65
lines changed

5 files changed

+102
-65
lines changed

guests/evil/src/env.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Payload that tries to read environment variables.
2-
use std::sync::Arc;
2+
use std::{hash::Hash, sync::Arc};
33

44
use arrow::datatypes::DataType;
55
use datafusion_common::{Result as DataFusionResult, ScalarValue};
@@ -49,6 +49,20 @@ impl std::fmt::Debug for StringUdf {
4949
}
5050
}
5151

52+
impl PartialEq<Self> for StringUdf {
53+
fn eq(&self, other: &Self) -> bool {
54+
self.name == other.name
55+
}
56+
}
57+
58+
impl Eq for StringUdf {}
59+
60+
impl Hash for StringUdf {
61+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62+
self.name.hash(state);
63+
}
64+
}
65+
5266
impl ScalarUDFImpl for StringUdf {
5367
fn as_any(&self) -> &dyn std::any::Any {
5468
self

guests/evil/src/fs.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Payload that interact with the (virtual) file sytem.
2-
use std::sync::Arc;
2+
use std::{hash::Hash, sync::Arc};
33

44
use arrow::{array::StringArray, datatypes::DataType};
55
use datafusion_common::{Result as DataFusionResult, cast::as_string_array};
@@ -49,6 +49,20 @@ impl std::fmt::Debug for String1Udf {
4949
}
5050
}
5151

52+
impl PartialEq<Self> for String1Udf {
53+
fn eq(&self, other: &Self) -> bool {
54+
self.name == other.name
55+
}
56+
}
57+
58+
impl Eq for String1Udf {}
59+
60+
impl Hash for String1Udf {
61+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62+
self.name.hash(state);
63+
}
64+
}
65+
5266
impl ScalarUDFImpl for String1Udf {
5367
fn as_any(&self) -> &dyn std::any::Any {
5468
self
@@ -111,7 +125,7 @@ impl String2Udf {
111125
Self {
112126
name,
113127
effect: Box::new(effect),
114-
signature: Signature::uniform(1, vec![DataType::Utf8], Volatility::Immutable),
128+
signature: Signature::uniform(2, vec![DataType::Utf8], Volatility::Immutable),
115129
}
116130
}
117131
}
@@ -132,6 +146,20 @@ impl std::fmt::Debug for String2Udf {
132146
}
133147
}
134148

149+
impl PartialEq<Self> for String2Udf {
150+
fn eq(&self, other: &Self) -> bool {
151+
self.name == other.name
152+
}
153+
}
154+
155+
impl Eq for String2Udf {}
156+
157+
impl Hash for String2Udf {
158+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
159+
self.name.hash(state);
160+
}
161+
}
162+
135163
impl ScalarUDFImpl for String2Udf {
136164
fn as_any(&self) -> &dyn std::any::Any {
137165
self

host/tests/integration_tests/evil/env.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use std::sync::Arc;
22

33
use arrow::datatypes::{DataType, Field};
4-
use datafusion_common::{cast::as_string_array, config::ConfigOptions};
4+
use datafusion_common::{ScalarValue, config::ConfigOptions};
55
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, async_udf::AsyncScalarUDFImpl};
66
use datafusion_udf_wasm_host::WasmScalarUdf;
77

8-
use crate::integration_tests::evil::test_utils::try_scalar_udfs;
8+
use crate::integration_tests::{evil::test_utils::try_scalar_udfs, test_utils::ColumnarValueExt};
99

1010
#[tokio::test]
1111
async fn test_args() {
@@ -45,23 +45,20 @@ async fn udf(name: &'static str) -> WasmScalarUdf {
4545
}
4646

4747
async fn call(udf: &WasmScalarUdf) -> Option<String> {
48-
let array = udf
49-
.invoke_async_with_args(
50-
ScalarFunctionArgs {
51-
args: vec![],
52-
arg_fields: vec![],
53-
number_rows: 1,
54-
return_field: Arc::new(Field::new("r", DataType::Null, true)),
55-
},
56-
&ConfigOptions::default(),
57-
)
48+
let scalar = udf
49+
.invoke_async_with_args(ScalarFunctionArgs {
50+
args: vec![],
51+
arg_fields: vec![],
52+
number_rows: 1,
53+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
54+
config_options: Arc::new(ConfigOptions::default()),
55+
})
5856
.await
59-
.unwrap();
60-
assert_eq!(array.len(), 1);
61-
as_string_array(&array)
6257
.unwrap()
63-
.into_iter()
64-
.next()
65-
.unwrap()
66-
.map(|s| s.to_owned())
58+
.unwrap_scalar();
59+
if let ScalarValue::Utf8(s) = scalar {
60+
s
61+
} else {
62+
unreachable!("invalid scalar type: {scalar:?}")
63+
}
6764
}

host/tests/integration_tests/evil/fs.rs

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use datafusion_expr::{
99
use datafusion_udf_wasm_host::WasmScalarUdf;
1010
use std::{fmt::Write, sync::Arc};
1111

12-
use crate::integration_tests::evil::test_utils::try_scalar_udfs;
12+
use crate::integration_tests::{evil::test_utils::try_scalar_udfs, test_utils::ColumnarValueExt};
1313

1414
const PATHS: &[&str] = &[
1515
// NOTE: it seems that the WASI guest transforms `` (empty string) into `.` before it even reaches the host
@@ -980,22 +980,21 @@ fn cross(array: &[&'static str]) -> (Vec<&'static str>, Vec<&'static str>) {
980980
/// Run UDF that expects one string input.
981981
async fn run_1(udf: &WasmScalarUdf) -> String {
982982
let array = udf
983-
.invoke_async_with_args(
984-
ScalarFunctionArgs {
985-
args: vec![ColumnarValue::Array(Arc::new(
986-
PATHS
987-
.iter()
988-
.map(|p| Some(p.to_owned()))
989-
.collect::<StringArray>(),
990-
))],
991-
arg_fields: vec![Arc::new(Field::new("a", DataType::Utf8, true))],
992-
number_rows: PATHS.len(),
993-
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
994-
},
995-
&ConfigOptions::default(),
996-
)
983+
.invoke_async_with_args(ScalarFunctionArgs {
984+
args: vec![ColumnarValue::Array(Arc::new(
985+
PATHS
986+
.iter()
987+
.map(|p| Some(p.to_owned()))
988+
.collect::<StringArray>(),
989+
))],
990+
arg_fields: vec![Arc::new(Field::new("a", DataType::Utf8, true))],
991+
number_rows: PATHS.len(),
992+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
993+
config_options: Arc::new(ConfigOptions::default()),
994+
})
997995
.await
998-
.unwrap();
996+
.unwrap()
997+
.unwrap_array();
999998
let array = as_string_array(&array).unwrap();
1000999

10011000
let mut out = String::new();
@@ -1017,33 +1016,32 @@ async fn run_2(udf: &WasmScalarUdf) -> String {
10171016
let (paths_a, paths_b) = cross(PATHS);
10181017

10191018
let array = udf
1020-
.invoke_async_with_args(
1021-
ScalarFunctionArgs {
1022-
args: vec![
1023-
ColumnarValue::Array(Arc::new(
1024-
paths_a
1025-
.iter()
1026-
.map(|p| Some(p.to_owned()))
1027-
.collect::<StringArray>(),
1028-
)),
1029-
ColumnarValue::Array(Arc::new(
1030-
paths_b
1031-
.iter()
1032-
.map(|p| Some(p.to_owned()))
1033-
.collect::<StringArray>(),
1034-
)),
1035-
],
1036-
arg_fields: vec![
1037-
Arc::new(Field::new("a", DataType::Utf8, true)),
1038-
Arc::new(Field::new("b", DataType::Utf8, true)),
1039-
],
1040-
number_rows: paths_a.len(),
1041-
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
1042-
},
1043-
&ConfigOptions::default(),
1044-
)
1019+
.invoke_async_with_args(ScalarFunctionArgs {
1020+
args: vec![
1021+
ColumnarValue::Array(Arc::new(
1022+
paths_a
1023+
.iter()
1024+
.map(|p| Some(p.to_owned()))
1025+
.collect::<StringArray>(),
1026+
)),
1027+
ColumnarValue::Array(Arc::new(
1028+
paths_b
1029+
.iter()
1030+
.map(|p| Some(p.to_owned()))
1031+
.collect::<StringArray>(),
1032+
)),
1033+
],
1034+
arg_fields: vec![
1035+
Arc::new(Field::new("a", DataType::Utf8, true)),
1036+
Arc::new(Field::new("b", DataType::Utf8, true)),
1037+
],
1038+
number_rows: paths_a.len(),
1039+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
1040+
config_options: Arc::new(ConfigOptions::default()),
1041+
})
10451042
.await
1046-
.unwrap();
1043+
.unwrap()
1044+
.unwrap_array();
10471045
let array = as_string_array(&array).unwrap();
10481046

10491047
let mut out = String::new();

host/tests/integration_tests/evil/runtime.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async fn test_maxptr() {
6161
let err = err_call_no_params(&udf).await.to_string();
6262

6363
// linear memory size is nondeterministic
64-
let err = Regex::new(r#"size 0x[0-9]+"#)
64+
let err = Regex::new(r#"size 0x[0-9a-f]+"#)
6565
.unwrap()
6666
.replace_all(&err, "size <SIZE>");
6767

0 commit comments

Comments
 (0)