Skip to content

fix: use == instead of is for timedelta type equality checks #1480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions bigframes/core/rewrite/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty


def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedExpr:
if expr.dtype is dtypes.TIMEDELTA_DTYPE:
if expr.dtype == dtypes.TIMEDELTA_DTYPE:
int_repr = utils.timedelta_to_micros(expr.value) # type: ignore
return _TypedExpr(ex.const(int_repr, expr.dtype), expr.dtype)

Expand Down Expand Up @@ -148,31 +148,31 @@ def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right)

if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
return _TypedExpr.create_op_expr(ops.timestamp_sub_op, left, right)

if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.DATE_DTYPE:
return _TypedExpr.create_op_expr(ops.date_diff_op, left, right)

if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE:
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
return _TypedExpr.create_op_expr(ops.date_sub_op, left, right)

return _TypedExpr.create_op_expr(ops.sub_op, left, right)


def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
# Re-arrange operands such that timestamp is always on the left and timedelta is
# always on the right.
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)

if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE:
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
return _TypedExpr.create_op_expr(ops.date_add_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE:
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE:
# Re-arrange operands such that date is always on the left and timedelta is
# always on the right.
return _TypedExpr.create_op_expr(ops.date_add_op, right, left)
Expand All @@ -183,9 +183,9 @@ def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result = _TypedExpr.create_op_expr(ops.mul_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
if dtypes.is_numeric(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)

return result
Expand All @@ -194,7 +194,7 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result = _TypedExpr.create_op_expr(ops.div_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)

return result
Expand All @@ -203,14 +203,14 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)

return result


def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr):
if arg.dtype is dtypes.TIMEDELTA_DTYPE:
if arg.dtype == dtypes.TIMEDELTA_DTYPE:
# Do nothing for values that are already timedeltas
return arg

Expand Down Expand Up @@ -239,19 +239,19 @@ def _rewrite_aggregation(
aggs.DateSeriesDiffOp(aggregation.op.periods), aggregation.arg
)

if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE:
if isinstance(aggregation.op, aggs.StdOp) and input_type == dtypes.TIMEDELTA_DTYPE:
return ex.UnaryAggregation(
aggs.StdOp(should_floor_result=True), aggregation.arg
)

if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE:
if isinstance(aggregation.op, aggs.MeanOp) and input_type == dtypes.TIMEDELTA_DTYPE:
return ex.UnaryAggregation(
aggs.MeanOp(should_floor_result=True), aggregation.arg
)

if (
isinstance(aggregation.op, aggs.QuantileOp)
and input_type is dtypes.TIMEDELTA_DTYPE
and input_type == dtypes.TIMEDELTA_DTYPE
):
return ex.UnaryAggregation(
aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True),
Expand Down
8 changes: 4 additions & 4 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SumOp(UnaryAggregateOp):
name: ClassVar[str] = "sum"

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE

if dtypes.is_numeric(input_types[0]):
Expand Down Expand Up @@ -185,7 +185,7 @@ def order_independent(self) -> bool:
return True

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])

Expand Down Expand Up @@ -233,7 +233,7 @@ class MeanOp(UnaryAggregateOp):
should_floor_result: bool = False

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])

Expand Down Expand Up @@ -275,7 +275,7 @@ class StdOp(UnaryAggregateOp):
should_floor_result: bool = False

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE

return signatures.FixedOutputType(
Expand Down
22 changes: 11 additions & 11 deletions bigframes/operations/numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def output_type(self, *input_types):
return input_types[0]

# Temporal addition.
if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
return left_type
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type):
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type):
return right_type

if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
Expand All @@ -135,7 +135,7 @@ def output_type(self, *input_types):
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.DATE_DTYPE:
return dtypes.DATETIME_DTYPE

if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE

if (left_type is None or dtypes.is_numeric(left_type)) and (
Expand Down Expand Up @@ -164,13 +164,13 @@ def output_type(self, *input_types):
if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE:
return dtypes.TIMEDELTA_DTYPE

if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
return left_type

if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.DATETIME_DTYPE

if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE

if (left_type is None or dtypes.is_numeric(left_type)) and (
Expand All @@ -193,9 +193,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
left_type = input_types[0]
right_type = input_types[1]

if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
return dtypes.TIMEDELTA_DTYPE
if dtypes.is_numeric(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_numeric(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE

if (left_type is None or dtypes.is_numeric(left_type)) and (
Expand All @@ -217,10 +217,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
left_type = input_types[0]
right_type = input_types[1]

if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
return dtypes.TIMEDELTA_DTYPE

if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.FLOAT_DTYPE

if (left_type is None or dtypes.is_numeric(left_type)) and (
Expand All @@ -244,10 +244,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
left_type = input_types[0]
right_type = input_types[1]

if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
return dtypes.TIMEDELTA_DTYPE

if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.INT_DTYPE

if (left_type is None or dtypes.is_numeric(left_type)) and (
Expand Down
8 changes: 4 additions & 4 deletions bigframes/operations/timedelta_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TimedeltaFloorOp(base_ops.UnaryOp):

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
input_type = input_types[0]
if dtypes.is_numeric(input_type) or input_type is dtypes.TIMEDELTA_DTYPE:
if dtypes.is_numeric(input_type) or input_type == dtypes.TIMEDELTA_DTYPE:
return dtypes.TIMEDELTA_DTYPE
raise TypeError(f"unsupported type: {input_type}")

Expand All @@ -62,11 +62,11 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
# timestamp + timedelta => timestamp
if (
dtypes.is_datetime_like(input_types[0])
and input_types[1] is dtypes.TIMEDELTA_DTYPE
and input_types[1] == dtypes.TIMEDELTA_DTYPE
):
return input_types[0]
# timedelta + timestamp => timestamp
if input_types[0] is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(
if input_types[0] == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(
input_types[1]
):
return input_types[1]
Expand All @@ -87,7 +87,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
# timestamp - timedelta => timestamp
if (
dtypes.is_datetime_like(input_types[0])
and input_types[1] is dtypes.TIMEDELTA_DTYPE
and input_types[1] == dtypes.TIMEDELTA_DTYPE
):
return input_types[0]

Expand Down
29 changes: 20 additions & 9 deletions tests/system/small/operations/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def temporal_dfs(session):
pd.Timedelta(-4, "m"),
pd.Timedelta(6, "h"),
],
"numeric_col": [1.5, 2, -3],
"float_col": [1.5, 2, -3],
"int_col": [1, 2, -3],
}
)

Expand Down Expand Up @@ -92,10 +93,10 @@ def _assert_series_equal(actual: pd.Series, expected: pd.Series):
(operator.sub, "timedelta_col_1", "timedelta_col_2"),
(operator.truediv, "timedelta_col_1", "timedelta_col_2"),
(operator.floordiv, "timedelta_col_1", "timedelta_col_2"),
(operator.truediv, "timedelta_col_1", "numeric_col"),
(operator.floordiv, "timedelta_col_1", "numeric_col"),
(operator.mul, "timedelta_col_1", "numeric_col"),
(operator.mul, "numeric_col", "timedelta_col_1"),
(operator.truediv, "timedelta_col_1", "float_col"),
(operator.floordiv, "timedelta_col_1", "float_col"),
(operator.mul, "timedelta_col_1", "float_col"),
(operator.mul, "float_col", "timedelta_col_1"),
],
)
def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2):
Expand All @@ -117,7 +118,7 @@ def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2):
(operator.truediv, "timedelta_col_1", 3),
(operator.floordiv, "timedelta_col_1", 3),
(operator.mul, "timedelta_col_1", 3),
(operator.mul, "numeric_col", pd.Timedelta(1, "s")),
(operator.mul, "float_col", pd.Timedelta(1, "s")),
],
)
def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal):
Expand All @@ -136,10 +137,10 @@ def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal)
(operator.sub, "timedelta_col_1", pd.Timedelta(2, "s")),
(operator.truediv, "timedelta_col_1", pd.Timedelta(2, "s")),
(operator.floordiv, "timedelta_col_1", pd.Timedelta(2, "s")),
(operator.truediv, "numeric_col", pd.Timedelta(2, "s")),
(operator.floordiv, "numeric_col", pd.Timedelta(2, "s")),
(operator.truediv, "float_col", pd.Timedelta(2, "s")),
(operator.floordiv, "float_col", pd.Timedelta(2, "s")),
(operator.mul, "timedelta_col_1", 3),
(operator.mul, "numeric_col", pd.Timedelta(1, "s")),
(operator.mul, "float_col", pd.Timedelta(1, "s")),
],
)
def test_timedelta_binary_ops_literal_and_series(temporal_dfs, op, col, literal):
Expand Down Expand Up @@ -181,6 +182,16 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype)
)


@pytest.mark.parametrize("column", ["datetime_col", "timestamp_col"])
def test_timestamp_add__ts_series_plus_td_series__explicit_cast(temporal_dfs, column):
bf_df, _ = temporal_dfs
dtype = pd.ArrowDtype(pa.duration("us"))

actual_result = bf_df[column] + bf_df["int_col"].astype(dtype)

assert len(actual_result) > 0


@pytest.mark.parametrize(
"literal",
[
Expand Down