Skip to content

Commit 2b613e5

Browse files
authored
Fix type narrowing of == None and in (None,) conditions (#15760)
1 parent 54bc37c commit 2b613e5

File tree

7 files changed

+48
-16
lines changed

7 files changed

+48
-16
lines changed

mypy/checker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@
216216
is_literal_type,
217217
is_named_instance,
218218
)
219-
from mypy.types_utils import is_optional, remove_optional, store_argument_type, strip_type
219+
from mypy.types_utils import is_overlapping_none, remove_optional, store_argument_type, strip_type
220220
from mypy.typetraverser import TypeTraverserVisitor
221221
from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars
222222
from mypy.util import is_dunder, is_sunder, is_typeshed_file
@@ -5660,13 +5660,13 @@ def has_no_custom_eq_checks(t: Type) -> bool:
56605660

56615661
if left_index in narrowable_operand_index_to_hash:
56625662
# We only try and narrow away 'None' for now
5663-
if is_optional(item_type):
5663+
if is_overlapping_none(item_type):
56645664
collection_item_type = get_proper_type(
56655665
builtin_item_type(iterable_type)
56665666
)
56675667
if (
56685668
collection_item_type is not None
5669-
and not is_optional(collection_item_type)
5669+
and not is_overlapping_none(collection_item_type)
56705670
and not (
56715671
isinstance(collection_item_type, Instance)
56725672
and collection_item_type.type.fullname == "builtins.object"
@@ -6073,7 +6073,7 @@ def refine_away_none_in_comparison(
60736073
non_optional_types = []
60746074
for i in chain_indices:
60756075
typ = operand_types[i]
6076-
if not is_optional(typ):
6076+
if not is_overlapping_none(typ):
60776077
non_optional_types.append(typ)
60786078

60796079
# Make sure we have a mixture of optional and non-optional types.
@@ -6083,7 +6083,7 @@ def refine_away_none_in_comparison(
60836083
if_map = {}
60846084
for i in narrowable_operand_indices:
60856085
expr_type = operand_types[i]
6086-
if not is_optional(expr_type):
6086+
if not is_overlapping_none(expr_type):
60876087
continue
60886088
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
60896089
if_map[operands[i]] = remove_optional(expr_type)

mypy/checkexpr.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@
169169
is_named_instance,
170170
split_with_prefix_and_suffix,
171171
)
172-
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
172+
from mypy.types_utils import (
173+
is_generic_instance,
174+
is_overlapping_none,
175+
is_self_type_like,
176+
remove_optional,
177+
)
173178
from mypy.typestate import type_state
174179
from mypy.typevars import fill_typevars
175180
from mypy.typevartuples import find_unpack_in_list
@@ -1809,7 +1814,7 @@ def infer_function_type_arguments_using_context(
18091814
# valid results.
18101815
erased_ctx = replace_meta_vars(ctx, ErasedType())
18111816
ret_type = callable.ret_type
1812-
if is_optional(ret_type) and is_optional(ctx):
1817+
if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
18131818
# If both the context and the return type are optional, unwrap the optional,
18141819
# since in 99% cases this is what a user expects. In other words, we replace
18151820
# Optional[T] <: Optional[int]

mypy/plugins/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
deserialize_type,
4444
get_proper_type,
4545
)
46-
from mypy.types_utils import is_optional
46+
from mypy.types_utils import is_overlapping_none
4747
from mypy.typevars import fill_typevars
4848
from mypy.util import get_unique_redefinition_name
4949

@@ -141,7 +141,7 @@ def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) ->
141141
break
142142
elif (
143143
arg_none
144-
and not is_optional(arg_type)
144+
and not is_overlapping_none(arg_type)
145145
and not (
146146
isinstance(arg_type, Instance)
147147
and arg_type.type.fullname == "builtins.object"

mypy/suggestions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
UnionType,
8080
get_proper_type,
8181
)
82-
from mypy.types_utils import is_optional, remove_optional
82+
from mypy.types_utils import is_overlapping_none, remove_optional
8383
from mypy.util import split_target
8484

8585

@@ -752,7 +752,7 @@ def score_type(self, t: Type, arg_pos: bool) -> int:
752752
return 20
753753
if any(has_any_type(x) for x in t.items):
754754
return 15
755-
if not is_optional(t):
755+
if not is_overlapping_none(t):
756756
return 10
757757
if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)):
758758
return 10
@@ -868,7 +868,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> str:
868868
return t.fallback.accept(self)
869869

870870
def visit_union_type(self, t: UnionType) -> str:
871-
if len(t.items) == 2 and is_optional(t):
871+
if len(t.items) == 2 and is_overlapping_none(t):
872872
return f"Optional[{remove_optional(t).accept(self)}]"
873873
else:
874874
return super().visit_union_type(t)

mypy/types_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ def is_generic_instance(tp: Type) -> bool:
101101
return isinstance(tp, Instance) and bool(tp.args)
102102

103103

104-
def is_optional(t: Type) -> bool:
104+
def is_overlapping_none(t: Type) -> bool:
105105
t = get_proper_type(t)
106-
return isinstance(t, UnionType) and any(
107-
isinstance(get_proper_type(e), NoneType) for e in t.items
106+
return isinstance(t, NoneType) or (
107+
isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) for e in t.items)
108108
)
109109

110110

test-data/unit/check-narrowing.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,32 @@ def g() -> None:
12631263
[builtins fixtures/dict.pyi]
12641264

12651265

1266+
[case testNarrowingOptionalEqualsNone]
1267+
from typing import Optional
1268+
1269+
class A: ...
1270+
1271+
val: Optional[A]
1272+
1273+
if val == None:
1274+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1275+
else:
1276+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1277+
if val != None:
1278+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1279+
else:
1280+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1281+
1282+
if val in (None,):
1283+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1284+
else:
1285+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1286+
if val not in (None,):
1287+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1288+
else:
1289+
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1290+
[builtins fixtures/primitives.pyi]
1291+
12661292
[case testNarrowingWithTupleOfTypes]
12671293
from typing import Tuple, Type
12681294

test-data/unit/fixtures/primitives.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class memoryview(Sequence[int]):
4545
def __iter__(self) -> Iterator[int]: pass
4646
def __contains__(self, other: object) -> bool: pass
4747
def __getitem__(self, item: int) -> int: pass
48-
class tuple(Generic[T]): pass
48+
class tuple(Generic[T]):
49+
def __contains__(self, other: object) -> bool: pass
4950
class list(Sequence[T]):
5051
def __iter__(self) -> Iterator[T]: pass
5152
def __contains__(self, other: object) -> bool: pass

0 commit comments

Comments
 (0)