Skip to content

fix: support JSON and STRUCT for bbq.sql_scalar #1754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions bigframes/bigquery/_operations/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import google.cloud.bigquery

import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir
import bigframes.core.sql
import bigframes.dataframe
import bigframes.dtypes
Expand Down Expand Up @@ -72,16 +73,16 @@ def sql_scalar(
# Another benefit of this is that if there is a syntax error in the SQL
# template, then this will fail with an error earlier in the process,
# aiding users in debugging.
base_series = columns[0]
literals = [
bigframes.dtypes.bigframes_dtype_to_literal(column.dtype) for column in columns
literals_sql = [
sqlglot_ir._literal(None, column.dtype).sql(dialect="bigquery")
for column in columns
]
literals_sql = [bigframes.core.sql.simple_literal(literal) for literal in literals]
select_sql = sql_template.format(*literals_sql)
dry_run_sql = f"SELECT {select_sql}"

# Use the executor directly, because we want the original column IDs, not
# the user-friendly column names that block.to_sql_query() would produce.
select_sql = sql_template.format(*literals_sql)
dry_run_sql = f"SELECT {select_sql}"
base_series = columns[0]
bqclient = base_series._session.bqclient
job = bqclient.query(
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)
Expand Down
7 changes: 7 additions & 0 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import typing

from google.cloud import bigquery
import numpy as np
import pyarrow as pa
import sqlglot as sg
import sqlglot.dialects.bigquery
Expand Down Expand Up @@ -213,7 +214,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
elif dtype == dtypes.BYTES_DTYPE:
return _cast(str(value), sqlglot_type)
elif dtypes.is_time_like(dtype):
if isinstance(value, np.generic):
value = value.item()
return _cast(sge.convert(value.isoformat()), sqlglot_type)
elif dtype in (dtypes.NUMERIC_DTYPE, dtypes.BIGNUMERIC_DTYPE):
return _cast(sge.convert(value), sqlglot_type)
elif dtypes.is_geo_like(dtype):
wkt = value if isinstance(value, str) else to_wkt(value)
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
Expand All @@ -234,6 +239,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
)
return values if len(value) > 0 else _cast(values, sqlglot_type)
else:
if isinstance(value, np.generic):
value = value.item()
return sge.convert(value)


Expand Down
27 changes: 0 additions & 27 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,33 +499,6 @@ def bigframes_dtype_to_arrow_dtype(
)


def bigframes_dtype_to_literal(
bigframes_dtype: Dtype,
) -> Any:
"""Create a representative literal value for a bigframes dtype.

The inverse of infer_literal_type().
"""
if isinstance(bigframes_dtype, pd.ArrowDtype):
arrow_type = bigframes_dtype.pyarrow_dtype
return arrow_type_to_literal(arrow_type)

if isinstance(bigframes_dtype, pd.Float64Dtype):
return 1.0
if isinstance(bigframes_dtype, pd.Int64Dtype):
return 1
if isinstance(bigframes_dtype, pd.BooleanDtype):
return True
if isinstance(bigframes_dtype, pd.StringDtype):
return "string"
if isinstance(bigframes_dtype, gpd.array.GeometryDtype):
return shapely.geometry.Point((0, 0))

raise TypeError(
f"No literal conversion for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
)


def arrow_type_to_literal(
arrow_type: pa.DataType,
) -> Any:
Expand Down
117 changes: 114 additions & 3 deletions tests/system/small/bigquery/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import bigframes.bigquery
import pandas as pd
import pytest

import bigframes.bigquery as bbq
import bigframes.dtypes as dtypes
import bigframes.pandas as bpd

def test_sql_scalar_on_scalars_null_index(scalars_df_null_index):
series = bigframes.bigquery.sql_scalar(

def test_sql_scalar_for_all_scalar_types(scalars_df_null_index):
series = bbq.sql_scalar(
"""
CAST({0} AS INT64)
+ BYTE_LENGTH({1})
Expand Down Expand Up @@ -48,3 +53,109 @@ def test_sql_scalar_on_scalars_null_index(scalars_df_null_index):
)
result = series.to_pandas()
assert len(result) == len(scalars_df_null_index)


def test_sql_scalar_for_bool_series(scalars_df_index):
series: bpd.Series = scalars_df_index["bool_col"]
result = bbq.sql_scalar("CAST({0} AS INT64)", [series])
expected = series.astype(dtypes.INT_DTYPE)
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


@pytest.mark.parametrize(
("column_name"),
[
pytest.param("bool_col"),
pytest.param("bytes_col"),
pytest.param("date_col"),
pytest.param("datetime_col"),
pytest.param("geography_col"),
pytest.param("int64_col"),
pytest.param("numeric_col"),
pytest.param("float64_col"),
pytest.param("string_col"),
pytest.param("time_col"),
pytest.param("timestamp_col"),
],
)
def test_sql_scalar_outputs_all_scalar_types(scalars_df_index, column_name):
series: bpd.Series = scalars_df_index[column_name]
result = bbq.sql_scalar("{0}", [series])
expected = series
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_for_array_series(repeated_df):
result = bbq.sql_scalar(
"""
ARRAY_LENGTH({0}) + ARRAY_LENGTH({1}) + ARRAY_LENGTH({2})
+ ARRAY_LENGTH({3}) + ARRAY_LENGTH({4}) + ARRAY_LENGTH({5})
+ ARRAY_LENGTH({6})
""",
[
repeated_df["int_list_col"],
repeated_df["bool_list_col"],
repeated_df["float_list_col"],
repeated_df["date_list_col"],
repeated_df["date_time_list_col"],
repeated_df["numeric_list_col"],
repeated_df["string_list_col"],
],
)

expected = (
repeated_df["int_list_col"].list.len()
+ repeated_df["bool_list_col"].list.len()
+ repeated_df["float_list_col"].list.len()
+ repeated_df["date_list_col"].list.len()
+ repeated_df["date_time_list_col"].list.len()
+ repeated_df["numeric_list_col"].list.len()
+ repeated_df["string_list_col"].list.len()
)
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_outputs_array_series(repeated_df):
result = bbq.sql_scalar("{0}", [repeated_df["int_list_col"]])
expected = repeated_df["int_list_col"]
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_for_struct_series(nested_structs_df):
result = bbq.sql_scalar(
"CHAR_LENGTH({0}.name) + {0}.age",
[nested_structs_df["person"]],
)
expected = nested_structs_df["person"].struct.field(
"name"
).str.len() + nested_structs_df["person"].struct.field("age")
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_outputs_struct_series(nested_structs_df):
result = bbq.sql_scalar("{0}", [nested_structs_df["person"]])
expected = nested_structs_df["person"]
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_for_json_series(json_df):
result = bbq.sql_scalar(
"""JSON_VALUE({0}, '$.int_value')""",
[
json_df["json_col"],
],
)
expected = bbq.json_value(json_df["json_col"], "$.int_value")
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())


def test_sql_scalar_outputs_json_series(json_df):
result = bbq.sql_scalar("{0}", [json_df["json_col"]])
expected = json_df["json_col"]
expected.name = None
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
ST_GEOGFROMTEXT('POINT (-122.0838511 37.3860517)'),
123456789,
0,
1.234567890,
CAST(1.234567890 AS NUMERIC),
1.25,
0,
0,
Expand All @@ -27,7 +27,7 @@ WITH `bfcte_0` AS (
ST_GEOGFROMTEXT('POINT (-71.104 42.315)'),
-987654321,
1,
1.234567890,
CAST(1.234567890 AS NUMERIC),
2.51,
1,
1,
Expand All @@ -44,7 +44,7 @@ WITH `bfcte_0` AS (
ST_GEOGFROMTEXT('POINT (-0.124474760143016 51.5007826749545)'),
314159,
0,
101.101010100,
CAST(101.101010100 AS NUMERIC),
25000000000.0,
2,
2,
Expand Down Expand Up @@ -95,7 +95,7 @@ WITH `bfcte_0` AS (
CAST(NULL AS GEOGRAPHY),
55555,
0,
5.555555000,
CAST(5.555555000 AS NUMERIC),
555.555,
5,
5,
Expand All @@ -112,7 +112,7 @@ WITH `bfcte_0` AS (
ST_GEOGFROMTEXT('LINESTRING (-0.127959 51.507728, -0.127026 51.507473)'),
101202303,
2,
-10.090807000,
CAST(-10.090807000 AS NUMERIC),
-123.456,
6,
6,
Expand All @@ -129,7 +129,7 @@ WITH `bfcte_0` AS (
CAST(NULL AS GEOGRAPHY),
-214748367,
2,
11111111.100000000,
CAST(11111111.100000000 AS NUMERIC),
42.42,
7,
7,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def from_ibis(cls, dtype: dt.DataType) -> str:
)
elif dtype.is_integer():
return "INT64"
elif dtype.is_boolean():
return "BOOLEAN"
elif dtype.is_binary():
return "BYTES"
elif dtype.is_string():
Expand Down