Skip to content

Commit a1e99e4

Browse files
authored
Merge pull request #1 from featuremesh/feature/json-extract
json extract
2 parents 6e231f2 + b3aa835 commit a1e99e4

File tree

8 files changed

+187
-3
lines changed

8 files changed

+187
-3
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ datafusion = { version = "46", default-features = false }
1515
jiter = "0.9"
1616
paste = "1"
1717
log = "0.4"
18+
jsonpath-rust = "1.0.0"
1819

1920
[dev-dependencies]
2021
datafusion = { version = "46", default-features = false, features = ["nested_expressions"] }
2122
codspeed-criterion-compat = "2.6"
2223
criterion = "0.5.1"
2324
clap = "4"
2425
tokio = { version = "1.43", features = ["full"] }
26+
rstest = "0.25.0"
2527

2628
[lints.clippy]
2729
dbg_macro = "deny"

src/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ fn dict_key_type(d: &DataType) -> Option<DataType> {
6868
None
6969
}
7070

71-
#[derive(Debug)]
71+
#[derive(Debug, PartialEq, Eq)]
7272
pub enum JsonPath<'s> {
7373
Key(&'s str),
7474
Index(usize),

src/json_extract.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use std::any::Any;
2+
use datafusion::arrow::array::StringArray;
3+
use datafusion::arrow::datatypes::{DataType, DataType::Utf8};
4+
use datafusion::common::{exec_err, Result as DataFusionResult, ScalarValue};
5+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
6+
use jsonpath_rust::parser::model::{Segment, Selector};
7+
use jsonpath_rust::parser::parse_json_path;
8+
use crate::common::{invoke, return_type_check, JsonPath};
9+
use crate::common_macros::make_udf_function;
10+
use crate::json_get_json::jiter_json_get_json;
11+
12+
make_udf_function!(
13+
JsonExtract,
14+
json_extract,
15+
json_data path,
16+
r#"Get a value from a JSON string by its "path" in JSONPath format"#
17+
);
18+
19+
#[derive(Debug)]
20+
pub(super) struct JsonExtract {
21+
signature: Signature,
22+
aliases: [String; 1],
23+
}
24+
25+
impl Default for JsonExtract {
26+
fn default() -> Self {
27+
Self {
28+
signature: Signature::exact(
29+
vec![Utf8, Utf8], // JSON data and JSONPath as strings
30+
Volatility::Immutable,
31+
),
32+
aliases: ["json_extract".to_string()],
33+
}
34+
}
35+
}
36+
37+
impl ScalarUDFImpl for JsonExtract {
38+
fn as_any(&self) -> &dyn Any {
39+
self
40+
}
41+
42+
fn name(&self) -> &str {
43+
self.aliases[0].as_str()
44+
}
45+
46+
fn signature(&self) -> &Signature {
47+
&self.signature
48+
}
49+
50+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
51+
return_type_check(arg_types, self.name(), Utf8)
52+
}
53+
54+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
55+
if args.args.len() != 2 {
56+
return exec_err!(
57+
"'{}' expects exactly 2 arguments (JSON data, path), got {}",
58+
self.name(),
59+
args.args.len()
60+
);
61+
}
62+
63+
let json_arg = &args.args[0];
64+
let path_arg = &args.args[1];
65+
66+
let path_str = match path_arg {
67+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
68+
_ => return exec_err!("'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument", self.name()),
69+
};
70+
71+
let path = parse_jsonpath(path_str);
72+
73+
invoke::<StringArray>(&[json_arg.clone()], |json, _| {
74+
jiter_json_get_json(json, &path)
75+
})
76+
}
77+
78+
fn aliases(&self) -> &[String] {
79+
&self.aliases
80+
}
81+
}
82+
83+
fn parse_jsonpath(path: &str) -> Vec<JsonPath<'static>> {
84+
let segments = parse_json_path(path)
85+
.map(|it| it.segments)
86+
.unwrap_or(Vec::new());
87+
88+
segments.into_iter().map(|segment| {
89+
match segment {
90+
Segment::Selector(s) => match s {
91+
Selector::Name(name) => JsonPath::Key(Box::leak(name.into_boxed_str())),
92+
Selector::Index(idx) => JsonPath::Index(idx as usize),
93+
_ => JsonPath::None,
94+
},
95+
_ => JsonPath::None,
96+
}
97+
}).collect::<Vec<_>>()
98+
}
99+
100+
#[cfg(test)]
101+
mod tests {
102+
use rstest::rstest;
103+
use super::*;
104+
105+
// Test cases for parse_jsonpath
106+
#[rstest]
107+
#[case("$.a.aa", vec![JsonPath::Key("a"), JsonPath::Key("aa")])]
108+
#[case("$.a.ab[0].ac", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(0), JsonPath::Key("ac")])]
109+
#[case("$.a.ab[1].ad", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(1), JsonPath::Key("ad")])]
110+
#[case(r#"$.a["a b"].ad"#, vec![JsonPath::Key("a"), JsonPath::Key("\"a b\""), JsonPath::Key("ad")])]
111+
#[tokio::test]
112+
async fn test_parse_jsonpath(
113+
#[case] path: &str,
114+
#[case] expected: Vec<JsonPath<'static>>,
115+
) {
116+
let result = parse_jsonpath(path);
117+
assert_eq!(result, expected);
118+
}
119+
}

src/json_get.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl InvokeResult for JsonUnion {
9393
}
9494
}
9595

96-
fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result<JsonUnionField, GetError> {
96+
pub(crate) fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result<JsonUnionField, GetError> {
9797
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
9898
build_union(&mut jiter, peek)
9999
} else {

src/json_get_json.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl ScalarUDFImpl for JsonGetJson {
5656
}
5757
}
5858

59-
fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result<String, GetError> {
59+
pub(crate) fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result<String, GetError> {
6060
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
6161
let start = jiter.current_index();
6262
jiter.known_skip(peek)?;

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ mod json_get_str;
1919
mod json_length;
2020
mod json_object_keys;
2121
mod rewrite;
22+
mod json_extract;
2223

2324
pub use common_union::{JsonUnionEncoder, JsonUnionValue};
2425

2526
pub mod functions {
2627
pub use crate::json_as_text::json_as_text;
2728
pub use crate::json_contains::json_contains;
2829
pub use crate::json_get::json_get;
30+
pub use crate::json_extract::json_extract;
2931
pub use crate::json_get_bool::json_get_bool;
3032
pub use crate::json_get_float::json_get_float;
3133
pub use crate::json_get_int::json_get_int;
@@ -60,6 +62,7 @@ pub mod udfs {
6062
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
6163
let functions: Vec<Arc<ScalarUDF>> = vec![
6264
json_get::json_get_udf(),
65+
json_extract::json_extract_udf(),
6366
json_get_bool::json_get_bool_udf(),
6467
json_get_float::json_get_float_udf(),
6568
json_get_int::json_get_int_udf(),

tests/json_extract_test.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use rstest::{fixture, rstest};
2+
use crate::utils::{display_val, run_query};
3+
4+
mod utils;
5+
6+
#[fixture]
7+
fn json_data() -> String {
8+
let json = r#"{"a": {"a a": "My Collection","ab": [{"ac": "Dune", "ca": "Frank Herbert"},{"ad": "Foundation", "da": "Isaac Asimov"}]}}"#;
9+
json.to_string()
10+
}
11+
12+
#[rstest]
13+
#[case("$.a.ab", "[{\"ac\": \"Dune\", \"ca\": \"Frank Herbert\"},{\"ad\": \"Foundation\", \"da\": \"Isaac Asimov\"}]")]
14+
#[tokio::test]
15+
async fn test_json_paths(
16+
json_data: String,
17+
#[case] path: &str,
18+
#[case] expected: &str,
19+
) {
20+
let result = json_extract(&json_data, path).await;
21+
assert_eq!(result, expected.to_string());
22+
}
23+
24+
#[rstest]
25+
#[tokio::test]
26+
#[ignore]
27+
async fn test_invalid_json_path(json_data: String) {
28+
let result = json_extract(&json_data, "store.invalid.path").await;
29+
assert_eq!(result, "".to_string());
30+
}
31+
32+
33+
async fn json_extract(json: &str, path: &str) -> String {
34+
let sql = format!("select json_extract('{}', '{}')", json, path);
35+
let batches = run_query(sql.as_str()).await.unwrap();
36+
display_val(batches).await.1
37+
}
38+

tests/main.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,28 @@ async fn test_json_get_union() {
8282
assert_batches_eq!(expected, &batches);
8383
}
8484

85+
#[tokio::test]
86+
async fn test_json_extract_union() {
87+
let batches = run_query("select name, json_extract(json_data, '$.foo') as foo from test")
88+
.await
89+
.unwrap();
90+
91+
let expected = [
92+
"+------------------+-------------+",
93+
"| name | foo |",
94+
"+------------------+-------------+",
95+
"| object_foo | {str=abc} |",
96+
"| object_foo_array | {array=[1]} |",
97+
"| object_foo_obj | {object={}} |",
98+
"| object_foo_null | {null=} |",
99+
"| object_bar | {null=} |",
100+
"| list_foo | {null=} |",
101+
"| invalid_json | {null=} |",
102+
"+------------------+-------------+",
103+
];
104+
assert_batches_eq!(expected, &batches);
105+
}
106+
85107
#[tokio::test]
86108
async fn test_json_get_array() {
87109
let sql = "select json_get('[1, 2, 3]', 2)";

0 commit comments

Comments
 (0)