Skip to content

Commit a082510

Browse files
feat: Allow local arithmetic execution in hybrid engine
1 parent 07222bf commit a082510

File tree

8 files changed

+465
-54
lines changed

8 files changed

+465
-54
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import bigframes.operations.comparison_ops as comp_ops
3636
import bigframes.operations.generic_ops as gen_ops
3737
import bigframes.operations.numeric_ops as num_ops
38+
import bigframes.operations.string_ops as string_ops
3839

3940
polars_installed = True
4041
if TYPE_CHECKING:
@@ -146,6 +147,14 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
146147
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
147148
return input.abs()
148149

150+
@compile_op.register(num_ops.FloorOp)
151+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
152+
return input.floor()
153+
154+
@compile_op.register(num_ops.CeilOp)
155+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
156+
return input.ceil()
157+
149158
@compile_op.register(num_ops.PosOp)
150159
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
151160
return input.__pos__()
@@ -182,10 +191,6 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
182191
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
183192
return l_input // r_input
184193

185-
@compile_op.register(num_ops.FloorDivOp)
186-
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
187-
return l_input // r_input
188-
189194
@compile_op.register(num_ops.ModOp)
190195
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
191196
return l_input % r_input
@@ -270,6 +275,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
270275
# eg. We want "True" instead of "true" for bool to strin
271276
return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe)
272277

278+
@compile_op.register(string_ops.StrConcatOp)
279+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
280+
assert isinstance(op, string_ops.StrConcatOp)
281+
return pl.concat_str(l_input, r_input)
282+
273283
@dataclasses.dataclass(frozen=True)
274284
class PolarsAggregateCompiler:
275285
scalar_compiler = PolarsExpressionCompiler()

bigframes/core/compile/polars/lowering.py

Lines changed: 229 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,241 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3737
return expr.op.as_expr(larg, rarg)
3838

3939

40+
class LowerAddRule(op_lowering.OpLoweringRule):
41+
@property
42+
def op(self) -> type[ops.ScalarOp]:
43+
return numeric_ops.AddOp
44+
45+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
46+
assert isinstance(expr.op, numeric_ops.AddOp)
47+
larg, rarg = expr.children[0], expr.children[1]
48+
49+
if (
50+
larg.output_type == dtypes.BOOL_DTYPE
51+
and rarg.output_type == dtypes.BOOL_DTYPE
52+
):
53+
int_result = expr.op.as_expr(
54+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
55+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
56+
)
57+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
58+
59+
if dtypes.is_string_like(larg.output_type) and dtypes.is_string_like(
60+
rarg.output_type
61+
):
62+
return ops.strconcat_op.as_expr(larg, rarg)
63+
64+
if larg.output_type == dtypes.BOOL_DTYPE:
65+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
66+
if rarg.output_type == dtypes.BOOL_DTYPE:
67+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
68+
69+
if (
70+
larg.output_type == dtypes.DATE_DTYPE
71+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
72+
):
73+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
74+
75+
if (
76+
larg.output_type == dtypes.TIMEDELTA_DTYPE
77+
and rarg.output_type == dtypes.DATE_DTYPE
78+
):
79+
rarg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(rarg)
80+
81+
return expr.op.as_expr(larg, rarg)
82+
83+
84+
class LowerSubRule(op_lowering.OpLoweringRule):
85+
@property
86+
def op(self) -> type[ops.ScalarOp]:
87+
return numeric_ops.SubOp
88+
89+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
90+
assert isinstance(expr.op, numeric_ops.SubOp)
91+
larg, rarg = expr.children[0], expr.children[1]
92+
93+
if (
94+
larg.output_type == dtypes.BOOL_DTYPE
95+
and rarg.output_type == dtypes.BOOL_DTYPE
96+
):
97+
int_result = expr.op.as_expr(
98+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
99+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
100+
)
101+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
102+
103+
if larg.output_type == dtypes.BOOL_DTYPE:
104+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
105+
if rarg.output_type == dtypes.BOOL_DTYPE:
106+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
107+
108+
if (
109+
larg.output_type == dtypes.DATE_DTYPE
110+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
111+
):
112+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
113+
114+
return expr.op.as_expr(larg, rarg)
115+
116+
117+
@dataclasses.dataclass
118+
class LowerMulRule(op_lowering.OpLoweringRule):
119+
@property
120+
def op(self) -> type[ops.ScalarOp]:
121+
return numeric_ops.MulOp
122+
123+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
124+
assert isinstance(expr.op, numeric_ops.MulOp)
125+
larg, rarg = expr.children[0], expr.children[1]
126+
127+
if (
128+
larg.output_type == dtypes.BOOL_DTYPE
129+
and rarg.output_type == dtypes.BOOL_DTYPE
130+
):
131+
int_result = expr.op.as_expr(
132+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
133+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
134+
)
135+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
136+
137+
if (
138+
larg.output_type == dtypes.BOOL_DTYPE
139+
and rarg.output_type != dtypes.BOOL_DTYPE
140+
):
141+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
142+
if (
143+
rarg.output_type == dtypes.BOOL_DTYPE
144+
and larg.output_type != dtypes.BOOL_DTYPE
145+
):
146+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
147+
148+
return expr.op.as_expr(larg, rarg)
149+
150+
151+
class LowerDivRule(op_lowering.OpLoweringRule):
152+
@property
153+
def op(self) -> type[ops.ScalarOp]:
154+
return numeric_ops.DivOp
155+
156+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
157+
assert isinstance(expr.op, numeric_ops.DivOp)
158+
159+
dividend = expr.children[0]
160+
divisor = expr.children[1]
161+
162+
if (
163+
dividend.output_type == dtypes.TIMEDELTA_DTYPE
164+
and divisor.output_type == dtypes.INT_DTYPE
165+
):
166+
int_result = expr.op.as_expr(
167+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
168+
)
169+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
170+
171+
if (
172+
dividend.output_type == dtypes.BOOL_DTYPE
173+
and divisor.output_type == dtypes.BOOL_DTYPE
174+
):
175+
int_result = expr.op.as_expr(
176+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend),
177+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor),
178+
)
179+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
180+
181+
# polars divide doesn't like bools, convert to int always
182+
# convert numerics to float always
183+
if dividend.output_type == dtypes.BOOL_DTYPE:
184+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
185+
elif dividend.output_type in (dtypes.BIGNUMERIC_DTYPE, dtypes.NUMERIC_DTYPE):
186+
dividend = ops.AsTypeOp(to_type=dtypes.FLOAT_DTYPE).as_expr(dividend)
187+
if divisor.output_type == dtypes.BOOL_DTYPE:
188+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
189+
190+
return numeric_ops.div_op.as_expr(dividend, divisor)
191+
192+
40193
class LowerFloorDivRule(op_lowering.OpLoweringRule):
41194
@property
42195
def op(self) -> type[ops.ScalarOp]:
43196
return numeric_ops.FloorDivOp
44197

45198
def lower(self, expr: expression.OpExpression) -> expression.Expression:
199+
assert isinstance(expr.op, numeric_ops.FloorDivOp)
200+
46201
dividend = expr.children[0]
47202
divisor = expr.children[1]
48-
using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or (
49-
divisor.output_type == dtypes.FLOAT_DTYPE
50-
)
51-
inf_or_zero = (
52-
expression.const(float("INF")) if using_floats else expression.const(0)
203+
204+
if (
205+
dividend.output_type == dtypes.TIMEDELTA_DTYPE
206+
and divisor.output_type == dtypes.INT_DTYPE
207+
):
208+
int_result = expr.op.as_expr(
209+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
210+
)
211+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
212+
213+
if dividend.output_type == dtypes.BOOL_DTYPE:
214+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
215+
if divisor.output_type == dtypes.BOOL_DTYPE:
216+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
217+
218+
return numeric_ops.floor_op.as_expr(
219+
numeric_ops.div_op.as_expr(dividend, divisor)
53220
)
54-
zero_result = ops.mul_op.as_expr(inf_or_zero, dividend)
55-
divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0))
56-
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
57221

58222

59-
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):
223+
class LowerModRule(op_lowering.OpLoweringRule):
224+
@property
225+
def op(self) -> type[ops.ScalarOp]:
226+
return numeric_ops.ModOp
227+
228+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
229+
og_expr = expr
230+
assert isinstance(expr.op, numeric_ops.ModOp)
231+
larg, rarg = expr.children[0], expr.children[1]
232+
233+
if (
234+
larg.output_type == dtypes.TIMEDELTA_DTYPE
235+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
236+
):
237+
larg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
238+
rarg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
239+
int_result = expr.op.as_expr(larg_int, rarg_int)
240+
w_zero_handling = ops.where_op.as_expr(
241+
int_result,
242+
ops.ne_op.as_expr(rarg_int, expression.const(0)),
243+
ops.mul_op.as_expr(rarg_int, expression.const(0)),
244+
)
245+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(w_zero_handling)
246+
247+
if larg.output_type == dtypes.BOOL_DTYPE:
248+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
249+
if rarg.output_type == dtypes.BOOL_DTYPE:
250+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
251+
252+
wo_bools = expr.op.as_expr(larg, rarg)
253+
254+
if og_expr.output_type == dtypes.INT_DTYPE:
255+
return ops.where_op.as_expr(
256+
wo_bools,
257+
ops.ne_op.as_expr(rarg, expression.const(0)),
258+
ops.mul_op.as_expr(rarg, expression.const(0)),
259+
)
260+
return wo_bools
261+
262+
263+
def _coerce_comparables(
264+
expr1: expression.Expression,
265+
expr2: expression.Expression,
266+
*,
267+
bools_only: bool = False
268+
):
269+
if bools_only:
270+
if (
271+
expr1.output_type != dtypes.BOOL_DTYPE
272+
and expr2.output_type != dtypes.BOOL_DTYPE
273+
):
274+
return expr1, expr2
60275

61276
target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
62277
if expr1.output_type != target_type:
@@ -90,7 +305,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90305

91306
POLARS_LOWERING_RULES = (
92307
*LOWER_COMPARISONS,
308+
LowerAddRule(),
309+
LowerSubRule(),
310+
LowerMulRule(),
311+
LowerDivRule(),
93312
LowerFloorDivRule(),
313+
LowerModRule(),
94314
)
95315

96316

0 commit comments

Comments
 (0)