diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index d6989d36a699..7801bf4cf4b6 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -4,7 +4,7 @@ and mypyc.irbuild.builder. """ -from typing import List, Optional, Union, Callable +from typing import List, Optional, Union, Callable, cast from mypy.nodes import ( Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr, @@ -13,7 +13,7 @@ SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr, Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS ) -from mypy.types import TupleType, get_proper_type +from mypy.types import TupleType, get_proper_type, Instance from mypyc.common import MAX_LITERAL_SHORT_INT from mypyc.ir.ops import ( @@ -406,8 +406,56 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: - # TODO: Don't produce an expression when used in conditional context + # x in (...)/[...] + # x not in (...)/[...] + if (e.operators[0] in ['in', 'not in'] + and len(e.operators) == 1 + and isinstance(e.operands[1], (TupleExpr, ListExpr))): + items = e.operands[1].items + n_items = len(items) + # x in y -> x == y[0] or ... or x == y[n] + # x not in y -> x != y[0] and ... and x != y[n] + # 16 is arbitrarily chosen to limit code size + if 1 < n_items < 16: + if e.operators[0] == 'in': + bin_op = 'or' + cmp_op = '==' + else: + bin_op = 'and' + cmp_op = '!=' + lhs = e.operands[0] + mypy_file = builder.graph['builtins'].tree + assert mypy_file is not None + bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node), []) + exprs = [] + for item in items: + expr = ComparisonExpr([cmp_op], [lhs, item]) + builder.types[expr] = bool_type + exprs.append(expr) + + or_expr = exprs.pop(0) # type: Expression + for expr in exprs: + or_expr = OpExpr(bin_op, or_expr, expr) + builder.types[or_expr] = bool_type + return builder.accept(or_expr) + # x in [y]/(y) -> x == y + # x not in [y]/(y) -> x != y + elif n_items == 1: + if e.operators[0] == 'in': + cmp_op = '==' + else: + cmp_op = '!=' + e.operators = [cmp_op] + e.operands[1] = items[0] + # x in []/() -> False + # x not in []/() -> True + elif n_items == 0: + if e.operators[0] == 'in': + return builder.false() + else: + return builder.true() + # TODO: Don't produce an expression when used in conditional context # All of the trickiness here is due to support for chained conditionals # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once. diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index e3ec3e732b0d..4e94441cf748 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -181,3 +181,67 @@ L2: r2 = CPySequenceTuple_GetItem(nt, 2) r3 = unbox(int, r2) return r3 + + +[case testTupleOperatorIn] +def f(i: int) -> bool: + return i in [1, 2, 3] +[out] +def f(i): + i :: int + r0, r1, r2 :: bool + r3 :: native_int + r4, r5, r6, r7 :: bool + r8 :: native_int + r9, r10, r11, r12 :: bool + r13 :: native_int + r14, r15, r16 :: bool +L0: + r3 = i & 1 + r4 = r3 == 0 + if r4 goto L1 else goto L2 :: bool +L1: + r5 = i == 2 + r2 = r5 + goto L3 +L2: + r6 = CPyTagged_IsEq_(i, 2) + r2 = r6 +L3: + if r2 goto L4 else goto L5 :: bool +L4: + r1 = r2 + goto L9 +L5: + r8 = i & 1 + r9 = r8 == 0 + if r9 goto L6 else goto L7 :: bool +L6: + r10 = i == 4 + r7 = r10 + goto L8 +L7: + r11 = CPyTagged_IsEq_(i, 4) + r7 = r11 +L8: + r1 = r7 +L9: + if r1 goto L10 else goto L11 :: bool +L10: + r0 = r1 + goto L15 +L11: + r13 = i & 1 + r14 = r13 == 0 + if r14 goto L12 else goto L13 :: bool +L12: + r15 = i == 6 + r12 = r15 + goto L14 +L13: + r16 = CPyTagged_IsEq_(i, 6) + r12 = r16 +L14: + r0 = r12 +L15: + return r0 diff --git a/mypyc/test-data/run-lists.test b/mypyc/test-data/run-lists.test index ed9cfa72c586..2f02be67d358 100644 --- a/mypyc/test-data/run-lists.test +++ b/mypyc/test-data/run-lists.test @@ -149,3 +149,122 @@ def test_slicing() -> None: assert s[1:long_int] == ["o", "o", "b", "a", "r"] assert s[long_int:] == [] assert s[-long_int:-1] == ["f", "o", "o", "b", "a"] + +[case testOperatorInExpression] + +def tuple_in_int0(i: int) -> bool: + return i in [] + +def tuple_in_int1(i: int) -> bool: + return i in (1,) + +def tuple_in_int3(i: int) -> bool: + return i in (1, 2, 3) + +def tuple_not_in_int0(i: int) -> bool: + return i not in [] + +def tuple_not_in_int1(i: int) -> bool: + return i not in (1,) + +def tuple_not_in_int3(i: int) -> bool: + return i not in (1, 2, 3) + +def tuple_in_str(s: "str") -> bool: + return s in ("foo", "bar", "baz") + +def tuple_not_in_str(s: "str") -> bool: + return s not in ("foo", "bar", "baz") + +def list_in_int0(i: int) -> bool: + return i in [] + +def list_in_int1(i: int) -> bool: + return i in (1,) + +def list_in_int3(i: int) -> bool: + return i in (1, 2, 3) + +def list_not_in_int0(i: int) -> bool: + return i not in [] + +def list_not_in_int1(i: int) -> bool: + return i not in (1,) + +def list_not_in_int3(i: int) -> bool: + return i not in (1, 2, 3) + +def list_in_str(s: "str") -> bool: + return s in ("foo", "bar", "baz") + +def list_not_in_str(s: "str") -> bool: + return s not in ("foo", "bar", "baz") + +def list_in_mixed(i: object): + return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type] + +[file driver.py] + +from native import * + +assert not tuple_in_int0(0) +assert not tuple_in_int1(0) +assert tuple_in_int1(1) +assert not tuple_in_int3(0) +assert tuple_in_int3(1) +assert tuple_in_int3(2) +assert tuple_in_int3(3) +assert not tuple_in_int3(4) + +assert tuple_not_in_int0(0) +assert tuple_not_in_int1(0) +assert not tuple_not_in_int1(1) +assert tuple_not_in_int3(0) +assert not tuple_not_in_int3(1) +assert not tuple_not_in_int3(2) +assert not tuple_not_in_int3(3) +assert tuple_not_in_int3(4) + +assert tuple_in_str("foo") +assert tuple_in_str("bar") +assert tuple_in_str("baz") +assert not tuple_in_str("apple") +assert not tuple_in_str("pie") +assert not tuple_in_str("\0") +assert not tuple_in_str("") + +assert not list_in_int0(0) +assert not list_in_int1(0) +assert list_in_int1(1) +assert not list_in_int3(0) +assert list_in_int3(1) +assert list_in_int3(2) +assert list_in_int3(3) +assert not list_in_int3(4) + +assert list_not_in_int0(0) +assert list_not_in_int1(0) +assert not list_not_in_int1(1) +assert list_not_in_int3(0) +assert not list_not_in_int3(1) +assert not list_not_in_int3(2) +assert not list_not_in_int3(3) +assert list_not_in_int3(4) + +assert list_in_str("foo") +assert list_in_str("bar") +assert list_in_str("baz") +assert not list_in_str("apple") +assert not list_in_str("pie") +assert not list_in_str("\0") +assert not list_in_str("") + +assert list_in_mixed(0) +assert list_in_mixed([]) +assert list_in_mixed({}) +assert list_in_mixed(()) +assert list_in_mixed(False) +assert list_in_mixed(0.0) +assert not list_in_mixed([1]) +assert not list_in_mixed(object) +assert list_in_mixed(type)