diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 391c7a59d60b..ee99d4f9187e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -35,7 +35,7 @@ ParamSpecExpr, ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, ) -from mypy.literals import literal +from mypy.literals import literal, try_literal_math from mypy import nodes from mypy import operators import mypy.checker @@ -2570,6 +2570,27 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: any_type = AnyType(TypeOfAny.from_another_any, source_any=right_type) return any_type, any_type + # STEP 0: + # We support `Literal` type math. For example, we want to reveal `1 + 3` + # as `Literal[4]`. So, we check if we have literal expressions first. + # We consider this to be the fast path, we move on if it is not a literal. + # But, operations on literal types are not processed further. + + if isinstance(context, OpExpr) and isinstance(left_type, (LiteralType, Instance)): + fallback_left_type = ( + left_type.fallback + if isinstance(left_type, LiteralType) + else left_type + ) + literal_result = try_literal_math( + context.op, + left_expr, left_type, + right_expr, right_type, + fallback=fallback_left_type, + ) + if literal_result is not None: + return literal_result, literal_result + # STEP 1: # We start by getting the __op__ and __rop__ methods, if they exist. diff --git a/mypy/literals.py b/mypy/literals.py index 00cf5916bec2..904a12b15d43 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -1,6 +1,8 @@ +import operator from typing import Optional, Union, Any, Tuple, Iterable from typing_extensions import Final +from mypy.types import Type, LiteralType, Instance, LiteralValue, get_proper_type from mypy.nodes import ( Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES, LITERAL_NO, NameExpr, LITERAL_TYPE, IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr, @@ -246,3 +248,66 @@ def visit_temp_node(self, e: TempNode) -> None: _hasher: Final = _Hasher() + + +_SUPPORTED_LITERAL_OPERATIONS: Final = { + int: ('+', '-', '*', '//'), # `/` returns `float` + str: ('+',), + bool: ('and', 'or'), +} +_OP_FUNCTIONS: Final = { + '+': operator.add, + '-': operator.sub, + '*': operator.mul, + '//': operator.floordiv, + 'and': operator.and_, + 'or': operator.or_, +} + + +def try_literal_math( + op: str, + left_expr: Expression, left_type: Type, + right_expr: Expression, right_type: Type, + *, + fallback: Instance, +) -> Optional[Instance]: + left_literal = _get_literal_value(left_expr, left_type) + if left_literal is None: + return None + right_literal = _get_literal_value(right_expr, right_type) + if right_literal is None: + return None + + lit_type = type(left_literal) + if (lit_type != type(right_literal) + or lit_type not in _SUPPORTED_LITERAL_OPERATIONS + or op not in _SUPPORTED_LITERAL_OPERATIONS[lit_type]): + return None + + op_method = _OP_FUNCTIONS[op] + try: + new_value = op_method(left_literal, right_literal) + except Exception: # We catch any possible problem: overflow, type error, etc. + return None + else: + return fallback.copy_modified(last_known_value=LiteralType( + new_value, + fallback=fallback, + )) + + +def _get_literal_value(expr: Expression, typ: Type) -> Optional[LiteralValue]: + # We can work with a literal type: + typ = get_proper_type(typ) + if isinstance(typ, LiteralType): + return typ.value + elif isinstance(typ, Instance) and typ.last_known_value: + return typ.last_known_value.value + + # Or a literal node (`True` / `False` are already `Literal[True] | [False]`): + if isinstance(expr, (IntExpr, StrExpr)): + return expr.value + + # It is not a literal: + return None diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index e3ec55c516db..873f6b77711a 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -617,7 +617,7 @@ class B: pass [case testInferLambdaTypeUsingContext] x : str = (lambda x: x + 1)(1) # E: Incompatible types in assignment (expression has type "int", variable has type "str") -reveal_type((lambda x, y: x + y)(1, 2)) # N: Revealed type is "builtins.int" +reveal_type((lambda x, y: x + y)(1, 2)) # N: Revealed type is "Literal[3]?" (lambda x, y: x + y)(1, "") # E: Unsupported operand types for + ("int" and "str") (lambda *, x, y: x + y)(x=1, y="") # E: Unsupported operand types for + ("int" and "str") reveal_type((lambda s, i: s)(i=0, s='x')) # N: Revealed type is "Literal['x']?" diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 37ae12419151..45775f36494a 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1745,19 +1745,19 @@ c: Literal[4] d: Literal['foo'] e: str -reveal_type(a + a) # N: Revealed type is "builtins.int" +reveal_type(a + a) # N: Revealed type is "Literal[6]?" reveal_type(a + b) # N: Revealed type is "builtins.int" reveal_type(b + a) # N: Revealed type is "builtins.int" -reveal_type(a + 1) # N: Revealed type is "builtins.int" -reveal_type(1 + a) # N: Revealed type is "builtins.int" -reveal_type(a + c) # N: Revealed type is "builtins.int" -reveal_type(c + a) # N: Revealed type is "builtins.int" +reveal_type(a + 1) # N: Revealed type is "Literal[4]?" +reveal_type(1 + a) # N: Revealed type is "Literal[4]?" +reveal_type(a + c) # N: Revealed type is "Literal[7]?" +reveal_type(c + a) # N: Revealed type is "Literal[7]?" -reveal_type(d + d) # N: Revealed type is "builtins.str" +reveal_type(d + d) # N: Revealed type is "Literal['foofoo']?" reveal_type(d + e) # N: Revealed type is "builtins.str" reveal_type(e + d) # N: Revealed type is "builtins.str" -reveal_type(d + 'foo') # N: Revealed type is "builtins.str" -reveal_type('foo' + d) # N: Revealed type is "builtins.str" +reveal_type(d + 'foo') # N: Revealed type is "Literal['foofoo']?" +reveal_type('foo' + d) # N: Revealed type is "Literal['foofoo']?" reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int" reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int" @@ -3346,3 +3346,167 @@ def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False] else: return (bool(), 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") [builtins fixtures/bool.pyi] + + +# Literal math +# ============ + +[case testLiteralIntMath] +from typing_extensions import Literal, Final + +reveal_type(1 + 2) # N: Revealed type is "Literal[3]?" +reveal_type(2 + 1) # N: Revealed type is "Literal[3]?" + +reveal_type(2 // 2) # N: Revealed type is "Literal[1]?" +reveal_type(5 // 2) # N: Revealed type is "Literal[2]?" + +reveal_type(1 + 2 + 3) # N: Revealed type is "Literal[6]?" +reveal_type(2 + 2 * 2) # N: Revealed type is "Literal[6]?" +reveal_type(2 * 2 + 2) # N: Revealed type is "Literal[6]?" +reveal_type(100 - 2 + 1) # N: Revealed type is "Literal[99]?" + +a: Literal[3] +b: Literal[4] +c: Final = 5 + +reveal_type(a + b) # N: Revealed type is "Literal[7]?" +reveal_type(a + 1 + b + 1) # N: Revealed type is "Literal[9]?" +reveal_type(1 + a + 1 + b) # N: Revealed type is "Literal[9]?" +reveal_type(a + c) # N: Revealed type is "Literal[8]?" +reveal_type(c + a) # N: Revealed type is "Literal[8]?" +reveal_type(c + c) # N: Revealed type is "Literal[10]?" + +i: int + +reveal_type(a + i) # N: Revealed type is "builtins.int" +reveal_type(i + a) # N: Revealed type is "builtins.int" +reveal_type(i * 2) # N: Revealed type is "builtins.int" +reveal_type(2 * i) # N: Revealed type is "builtins.int" +reveal_type(i // 2) # N: Revealed type is "builtins.int" +reveal_type(2 // i) # N: Revealed type is "builtins.int" +reveal_type(i - 2) # N: Revealed type is "builtins.int" +reveal_type(2 - i) # N: Revealed type is "builtins.int" +reveal_type(i - c) # N: Revealed type is "builtins.int" + +# Corner cases: + +reveal_type(9223372036854775807 + 9223372036854775807) # N: Revealed type is "Literal[18446744073709551614]?" +reveal_type(9223372036854775807 * 9223372036854775807) # N: Revealed type is "Literal[85070591730234615847396907784232501249]?" + +reveal_type(1 // 0) # N: Revealed type is "builtins.int" +reveal_type(1 + 0) # N: Revealed type is "Literal[1]?" +[builtins fixtures/primitives.pyi] + + +[case testLiteralStrMath] +from typing_extensions import Literal, Final + +reveal_type('a' + 'b') # N: Revealed type is "Literal['ab']?" +reveal_type('b' + 'a') # N: Revealed type is "Literal['ba']?" + +a: Literal['a'] +b: Literal['b'] +c: Final = 'c' + +reveal_type(a + '!') # N: Revealed type is "Literal['a!']?" +reveal_type('!' + a) # N: Revealed type is "Literal['!a']?" +reveal_type(a + b + c) # N: Revealed type is "Literal['abc']?" +reveal_type(c + b + a) # N: Revealed type is "Literal['cba']?" +reveal_type(a + '!' + b + '?' + c) # N: Revealed type is "Literal['a!b?c']?" +reveal_type(c + '1' + a + '2' + b) # N: Revealed type is "Literal['c1a2b']?" + +s: str + +reveal_type(s + 'a') # N: Revealed type is "builtins.str" +reveal_type('a' + s) # N: Revealed type is "builtins.str" +reveal_type(s + a) # N: Revealed type is "builtins.str" +reveal_type(a + s) # N: Revealed type is "builtins.str" +reveal_type(s + c) # N: Revealed type is "builtins.str" +reveal_type(c + s) # N: Revealed type is "builtins.str" + +# Corner cases: + +reveal_type('a' + '') # N: Revealed type is "Literal['a']?" +reveal_type(a + '') # N: Revealed type is "Literal['a']?" +reveal_type('' + '') # N: Revealed type is "Literal['']?" +[builtins fixtures/primitives.pyi] + + +[case testLiteralBytesMath] +from typing_extensions import Literal, Final + +reveal_type(b'a' + b'b') # N: Revealed type is "Literal[b'ab']?" +reveal_type(b'b' + b'a') # N: Revealed type is "Literal[b'ba']?" + +a: Literal[b'a'] +b: Literal[b'b'] +c: Final = b'c' + +reveal_type(a + b'!') # N: Revealed type is "Literal[b'a!']?" +reveal_type(b'!' + a) # N: Revealed type is "Literal[b'!a']?" +reveal_type(a + b + c) # N: Revealed type is "Literal[b'abc']?" +reveal_type(c + b + a) # N: Revealed type is "Literal[b'cba']?" +reveal_type(a + b'!' + b + b'?' + c) # N: Revealed type is "Literal[b'a!b?c']?" +reveal_type(c + b'1' + a + b'2' + b) # N: Revealed type is "Literal[b'c1a2b']?" + +s: bytes + +reveal_type(s + b'a') # N: Revealed type is "builtins.bytes" +reveal_type(b'a' + s) # N: Revealed type is "builtins.bytes" +reveal_type(s + a) # N: Revealed type is "builtins.bytes" +reveal_type(a + s) # N: Revealed type is "builtins.bytes" +reveal_type(s + c) # N: Revealed type is "builtins.bytes" +reveal_type(c + s) # N: Revealed type is "builtins.bytes" + +# Corner cases: + +reveal_type(b'a' + b'') # N: Revealed type is "Literal[b'a']?" +reveal_type(a + b'') # N: Revealed type is "Literal[b'a']?" +reveal_type(b'' + b'') # N: Revealed type is "Literal[b'']?" +[builtins fixtures/primitives.pyi] + + +[case testLiteralBoolMath] +from typing_extensions import Literal, Final + +reveal_type(True or False) # N: Revealed type is "Literal[True]?" +reveal_type(False or True) # N: Revealed type is "Literal[True]" +reveal_type(True or True) # N: Revealed type is "Literal[True]?" +reveal_type(False or False) # N: Revealed type is "Literal[False]" + +reveal_type(True and False) # N: Revealed type is "Literal[False]" +reveal_type(False and True) # N: Revealed type is "Literal[False]?" +reveal_type(True and True) # N: Revealed type is "Literal[True]" +reveal_type(False and False) # N: Revealed type is "Literal[False]?" + +reveal_type(True and False and True) # N: Revealed type is "Literal[False]" +reveal_type(True or False or False) # N: Revealed type is "Literal[True]?" + +t: Literal[True] +f: Literal[False] +c: Final = True + +reveal_type(t or False) # N: Revealed type is "Literal[True]" +reveal_type(False or t) # N: Revealed type is "Literal[True]" +reveal_type(t or f) # N: Revealed type is "Literal[True]" +reveal_type(t and c) # N: Revealed type is "Literal[True]" + +b: bool + +reveal_type(True or b) # N: Revealed type is "Literal[True]?" +reveal_type(b and True) # N: Revealed type is "builtins.bool" +reveal_type(b and False) # N: Revealed type is "Literal[False]" +[builtins fixtures/primitives.pyi] + + +[case testLiteralMathLoopContext] +def func1(loop_count: int): + x = 1 + reveal_type(x) # N: Revealed type is "builtins.int" + x = x + 1 + reveal_type(x) # N: Revealed type is "builtins.int" + + for _ in [1, 2, 3]: + x = x + 1 + reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 71f59a9c1d8c..1e37413660b2 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -17,6 +17,8 @@ class int: def __init__(self, x: object = ..., base: int = ...) -> None: pass def __add__(self, i: int) -> int: pass def __rmul__(self, x: int) -> int: pass + def __sub__(self, x: int) -> int: pass + def __floordiv__(self, x: int) -> int: pass class float: def __float__(self) -> float: pass class complex: pass @@ -28,6 +30,7 @@ class str(Sequence[str]): def __getitem__(self, item: int) -> str: pass def format(self, *args, **kwargs) -> str: pass class bytes(Sequence[int]): + def __add__(self, s: bytes) -> bytes: pass def __iter__(self) -> Iterator[int]: pass def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> int: pass diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 7a0115f17e9c..a5de72602d7a 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -121,7 +121,7 @@ class float: class type: pass class str: pass [out] -OpExpr(3) : builtins.int +OpExpr(3) : Literal[3]? OpExpr(4) : builtins.float OpExpr(5) : builtins.float OpExpr(6) : builtins.float