Skip to content

Introduce temporary named expressions for match subjects #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
CallExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
Context,
ContinueStmt,
Decorator,
Expand Down Expand Up @@ -350,6 +351,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
# functions such as open(), etc.
plugin: Plugin

# A helper state to produce unique temporary names on demand.
_unique_id: int

Comment on lines +354 to +356
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

_unique_id must be added to __slots__ to avoid AttributeError.

TypeChecker uses __slots__ to save memory. Writing self._unique_id = 0 in __init__ will raise at runtime unless the new attribute is appended to the __slots__ tuple defined near the top of the class.

Suggested follow‑up (illustrative – adjust exact location):

 class TypeChecker(TraverserVisitor):
-    __slots__ = (
-        "...",
-        # existing slots
+    __slots__ = (
+        "...",
+        "_unique_id",   # ← add here
     )

Committable suggestion skipped: line range outside the PR's diff.

def __init__(
self,
errors: Errors,
Expand Down Expand Up @@ -414,6 +418,7 @@ def __init__(
self, self.msg, self.plugin, per_line_checking_time_ns
)
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
self._unique_id = 0

@property
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
Expand Down Expand Up @@ -5413,21 +5418,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return

def visit_match_stmt(self, s: MatchStmt) -> None:
named_subject: Expression
if isinstance(s.subject, CallExpr):
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
if s.subject_dummy is None:
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
name = "dummy-match-" + id
v = Var(name)
s.subject_dummy = NameExpr(name)
s.subject_dummy.node = v
named_subject = s.subject_dummy
else:
named_subject = s.subject

named_subject = self._make_named_statement_for_match(s)
with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))

Expand Down Expand Up @@ -5459,6 +5450,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
pattern_map, else_map = conditional_types_to_typemaps(
named_subject, pattern_type.type, pattern_type.rest_type
)
# Maybe the subject type can be inferred from constraints on
# its attribute/item?
if pattern_map and named_subject in pattern_map:
pattern_map[s.subject] = pattern_map[named_subject]
if else_map and named_subject in else_map:
Comment on lines +5453 to +5457
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unchecked Type Assignment in Match Statement Subject Processing

The code assigns type information from the named_subject (a temporary variable) directly to the original subject without validating type compatibility. This could lead to incorrect type narrowing if the pattern constraints are not fully compatible with the original subject type. An attacker could potentially craft code that exploits this behavior to bypass type checks or cause the type checker to make incorrect assumptions about variable types.

Suggested change
# Maybe the subject type can be inferred from constraints on
# its attribute/item?
if pattern_map and named_subject in pattern_map:
pattern_map[s.subject] = pattern_map[named_subject]
if else_map and named_subject in else_map:
if pattern_map and named_subject in pattern_map:
inferred_type = pattern_map[named_subject]
if self.is_compatible_with(inferred_type, self.expr_checker.accept(s.subject)):
pattern_map[s.subject] = inferred_type
if else_map and named_subject in else_map:
inferred_type = else_map[named_subject]
if self.is_compatible_with(inferred_type, self.expr_checker.accept(s.subject)):
else_map[s.subject] = inferred_type
Rationale
  • The fix ensures that type information is only propagated when the inferred type is compatible with the original subject's type.
  • This prevents potential type confusion that could lead to incorrect type narrowing or type-based security issues.
  • By validating type compatibility, we maintain the type system's integrity and prevent potential type-related vulnerabilities.
  • The fix follows the principle of defensive programming by not making assumptions about type compatibility.
References
  • cwe: CWE-704
  • owasp: A08:2021-Software and Data Integrity Failures

else_map[s.subject] = else_map[named_subject]
Comment on lines +5453 to +5458
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Type Propagation for Temporary Match Expressions

The code only propagates type information from the temporary named subject to the original subject but not in the reverse direction. This one-way propagation creates a risk of state inconsistency where the original subject's type information might be more refined than the temporary subject's. In complex pattern matching scenarios, this could lead to incorrect type inference and potentially missed type errors.

Suggested change
# Maybe the subject type can be inferred from constraints on
# its attribute/item?
if pattern_map and named_subject in pattern_map:
pattern_map[s.subject] = pattern_map[named_subject]
if else_map and named_subject in else_map:
else_map[s.subject] = else_map[named_subject]
# Ensure bidirectional type propagation between original and temporary subjects
if pattern_map:
if named_subject in pattern_map:
pattern_map[s.subject] = pattern_map[named_subject]
elif s.subject in pattern_map and named_subject != s.subject:
pattern_map[named_subject] = pattern_map[s.subject]
if else_map:
if named_subject in else_map:
else_map[s.subject] = else_map[named_subject]
elif s.subject in else_map and named_subject != s.subject:
else_map[named_subject] = else_map[s.subject]
Rationale
  • This fix ensures that type information flows bidirectionally between the original subject and the temporary named subject, maintaining state consistency.
  • It implements the principle of type coherence where all representations of the same value should have consistent type information.
  • The approach prevents potential missed type errors by ensuring that refined type information from either source is properly shared.
  • The additional condition 'named_subject != s.subject' prevents unnecessary assignments when the subject is already a named expression.
References
  • Standard: ISO/IEC 25010 Functional Correctness - State Consistency

pattern_map = self.propagate_up_typemap_info(pattern_map)
else_map = self.propagate_up_typemap_info(else_map)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
Expand Down Expand Up @@ -5506,6 +5503,36 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
expressions_to_preserve = (
# Already named - we should infer type of it as given
NameExpr,
AssignmentExpr,
# Primitive literals - their type is known, no need to name them
IntExpr,
StrExpr,
BytesExpr,
FloatExpr,
ComplexExpr,
EllipsisExpr,
)
if isinstance(subject, expressions_to_preserve):
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
s.subject_dummy = named_subject
Comment on lines 5503 to +5533
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete Type Inference for Complex Match Subjects

The _make_named_statement_for_match method creates a temporary variable for complex match subjects but doesn't properly annotate it with the subject's type. This can lead to imprecise type inference for complex expressions in match statements. While the code will execute, it may miss potential type errors or provide less helpful type information in match case blocks, reducing the effectiveness of the type checker.

Suggested change
with self.binder.frame_context(can_skip=False, fall_through=2):
pass
def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
expressions_to_preserve = (
# Already named - we should infer type of it as given
NameExpr,
AssignmentExpr,
# Primitive literals - their type is known, no need to name them
IntExpr,
StrExpr,
BytesExpr,
FloatExpr,
ComplexExpr,
EllipsisExpr,
)
if isinstance(subject, expressions_to_preserve):
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
s.subject_dummy = named_subject
v = Var(name)
# Preserve the subject's type for better type inference
v.type = self.expr_checker.accept(subject)
named_subject = NameExpr(name)
named_subject.node = v
Rationale
  • This fix ensures that the temporary variable preserves the type information from the original subject expression.
  • It implements the principle of type preservation, ensuring that no type information is lost during the transformation.
  • The approach improves the precision of type checking in match statement case blocks.
  • It maintains functional correctness by ensuring that the type checker can make accurate decisions based on complete type information.
References
  • Standard: ISO/IEC 25010 Functional Correctness - Type Consistency

return named_subject

Comment on lines +5506 to +5535
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Preserve source‑location info when synthesising a dummy subject.

NameExpr() defaults to line/column = -1, which degrades error positioning and can break linters relying on accurate locations.
Copy the location from the original subject so diagnostics continue to point to the correct place.

-            named_subject = NameExpr(name)
+            named_subject = NameExpr(name)
+            # Keep error messages anchored to the original expression
+            named_subject.set_line(subject.line)
+            named_subject.set_column(subject.column)

(If set_column is unavailable, use copy_location(named_subject, subject) helper that is used elsewhere in mypy.)

Optionally, also populate v.line = subject.line for completeness.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
expressions_to_preserve = (
# Already named - we should infer type of it as given
NameExpr,
AssignmentExpr,
# Primitive literals - their type is known, no need to name them
IntExpr,
StrExpr,
BytesExpr,
FloatExpr,
ComplexExpr,
EllipsisExpr,
)
if isinstance(subject, expressions_to_preserve):
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
s.subject_dummy = named_subject
return named_subject
def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
expressions_to_preserve = (
# Already named - we should infer type of it as given
NameExpr,
AssignmentExpr,
# Primitive literals - their type is known, no need to name them
IntExpr,
StrExpr,
BytesExpr,
FloatExpr,
ComplexExpr,
EllipsisExpr,
)
if isinstance(subject, expressions_to_preserve):
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
# Keep error messages anchored to the original expression
named_subject.set_line(subject.line)
named_subject.set_column(subject.column)
named_subject.node = v
s.subject_dummy = named_subject
return named_subject

def _get_recursive_sub_patterns_map(
self, expr: Expression, typ: Type
) -> dict[Expression, Type]:
Expand Down Expand Up @@ -7885,6 +7912,12 @@ def warn_deprecated_overload_item(
if candidate == target:
self.warn_deprecated(item.func, context)

def new_unique_dummy_name(self, namespace: str) -> str:
"""Generate a name that is guaranteed to be unique for this TypeChecker instance."""
name = f"dummy-{namespace}-{self._unique_id}"
self._unique_id += 1
return name

# leafs

def visit_pass_stmt(self, o: PassStmt, /) -> None:
Expand Down
85 changes: 84 additions & 1 deletion test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def main() -> None:
case a:
reveal_type(a) # N: Revealed type is "builtins.int"

[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail]
[case testMatchCapturePatternFromAsyncFunctionReturningUnion]
async def func1(arg: bool) -> str | int: ...
async def func2(arg: bool) -> bytes | int: ...

Expand Down Expand Up @@ -2586,6 +2586,89 @@ def fn2(x: Some | int | str) -> None:
pass
[builtins fixtures/dict.pyi]

[case testMatchFunctionCall]
# flags: --warn-unreachable

def fn() -> int | str: ...

match fn():
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchAttribute]
# flags: --warn-unreachable

class A:
foo: int | str

match A().foo:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchOperations]
# flags: --warn-unreachable

x: int
match -x:
case -1 as s:
reveal_type(s) # N: Revealed type is "Literal[-1]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 + 2:
case 3 as s:
reveal_type(s) # N: Revealed type is "Literal[3]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 > 2:
case True as s:
reveal_type(s) # N: Revealed type is "Literal[True]"
case False as s:
reveal_type(s) # N: Revealed type is "Literal[False]"
case other:
other # E: Statement is unreachable
[builtins fixtures/ops.pyi]

[case testMatchDictItem]
# flags: --warn-unreachable

m: dict[str, int | str]
k: str

match m[k]:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[builtins fixtures/dict.pyi]

[case testMatchLiteralValuePathological]
# flags: --warn-unreachable

match 0:
case 0 as i:
reveal_type(i) # N: Revealed type is "Literal[0]?"
case int(i):
i # E: Statement is unreachable
case other:
other # E: Statement is unreachable

[case testMatchNamedTupleSequence]
from typing import Any, NamedTuple

Expand Down