Skip to content

Commit 8767541

Browse files
committed
feat: support astype conversions to and from JSON dtypes
1 parent c260fc8 commit 8767541

File tree

3 files changed

+199
-4
lines changed

3 files changed

+199
-4
lines changed

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,35 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
11481148
elif to_type == ibis_dtypes.time:
11491149
return x_converted.time()
11501150

1151+
if to_type == ibis_dtypes.json:
1152+
if x.type() == ibis_dtypes.string:
1153+
return parse_json_in_safe(x) if op.safe else parse_json(x)
1154+
if x.type() == ibis_dtypes.bool:
1155+
x_bool = typing.cast(
1156+
ibis_types.StringValue,
1157+
bigframes.core.compile.ibis_types.cast_ibis_value(
1158+
x, ibis_dtypes.string, safe=op.safe
1159+
),
1160+
).lower()
1161+
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
1162+
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
1163+
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
1164+
x, ibis_dtypes.string, safe=op.safe
1165+
)
1166+
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)
1167+
1168+
if x.type() == ibis_dtypes.json:
1169+
if to_type == ibis_dtypes.int64:
1170+
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
1171+
if to_type == ibis_dtypes.float64:
1172+
return (
1173+
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
1174+
)
1175+
if to_type == ibis_dtypes.bool:
1176+
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
1177+
if to_type == ibis_dtypes.string:
1178+
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
1179+
11511180
# TODO: either inline this function, or push rest of this op into the function
11521181
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)
11531182

@@ -2031,6 +2060,11 @@ def parse_json(json_str: str) -> ibis_dtypes.JSON: # type: ignore[empty-body]
20312060
"""Converts a JSON-formatted STRING value to a JSON value."""
20322061

20332062

2063+
@ibis_udf.scalar.builtin(name="SAFE.PARSE_JSON")
2064+
def parse_json_in_safe(json_str: str) -> ibis_dtypes.JSON: # type: ignore[empty-body]
2065+
"""Converts a JSON-formatted STRING value to a JSON value in the safe mode."""
2066+
2067+
20342068
@ibis_udf.scalar.builtin(name="json_set")
20352069
def json_set( # type: ignore[empty-body]
20362070
json_obj: ibis_dtypes.JSON, json_path: ibis_dtypes.String, json_value
@@ -2059,6 +2093,46 @@ def json_value( # type: ignore[empty-body]
20592093
"""Retrieve value of a JSON field as plain STRING."""
20602094

20612095

2096+
@ibis_udf.scalar.builtin(name="INT64")
2097+
def cast_json_to_int64(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Int64: # type: ignore[empty-body]
2098+
"""Converts a JSON number to a SQL INT64 value."""
2099+
2100+
2101+
@ibis_udf.scalar.builtin(name="SAFE.INT64")
2102+
def cast_json_to_int64_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Int64: # type: ignore[empty-body]
2103+
"""Converts a JSON number to a SQL INT64 value in the safe mode."""
2104+
2105+
2106+
@ibis_udf.scalar.builtin(name="FLOAT64")
2107+
def cast_json_to_float64(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Float64: # type: ignore[empty-body]
2108+
"""Attempts to convert a JSON value to a SQL FLOAT64 value."""
2109+
2110+
2111+
@ibis_udf.scalar.builtin(name="SAFE.FLOAT64")
2112+
def cast_json_to_float64_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Float64: # type: ignore[empty-body]
2113+
"""Attempts to convert a JSON value to a SQL FLOAT64 value."""
2114+
2115+
2116+
@ibis_udf.scalar.builtin(name="BOOL")
2117+
def cast_json_to_bool(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Boolean: # type: ignore[empty-body]
2118+
"""Attempts to convert a JSON value to a SQL BOOL value."""
2119+
2120+
2121+
@ibis_udf.scalar.builtin(name="SAFE.BOOL")
2122+
def cast_json_to_bool_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.Boolean: # type: ignore[empty-body]
2123+
"""Attempts to convert a JSON value to a SQL BOOL value."""
2124+
2125+
2126+
@ibis_udf.scalar.builtin(name="STRING")
2127+
def cast_json_to_string(json_str: ibis_dtypes.JSON) -> ibis_dtypes.String: # type: ignore[empty-body]
2128+
"""Attempts to convert a JSON value to a SQL STRING value."""
2129+
2130+
2131+
@ibis_udf.scalar.builtin(name="SAFE.STRING")
2132+
def cast_json_to_string_in_safe(json_str: ibis_dtypes.JSON) -> ibis_dtypes.String: # type: ignore[empty-body]
2133+
"""Attempts to convert a JSON value to a SQL STRING value."""
2134+
2135+
20622136
@ibis_udf.scalar.builtin(name="ML.DISTANCE")
20632137
def vector_distance(vector1, vector2, type: str) -> ibis_dtypes.Float64: # type: ignore[empty-body]
20642138
"""Computes the distance between two vectors using specified type ("EUCLIDEAN", "MANHATTAN", or "COSINE")"""

tests/system/small/test_series.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
import datetime as dt
16+
import json
1617
import math
1718
import re
1819
import tempfile
1920

2021
import db_dtypes # type: ignore
2122
import geopandas as gpd # type: ignore
23+
import google.api_core.exceptions
2224
import numpy
2325
from packaging.version import Version
2426
import pandas as pd
@@ -3474,9 +3476,11 @@ def foo(x):
34743476
("int64_col", pd.ArrowDtype(pa.timestamp("us"))),
34753477
("int64_col", pd.ArrowDtype(pa.timestamp("us", tz="UTC"))),
34763478
("int64_col", "time64[us][pyarrow]"),
3479+
("int64_col", pd.ArrowDtype(db_dtypes.JSONArrowType())),
34773480
("bool_col", "Int64"),
34783481
("bool_col", "string[pyarrow]"),
34793482
("bool_col", "Float64"),
3483+
("bool_col", pd.ArrowDtype(db_dtypes.JSONArrowType())),
34803484
("string_col", "binary[pyarrow]"),
34813485
("bytes_col", "string[pyarrow]"),
34823486
# pandas actually doesn't let folks convert to/from naive timestamp and
@@ -3541,7 +3545,7 @@ def test_astype_safe(session):
35413545
pd.testing.assert_series_equal(result, exepcted)
35423546

35433547

3544-
def test_series_astype_error_error(session):
3548+
def test_series_astype_w_invalid_error(session):
35453549
input = pd.Series(["hello", "world", "3.11", "4000"])
35463550
with pytest.raises(ValueError):
35473551
session.read_pandas(input).astype("Float64", errors="bad_value")
@@ -3676,6 +3680,119 @@ def test_timestamp_astype_string():
36763680
assert bf_result.dtype == "string[pyarrow]"
36773681

36783682

3683+
@pytest.mark.parametrize("errors", ["raise", "null"])
3684+
def test_float_astype_json(errors):
3685+
data = ["1.25", "2500000000", None, "-12323.24"]
3686+
bf_series = series.Series(data, dtype=dtypes.FLOAT_DTYPE)
3687+
3688+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors=errors)
3689+
assert bf_result.dtype == dtypes.JSON_DTYPE
3690+
3691+
expected_result = pd.Series(data, dtype=dtypes.JSON_DTYPE)
3692+
expected_result.index = expected_result.index.astype("Int64")
3693+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected_result)
3694+
3695+
3696+
@pytest.mark.parametrize("errors", ["raise", "null"])
3697+
def test_string_astype_json(errors):
3698+
data = [
3699+
"1",
3700+
None,
3701+
'["1","3","5"]',
3702+
'{"a":1,"b":["x","y"],"c":{"x":[],"z":false}}',
3703+
]
3704+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3705+
3706+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors=errors)
3707+
assert bf_result.dtype == dtypes.JSON_DTYPE
3708+
3709+
pd_result = bf_series.to_pandas().astype(dtypes.JSON_DTYPE)
3710+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
3711+
3712+
3713+
def test_string_astype_json_in_safe_mode():
3714+
data = ["this is not a valid json string"]
3715+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3716+
bf_result = bf_series.astype(dtypes.JSON_DTYPE, errors="null")
3717+
assert bf_result.dtype == dtypes.JSON_DTYPE
3718+
3719+
expected = pd.Series([None], dtype=dtypes.JSON_DTYPE)
3720+
expected.index = expected.index.astype("Int64")
3721+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3722+
3723+
3724+
def test_string_astype_json_raise_error():
3725+
data = ["this is not a valid json string"]
3726+
bf_series = series.Series(data, dtype=dtypes.STRING_DTYPE)
3727+
with pytest.raises(
3728+
google.api_core.exceptions.BadRequest,
3729+
match="syntax error while parsing value",
3730+
):
3731+
bf_series.astype(dtypes.JSON_DTYPE, errors="raise").to_pandas()
3732+
3733+
3734+
@pytest.mark.parametrize("errors", ["raise", "null"])
3735+
@pytest.mark.parametrize(
3736+
("data", "to_type"),
3737+
[
3738+
pytest.param(["1", "10.0", None], dtypes.INT_DTYPE, id="to_int"),
3739+
pytest.param(["0.0001", "2500000000", None], dtypes.FLOAT_DTYPE, id="to_float"),
3740+
pytest.param(["true", "false", None], dtypes.BOOL_DTYPE, id="to_bool"),
3741+
pytest.param(['"str"', None], dtypes.STRING_DTYPE, id="to_string"),
3742+
pytest.param(
3743+
['"str"', None],
3744+
dtypes.TIME_DTYPE,
3745+
id="invalid",
3746+
marks=pytest.mark.xfail(raises=TypeError),
3747+
),
3748+
],
3749+
)
3750+
def test_json_astype_others(data, to_type, errors):
3751+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3752+
3753+
bf_result = bf_series.astype(to_type, errors=errors)
3754+
assert bf_result.dtype == to_type
3755+
3756+
load_data = [json.loads(item) if item is not None else None for item in data]
3757+
expected = pd.Series(load_data, dtype=to_type)
3758+
expected.index = expected.index.astype("Int64")
3759+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3760+
3761+
3762+
@pytest.mark.parametrize(
3763+
("data", "to_type"),
3764+
[
3765+
pytest.param(["10.2", None], dtypes.INT_DTYPE, id="to_int"),
3766+
pytest.param(["false", None], dtypes.FLOAT_DTYPE, id="to_float"),
3767+
pytest.param(["10.2", None], dtypes.BOOL_DTYPE, id="to_bool"),
3768+
pytest.param(["true", None], dtypes.STRING_DTYPE, id="to_string"),
3769+
],
3770+
)
3771+
def test_json_astype_others_raise_error(data, to_type):
3772+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3773+
with pytest.raises(google.api_core.exceptions.BadRequest):
3774+
bf_series.astype(to_type, errors="raise").to_pandas()
3775+
3776+
3777+
@pytest.mark.parametrize(
3778+
("data", "to_type"),
3779+
[
3780+
pytest.param(["10.2", None], dtypes.INT_DTYPE, id="to_int"),
3781+
pytest.param(["false", None], dtypes.FLOAT_DTYPE, id="to_float"),
3782+
pytest.param(["10.2", None], dtypes.BOOL_DTYPE, id="to_bool"),
3783+
pytest.param(["true", None], dtypes.STRING_DTYPE, id="to_string"),
3784+
],
3785+
)
3786+
def test_json_astype_others_in_safe_mode(data, to_type):
3787+
bf_series = series.Series(data, dtype=dtypes.JSON_DTYPE)
3788+
bf_result = bf_series.astype(to_type, errors="null")
3789+
assert bf_result.dtype == to_type
3790+
3791+
expected = pd.Series([None, None], dtype=to_type)
3792+
expected.index = expected.index.astype("Int64")
3793+
pd.testing.assert_series_equal(bf_result.to_pandas(), expected)
3794+
3795+
36793796
@pytest.mark.parametrize(
36803797
"index",
36813798
[0, 5, -2],
@@ -3687,9 +3804,7 @@ def test_iloc_single_integer(scalars_df_index, scalars_pandas_df_index, index):
36873804
assert bf_result == pd_result
36883805

36893806

3690-
def test_iloc_single_integer_out_of_bound_error(
3691-
scalars_df_index, scalars_pandas_df_index
3692-
):
3807+
def test_iloc_single_integer_out_of_bound_error(scalars_df_index):
36933808
with pytest.raises(IndexError, match="single positional indexer is out-of-bounds"):
36943809
scalars_df_index.string_col.iloc[99]
36953810

third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,12 @@ def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
12221222
# not actually a table, but easier to quote individual namespace
12231223
# components this way
12241224
namespace = op.__udf_namespace__
1225+
1226+
# Function names prefixed with "SAFE.", such as `SAFE.PARSE_JSON`,
1227+
# are typically not quoted.
1228+
if funcname.startswith("SAFE."):
1229+
return funcname
1230+
12251231
return sg.table(funcname, db=namespace.database, catalog=namespace.catalog).sql(
12261232
self.dialect
12271233
)

0 commit comments

Comments
 (0)