From 74c305136f570d10d71450b16464e4dd81362188 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 11 Jul 2025 23:27:35 +0000 Subject: [PATCH 1/4] feat: Allow local arithmetic execution in hybrid engine --- bigframes/core/compile/polars/compiler.py | 18 +- bigframes/core/compile/polars/lowering.py | 257 +++++++++++++++++- bigframes/core/compile/scalar_op_compiler.py | 49 +++- bigframes/core/pyarrow_utils.py | 18 +- bigframes/operations/numeric_ops.py | 39 ++- bigframes/session/polars_executor.py | 26 +- bigframes/testing/engine_utils.py | 2 +- .../system/small/engines/test_numeric_ops.py | 158 +++++++++++ 8 files changed, 519 insertions(+), 48 deletions(-) create mode 100644 tests/system/small/engines/test_numeric_ops.py diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index c31c122078..12f944c211 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -35,6 +35,7 @@ import bigframes.operations.comparison_ops as comp_ops import bigframes.operations.generic_ops as gen_ops import bigframes.operations.numeric_ops as num_ops +import bigframes.operations.string_ops as string_ops polars_installed = True if TYPE_CHECKING: @@ -146,6 +147,14 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: return input.abs() + @compile_op.register(num_ops.FloorOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.floor() + + @compile_op.register(num_ops.CeilOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.ceil() + @compile_op.register(num_ops.PosOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: return input.__pos__() @@ -182,10 +191,6 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: return l_input // r_input - @compile_op.register(num_ops.FloorDivOp) - def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: - return l_input // r_input - @compile_op.register(num_ops.ModOp) def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: return l_input % r_input @@ -270,6 +275,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: # eg. We want "True" instead of "true" for bool to strin return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe) + @compile_op.register(string_ops.StrConcatOp) + def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: + assert isinstance(op, string_ops.StrConcatOp) + return pl.concat_str(l_input, r_input) + @dataclasses.dataclass(frozen=True) class PolarsAggregateCompiler: scalar_compiler = PolarsExpressionCompiler() diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index 48d63e9ed9..63aae79482 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -37,26 +37,258 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return expr.op.as_expr(larg, rarg) +class LowerAddRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return numeric_ops.AddOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, numeric_ops.AddOp) + larg, rarg = expr.children[0], expr.children[1] + + if ( + larg.output_type == dtypes.BOOL_DTYPE + and rarg.output_type == dtypes.BOOL_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg), + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg), + ) + return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result) + + if dtypes.is_string_like(larg.output_type) and dtypes.is_string_like( + rarg.output_type + ): + return ops.strconcat_op.as_expr(larg, rarg) + + if larg.output_type == dtypes.BOOL_DTYPE: + larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg) + if rarg.output_type == dtypes.BOOL_DTYPE: + rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg) + + if ( + larg.output_type == dtypes.DATE_DTYPE + and rarg.output_type == dtypes.TIMEDELTA_DTYPE + ): + larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg) + + if ( + larg.output_type == dtypes.TIMEDELTA_DTYPE + and rarg.output_type == dtypes.DATE_DTYPE + ): + rarg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(rarg) + + return expr.op.as_expr(larg, rarg) + + +class LowerSubRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return numeric_ops.SubOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, numeric_ops.SubOp) + larg, rarg = expr.children[0], expr.children[1] + + if ( + larg.output_type == dtypes.BOOL_DTYPE + and rarg.output_type == dtypes.BOOL_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg), + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg), + ) + return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result) + + if larg.output_type == dtypes.BOOL_DTYPE: + larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg) + if rarg.output_type == dtypes.BOOL_DTYPE: + rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg) + + if ( + larg.output_type == dtypes.DATE_DTYPE + and rarg.output_type == dtypes.TIMEDELTA_DTYPE + ): + larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg) + + return expr.op.as_expr(larg, rarg) + + +@dataclasses.dataclass +class LowerMulRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return numeric_ops.MulOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, numeric_ops.MulOp) + larg, rarg = expr.children[0], expr.children[1] + + if ( + larg.output_type == dtypes.BOOL_DTYPE + and rarg.output_type == dtypes.BOOL_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg), + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg), + ) + return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result) + + if ( + larg.output_type == dtypes.BOOL_DTYPE + and rarg.output_type != dtypes.BOOL_DTYPE + ): + larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg) + if ( + rarg.output_type == dtypes.BOOL_DTYPE + and larg.output_type != dtypes.BOOL_DTYPE + ): + rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg) + + return expr.op.as_expr(larg, rarg) + + +class LowerDivRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return numeric_ops.DivOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, numeric_ops.DivOp) + + dividend = expr.children[0] + divisor = expr.children[1] + + if ( + dividend.output_type == dtypes.TIMEDELTA_DTYPE + and divisor.output_type == dtypes.INT_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor + ) + return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result) + + if ( + dividend.output_type == dtypes.BOOL_DTYPE + and divisor.output_type == dtypes.BOOL_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor), + ) + return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result) + + # polars divide doesn't like bools, convert to int always + # convert numerics to float always + if dividend.output_type == dtypes.BOOL_DTYPE: + dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend) + elif dividend.output_type in (dtypes.BIGNUMERIC_DTYPE, dtypes.NUMERIC_DTYPE): + dividend = ops.AsTypeOp(to_type=dtypes.FLOAT_DTYPE).as_expr(dividend) + if divisor.output_type == dtypes.BOOL_DTYPE: + divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor) + + return numeric_ops.div_op.as_expr(dividend, divisor) + + class LowerFloorDivRule(op_lowering.OpLoweringRule): @property def op(self) -> type[ops.ScalarOp]: return numeric_ops.FloorDivOp def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, numeric_ops.FloorDivOp) + dividend = expr.children[0] divisor = expr.children[1] - using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or ( - divisor.output_type == dtypes.FLOAT_DTYPE - ) - inf_or_zero = ( - expression.const(float("INF")) if using_floats else expression.const(0) - ) - zero_result = ops.mul_op.as_expr(inf_or_zero, dividend) - divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0)) - return ops.where_op.as_expr(zero_result, divisor_is_zero, expr) + + if ( + dividend.output_type == dtypes.TIMEDELTA_DTYPE + and divisor.output_type == dtypes.TIMEDELTA_DTYPE + ): + int_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor), + ) + return int_result + if dividend.output_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric( + divisor.output_type + ): + # this is pretty fragile as zero will break it, and must fit back into int + numeric_result = expr.op.as_expr( + ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor + ) + int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result) + return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result) + + if dividend.output_type == dtypes.BOOL_DTYPE: + dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend) + if divisor.output_type == dtypes.BOOL_DTYPE: + divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor) + + if expr.output_type != dtypes.FLOAT_DTYPE: + # need to guard against zero divisor + # multiply dividend in this case to propagate nulls + return ops.where_op.as_expr( + ops.mul_op.as_expr(dividend, expression.const(0)), + ops.eq_op.as_expr(divisor, expression.const(0)), + numeric_ops.floordiv_op.as_expr(dividend, divisor), + ) + else: + return expr.op.as_expr(dividend, divisor) + + +class LowerModRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return numeric_ops.ModOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + og_expr = expr + assert isinstance(expr.op, numeric_ops.ModOp) + larg, rarg = expr.children[0], expr.children[1] + + if ( + larg.output_type == dtypes.TIMEDELTA_DTYPE + and rarg.output_type == dtypes.TIMEDELTA_DTYPE + ): + larg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg) + rarg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg) + int_result = expr.op.as_expr(larg_int, rarg_int) + w_zero_handling = ops.where_op.as_expr( + int_result, + ops.ne_op.as_expr(rarg_int, expression.const(0)), + ops.mul_op.as_expr(rarg_int, expression.const(0)), + ) + return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(w_zero_handling) + + if larg.output_type == dtypes.BOOL_DTYPE: + larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg) + if rarg.output_type == dtypes.BOOL_DTYPE: + rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg) + + wo_bools = expr.op.as_expr(larg, rarg) + + if og_expr.output_type == dtypes.INT_DTYPE: + return ops.where_op.as_expr( + wo_bools, + ops.ne_op.as_expr(rarg, expression.const(0)), + ops.mul_op.as_expr(rarg, expression.const(0)), + ) + return wo_bools -def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression): +def _coerce_comparables( + expr1: expression.Expression, + expr2: expression.Expression, + *, + bools_only: bool = False +): + if bools_only: + if ( + expr1.output_type != dtypes.BOOL_DTYPE + and expr2.output_type != dtypes.BOOL_DTYPE + ): + return expr1, expr2 target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type) if expr1.output_type != target_type: @@ -90,7 +322,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): POLARS_LOWERING_RULES = ( *LOWER_COMPARISONS, + LowerAddRule(), + LowerSubRule(), + LowerMulRule(), + LowerDivRule(), LowerFloorDivRule(), + LowerModRule(), ) diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index 30da6b2cb2..7c7890cd6e 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -1498,7 +1498,7 @@ def eq_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x == y @@ -1508,7 +1508,7 @@ def eq_nulls_match_op( y: ibis_types.Value, ): """Variant of eq_op where nulls match each other. Only use where dtypes are known to be same.""" - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) literal = ibis_types.literal("$NULL_SENTINEL$") if hasattr(x, "fill_null"): left = x.cast(ibis_dtypes.str).fill_null(literal) @@ -1525,7 +1525,7 @@ def ne_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x != y @@ -1537,13 +1537,10 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue ) -def _coerce_comparables( - x: ibis_types.Value, - y: ibis_types.Value, -): - if x.type().is_boolean() and not y.type().is_boolean(): +def _coerce_bools(x: ibis_types.Value, y: ibis_types.Value, *, always: bool = False): + if x.type().is_boolean() and (always or not y.type().is_boolean()): x = x.cast(ibis_dtypes.int64) - elif y.type().is_boolean() and not x.type().is_boolean(): + if y.type().is_boolean() and (always or not x.type().is_boolean()): y = y.cast(ibis_dtypes.int64) return x, y @@ -1604,8 +1601,18 @@ def add_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_bools(x, y) if isinstance(x, ibis_types.NullScalar) or isinstance(x, ibis_types.NullScalar): return ibis_types.null() + + if x.type().is_boolean() and y.type().is_boolean(): + x, y = _coerce_bools(x, y, always=True) + return ( + typing.cast(ibis_types.NumericValue, x) + + typing.cast(ibis_types.NumericValue, x) + ).cast(ibis_dtypes.Boolean) + + x, y = _coerce_bools(x, y) return x + y # type: ignore @@ -1615,6 +1622,7 @@ def sub_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_bools(x, y) return typing.cast(ibis_types.NumericValue, x) - typing.cast( ibis_types.NumericValue, y ) @@ -1626,6 +1634,13 @@ def mul_op( x: ibis_types.Value, y: ibis_types.Value, ): + if x.type().is_boolean() and y.type().is_boolean(): + x, y = _coerce_bools(x, y, always=True) + return ( + typing.cast(ibis_types.NumericValue, x) + * typing.cast(ibis_types.NumericValue, x) + ).cast(ibis_dtypes.Boolean) + x, y = _coerce_bools(x, y) return typing.cast(ibis_types.NumericValue, x) * typing.cast( ibis_types.NumericValue, y ) @@ -1637,6 +1652,7 @@ def div_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_bools(x, y) return typing.cast(ibis_types.NumericValue, x) / typing.cast( ibis_types.NumericValue, y ) @@ -1648,6 +1664,7 @@ def pow_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_bools(x, y) if x.type().is_integer() and y.type().is_integer(): return _int_pow_op(x, y) else: @@ -1661,6 +1678,7 @@ def unsafe_pow_op( y: ibis_types.Value, ): """For internal use only - where domain and overflow checks are not needed.""" + x, y = _coerce_bools(x, y) return typing.cast(ibis_types.NumericValue, x) ** typing.cast( ibis_types.NumericValue, y ) @@ -1749,7 +1767,7 @@ def lt_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x < y @@ -1759,7 +1777,7 @@ def le_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x <= y @@ -1769,7 +1787,7 @@ def gt_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x > y @@ -1779,7 +1797,7 @@ def ge_op( x: ibis_types.Value, y: ibis_types.Value, ): - x, y = _coerce_comparables(x, y) + x, y = _coerce_bools(x, y) return x >= y @@ -1789,6 +1807,10 @@ def floordiv_op( x: ibis_types.Value, y: ibis_types.Value, ): + if x.type().is_boolean(): + x = x.cast(ibis_dtypes.int64) + elif y.type().is_boolean(): + y = y.cast(ibis_dtypes.int64) x_numeric = typing.cast(ibis_types.NumericValue, x) y_numeric = typing.cast(ibis_types.NumericValue, y) floordiv_expr = x_numeric // y_numeric @@ -1827,6 +1849,7 @@ def mod_op( if isinstance(op, ibis_generic.Literal) and op.value == 0: return ibis_types.null().cast(x.type()) + x, y = _coerce_bools(x, y) if x.type().is_integer() and y.type().is_integer(): # both are ints, no casting necessary return _int_mod( diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py index b9dc2ea2b3..d86457e878 100644 --- a/bigframes/core/pyarrow_utils.py +++ b/bigframes/core/pyarrow_utils.py @@ -78,10 +78,20 @@ def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch: if batch.schema == schema: return batch # TODO: Use RecordBatch.cast once min pyarrow>=16.0 - return pa.record_batch( - [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], - schema=schema, - ) + # return pa.record_batch( + # [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], + # schema=schema, + # ) + arrs = [] + + for arr, type in zip(batch.columns, schema.types): + try: + value = arr.cast(type) + except Exception as e: + print(e) + raise + arrs.append(value) + return pa.record_batch(arrs, schema=schema) def truncate_pyarrow_iterable( diff --git a/bigframes/operations/numeric_ops.py b/bigframes/operations/numeric_ops.py index 64eec9d8a1..7ef38f48be 100644 --- a/bigframes/operations/numeric_ops.py +++ b/bigframes/operations/numeric_ops.py @@ -140,7 +140,8 @@ class AddOp(base_ops.BinaryOp): def output_type(self, *input_types): left_type = input_types[0] right_type = input_types[1] - if all(map(dtypes.is_string_like, input_types)) and len(set(input_types)) == 1: + # TODO: Binary/bytes addition requires impl + if all(map(lambda t: t == dtypes.STRING_DTYPE, input_types)): # String addition return input_types[0] @@ -179,7 +180,7 @@ def output_type(self, *input_types): left_type = input_types[0] right_type = input_types[1] - if dtypes.is_datetime_like(left_type) and dtypes.is_datetime_like(right_type): + if left_type == dtypes.DATETIME_DTYPE and right_type == dtypes.DATETIME_DTYPE: return dtypes.TIMEDELTA_DTYPE if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE: @@ -194,6 +195,9 @@ def output_type(self, *input_types): if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE + if left_type == dtypes.BOOL_DTYPE and right_type == dtypes.BOOL_DTYPE: + raise TypeError(f"Cannot subtract dtypes {left_type} and {right_type}") + if (left_type is None or dtypes.is_numeric(left_type)) and ( right_type is None or dtypes.is_numeric(right_type) ): @@ -214,9 +218,15 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and right_type in ( + dtypes.INT_DTYPE, + dtypes.FLOAT_DTYPE, + ): return dtypes.TIMEDELTA_DTYPE - if dtypes.is_numeric(left_type) and right_type == dtypes.TIMEDELTA_DTYPE: + if ( + left_type in (dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE) + and right_type == dtypes.TIMEDELTA_DTYPE + ): return dtypes.TIMEDELTA_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( @@ -238,12 +248,16 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == (dtypes.INT_DTYPE): + # will fail outright if right value is zero though return dtypes.TIMEDELTA_DTYPE if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.FLOAT_DTYPE + if left_type == dtypes.BOOL_DTYPE and right_type == dtypes.BOOL_DTYPE: + raise TypeError(f"Cannot divide dtypes {left_type} and {right_type}") + if (left_type is None or dtypes.is_numeric(left_type)) and ( right_type is None or dtypes.is_numeric(right_type) ): @@ -265,11 +279,14 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: + return dtypes.INT_DTYPE + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): return dtypes.TIMEDELTA_DTYPE - if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: - return dtypes.INT_DTYPE + if left_type == dtypes.BOOL_DTYPE and right_type == dtypes.BOOL_DTYPE: + raise TypeError(f"Cannot floor divide dtypes {left_type} and {right_type}") if (left_type is None or dtypes.is_numeric(left_type)) and ( right_type is None or dtypes.is_numeric(right_type) @@ -292,6 +309,14 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE + if left_type in ( + dtypes.NUMERIC_DTYPE, + dtypes.BIGNUMERIC_DTYPE, + ) or right_type in (dtypes.NUMERIC_DTYPE, dtypes.BIGNUMERIC_DTYPE): + raise TypeError(f"Cannot mod dtypes {left_type} and {right_type}") + + if left_type == dtypes.BOOL_DTYPE and right_type == dtypes.BOOL_DTYPE: + raise TypeError(f"Cannot mod dtypes {left_type} and {right_type}") if (left_type is None or dtypes.is_numeric(left_type)) and ( right_type is None or dtypes.is_numeric(right_type) diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 3c23e4c200..5dbaa30c2f 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -21,6 +21,7 @@ from bigframes.core import array_value, bigframe_node, expression, local_data, nodes import bigframes.operations from bigframes.operations import aggregations as agg_ops +from bigframes.operations import comparison_ops, numeric_ops from bigframes.session import executor, semi_executor if TYPE_CHECKING: @@ -41,13 +42,19 @@ ) _COMPATIBLE_SCALAR_OPS = ( - bigframes.operations.eq_op, - bigframes.operations.eq_null_match_op, - bigframes.operations.ne_op, - bigframes.operations.gt_op, - bigframes.operations.lt_op, - bigframes.operations.ge_op, - bigframes.operations.le_op, + comparison_ops.EqOp, + comparison_ops.EqNullsMatchOp, + comparison_ops.NeOp, + comparison_ops.LtOp, + comparison_ops.GtOp, + comparison_ops.LeOp, + comparison_ops.GeOp, + numeric_ops.AddOp, + numeric_ops.SubOp, + numeric_ops.MulOp, + numeric_ops.DivOp, + numeric_ops.FloorDivOp, + numeric_ops.ModOp, ) _COMPATIBLE_AGG_OPS = ( agg_ops.SizeOp, @@ -74,7 +81,7 @@ def _is_node_polars_executable(node: nodes.BigFrameNode): if not type(expr.op) in _COMPATIBLE_AGG_OPS: return False if isinstance(expr, expression.Expression): - if not _get_expr_ops(expr).issubset(_COMPATIBLE_SCALAR_OPS): + if not set(map(type, _get_expr_ops(expr))).issubset(_COMPATIBLE_SCALAR_OPS): return False return True @@ -117,7 +124,8 @@ def _can_execute(self, plan: bigframe_node.BigFrameNode): def _adapt_array(self, array: pa.Array) -> pa.Array: target_type = local_data.logical_type_replacements(array.type) if target_type != array.type: - return array.cast(target_type) + # Safe is false to handle weird polars decimal scaling + return array.cast(target_type, safe=False) return array def _adapt_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: diff --git a/bigframes/testing/engine_utils.py b/bigframes/testing/engine_utils.py index 8aa52cf51a..625d1727ee 100644 --- a/bigframes/testing/engine_utils.py +++ b/bigframes/testing/engine_utils.py @@ -31,4 +31,4 @@ def assert_equivalence_execution( assert e1_result.schema == e2_result.schema e1_table = e1_result.to_pandas() e2_table = e2_result.to_pandas() - pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-10) + pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-5) diff --git a/tests/system/small/engines/test_numeric_ops.py b/tests/system/small/engines/test_numeric_ops.py new file mode 100644 index 0000000000..ef14a3ff7e --- /dev/null +++ b/tests/system/small/engines/test_numeric_ops.py @@ -0,0 +1,158 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import itertools + +import pytest + +from bigframes.core import array_value, expression +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +def apply_op_pairwise( + array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[] +) -> array_value.ArrayValue: + exprs = [] + labels = [] + for l_arg, r_arg in itertools.product(array.column_ids, array.column_ids): + if (l_arg in excluded_cols) or (r_arg in excluded_cols): + continue + try: + _ = op.output_type( + array.get_column_type(l_arg), array.get_column_type(r_arg) + ) + expr = op.as_expr(l_arg, r_arg) + exprs.append(expr) + labels.append(f"{l_arg}_{r_arg}") + except TypeError: + continue + assert len(exprs) > 0 + new_arr, ids = array.compute_values(exprs) + new_arr = new_arr.rename_columns( + {new_col: label for new_col, label in zip(ids, labels)} + ) + return new_arr + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_add( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise(scalars_array_value, ops.add_op) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_sub( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise(scalars_array_value, ops.sub_op) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_mul( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise(scalars_array_value, ops.mul_op) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_div(scalars_array_value: array_value.ArrayValue, engine): + # TODO: Duration div is sensitive to zeroes + # TODO: Numeric col is sensitive to scale shifts + arr = apply_op_pairwise( + scalars_array_value, ops.div_op, excluded_cols=["duration_col", "numeric_col"] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_div_durations( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.div_op.as_expr( + expression.deref("duration_col"), + expression.const(datetime.timedelta(seconds=3)), + ), + ops.div_op.as_expr( + expression.deref("duration_col"), + expression.const(datetime.timedelta(seconds=-3)), + ), + ops.div_op.as_expr(expression.deref("duration_col"), expression.const(4)), + ops.div_op.as_expr(expression.deref("duration_col"), expression.const(-4)), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_floordiv( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise( + scalars_array_value, + ops.floordiv_op, + excluded_cols=["duration_col", "numeric_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_floordiv_durations( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.floordiv_op.as_expr( + expression.deref("duration_col"), + expression.const(datetime.timedelta(seconds=3)), + ), + ops.floordiv_op.as_expr( + expression.deref("duration_col"), + expression.const(datetime.timedelta(seconds=-3)), + ), + ops.floordiv_op.as_expr( + expression.deref("duration_col"), expression.const(4) + ), + ops.floordiv_op.as_expr( + expression.deref("duration_col"), expression.const(-4) + ), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_mod( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise(scalars_array_value, ops.mod_op) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) From e6d96a2e1dba49ba8c37441e35ecaaecfd7872fa Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 16 Jul 2025 21:50:45 +0000 Subject: [PATCH 2/4] revert cast_batch --- bigframes/core/pyarrow_utils.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py index d86457e878..b9dc2ea2b3 100644 --- a/bigframes/core/pyarrow_utils.py +++ b/bigframes/core/pyarrow_utils.py @@ -78,20 +78,10 @@ def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch: if batch.schema == schema: return batch # TODO: Use RecordBatch.cast once min pyarrow>=16.0 - # return pa.record_batch( - # [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], - # schema=schema, - # ) - arrs = [] - - for arr, type in zip(batch.columns, schema.types): - try: - value = arr.cast(type) - except Exception as e: - print(e) - raise - arrs.append(value) - return pa.record_batch(arrs, schema=schema) + return pa.record_batch( + [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], + schema=schema, + ) def truncate_pyarrow_iterable( From 806d18610852aa3fe6b22070f2aec4b1b5e3e2bb Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 16 Jul 2025 23:41:57 +0000 Subject: [PATCH 3/4] fix timestamp diff --- bigframes/operations/numeric_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bigframes/operations/numeric_ops.py b/bigframes/operations/numeric_ops.py index 7ef38f48be..efd963c141 100644 --- a/bigframes/operations/numeric_ops.py +++ b/bigframes/operations/numeric_ops.py @@ -183,6 +183,9 @@ def output_type(self, *input_types): if left_type == dtypes.DATETIME_DTYPE and right_type == dtypes.DATETIME_DTYPE: return dtypes.TIMEDELTA_DTYPE + if left_type == dtypes.TIMESTAMP_DTYPE and right_type == dtypes.TIMESTAMP_DTYPE: + return dtypes.TIMEDELTA_DTYPE + if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE: return dtypes.TIMEDELTA_DTYPE From 5aec1d1b274e310abb745f5e3e2f60740203e56d Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 16 Jul 2025 23:56:21 +0000 Subject: [PATCH 4/4] fix timedelta division --- bigframes/core/compile/polars/lowering.py | 9 +++++---- bigframes/operations/numeric_ops.py | 4 ++-- tests/system/small/engines/test_numeric_ops.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index 63aae79482..ee0933b450 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -159,13 +159,14 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: dividend = expr.children[0] divisor = expr.children[1] - if ( - dividend.output_type == dtypes.TIMEDELTA_DTYPE - and divisor.output_type == dtypes.INT_DTYPE + if dividend.output_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric( + divisor.output_type ): - int_result = expr.op.as_expr( + # exact same as floordiv impl for timedelta + numeric_result = ops.floordiv_op.as_expr( ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor ) + int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result) return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result) if ( diff --git a/bigframes/operations/numeric_ops.py b/bigframes/operations/numeric_ops.py index efd963c141..afdc924c0b 100644 --- a/bigframes/operations/numeric_ops.py +++ b/bigframes/operations/numeric_ops.py @@ -251,8 +251,8 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type == dtypes.TIMEDELTA_DTYPE and right_type == (dtypes.INT_DTYPE): - # will fail outright if right value is zero though + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + # will fail outright if result undefined or otherwise can't be coerced back into an int return dtypes.TIMEDELTA_DTYPE if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: diff --git a/tests/system/small/engines/test_numeric_ops.py b/tests/system/small/engines/test_numeric_ops.py index ef14a3ff7e..b53da977f5 100644 --- a/tests/system/small/engines/test_numeric_ops.py +++ b/tests/system/small/engines/test_numeric_ops.py @@ -106,6 +106,12 @@ def test_engines_project_div_durations( ), ops.div_op.as_expr(expression.deref("duration_col"), expression.const(4)), ops.div_op.as_expr(expression.deref("duration_col"), expression.const(-4)), + ops.div_op.as_expr( + expression.deref("duration_col"), expression.const(55.55) + ), + ops.div_op.as_expr( + expression.deref("duration_col"), expression.const(-55.55) + ), ] ) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -144,6 +150,12 @@ def test_engines_project_floordiv_durations( ops.floordiv_op.as_expr( expression.deref("duration_col"), expression.const(-4) ), + ops.floordiv_op.as_expr( + expression.deref("duration_col"), expression.const(55.55) + ), + ops.floordiv_op.as_expr( + expression.deref("duration_col"), expression.const(-55.55) + ), ] ) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)