Skip to content

Commit 8ef4de1

Browse files
authored
feat: support astype conversions to and from JSON dtypes (#1716)
1 parent 5585f7a commit 8ef4de1

File tree

3 files changed

+199
-4
lines changed

3 files changed

+199
-4
lines changed

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,35 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
11641164
elif to_type == ibis_dtypes.time:
11651165
return x_converted.time()
11661166

1167+
if to_type == ibis_dtypes.json:
1168+
if x.type() == ibis_dtypes.string:
1169+
return parse_json_in_safe(x) if op.safe else parse_json(x)
1170+
if x.type() == ibis_dtypes.bool:
1171+
x_bool = typing.cast(
1172+
ibis_types.StringValue,
1173+
bigframes.core.compile.ibis_types.cast_ibis_value(
1174+
x, ibis_dtypes.string, safe=op.safe
1175+
),
1176+
).lower()
1177+
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
1178+
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
1179+
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
1180+
x, ibis_dtypes.string, safe=op.safe
1181+
)
1182+
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)
1183+
1184+
if x.type() == ibis_dtypes.json:
1185+
if to_type == ibis_dtypes.int64:
1186+
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
1187+
if to_type == ibis_dtypes.float64:
1188+
return (
1189+
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
1190+
)
1191+
if to_type == ibis_dtypes.bool:
1192+
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
1193+
if to_type == ibis_dtypes.string:
1194+
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
1195+
11671196
# TODO: either inline this function, or push rest of this op into the function
11681197
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)
11691198

@@ -2047,6 +2076,11 @@ def parse_json(json_str: str) -> ibis_dtypes.JSON: # type: ignore[empty-body]
20472076
"""Converts a JSON-formatted STRING value to a JSON value."""
20482077

20492078

2079+
@ibis_udf.scalar.builtin(name="SAFE.PARSE_JSON")
2080+
def parse_json_in_safe(json_str: str) -> ibis_dtypes.JSON: # type: ignore[empty-body]
2081+
"""Converts a JSON-formatted STRING value to a JSON value in the safe mode."""
2082+
2083+
20502084
@ibis_udf.scalar.builtin(name="json_set")
20512085
def json_set( # type: ignore[empty-body]
20522086
json_obj: ibis_dtypes.JSON, json_path: ibis_dtypes.String, json_value
@@ -2075,6 +2109,46 @@ def json_value( # type: ignore[empty-body]
20752109
"""Retrieve value of a JSON field as plain STRING."""
20762110

20772111

2112+
@ibis_udf.scalar.builtin(name="INT64")
2113+
def cast_json_to_int64(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Int64: # type: ignore[empty-body]
2114+
"""Converts a JSON number to a SQL INT64 value."""
2115+
2116+
2117+
@ibis_udf.scalar.builtin(name="SAFE.INT64")
2118+
def cast_json_to_int64_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Int64: # type: ignore[empty-body]
2119+
"""Converts a JSON number to a SQL INT64 value in the safe mode."""
2120+
2121+
2122+
@ibis_udf.scalar.builtin(name="FLOAT64")
2123+
def cast_json_to_float64(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Float64: # type: ignore[empty-body]
2124+
"""Attempts to convert a JSON value to a SQL FLOAT64 value."""
2125+
2126+
2127+
@ibis_udf.scalar.builtin(name="SAFE.FLOAT64")
2128+
def cast_json_to_float64_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Float64: # type: ignore[empty-body]
2129+
"""Attempts to convert a JSON value to a SQL FLOAT64 value."""
2130+
2131+
2132+
@ibis_udf.scalar.builtin(name="BOOL")
2133+
def cast_json_to_bool(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Boolean: # type: ignore[empty-body]
2134+
"""Attempts to convert a JSON value to a SQL BOOL value."""
2135+
2136+
2137+
@ibis_udf.scalar.builtin(name="SAFE.BOOL")
2138+
def cast_json_to_bool_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Boolean: # type: ignore[empty-body]
2139+
"""Attempts to convert a JSON value to a SQL BOOL value."""
2140+
2141+
2142+
@ibis_udf.scalar.builtin(name="STRING")
2143+
def cast_json_to_string(json_str: ibis_dtypes.JSON) -> ibis_dtypes.String: # type: ignore[empty-body]
2144+
"""Attempts to convert a JSON value to a SQL STRING value."""
2145+
2146+
2147+
@ibis_udf.scalar.builtin(name="SAFE.STRING")
2148+
def cast_json_to_string_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.String: # type: ignore[empty-body]
2149+
"""Attempts to convert a JSON value to a SQL STRING value."""
2150+
2151+
20782152
@ibis_udf.scalar.builtin(name="ML.DISTANCE")
20792153
def vector_distance(vector1, vector2, type: str) -> ibis_dtypes.Float64: # type: ignore[empty-body]
20802154
"""Computes the distance between two vectors using specified type ("EUCLIDEAN", "MANHATTAN", or "COSINE")"""

tests/system/small/test_series.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
import datetime as dt
16+
import json
1617
import math
1718
import re
1819
import tempfile
1920

2021
import db_dtypes # type: ignore
2122
import geopandas as gpd # type: ignore
23+
import google.api_core.exceptions
2224
import numpy
2325
from packaging.version import Version
2426
import pandas as pd
@@ -3474,9 +3476,11 @@ def foo(x):
34743476
("int64_col", pd.ArrowDtype(pa.timestamp("us"))),
34753477
("int64_col", pd.ArrowDtype(pa.timestamp("us", tz="UTC"))),
34763478
("int64_col", "time64[us][pyarrow]"),
3479+
("int64_col", pd.ArrowDtype(db_dtypes.JSONArrowType())),
34773480
("bool_col", "Int64"),
34783481
("bool_col", "string[pyarrow]"),
34793482
("bool_col", "Float64"),
3483+
("bool_col", pd.ArrowDtype(db_dtypes.JSONArrowType())),
34803484
("string_col", "binary[pyarrow]"),
34813485
("bytes_col", "string[pyarrow]"),
34823486
# pandas actually doesn't let folks convert to/from naive timestamp and
@@ -3541,7 +3545,7 @@ def test_astype_safe(session):
35413545
pd.testing.assert_series_equal(result, exepcted)
35423546

35433547

3544-
def test_series_astype_error_error(session):
3548+
def test_series_astype_w_invalid_error(session):
35453549
input = pd.Series(["hello", "world", "3.11", "4000"])
35463550
with pytest.raises(ValueError):
35473551
session.read_pandas(input).astype("Float64", errors="bad_value")
@@ -3676,6 +3680,119 @@ def test_timestamp_astype_string():
36763680
assert bf_result.dtype == "string[pyarrow]"
36773681

36783682

3683+
@pytest.mark.parametrize("errors", ["raise", "null"])
3684+
def test_float_astype_json(errors):
3685+
data = ["1.25", "2500000000", None, "-12323.24"]
3686+
bf_series = series.Series(data, dtype=dtypes.FLOAT_DTYPE)
3687+
3688+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors=errors)
3689+
assert bf_result.dtype == dtypes.JSON_DTYPE
3690+
3691+
expected_result = pd.Series(data, dtype=dtypes.JSON_DTYPE)
3692+
expected_result.index = expected_result.index.astype("Int64")
3693+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected_result)
3694+
3695+
3696+
@pytest.mark.parametrize("errors", ["raise", "null"])
3697+
def test_string_astype_json(errors):
3698+
data = [
3699+
"1",
3700+
None,
3701+
'["1","3","5"]',
3702+
'{"a":1,"b":["x","y"],"c":{"x":[],"z":false}}',
3703+
]
3704+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3705+
3706+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors=errors)
3707+
assert bf_result.dtype == dtypes.JSON_DTYPE
3708+
3709+
pd_result = bf_series.to_pandas().astype(dtypes.JSON_DTYPE)
3710+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
3711+
3712+
3713+
def test_string_astype_json_in_safe_mode():
3714+
data = ["this is not a valid json string"]
3715+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3716+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors="null")
3717+
assert bf_result.dtype == dtypes.JSON_DTYPE
3718+
3719+
expected = pd.Series([None], dtype=dtypes.JSON_DTYPE)
3720+
expected.index = expected.index.astype("Int64")
3721+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3722+
3723+
3724+
def test_string_astype_json_raise_error():
3725+
data = ["this is not a valid json string"]
3726+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3727+
with pytest.raises(
3728+
google.api_core.exceptions.BadRequest,
3729+
match="syntax error while parsing value",
3730+
):
3731+
bf_series.astype(dtypes.JSON_DTYPE, errors="raise").to_pandas()
3732+
3733+
3734+
@pytest.mark.parametrize("errors", ["raise", "null"])
3735+
@pytest.mark.parametrize(
3736+
("data", "to_type"),
3737+
[
3738+
pytest.param(["1", "10.0", None], dtypes.INT_DTYPE, id="to_int"),
3739+
pytest.param(["0.0001", "2500000000", None], dtypes.FLOAT_DTYPE, id="to_float"),
3740+
pytest.param(["true", "false", None], dtypes.BOOL_DTYPE, id="to_bool"),
3741+
pytest.param(['"str"', None], dtypes.STRING_DTYPE, id="to_string"),
3742+
pytest.param(
3743+
['"str"', None],
3744+
dtypes.TIME_DTYPE,
3745+
id="invalid",
3746+
marks=pytest.mark.xfail(raises=TypeError),
3747+
),
3748+
],
3749+
)
3750+
def test_json_astype_others(data, to_type, errors):
3751+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3752+
3753+
bf_result = bf_series.astype(to_type, errors=errors)
3754+
assert bf_result.dtype == to_type
3755+
3756+
load_data = [json.loads(item) if item is not None else None for item in data]
3757+
expected = pd.Series(load_data, dtype=to_type)
3758+
expected.index = expected.index.astype("Int64")
3759+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3760+
3761+
3762+
@pytest.mark.parametrize(
3763+
("data", "to_type"),
3764+
[
3765+
pytest.param(["10.2", None], dtypes.INT_DTYPE, id="to_int"),
3766+
pytest.param(["false", None], dtypes.FLOAT_DTYPE, id="to_float"),
3767+
pytest.param(["10.2", None], dtypes.BOOL_DTYPE, id="to_bool"),
3768+
pytest.param(["true", None], dtypes.STRING_DTYPE, id="to_string"),
3769+
],
3770+
)
3771+
def test_json_astype_others_raise_error(data, to_type):
3772+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3773+
with pytest.raises(google.api_core.exceptions.BadRequest):
3774+
bf_series.astype(to_type, errors="raise").to_pandas()
3775+
3776+
3777+
@pytest.mark.parametrize(
3778+
("data", "to_type"),
3779+
[
3780+
pytest.param(["10.2", None], dtypes.INT_DTYPE, id="to_int"),
3781+
pytest.param(["false", None], dtypes.FLOAT_DTYPE, id="to_float"),
3782+
pytest.param(["10.2", None], dtypes.BOOL_DTYPE, id="to_bool"),
3783+
pytest.param(["true", None], dtypes.STRING_DTYPE, id="to_string"),
3784+
],
3785+
)
3786+
def test_json_astype_others_in_safe_mode(data, to_type):
3787+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3788+
bf_result = bf_series.astype(to_type, errors="null")
3789+
assert bf_result.dtype == to_type
3790+
3791+
expected = pd.Series([None, None], dtype=to_type)
3792+
expected.index = expected.index.astype("Int64")
3793+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3794+
3795+
36793796
@pytest.mark.parametrize(
36803797
"index",
36813798
[0, 5, -2],
@@ -3687,9 +3804,7 @@ def test_iloc_single_integer(scalars_df_index, scalars_pandas_df_index, index):
36873804
assert bf_result == pd_result
36883805

36893806

3690-
def test_iloc_single_integer_out_of_bound_error(
3691-
scalars_df_index, scalars_pandas_df_index
3692-
):
3807+
def test_iloc_single_integer_out_of_bound_error(scalars_df_index):
36933808
with pytest.raises(IndexError, match="single positional indexer is out-of-bounds"):
36943809
scalars_df_index.string_col.iloc[99]
36953810

third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,12 @@ def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
12221222
# not actually a table, but easier to quote individual namespace
12231223
# components this way
12241224
namespace = op.__udf_namespace__
1225+
1226+
# Function names prefixed with "SAFE.", such as `SAFE.PARSE_JSON`,
1227+
# are typically not quoted.
1228+
if funcname.startswith("SAFE."):
1229+
return funcname
1230+
12251231
return sg.table(funcname, db=namespace.database, catalog=namespace.catalog).sql(
12261232
self.dialect
12271233
)

0 commit comments

Comments
 (0)