diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index bf3c0ee639..ea8e608a84 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -111,7 +111,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedExpr: - if expr.dtype is dtypes.TIMEDELTA_DTYPE: + if expr.dtype == dtypes.TIMEDELTA_DTYPE: int_repr = utils.timedelta_to_micros(expr.value) # type: ignore return _TypedExpr(ex.const(int_repr, expr.dtype), expr.dtype) @@ -148,31 +148,31 @@ def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype): return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right) - if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.timestamp_sub_op, left, right) if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.DATE_DTYPE: return _TypedExpr.create_op_expr(ops.date_diff_op, left, right) - if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE: + if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.date_sub_op, left, right) return _TypedExpr.create_op_expr(ops.sub_op, left, right) def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: - if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right) - if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype): + if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype): # Re-arrange operands such that timestamp is always on the left and timedelta is # always on the right. return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left) - if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE: + if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.date_add_op, left, right) - if left.dtype is dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE: + if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE: # Re-arrange operands such that date is always on the left and timedelta is # always on the right. return _TypedExpr.create_op_expr(ops.date_add_op, right, left) @@ -183,9 +183,9 @@ def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: result = _TypedExpr.create_op_expr(ops.mul_op, left, right) - if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): + if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result) - if dtypes.is_numeric(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result) return result @@ -194,7 +194,7 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: result = _TypedExpr.create_op_expr(ops.div_op, left, right) - if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): + if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result) return result @@ -203,14 +203,14 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right) - if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): + if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result) return result def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr): - if arg.dtype is dtypes.TIMEDELTA_DTYPE: + if arg.dtype == dtypes.TIMEDELTA_DTYPE: # Do nothing for values that are already timedeltas return arg @@ -239,19 +239,19 @@ def _rewrite_aggregation( aggs.DateSeriesDiffOp(aggregation.op.periods), aggregation.arg ) - if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE: + if isinstance(aggregation.op, aggs.StdOp) and input_type == dtypes.TIMEDELTA_DTYPE: return ex.UnaryAggregation( aggs.StdOp(should_floor_result=True), aggregation.arg ) - if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE: + if isinstance(aggregation.op, aggs.MeanOp) and input_type == dtypes.TIMEDELTA_DTYPE: return ex.UnaryAggregation( aggs.MeanOp(should_floor_result=True), aggregation.arg ) if ( isinstance(aggregation.op, aggs.QuantileOp) - and input_type is dtypes.TIMEDELTA_DTYPE + and input_type == dtypes.TIMEDELTA_DTYPE ): return ex.UnaryAggregation( aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True), diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index a714f5804c..0ae4516dfd 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -142,7 +142,7 @@ class SumOp(UnaryAggregateOp): name: ClassVar[str] = "sum" def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if input_types[0] is dtypes.TIMEDELTA_DTYPE: + if input_types[0] == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE if dtypes.is_numeric(input_types[0]): @@ -185,7 +185,7 @@ def order_independent(self) -> bool: return True def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if input_types[0] is dtypes.TIMEDELTA_DTYPE: + if input_types[0] == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0]) @@ -233,7 +233,7 @@ class MeanOp(UnaryAggregateOp): should_floor_result: bool = False def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if input_types[0] is dtypes.TIMEDELTA_DTYPE: + if input_types[0] == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0]) @@ -275,7 +275,7 @@ class StdOp(UnaryAggregateOp): should_floor_result: bool = False def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if input_types[0] is dtypes.TIMEDELTA_DTYPE: + if input_types[0] == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE return signatures.FixedOutputType( diff --git a/bigframes/operations/numeric_ops.py b/bigframes/operations/numeric_ops.py index ae23aff707..d06d6eb336 100644 --- a/bigframes/operations/numeric_ops.py +++ b/bigframes/operations/numeric_ops.py @@ -124,9 +124,9 @@ def output_type(self, *input_types): return input_types[0] # Temporal addition. - if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE: return left_type - if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type): return right_type if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: @@ -135,7 +135,7 @@ def output_type(self, *input_types): if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.DATE_DTYPE: return dtypes.DATETIME_DTYPE - if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE: + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( @@ -164,13 +164,13 @@ def output_type(self, *input_types): if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE: return dtypes.TIMEDELTA_DTYPE - if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE: return left_type if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.DATETIME_DTYPE - if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE: + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( @@ -193,9 +193,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): return dtypes.TIMEDELTA_DTYPE - if dtypes.is_numeric(left_type) and right_type is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_numeric(left_type) and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( @@ -217,10 +217,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): return dtypes.TIMEDELTA_DTYPE - if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE: + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.FLOAT_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( @@ -244,10 +244,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT left_type = input_types[0] right_type = input_types[1] - if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): + if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type): return dtypes.TIMEDELTA_DTYPE - if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE: + if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE: return dtypes.INT_DTYPE if (left_type is None or dtypes.is_numeric(left_type)) and ( diff --git a/bigframes/operations/timedelta_ops.py b/bigframes/operations/timedelta_ops.py index b831e3f864..5e9a1189e4 100644 --- a/bigframes/operations/timedelta_ops.py +++ b/bigframes/operations/timedelta_ops.py @@ -46,7 +46,7 @@ class TimedeltaFloorOp(base_ops.UnaryOp): def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: input_type = input_types[0] - if dtypes.is_numeric(input_type) or input_type is dtypes.TIMEDELTA_DTYPE: + if dtypes.is_numeric(input_type) or input_type == dtypes.TIMEDELTA_DTYPE: return dtypes.TIMEDELTA_DTYPE raise TypeError(f"unsupported type: {input_type}") @@ -62,11 +62,11 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT # timestamp + timedelta => timestamp if ( dtypes.is_datetime_like(input_types[0]) - and input_types[1] is dtypes.TIMEDELTA_DTYPE + and input_types[1] == dtypes.TIMEDELTA_DTYPE ): return input_types[0] # timedelta + timestamp => timestamp - if input_types[0] is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like( + if input_types[0] == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like( input_types[1] ): return input_types[1] @@ -87,7 +87,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT # timestamp - timedelta => timestamp if ( dtypes.is_datetime_like(input_types[0]) - and input_types[1] is dtypes.TIMEDELTA_DTYPE + and input_types[1] == dtypes.TIMEDELTA_DTYPE ): return input_types[0] diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 53cb5f7419..0cf394e454 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -58,7 +58,8 @@ def temporal_dfs(session): pd.Timedelta(-4, "m"), pd.Timedelta(6, "h"), ], - "numeric_col": [1.5, 2, -3], + "float_col": [1.5, 2, -3], + "int_col": [1, 2, -3], } ) @@ -92,10 +93,10 @@ def _assert_series_equal(actual: pd.Series, expected: pd.Series): (operator.sub, "timedelta_col_1", "timedelta_col_2"), (operator.truediv, "timedelta_col_1", "timedelta_col_2"), (operator.floordiv, "timedelta_col_1", "timedelta_col_2"), - (operator.truediv, "timedelta_col_1", "numeric_col"), - (operator.floordiv, "timedelta_col_1", "numeric_col"), - (operator.mul, "timedelta_col_1", "numeric_col"), - (operator.mul, "numeric_col", "timedelta_col_1"), + (operator.truediv, "timedelta_col_1", "float_col"), + (operator.floordiv, "timedelta_col_1", "float_col"), + (operator.mul, "timedelta_col_1", "float_col"), + (operator.mul, "float_col", "timedelta_col_1"), ], ) def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2): @@ -117,7 +118,7 @@ def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2): (operator.truediv, "timedelta_col_1", 3), (operator.floordiv, "timedelta_col_1", 3), (operator.mul, "timedelta_col_1", 3), - (operator.mul, "numeric_col", pd.Timedelta(1, "s")), + (operator.mul, "float_col", pd.Timedelta(1, "s")), ], ) def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal): @@ -136,10 +137,10 @@ def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal) (operator.sub, "timedelta_col_1", pd.Timedelta(2, "s")), (operator.truediv, "timedelta_col_1", pd.Timedelta(2, "s")), (operator.floordiv, "timedelta_col_1", pd.Timedelta(2, "s")), - (operator.truediv, "numeric_col", pd.Timedelta(2, "s")), - (operator.floordiv, "numeric_col", pd.Timedelta(2, "s")), + (operator.truediv, "float_col", pd.Timedelta(2, "s")), + (operator.floordiv, "float_col", pd.Timedelta(2, "s")), (operator.mul, "timedelta_col_1", 3), - (operator.mul, "numeric_col", pd.Timedelta(1, "s")), + (operator.mul, "float_col", pd.Timedelta(1, "s")), ], ) def test_timedelta_binary_ops_literal_and_series(temporal_dfs, op, col, literal): @@ -181,6 +182,16 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype) ) +@pytest.mark.parametrize("column", ["datetime_col", "timestamp_col"]) +def test_timestamp_add__ts_series_plus_td_series__explicit_cast(temporal_dfs, column): + bf_df, _ = temporal_dfs + dtype = pd.ArrowDtype(pa.duration("us")) + + actual_result = bf_df[column] + bf_df["int_col"].astype(dtype) + + assert len(actual_result) > 0 + + @pytest.mark.parametrize( "literal", [