Skip to content

[mypyc] Speed up in operations for list/tuple #9004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and mypyc.irbuild.builder.
"""

from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, cast

from mypy.nodes import (
Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr,
Expand All @@ -13,7 +13,7 @@
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS
)
from mypy.types import TupleType, get_proper_type
from mypy.types import TupleType, get_proper_type, Instance

from mypyc.common import MAX_LITERAL_SHORT_INT
from mypyc.ir.ops import (
Expand Down Expand Up @@ -406,8 +406,56 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val


def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
# TODO: Don't produce an expression when used in conditional context
# x in (...)/[...]
# x not in (...)/[...]
if (e.operators[0] in ['in', 'not in']
and len(e.operators) == 1
and isinstance(e.operands[1], (TupleExpr, ListExpr))):
items = e.operands[1].items
n_items = len(items)
# x in y -> x == y[0] or ... or x == y[n]
# x not in y -> x != y[0] and ... and x != y[n]
# 16 is arbitrarily chosen to limit code size
if 1 < n_items < 16:
if e.operators[0] == 'in':
bin_op = 'or'
cmp_op = '=='
else:
bin_op = 'and'
cmp_op = '!='
lhs = e.operands[0]
mypy_file = builder.graph['builtins'].tree
assert mypy_file is not None
bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node), [])
exprs = []
for item in items:
expr = ComparisonExpr([cmp_op], [lhs, item])
builder.types[expr] = bool_type
exprs.append(expr)

or_expr = exprs.pop(0) # type: Expression
for expr in exprs:
or_expr = OpExpr(bin_op, or_expr, expr)
builder.types[or_expr] = bool_type
return builder.accept(or_expr)
# x in [y]/(y) -> x == y
# x not in [y]/(y) -> x != y
elif n_items == 1:
if e.operators[0] == 'in':
cmp_op = '=='
else:
cmp_op = '!='
e.operators = [cmp_op]
e.operands[1] = items[0]
# x in []/() -> False
# x not in []/() -> True
elif n_items == 0:
if e.operators[0] == 'in':
return builder.false()
else:
return builder.true()

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
Expand Down
64 changes: 64 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,67 @@ L2:
r2 = CPySequenceTuple_GetItem(nt, 2)
r3 = unbox(int, r2)
return r3


[case testTupleOperatorIn]
def f(i: int) -> bool:
return i in [1, 2, 3]
[out]
def f(i):
i :: int
r0, r1, r2 :: bool
r3 :: native_int
r4, r5, r6, r7 :: bool
r8 :: native_int
r9, r10, r11, r12 :: bool
r13 :: native_int
r14, r15, r16 :: bool
L0:
r3 = i & 1
r4 = r3 == 0
if r4 goto L1 else goto L2 :: bool
L1:
r5 = i == 2
r2 = r5
goto L3
L2:
r6 = CPyTagged_IsEq_(i, 2)
r2 = r6
L3:
if r2 goto L4 else goto L5 :: bool
L4:
r1 = r2
goto L9
L5:
r8 = i & 1
r9 = r8 == 0
if r9 goto L6 else goto L7 :: bool
L6:
r10 = i == 4
r7 = r10
goto L8
L7:
r11 = CPyTagged_IsEq_(i, 4)
r7 = r11
L8:
r1 = r7
L9:
if r1 goto L10 else goto L11 :: bool
L10:
r0 = r1
goto L15
L11:
r13 = i & 1
r14 = r13 == 0
if r14 goto L12 else goto L13 :: bool
L12:
r15 = i == 6
r12 = r15
goto L14
L13:
r16 = CPyTagged_IsEq_(i, 6)
r12 = r16
L14:
r0 = r12
L15:
return r0
119 changes: 119 additions & 0 deletions mypyc/test-data/run-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,122 @@ def test_slicing() -> None:
assert s[1:long_int] == ["o", "o", "b", "a", "r"]
assert s[long_int:] == []
assert s[-long_int:-1] == ["f", "o", "o", "b", "a"]

[case testOperatorInExpression]

def tuple_in_int0(i: int) -> bool:
return i in []

def tuple_in_int1(i: int) -> bool:
return i in (1,)

def tuple_in_int3(i: int) -> bool:
return i in (1, 2, 3)

def tuple_not_in_int0(i: int) -> bool:
return i not in []

def tuple_not_in_int1(i: int) -> bool:
return i not in (1,)

def tuple_not_in_int3(i: int) -> bool:
return i not in (1, 2, 3)

def tuple_in_str(s: "str") -> bool:
return s in ("foo", "bar", "baz")

def tuple_not_in_str(s: "str") -> bool:
return s not in ("foo", "bar", "baz")

def list_in_int0(i: int) -> bool:
return i in []

def list_in_int1(i: int) -> bool:
return i in (1,)

def list_in_int3(i: int) -> bool:
return i in (1, 2, 3)

def list_not_in_int0(i: int) -> bool:
return i not in []

def list_not_in_int1(i: int) -> bool:
return i not in (1,)

def list_not_in_int3(i: int) -> bool:
return i not in (1, 2, 3)

def list_in_str(s: "str") -> bool:
return s in ("foo", "bar", "baz")

def list_not_in_str(s: "str") -> bool:
return s not in ("foo", "bar", "baz")

def list_in_mixed(i: object):
return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type]

[file driver.py]

from native import *

assert not tuple_in_int0(0)
assert not tuple_in_int1(0)
assert tuple_in_int1(1)
assert not tuple_in_int3(0)
assert tuple_in_int3(1)
assert tuple_in_int3(2)
assert tuple_in_int3(3)
assert not tuple_in_int3(4)

assert tuple_not_in_int0(0)
assert tuple_not_in_int1(0)
assert not tuple_not_in_int1(1)
assert tuple_not_in_int3(0)
assert not tuple_not_in_int3(1)
assert not tuple_not_in_int3(2)
assert not tuple_not_in_int3(3)
assert tuple_not_in_int3(4)

assert tuple_in_str("foo")
assert tuple_in_str("bar")
assert tuple_in_str("baz")
assert not tuple_in_str("apple")
assert not tuple_in_str("pie")
assert not tuple_in_str("\0")
assert not tuple_in_str("")

assert not list_in_int0(0)
assert not list_in_int1(0)
assert list_in_int1(1)
assert not list_in_int3(0)
assert list_in_int3(1)
assert list_in_int3(2)
assert list_in_int3(3)
assert not list_in_int3(4)

assert list_not_in_int0(0)
assert list_not_in_int1(0)
assert not list_not_in_int1(1)
assert list_not_in_int3(0)
assert not list_not_in_int3(1)
assert not list_not_in_int3(2)
assert not list_not_in_int3(3)
assert list_not_in_int3(4)

assert list_in_str("foo")
assert list_in_str("bar")
assert list_in_str("baz")
assert not list_in_str("apple")
assert not list_in_str("pie")
assert not list_in_str("\0")
assert not list_in_str("")

assert list_in_mixed(0)
assert list_in_mixed([])
assert list_in_mixed({})
assert list_in_mixed(())
assert list_in_mixed(False)
assert list_in_mixed(0.0)
assert not list_in_mixed([1])
assert not list_in_mixed(object)
assert list_in_mixed(type)