@@ -111,7 +111,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty
111
111
112
112
113
113
def _rewrite_scalar_constant_expr (expr : ex .ScalarConstantExpression ) -> _TypedExpr :
114
- if expr .dtype is dtypes .TIMEDELTA_DTYPE :
114
+ if expr .dtype == dtypes .TIMEDELTA_DTYPE :
115
115
int_repr = utils .timedelta_to_micros (expr .value ) # type: ignore
116
116
return _TypedExpr (ex .const (int_repr , expr .dtype ), expr .dtype )
117
117
@@ -148,31 +148,31 @@ def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
148
148
if dtypes .is_datetime_like (left .dtype ) and dtypes .is_datetime_like (right .dtype ):
149
149
return _TypedExpr .create_op_expr (ops .timestamp_diff_op , left , right )
150
150
151
- if dtypes .is_datetime_like (left .dtype ) and right .dtype is dtypes .TIMEDELTA_DTYPE :
151
+ if dtypes .is_datetime_like (left .dtype ) and right .dtype == dtypes .TIMEDELTA_DTYPE :
152
152
return _TypedExpr .create_op_expr (ops .timestamp_sub_op , left , right )
153
153
154
154
if left .dtype == dtypes .DATE_DTYPE and right .dtype == dtypes .DATE_DTYPE :
155
155
return _TypedExpr .create_op_expr (ops .date_diff_op , left , right )
156
156
157
- if left .dtype == dtypes .DATE_DTYPE and right .dtype is dtypes .TIMEDELTA_DTYPE :
157
+ if left .dtype == dtypes .DATE_DTYPE and right .dtype == dtypes .TIMEDELTA_DTYPE :
158
158
return _TypedExpr .create_op_expr (ops .date_sub_op , left , right )
159
159
160
160
return _TypedExpr .create_op_expr (ops .sub_op , left , right )
161
161
162
162
163
163
def _rewrite_add_op (left : _TypedExpr , right : _TypedExpr ) -> _TypedExpr :
164
- if dtypes .is_datetime_like (left .dtype ) and right .dtype is dtypes .TIMEDELTA_DTYPE :
164
+ if dtypes .is_datetime_like (left .dtype ) and right .dtype == dtypes .TIMEDELTA_DTYPE :
165
165
return _TypedExpr .create_op_expr (ops .timestamp_add_op , left , right )
166
166
167
- if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_datetime_like (right .dtype ):
167
+ if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_datetime_like (right .dtype ):
168
168
# Re-arrange operands such that timestamp is always on the left and timedelta is
169
169
# always on the right.
170
170
return _TypedExpr .create_op_expr (ops .timestamp_add_op , right , left )
171
171
172
- if left .dtype == dtypes .DATE_DTYPE and right .dtype is dtypes .TIMEDELTA_DTYPE :
172
+ if left .dtype == dtypes .DATE_DTYPE and right .dtype == dtypes .TIMEDELTA_DTYPE :
173
173
return _TypedExpr .create_op_expr (ops .date_add_op , left , right )
174
174
175
- if left .dtype is dtypes .TIMEDELTA_DTYPE and right .dtype == dtypes .DATE_DTYPE :
175
+ if left .dtype == dtypes .TIMEDELTA_DTYPE and right .dtype == dtypes .DATE_DTYPE :
176
176
# Re-arrange operands such that date is always on the left and timedelta is
177
177
# always on the right.
178
178
return _TypedExpr .create_op_expr (ops .date_add_op , right , left )
@@ -183,9 +183,9 @@ def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
183
183
def _rewrite_mul_op (left : _TypedExpr , right : _TypedExpr ) -> _TypedExpr :
184
184
result = _TypedExpr .create_op_expr (ops .mul_op , left , right )
185
185
186
- if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
186
+ if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
187
187
return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
188
- if dtypes .is_numeric (left .dtype ) and right .dtype is dtypes .TIMEDELTA_DTYPE :
188
+ if dtypes .is_numeric (left .dtype ) and right .dtype == dtypes .TIMEDELTA_DTYPE :
189
189
return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
190
190
191
191
return result
@@ -194,7 +194,7 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
194
194
def _rewrite_div_op (left : _TypedExpr , right : _TypedExpr ) -> _TypedExpr :
195
195
result = _TypedExpr .create_op_expr (ops .div_op , left , right )
196
196
197
- if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
197
+ if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
198
198
return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
199
199
200
200
return result
@@ -203,14 +203,14 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
203
203
def _rewrite_floordiv_op (left : _TypedExpr , right : _TypedExpr ) -> _TypedExpr :
204
204
result = _TypedExpr .create_op_expr (ops .floordiv_op , left , right )
205
205
206
- if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
206
+ if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
207
207
return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
208
208
209
209
return result
210
210
211
211
212
212
def _rewrite_to_timedelta_op (op : ops .ToTimedeltaOp , arg : _TypedExpr ):
213
- if arg .dtype is dtypes .TIMEDELTA_DTYPE :
213
+ if arg .dtype == dtypes .TIMEDELTA_DTYPE :
214
214
# Do nothing for values that are already timedeltas
215
215
return arg
216
216
@@ -239,19 +239,19 @@ def _rewrite_aggregation(
239
239
aggs .DateSeriesDiffOp (aggregation .op .periods ), aggregation .arg
240
240
)
241
241
242
- if isinstance (aggregation .op , aggs .StdOp ) and input_type is dtypes .TIMEDELTA_DTYPE :
242
+ if isinstance (aggregation .op , aggs .StdOp ) and input_type == dtypes .TIMEDELTA_DTYPE :
243
243
return ex .UnaryAggregation (
244
244
aggs .StdOp (should_floor_result = True ), aggregation .arg
245
245
)
246
246
247
- if isinstance (aggregation .op , aggs .MeanOp ) and input_type is dtypes .TIMEDELTA_DTYPE :
247
+ if isinstance (aggregation .op , aggs .MeanOp ) and input_type == dtypes .TIMEDELTA_DTYPE :
248
248
return ex .UnaryAggregation (
249
249
aggs .MeanOp (should_floor_result = True ), aggregation .arg
250
250
)
251
251
252
252
if (
253
253
isinstance (aggregation .op , aggs .QuantileOp )
254
- and input_type is dtypes .TIMEDELTA_DTYPE
254
+ and input_type == dtypes .TIMEDELTA_DTYPE
255
255
):
256
256
return ex .UnaryAggregation (
257
257
aggs .QuantileOp (q = aggregation .op .q , should_floor_result = True ),
0 commit comments