Skip to content

Commit 0db248b

Browse files
authored
fix: use == instead of is for timedelta type equality checks (#1480)
* fix: use instead of for timedelta type equality checks * use int column for casting
1 parent e720f41 commit 0db248b

File tree

5 files changed

+54
-43
lines changed

5 files changed

+54
-43
lines changed

bigframes/core/rewrite/timedeltas.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty
111111

112112

113113
def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedExpr:
114-
if expr.dtype is dtypes.TIMEDELTA_DTYPE:
114+
if expr.dtype == dtypes.TIMEDELTA_DTYPE:
115115
int_repr = utils.timedelta_to_micros(expr.value) # type: ignore
116116
return _TypedExpr(ex.const(int_repr, expr.dtype), expr.dtype)
117117

@@ -148,31 +148,31 @@ def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
148148
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
149149
return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right)
150150

151-
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
151+
if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
152152
return _TypedExpr.create_op_expr(ops.timestamp_sub_op, left, right)
153153

154154
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.DATE_DTYPE:
155155
return _TypedExpr.create_op_expr(ops.date_diff_op, left, right)
156156

157-
if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE:
157+
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
158158
return _TypedExpr.create_op_expr(ops.date_sub_op, left, right)
159159

160160
return _TypedExpr.create_op_expr(ops.sub_op, left, right)
161161

162162

163163
def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
164-
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
164+
if dtypes.is_datetime_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
165165
return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right)
166166

167-
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
167+
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
168168
# Re-arrange operands such that timestamp is always on the left and timedelta is
169169
# always on the right.
170170
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)
171171

172-
if left.dtype == dtypes.DATE_DTYPE and right.dtype is dtypes.TIMEDELTA_DTYPE:
172+
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
173173
return _TypedExpr.create_op_expr(ops.date_add_op, left, right)
174174

175-
if left.dtype is dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE:
175+
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.DATE_DTYPE:
176176
# Re-arrange operands such that date is always on the left and timedelta is
177177
# always on the right.
178178
return _TypedExpr.create_op_expr(ops.date_add_op, right, left)
@@ -183,9 +183,9 @@ def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
183183
def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
184184
result = _TypedExpr.create_op_expr(ops.mul_op, left, right)
185185

186-
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
186+
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
187187
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
188-
if dtypes.is_numeric(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
188+
if dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
189189
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
190190

191191
return result
@@ -194,7 +194,7 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
194194
def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
195195
result = _TypedExpr.create_op_expr(ops.div_op, left, right)
196196

197-
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
197+
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
198198
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
199199

200200
return result
@@ -203,14 +203,14 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
203203
def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
204204
result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right)
205205

206-
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
206+
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
207207
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
208208

209209
return result
210210

211211

212212
def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr):
213-
if arg.dtype is dtypes.TIMEDELTA_DTYPE:
213+
if arg.dtype == dtypes.TIMEDELTA_DTYPE:
214214
# Do nothing for values that are already timedeltas
215215
return arg
216216

@@ -239,19 +239,19 @@ def _rewrite_aggregation(
239239
aggs.DateSeriesDiffOp(aggregation.op.periods), aggregation.arg
240240
)
241241

242-
if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE:
242+
if isinstance(aggregation.op, aggs.StdOp) and input_type == dtypes.TIMEDELTA_DTYPE:
243243
return ex.UnaryAggregation(
244244
aggs.StdOp(should_floor_result=True), aggregation.arg
245245
)
246246

247-
if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE:
247+
if isinstance(aggregation.op, aggs.MeanOp) and input_type == dtypes.TIMEDELTA_DTYPE:
248248
return ex.UnaryAggregation(
249249
aggs.MeanOp(should_floor_result=True), aggregation.arg
250250
)
251251

252252
if (
253253
isinstance(aggregation.op, aggs.QuantileOp)
254-
and input_type is dtypes.TIMEDELTA_DTYPE
254+
and input_type == dtypes.TIMEDELTA_DTYPE
255255
):
256256
return ex.UnaryAggregation(
257257
aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True),

bigframes/operations/aggregations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class SumOp(UnaryAggregateOp):
142142
name: ClassVar[str] = "sum"
143143

144144
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
145-
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
145+
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
146146
return dtypes.TIMEDELTA_DTYPE
147147

148148
if dtypes.is_numeric(input_types[0]):
@@ -185,7 +185,7 @@ def order_independent(self) -> bool:
185185
return True
186186

187187
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
188-
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
188+
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
189189
return dtypes.TIMEDELTA_DTYPE
190190
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])
191191

@@ -233,7 +233,7 @@ class MeanOp(UnaryAggregateOp):
233233
should_floor_result: bool = False
234234

235235
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
236-
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
236+
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
237237
return dtypes.TIMEDELTA_DTYPE
238238
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])
239239

@@ -275,7 +275,7 @@ class StdOp(UnaryAggregateOp):
275275
should_floor_result: bool = False
276276

277277
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
278-
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
278+
if input_types[0] == dtypes.TIMEDELTA_DTYPE:
279279
return dtypes.TIMEDELTA_DTYPE
280280

281281
return signatures.FixedOutputType(

bigframes/operations/numeric_ops.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def output_type(self, *input_types):
124124
return input_types[0]
125125

126126
# Temporal addition.
127-
if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
127+
if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
128128
return left_type
129-
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type):
129+
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type):
130130
return right_type
131131

132132
if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
@@ -135,7 +135,7 @@ def output_type(self, *input_types):
135135
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.DATE_DTYPE:
136136
return dtypes.DATETIME_DTYPE
137137

138-
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
138+
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
139139
return dtypes.TIMEDELTA_DTYPE
140140

141141
if (left_type is None or dtypes.is_numeric(left_type)) and (
@@ -164,13 +164,13 @@ def output_type(self, *input_types):
164164
if left_type == dtypes.DATE_DTYPE and right_type == dtypes.DATE_DTYPE:
165165
return dtypes.TIMEDELTA_DTYPE
166166

167-
if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
167+
if dtypes.is_datetime_like(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
168168
return left_type
169169

170170
if left_type == dtypes.DATE_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
171171
return dtypes.DATETIME_DTYPE
172172

173-
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
173+
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
174174
return dtypes.TIMEDELTA_DTYPE
175175

176176
if (left_type is None or dtypes.is_numeric(left_type)) and (
@@ -193,9 +193,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
193193
left_type = input_types[0]
194194
right_type = input_types[1]
195195

196-
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
196+
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
197197
return dtypes.TIMEDELTA_DTYPE
198-
if dtypes.is_numeric(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
198+
if dtypes.is_numeric(left_type) and right_type == dtypes.TIMEDELTA_DTYPE:
199199
return dtypes.TIMEDELTA_DTYPE
200200

201201
if (left_type is None or dtypes.is_numeric(left_type)) and (
@@ -217,10 +217,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
217217
left_type = input_types[0]
218218
right_type = input_types[1]
219219

220-
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
220+
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
221221
return dtypes.TIMEDELTA_DTYPE
222222

223-
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
223+
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
224224
return dtypes.FLOAT_DTYPE
225225

226226
if (left_type is None or dtypes.is_numeric(left_type)) and (
@@ -244,10 +244,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
244244
left_type = input_types[0]
245245
right_type = input_types[1]
246246

247-
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
247+
if left_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
248248
return dtypes.TIMEDELTA_DTYPE
249249

250-
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
250+
if left_type == dtypes.TIMEDELTA_DTYPE and right_type == dtypes.TIMEDELTA_DTYPE:
251251
return dtypes.INT_DTYPE
252252

253253
if (left_type is None or dtypes.is_numeric(left_type)) and (

bigframes/operations/timedelta_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class TimedeltaFloorOp(base_ops.UnaryOp):
4646

4747
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
4848
input_type = input_types[0]
49-
if dtypes.is_numeric(input_type) or input_type is dtypes.TIMEDELTA_DTYPE:
49+
if dtypes.is_numeric(input_type) or input_type == dtypes.TIMEDELTA_DTYPE:
5050
return dtypes.TIMEDELTA_DTYPE
5151
raise TypeError(f"unsupported type: {input_type}")
5252

@@ -62,11 +62,11 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
6262
# timestamp + timedelta => timestamp
6363
if (
6464
dtypes.is_datetime_like(input_types[0])
65-
and input_types[1] is dtypes.TIMEDELTA_DTYPE
65+
and input_types[1] == dtypes.TIMEDELTA_DTYPE
6666
):
6767
return input_types[0]
6868
# timedelta + timestamp => timestamp
69-
if input_types[0] is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(
69+
if input_types[0] == dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(
7070
input_types[1]
7171
):
7272
return input_types[1]
@@ -87,7 +87,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
8787
# timestamp - timedelta => timestamp
8888
if (
8989
dtypes.is_datetime_like(input_types[0])
90-
and input_types[1] is dtypes.TIMEDELTA_DTYPE
90+
and input_types[1] == dtypes.TIMEDELTA_DTYPE
9191
):
9292
return input_types[0]
9393

tests/system/small/operations/test_timedeltas.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def temporal_dfs(session):
5858
pd.Timedelta(-4, "m"),
5959
pd.Timedelta(6, "h"),
6060
],
61-
"numeric_col": [1.5, 2, -3],
61+
"float_col": [1.5, 2, -3],
62+
"int_col": [1, 2, -3],
6263
}
6364
)
6465

@@ -92,10 +93,10 @@ def _assert_series_equal(actual: pd.Series, expected: pd.Series):
9293
(operator.sub, "timedelta_col_1", "timedelta_col_2"),
9394
(operator.truediv, "timedelta_col_1", "timedelta_col_2"),
9495
(operator.floordiv, "timedelta_col_1", "timedelta_col_2"),
95-
(operator.truediv, "timedelta_col_1", "numeric_col"),
96-
(operator.floordiv, "timedelta_col_1", "numeric_col"),
97-
(operator.mul, "timedelta_col_1", "numeric_col"),
98-
(operator.mul, "numeric_col", "timedelta_col_1"),
96+
(operator.truediv, "timedelta_col_1", "float_col"),
97+
(operator.floordiv, "timedelta_col_1", "float_col"),
98+
(operator.mul, "timedelta_col_1", "float_col"),
99+
(operator.mul, "float_col", "timedelta_col_1"),
99100
],
100101
)
101102
def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2):
@@ -117,7 +118,7 @@ def test_timedelta_binary_ops_between_series(temporal_dfs, op, col_1, col_2):
117118
(operator.truediv, "timedelta_col_1", 3),
118119
(operator.floordiv, "timedelta_col_1", 3),
119120
(operator.mul, "timedelta_col_1", 3),
120-
(operator.mul, "numeric_col", pd.Timedelta(1, "s")),
121+
(operator.mul, "float_col", pd.Timedelta(1, "s")),
121122
],
122123
)
123124
def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal):
@@ -136,10 +137,10 @@ def test_timedelta_binary_ops_series_and_literal(temporal_dfs, op, col, literal)
136137
(operator.sub, "timedelta_col_1", pd.Timedelta(2, "s")),
137138
(operator.truediv, "timedelta_col_1", pd.Timedelta(2, "s")),
138139
(operator.floordiv, "timedelta_col_1", pd.Timedelta(2, "s")),
139-
(operator.truediv, "numeric_col", pd.Timedelta(2, "s")),
140-
(operator.floordiv, "numeric_col", pd.Timedelta(2, "s")),
140+
(operator.truediv, "float_col", pd.Timedelta(2, "s")),
141+
(operator.floordiv, "float_col", pd.Timedelta(2, "s")),
141142
(operator.mul, "timedelta_col_1", 3),
142-
(operator.mul, "numeric_col", pd.Timedelta(1, "s")),
143+
(operator.mul, "float_col", pd.Timedelta(1, "s")),
143144
],
144145
)
145146
def test_timedelta_binary_ops_literal_and_series(temporal_dfs, op, col, literal):
@@ -181,6 +182,16 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype)
181182
)
182183

183184

185+
@pytest.mark.parametrize("column", ["datetime_col", "timestamp_col"])
186+
def test_timestamp_add__ts_series_plus_td_series__explicit_cast(temporal_dfs, column):
187+
bf_df, _ = temporal_dfs
188+
dtype = pd.ArrowDtype(pa.duration("us"))
189+
190+
actual_result = bf_df[column] + bf_df["int_col"].astype(dtype)
191+
192+
assert len(actual_result) > 0
193+
194+
184195
@pytest.mark.parametrize(
185196
"literal",
186197
[

0 commit comments

Comments
 (0)