Skip to content

Commit dd71fac

Browse files
feat: Allow local arithmetic execution in hybrid engine
1 parent 07222bf commit dd71fac

File tree

5 files changed

+140
-19
lines changed

5 files changed

+140
-19
lines changed

bigframes/core/compile/polars/lowering.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@
2626
@dataclasses.dataclass
2727
class CoerceArgsRule(op_lowering.OpLoweringRule):
2828
op_type: type[ops.BinaryOp]
29+
bools_only: bool = False
2930

3031
@property
3132
def op(self) -> type[ops.ScalarOp]:
3233
return self.op_type
3334

3435
def lower(self, expr: expression.OpExpression) -> expression.Expression:
3536
assert isinstance(expr.op, self.op_type)
36-
larg, rarg = _coerce_comparables(expr.children[0], expr.children[1])
37+
larg, rarg = _coerce_comparables(
38+
expr.children[0], expr.children[1], bools_only=self.bools_only
39+
)
3740
return expr.op.as_expr(larg, rarg)
3841

3942

@@ -56,7 +59,18 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
5659
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
5760

5861

59-
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):
62+
def _coerce_comparables(
63+
expr1: expression.Expression,
64+
expr2: expression.Expression,
65+
*,
66+
bools_only: bool = False
67+
):
68+
if bools_only:
69+
if (
70+
expr1.output_type != dtypes.BOOL_DTYPE
71+
and expr2.output_type != dtypes.BOOL_DTYPE
72+
):
73+
return expr1, expr2
6074

6175
target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
6276
if expr1.output_type != target_type:
@@ -88,8 +102,20 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
88102
)
89103
)
90104

105+
LOWER_NUMERICS = tuple(
106+
CoerceArgsRule(op, bools_only=True)
107+
for op in (
108+
numeric_ops.AddOp,
109+
numeric_ops.SubOp,
110+
numeric_ops.MulOp,
111+
numeric_ops.DivOp,
112+
numeric_ops.FloorDivOp,
113+
numeric_ops.ModOp,
114+
)
115+
)
91116
POLARS_LOWERING_RULES = (
92117
*LOWER_COMPARISONS,
118+
*LOWER_NUMERICS,
93119
LowerFloorDivRule(),
94120
)
95121

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def eq_op(
14981498
x: ibis_types.Value,
14991499
y: ibis_types.Value,
15001500
):
1501-
x, y = _coerce_comparables(x, y)
1501+
x, y = _coerce_bools(x, y)
15021502
return x == y
15031503

15041504

@@ -1508,7 +1508,7 @@ def eq_nulls_match_op(
15081508
y: ibis_types.Value,
15091509
):
15101510
"""Variant of eq_op where nulls match each other. Only use where dtypes are known to be same."""
1511-
x, y = _coerce_comparables(x, y)
1511+
x, y = _coerce_bools(x, y)
15121512
literal = ibis_types.literal("$NULL_SENTINEL$")
15131513
if hasattr(x, "fill_null"):
15141514
left = x.cast(ibis_dtypes.str).fill_null(literal)
@@ -1525,7 +1525,7 @@ def ne_op(
15251525
x: ibis_types.Value,
15261526
y: ibis_types.Value,
15271527
):
1528-
x, y = _coerce_comparables(x, y)
1528+
x, y = _coerce_bools(x, y)
15291529
return x != y
15301530

15311531

@@ -1537,7 +1537,7 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue
15371537
)
15381538

15391539

1540-
def _coerce_comparables(
1540+
def _coerce_bools(
15411541
x: ibis_types.Value,
15421542
y: ibis_types.Value,
15431543
):
@@ -1604,6 +1604,7 @@ def add_op(
16041604
x: ibis_types.Value,
16051605
y: ibis_types.Value,
16061606
):
1607+
x, y = _coerce_bools(x, y)
16071608
if isinstance(x, ibis_types.NullScalar) or isinstance(x, ibis_types.NullScalar):
16081609
return ibis_types.null()
16091610
return x + y # type: ignore
@@ -1615,6 +1616,7 @@ def sub_op(
16151616
x: ibis_types.Value,
16161617
y: ibis_types.Value,
16171618
):
1619+
x, y = _coerce_bools(x, y)
16181620
return typing.cast(ibis_types.NumericValue, x) - typing.cast(
16191621
ibis_types.NumericValue, y
16201622
)
@@ -1626,6 +1628,7 @@ def mul_op(
16261628
x: ibis_types.Value,
16271629
y: ibis_types.Value,
16281630
):
1631+
x, y = _coerce_bools(x, y)
16291632
return typing.cast(ibis_types.NumericValue, x) * typing.cast(
16301633
ibis_types.NumericValue, y
16311634
)
@@ -1637,6 +1640,7 @@ def div_op(
16371640
x: ibis_types.Value,
16381641
y: ibis_types.Value,
16391642
):
1643+
x, y = _coerce_bools(x, y)
16401644
return typing.cast(ibis_types.NumericValue, x) / typing.cast(
16411645
ibis_types.NumericValue, y
16421646
)
@@ -1648,6 +1652,7 @@ def pow_op(
16481652
x: ibis_types.Value,
16491653
y: ibis_types.Value,
16501654
):
1655+
x, y = _coerce_bools(x, y)
16511656
if x.type().is_integer() and y.type().is_integer():
16521657
return _int_pow_op(x, y)
16531658
else:
@@ -1661,6 +1666,7 @@ def unsafe_pow_op(
16611666
y: ibis_types.Value,
16621667
):
16631668
"""For internal use only - where domain and overflow checks are not needed."""
1669+
x, y = _coerce_bools(x, y)
16641670
return typing.cast(ibis_types.NumericValue, x) ** typing.cast(
16651671
ibis_types.NumericValue, y
16661672
)
@@ -1749,7 +1755,7 @@ def lt_op(
17491755
x: ibis_types.Value,
17501756
y: ibis_types.Value,
17511757
):
1752-
x, y = _coerce_comparables(x, y)
1758+
x, y = _coerce_bools(x, y)
17531759
return x < y
17541760

17551761

@@ -1759,7 +1765,7 @@ def le_op(
17591765
x: ibis_types.Value,
17601766
y: ibis_types.Value,
17611767
):
1762-
x, y = _coerce_comparables(x, y)
1768+
x, y = _coerce_bools(x, y)
17631769
return x <= y
17641770

17651771

@@ -1769,7 +1775,7 @@ def gt_op(
17691775
x: ibis_types.Value,
17701776
y: ibis_types.Value,
17711777
):
1772-
x, y = _coerce_comparables(x, y)
1778+
x, y = _coerce_bools(x, y)
17731779
return x > y
17741780

17751781

@@ -1779,7 +1785,7 @@ def ge_op(
17791785
x: ibis_types.Value,
17801786
y: ibis_types.Value,
17811787
):
1782-
x, y = _coerce_comparables(x, y)
1788+
x, y = _coerce_bools(x, y)
17831789
return x >= y
17841790

17851791

@@ -1822,6 +1828,7 @@ def mod_op(
18221828
x: ibis_types.Value,
18231829
y: ibis_types.Value,
18241830
):
1831+
x, y = _coerce_bools(x, y)
18251832
# Hacky short-circuit to avoid passing zero-literal to sql backend, evaluate locally instead to null.
18261833
op = y.op()
18271834
if isinstance(op, ibis_generic.Literal) and op.value == 0:

bigframes/operations/numeric_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def output_type(self, *input_types):
179179
left_type = input_types[0]
180180
right_type = input_types[1]
181181

182-
if dtypes.is_datetime_like(left_type) and dtypes.is_datetime_like(right_type):
182+
if left_type == dtypes.DATETIME_DTYPE and right_type == dtypes.DATETIME_DTYPE:
183183
return dtypes.TIMEDELTA_DTYPE
184184

185185
if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE:

bigframes/session/polars_executor.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from bigframes.core import array_value, bigframe_node, expression, local_data, nodes
2222
import bigframes.operations
2323
from bigframes.operations import aggregations as agg_ops
24+
from bigframes.operations import comparison_ops, generic_ops, numeric_ops
2425
from bigframes.session import executor, semi_executor
2526

2627
if TYPE_CHECKING:
@@ -41,13 +42,20 @@
4142
)
4243

4344
_COMPATIBLE_SCALAR_OPS = (
44-
bigframes.operations.eq_op,
45-
bigframes.operations.eq_null_match_op,
46-
bigframes.operations.ne_op,
47-
bigframes.operations.gt_op,
48-
bigframes.operations.lt_op,
49-
bigframes.operations.ge_op,
50-
bigframes.operations.le_op,
45+
comparison_ops.EqOp,
46+
comparison_ops.EqNullsMatchOp,
47+
comparison_ops.NeOp,
48+
comparison_ops.LtOp,
49+
comparison_ops.GtOp,
50+
comparison_ops.LeOp,
51+
comparison_ops.GeOp,
52+
generic_ops.WhereOp,
53+
numeric_ops.AddOp,
54+
numeric_ops.SubOp,
55+
numeric_ops.MulOp,
56+
numeric_ops.DivOp,
57+
numeric_ops.FloorDivOp,
58+
numeric_ops.ModOp,
5159
)
5260
_COMPATIBLE_AGG_OPS = (
5361
agg_ops.SizeOp,
@@ -74,7 +82,7 @@ def _is_node_polars_executable(node: nodes.BigFrameNode):
7482
if not type(expr.op) in _COMPATIBLE_AGG_OPS:
7583
return False
7684
if isinstance(expr, expression.Expression):
77-
if not _get_expr_ops(expr).issubset(_COMPATIBLE_SCALAR_OPS):
85+
if not set(map(type, _get_expr_ops(expr))).issubset(_COMPATIBLE_SCALAR_OPS):
7886
return False
7987
return True
8088

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
from typing import Callable
17+
18+
import pytest
19+
20+
from bigframes.core import array_value, expression
21+
import bigframes.operations as ops
22+
from bigframes.session import polars_executor
23+
from bigframes.testing.engine_utils import assert_equivalence_execution
24+
25+
pytest.importorskip("polars")
26+
27+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
28+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
29+
30+
OP_CONSTRAINTS: dict[
31+
type[ops.BinaryOp],
32+
Callable[[str, str], expression.Expression],
33+
] = {ops.DivOp: lambda x, y: ops.ne_op.as_expr(y, expression.const(0))}
34+
35+
36+
def apply_op_pairwise(
37+
array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[]
38+
) -> array_value.ArrayValue:
39+
exprs = []
40+
for l_arg, r_arg in itertools.permutations(array.column_ids, 2):
41+
if (l_arg in excluded_cols) or (r_arg in excluded_cols):
42+
continue
43+
try:
44+
_ = op.output_type(
45+
array.get_column_type(l_arg), array.get_column_type(r_arg)
46+
)
47+
expr = op.as_expr(l_arg, r_arg)
48+
op_type = type(op)
49+
if op_type in OP_CONSTRAINTS:
50+
expr = ops.where_op.as_expr(
51+
expr, OP_CONSTRAINTS[op_type](l_arg, r_arg), expression.const(None)
52+
)
53+
exprs.append(op.as_expr(l_arg, r_arg))
54+
except TypeError:
55+
continue
56+
assert len(exprs) > 0
57+
new_arr, _ = array.compute_values(exprs)
58+
return new_arr
59+
60+
61+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
62+
@pytest.mark.parametrize(
63+
"op",
64+
[
65+
ops.add_op,
66+
ops.sub_op,
67+
ops.mul_op,
68+
ops.div_op,
69+
# ops.floordiv_op,
70+
# ops.mod_op,
71+
],
72+
)
73+
def test_engines_project_numeric_op(
74+
scalars_array_value: array_value.ArrayValue, engine, op
75+
):
76+
# exclude string cols as does not contain dates
77+
# bool col actually doesn't work properly for bq engine
78+
# .select_columns(["datetime_col", "duration_col"]
79+
arr = apply_op_pairwise(scalars_array_value, op)
80+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)