@@ -37,26 +37,241 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
37
37
return expr .op .as_expr (larg , rarg )
38
38
39
39
40
+ class LowerAddRule (op_lowering .OpLoweringRule ):
41
+ @property
42
+ def op (self ) -> type [ops .ScalarOp ]:
43
+ return numeric_ops .AddOp
44
+
45
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
46
+ assert isinstance (expr .op , numeric_ops .AddOp )
47
+ larg , rarg = expr .children [0 ], expr .children [1 ]
48
+
49
+ if (
50
+ larg .output_type == dtypes .BOOL_DTYPE
51
+ and rarg .output_type == dtypes .BOOL_DTYPE
52
+ ):
53
+ int_result = expr .op .as_expr (
54
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
55
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
56
+ )
57
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
58
+
59
+ if dtypes .is_string_like (larg .output_type ) and dtypes .is_string_like (
60
+ rarg .output_type
61
+ ):
62
+ return ops .strconcat_op .as_expr (larg , rarg )
63
+
64
+ if larg .output_type == dtypes .BOOL_DTYPE :
65
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
66
+ if rarg .output_type == dtypes .BOOL_DTYPE :
67
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
68
+
69
+ if (
70
+ larg .output_type == dtypes .DATE_DTYPE
71
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
72
+ ):
73
+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
74
+
75
+ if (
76
+ larg .output_type == dtypes .TIMEDELTA_DTYPE
77
+ and rarg .output_type == dtypes .DATE_DTYPE
78
+ ):
79
+ rarg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (rarg )
80
+
81
+ return expr .op .as_expr (larg , rarg )
82
+
83
+
84
+ class LowerSubRule (op_lowering .OpLoweringRule ):
85
+ @property
86
+ def op (self ) -> type [ops .ScalarOp ]:
87
+ return numeric_ops .SubOp
88
+
89
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
90
+ assert isinstance (expr .op , numeric_ops .SubOp )
91
+ larg , rarg = expr .children [0 ], expr .children [1 ]
92
+
93
+ if (
94
+ larg .output_type == dtypes .BOOL_DTYPE
95
+ and rarg .output_type == dtypes .BOOL_DTYPE
96
+ ):
97
+ int_result = expr .op .as_expr (
98
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
99
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
100
+ )
101
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
102
+
103
+ if larg .output_type == dtypes .BOOL_DTYPE :
104
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
105
+ if rarg .output_type == dtypes .BOOL_DTYPE :
106
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
107
+
108
+ if (
109
+ larg .output_type == dtypes .DATE_DTYPE
110
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
111
+ ):
112
+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
113
+
114
+ return expr .op .as_expr (larg , rarg )
115
+
116
+
117
+ @dataclasses .dataclass
118
+ class LowerMulRule (op_lowering .OpLoweringRule ):
119
+ @property
120
+ def op (self ) -> type [ops .ScalarOp ]:
121
+ return numeric_ops .MulOp
122
+
123
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
124
+ assert isinstance (expr .op , numeric_ops .MulOp )
125
+ larg , rarg = expr .children [0 ], expr .children [1 ]
126
+
127
+ if (
128
+ larg .output_type == dtypes .BOOL_DTYPE
129
+ and rarg .output_type == dtypes .BOOL_DTYPE
130
+ ):
131
+ int_result = expr .op .as_expr (
132
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
133
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
134
+ )
135
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
136
+
137
+ if (
138
+ larg .output_type == dtypes .BOOL_DTYPE
139
+ and rarg .output_type != dtypes .BOOL_DTYPE
140
+ ):
141
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
142
+ if (
143
+ rarg .output_type == dtypes .BOOL_DTYPE
144
+ and larg .output_type != dtypes .BOOL_DTYPE
145
+ ):
146
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
147
+
148
+ return expr .op .as_expr (larg , rarg )
149
+
150
+
151
+ class LowerDivRule (op_lowering .OpLoweringRule ):
152
+ @property
153
+ def op (self ) -> type [ops .ScalarOp ]:
154
+ return numeric_ops .DivOp
155
+
156
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
157
+ assert isinstance (expr .op , numeric_ops .DivOp )
158
+
159
+ dividend = expr .children [0 ]
160
+ divisor = expr .children [1 ]
161
+
162
+ if (
163
+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
164
+ and divisor .output_type == dtypes .INT_DTYPE
165
+ ):
166
+ int_result = expr .op .as_expr (
167
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
168
+ )
169
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
170
+
171
+ if (
172
+ dividend .output_type == dtypes .BOOL_DTYPE
173
+ and divisor .output_type == dtypes .BOOL_DTYPE
174
+ ):
175
+ int_result = expr .op .as_expr (
176
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
177
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
178
+ )
179
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
180
+
181
+ # polars divide doesn't like bools, convert to int always
182
+ # convert numerics to float always
183
+ if dividend .output_type == dtypes .BOOL_DTYPE :
184
+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
185
+ elif dividend .output_type in (dtypes .BIGNUMERIC_DTYPE , dtypes .NUMERIC_DTYPE ):
186
+ dividend = ops .AsTypeOp (to_type = dtypes .FLOAT_DTYPE ).as_expr (dividend )
187
+ if divisor .output_type == dtypes .BOOL_DTYPE :
188
+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
189
+
190
+ return numeric_ops .div_op .as_expr (dividend , divisor )
191
+
192
+
40
193
class LowerFloorDivRule (op_lowering .OpLoweringRule ):
41
194
@property
42
195
def op (self ) -> type [ops .ScalarOp ]:
43
196
return numeric_ops .FloorDivOp
44
197
45
198
def lower (self , expr : expression .OpExpression ) -> expression .Expression :
199
+ assert isinstance (expr .op , numeric_ops .FloorDivOp )
200
+
46
201
dividend = expr .children [0 ]
47
202
divisor = expr .children [1 ]
48
- using_floats = (dividend .output_type == dtypes .FLOAT_DTYPE ) or (
49
- divisor .output_type == dtypes .FLOAT_DTYPE
50
- )
51
- inf_or_zero = (
52
- expression .const (float ("INF" )) if using_floats else expression .const (0 )
203
+
204
+ if (
205
+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
206
+ and divisor .output_type == dtypes .INT_DTYPE
207
+ ):
208
+ int_result = expr .op .as_expr (
209
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
210
+ )
211
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
212
+
213
+ if dividend .output_type == dtypes .BOOL_DTYPE :
214
+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
215
+ if divisor .output_type == dtypes .BOOL_DTYPE :
216
+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
217
+
218
+ return numeric_ops .floor_op .as_expr (
219
+ numeric_ops .div_op .as_expr (dividend , divisor )
53
220
)
54
- zero_result = ops .mul_op .as_expr (inf_or_zero , dividend )
55
- divisor_is_zero = ops .eq_op .as_expr (divisor , expression .const (0 ))
56
- return ops .where_op .as_expr (zero_result , divisor_is_zero , expr )
57
221
58
222
59
- def _coerce_comparables (expr1 : expression .Expression , expr2 : expression .Expression ):
223
+ class LowerModRule (op_lowering .OpLoweringRule ):
224
+ @property
225
+ def op (self ) -> type [ops .ScalarOp ]:
226
+ return numeric_ops .ModOp
227
+
228
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
229
+ og_expr = expr
230
+ assert isinstance (expr .op , numeric_ops .ModOp )
231
+ larg , rarg = expr .children [0 ], expr .children [1 ]
232
+
233
+ if (
234
+ larg .output_type == dtypes .TIMEDELTA_DTYPE
235
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
236
+ ):
237
+ larg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
238
+ rarg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
239
+ int_result = expr .op .as_expr (larg_int , rarg_int )
240
+ w_zero_handling = ops .where_op .as_expr (
241
+ int_result ,
242
+ ops .ne_op .as_expr (rarg_int , expression .const (0 )),
243
+ ops .mul_op .as_expr (rarg_int , expression .const (0 )),
244
+ )
245
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (w_zero_handling )
246
+
247
+ if larg .output_type == dtypes .BOOL_DTYPE :
248
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
249
+ if rarg .output_type == dtypes .BOOL_DTYPE :
250
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
251
+
252
+ wo_bools = expr .op .as_expr (larg , rarg )
253
+
254
+ if og_expr .output_type == dtypes .INT_DTYPE :
255
+ return ops .where_op .as_expr (
256
+ wo_bools ,
257
+ ops .ne_op .as_expr (rarg , expression .const (0 )),
258
+ ops .mul_op .as_expr (rarg , expression .const (0 )),
259
+ )
260
+ return wo_bools
261
+
262
+
263
+ def _coerce_comparables (
264
+ expr1 : expression .Expression ,
265
+ expr2 : expression .Expression ,
266
+ * ,
267
+ bools_only : bool = False
268
+ ):
269
+ if bools_only :
270
+ if (
271
+ expr1 .output_type != dtypes .BOOL_DTYPE
272
+ and expr2 .output_type != dtypes .BOOL_DTYPE
273
+ ):
274
+ return expr1 , expr2
60
275
61
276
target_type = dtypes .coerce_to_common (expr1 .output_type , expr2 .output_type )
62
277
if expr1 .output_type != target_type :
@@ -90,7 +305,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90
305
91
306
POLARS_LOWERING_RULES = (
92
307
* LOWER_COMPARISONS ,
308
+ LowerAddRule (),
309
+ LowerSubRule (),
310
+ LowerMulRule (),
311
+ LowerDivRule (),
93
312
LowerFloorDivRule (),
313
+ LowerModRule (),
94
314
)
95
315
96
316
0 commit comments