Skip to content

Commit 1e96f1d

Browse files
authored
[mypyc] Optimize truth value testing for strings (#10269)
Add a new primitive to check for empty strings. Related to mypyc/mypyc#768.
1 parent e46aa6d commit 1e96f1d

File tree

7 files changed

+84
-42
lines changed

7 files changed

+84
-42
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
)
6464
from mypyc.primitives.int_ops import int_comparison_op_mapping
6565
from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op
66-
from mypyc.primitives.str_ops import unicode_compare
66+
from mypyc.primitives.str_ops import unicode_compare, str_check_if_true
6767
from mypyc.primitives.set_ops import new_set_op
6868
from mypyc.rt_subtype import is_runtime_subtype
6969
from mypyc.subtype import is_subtype
@@ -958,6 +958,8 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
958958
zero = Integer(0, short_int_rprimitive)
959959
self.compare_tagged_condition(value, zero, '!=', true, false, value.line)
960960
return
961+
elif is_same_type(value.type, str_rprimitive):
962+
value = self.call_c(str_check_if_true, [value], value.line)
961963
elif is_same_type(value.type, list_rprimitive):
962964
length = self.builtin_len(value, value.line)
963965
zero = Integer(0)

mypyc/lib-rt/str_ops.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
102102
}
103103
return CPyObject_GetSlice(obj, start, end);
104104
}
105+
/* Check if the given string is true (i.e. it's length isn't zero) */
106+
bool CPyStr_IsTrue(PyObject *obj) {
107+
Py_ssize_t length = PyUnicode_GET_LENGTH(obj);
108+
return length != 0;
109+
}

mypyc/primitives/str_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
66
from mypyc.ir.rtypes import (
77
RType, object_rprimitive, str_rprimitive, int_rprimitive, list_rprimitive,
8-
c_int_rprimitive, pointer_rprimitive, bool_rprimitive
8+
c_int_rprimitive, pointer_rprimitive, bool_rprimitive, bit_rprimitive
99
)
1010
from mypyc.primitives.registry import (
1111
method_op, binary_op, function_op,
@@ -126,3 +126,11 @@
126126
return_type=str_rprimitive,
127127
c_function_name="CPyStr_Replace",
128128
error_kind=ERR_MAGIC)
129+
130+
# check if a string is true (isn't an empty string)
131+
str_check_if_true = custom_op(
132+
arg_types=[str_rprimitive],
133+
return_type=bit_rprimitive,
134+
c_function_name="CPyStr_IsTrue",
135+
error_kind=ERR_NEVER,
136+
)

mypyc/test-data/irbuild-basic.test

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -196,24 +196,20 @@ def f(x: object, y: object) -> str:
196196
def f(x, y):
197197
x, y :: object
198198
r0 :: str
199-
r1 :: int32
200-
r2 :: bit
201-
r3 :: bool
202-
r4, r5 :: str
199+
r1 :: bit
200+
r2, r3 :: str
203201
L0:
204202
r0 = PyObject_Str(x)
205-
r1 = PyObject_IsTrue(r0)
206-
r2 = r1 >= 0 :: signed
207-
r3 = truncate r1: int32 to builtins.bool
208-
if r3 goto L1 else goto L2 :: bool
203+
r1 = CPyStr_IsTrue(r0)
204+
if r1 goto L1 else goto L2 :: bool
209205
L1:
210-
r4 = r0
206+
r2 = r0
211207
goto L3
212208
L2:
213-
r5 = PyObject_Str(y)
214-
r4 = r5
209+
r3 = PyObject_Str(y)
210+
r2 = r3
215211
L3:
216-
return r4
212+
return r2
217213

218214
[case testOr]
219215
def f(x: int, y: int) -> int:
@@ -276,24 +272,20 @@ def f(x: object, y: object) -> str:
276272
def f(x, y):
277273
x, y :: object
278274
r0 :: str
279-
r1 :: int32
280-
r2 :: bit
281-
r3 :: bool
282-
r4, r5 :: str
275+
r1 :: bit
276+
r2, r3 :: str
283277
L0:
284278
r0 = PyObject_Str(x)
285-
r1 = PyObject_IsTrue(r0)
286-
r2 = r1 >= 0 :: signed
287-
r3 = truncate r1: int32 to builtins.bool
288-
if r3 goto L2 else goto L1 :: bool
279+
r1 = CPyStr_IsTrue(r0)
280+
if r1 goto L2 else goto L1 :: bool
289281
L1:
290-
r4 = r0
282+
r2 = r0
291283
goto L3
292284
L2:
293-
r5 = PyObject_Str(y)
294-
r4 = r5
285+
r3 = PyObject_Str(y)
286+
r2 = r3
295287
L3:
296-
return r4
288+
return r2
297289

298290
[case testSimpleNot]
299291
def f(x: int, y: int) -> int:

mypyc/test-data/irbuild-statements.test

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -720,28 +720,24 @@ def complex_msg(x, s):
720720
r0 :: object
721721
r1 :: bit
722722
r2 :: str
723-
r3 :: int32
724-
r4 :: bit
725-
r5 :: bool
726-
r6 :: object
727-
r7 :: str
728-
r8, r9 :: object
723+
r3 :: bit
724+
r4 :: object
725+
r5 :: str
726+
r6, r7 :: object
729727
L0:
730728
r0 = load_address _Py_NoneStruct
731729
r1 = x != r0
732730
if r1 goto L1 else goto L2 :: bool
733731
L1:
734732
r2 = cast(str, x)
735-
r3 = PyObject_IsTrue(r2)
736-
r4 = r3 >= 0 :: signed
737-
r5 = truncate r3: int32 to builtins.bool
738-
if r5 goto L3 else goto L2 :: bool
733+
r3 = CPyStr_IsTrue(r2)
734+
if r3 goto L3 else goto L2 :: bool
739735
L2:
740-
r6 = builtins :: module
741-
r7 = 'AssertionError'
742-
r8 = CPyObject_GetAttr(r6, r7)
743-
r9 = PyObject_CallFunctionObjArgs(r8, s, 0)
744-
CPy_Raise(r9)
736+
r4 = builtins :: module
737+
r5 = 'AssertionError'
738+
r6 = CPyObject_GetAttr(r4, r5)
739+
r7 = PyObject_CallFunctionObjArgs(r6, s, 0)
740+
CPy_Raise(r7)
745741
unreachable
746742
L3:
747743
return 1

mypyc/test-data/irbuild-str.test

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,24 @@ L4:
138138
return r6
139139
L5:
140140
unreachable
141-
141+
142+
[case testStrToBool]
143+
def is_true(x: str) -> bool:
144+
if x:
145+
return True
146+
else:
147+
return False
148+
[out]
149+
def is_true(x):
150+
x :: str
151+
r0 :: bit
152+
L0:
153+
r0 = CPyStr_IsTrue(x)
154+
if r0 goto L1 else goto L2 :: bool
155+
L1:
156+
return 1
157+
L2:
158+
return 0
159+
L3:
160+
unreachable
161+

mypyc/test-data/run-strings.test

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,22 @@ def test_str_replace() -> None:
159159
assert a.replace("foo", "bar", 4) == "barbarbar"
160160
assert a.replace("aaa", "bar") == "foofoofoo"
161161
assert a.replace("ofo", "xyzw") == "foxyzwxyzwo"
162+
163+
def is_true(x: str) -> bool:
164+
if x:
165+
return True
166+
else:
167+
return False
168+
169+
def is_false(x: str) -> bool:
170+
if not x:
171+
return True
172+
else:
173+
return False
174+
175+
def test_str_to_bool() -> None:
176+
assert is_false('')
177+
assert not is_true('')
178+
for x in 'a', 'foo', 'bar', 'some string':
179+
assert is_true(x)
180+
assert not is_false(x)

0 commit comments

Comments
 (0)