Skip to content

Commit 1072c78

Browse files
Fix Literal strings containing pipe characters (#17148)
Fixes #16367 During semantic analysis, we try to parse all strings as types, including those inside Literal[]. Previously, we preserved the original string in the `UnboundType.original_str_expr` attribute, but if a type is parsed as a Union, we didn't have a place to put the value. This PR instead always wraps string types in a RawExpressionType node, which now optionally includes a `.node` attribute containing the parsed type. This way, we don't need to worry about preserving the original string as a custom attribute on different kinds of types that can appear in this context. The downside is that more code needs to be aware of RawExpressionType.
1 parent df35dcf commit 1072c78

15 files changed

+142
-90
lines changed

mypy/fastparse.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -319,14 +319,7 @@ def parse_type_string(
319319
"""
320320
try:
321321
_, node = parse_type_comment(f"({expr_string})", line=line, column=column, errors=None)
322-
if isinstance(node, UnboundType) and node.original_str_expr is None:
323-
node.original_str_expr = expr_string
324-
node.original_str_fallback = expr_fallback_name
325-
return node
326-
elif isinstance(node, UnionType):
327-
return node
328-
else:
329-
return RawExpressionType(expr_string, expr_fallback_name, line, column)
322+
return RawExpressionType(expr_string, expr_fallback_name, line, column, node=node)
330323
except (SyntaxError, ValueError):
331324
# Note: the parser will raise a `ValueError` instead of a SyntaxError if
332325
# the string happens to contain things like \x00.
@@ -1034,6 +1027,8 @@ def set_type_optional(self, type: Type | None, initializer: Expression | None) -
10341027
return
10351028
# Indicate that type should be wrapped in an Optional if arg is initialized to None.
10361029
optional = isinstance(initializer, NameExpr) and initializer.name == "None"
1030+
if isinstance(type, RawExpressionType) and type.node is not None:
1031+
type = type.node
10371032
if isinstance(type, UnboundType):
10381033
type.optional = optional
10391034

mypy/semanal.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,10 +3231,10 @@ def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
32313231
def analyze_lvalues(self, s: AssignmentStmt) -> None:
32323232
# We cannot use s.type, because analyze_simple_literal_type() will set it.
32333233
explicit = s.unanalyzed_type is not None
3234-
if self.is_final_type(s.unanalyzed_type):
3234+
final_type = self.unwrap_final_type(s.unanalyzed_type)
3235+
if final_type is not None:
32353236
# We need to exclude bare Final.
3236-
assert isinstance(s.unanalyzed_type, UnboundType)
3237-
if not s.unanalyzed_type.args:
3237+
if not final_type.args:
32383238
explicit = False
32393239

32403240
if s.rvalue:
@@ -3300,19 +3300,19 @@ def unwrap_final(self, s: AssignmentStmt) -> bool:
33003300
33013301
Returns True if Final[...] was present.
33023302
"""
3303-
if not s.unanalyzed_type or not self.is_final_type(s.unanalyzed_type):
3303+
final_type = self.unwrap_final_type(s.unanalyzed_type)
3304+
if final_type is None:
33043305
return False
3305-
assert isinstance(s.unanalyzed_type, UnboundType)
3306-
if len(s.unanalyzed_type.args) > 1:
3307-
self.fail("Final[...] takes at most one type argument", s.unanalyzed_type)
3306+
if len(final_type.args) > 1:
3307+
self.fail("Final[...] takes at most one type argument", final_type)
33083308
invalid_bare_final = False
3309-
if not s.unanalyzed_type.args:
3309+
if not final_type.args:
33103310
s.type = None
33113311
if isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs:
33123312
invalid_bare_final = True
33133313
self.fail("Type in Final[...] can only be omitted if there is an initializer", s)
33143314
else:
3315-
s.type = s.unanalyzed_type.args[0]
3315+
s.type = final_type.args[0]
33163316

33173317
if s.type is not None and self.is_classvar(s.type):
33183318
self.fail("Variable should not be annotated with both ClassVar and Final", s)
@@ -4713,13 +4713,18 @@ def is_classvar(self, typ: Type) -> bool:
47134713
return False
47144714
return sym.node.fullname == "typing.ClassVar"
47154715

4716-
def is_final_type(self, typ: Type | None) -> bool:
4716+
def unwrap_final_type(self, typ: Type | None) -> UnboundType | None:
4717+
if typ is None:
4718+
return None
4719+
typ = typ.resolve_string_annotation()
47174720
if not isinstance(typ, UnboundType):
4718-
return False
4721+
return None
47194722
sym = self.lookup_qualified(typ.name, typ)
47204723
if not sym or not sym.node:
4721-
return False
4722-
return sym.node.fullname in FINAL_TYPE_NAMES
4724+
return None
4725+
if sym.node.fullname in FINAL_TYPE_NAMES:
4726+
return typ
4727+
return None
47234728

47244729
def fail_invalid_classvar(self, context: Context) -> None:
47254730
self.fail(message_registry.CLASS_VAR_OUTSIDE_OF_CLASS, context)

mypy/server/astmerge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ def visit_typeddict_type(self, typ: TypedDictType) -> None:
507507
typ.fallback.accept(self)
508508

509509
def visit_raw_expression_type(self, t: RawExpressionType) -> None:
510-
pass
510+
if t.node is not None:
511+
t.node.accept(self)
511512

512513
def visit_literal_type(self, typ: LiteralType) -> None:
513514
typ.fallback.accept(self)

mypy/stubutil.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
from mypy.modulefinder import ModuleNotFoundReason
1818
from mypy.moduleinspect import InspectError, ModuleInspect
1919
from mypy.stubdoc import ArgSig, FunctionSig
20-
from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType
20+
from mypy.types import (
21+
AnyType,
22+
NoneType,
23+
RawExpressionType,
24+
Type,
25+
TypeList,
26+
TypeStrVisitor,
27+
UnboundType,
28+
UnionType,
29+
)
2130

2231
# Modules that may fail when imported, or that may have side effects (fully qualified).
2332
NOT_IMPORTABLE_MODULES = ()
@@ -291,12 +300,11 @@ def args_str(self, args: Iterable[Type]) -> str:
291300
The main difference from list_str is the preservation of quotes for string
292301
arguments
293302
"""
294-
types = ["builtins.bytes", "builtins.str"]
295303
res = []
296304
for arg in args:
297305
arg_str = arg.accept(self)
298-
if isinstance(arg, UnboundType) and arg.original_str_fallback in types:
299-
res.append(f"'{arg_str}'")
306+
if isinstance(arg, RawExpressionType):
307+
res.append(repr(arg.literal_value))
300308
else:
301309
res.append(arg_str)
302310
return ", ".join(res)

mypy/type_visitor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> T:
376376
return self.query_types(t.items.values())
377377

378378
def visit_raw_expression_type(self, t: RawExpressionType) -> T:
379+
if t.node is not None:
380+
return t.node.accept(self)
379381
return self.strategy([])
380382

381383
def visit_literal_type(self, t: LiteralType) -> T:
@@ -516,6 +518,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> bool:
516518
return self.query_types(list(t.items.values()))
517519

518520
def visit_raw_expression_type(self, t: RawExpressionType) -> bool:
521+
if t.node is not None:
522+
return t.node.accept(self)
519523
return self.default
520524

521525
def visit_literal_type(self, t: LiteralType) -> bool:

mypy/typeanal.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
10701070
return ret
10711071

10721072
def anal_type_guard(self, t: Type) -> Type | None:
1073+
t = t.resolve_string_annotation()
10731074
if isinstance(t, UnboundType):
10741075
sym = self.lookup_qualified(t.name, t)
10751076
if sym is not None and sym.node is not None:
@@ -1088,6 +1089,7 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Type | None:
10881089
return None
10891090

10901091
def anal_type_is(self, t: Type) -> Type | None:
1092+
t = t.resolve_string_annotation()
10911093
if isinstance(t, UnboundType):
10921094
sym = self.lookup_qualified(t.name, t)
10931095
if sym is not None and sym.node is not None:
@@ -1105,6 +1107,7 @@ def anal_type_is_arg(self, t: UnboundType, fullname: str) -> Type | None:
11051107

11061108
def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type:
11071109
"""Analyze signature argument type for *args and **kwargs argument."""
1110+
t = t.resolve_string_annotation()
11081111
if isinstance(t, UnboundType) and t.name and "." in t.name and not t.args:
11091112
components = t.name.split(".")
11101113
tvar_name = ".".join(components[:-1])
@@ -1195,6 +1198,8 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:
11951198
# make signatures like "foo(x: 20) -> None" legal, we can change
11961199
# this method so it generates and returns an actual LiteralType
11971200
# instead.
1201+
if t.node is not None:
1202+
return t.node.accept(self)
11981203

11991204
if self.report_invalid_types:
12001205
if t.base_type_name in ("builtins.int", "builtins.bool"):
@@ -1455,6 +1460,7 @@ def analyze_callable_args(
14551460
invalid_unpacks: list[Type] = []
14561461
second_unpack_last = False
14571462
for i, arg in enumerate(arglist.items):
1463+
arg = arg.resolve_string_annotation()
14581464
if isinstance(arg, CallableArgument):
14591465
args.append(arg.typ)
14601466
names.append(arg.name)
@@ -1535,18 +1541,6 @@ def analyze_literal_type(self, t: UnboundType) -> Type:
15351541
return UnionType.make_union(output, line=t.line)
15361542

15371543
def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> list[Type] | None:
1538-
# This UnboundType was originally defined as a string.
1539-
if isinstance(arg, UnboundType) and arg.original_str_expr is not None:
1540-
assert arg.original_str_fallback is not None
1541-
return [
1542-
LiteralType(
1543-
value=arg.original_str_expr,
1544-
fallback=self.named_type(arg.original_str_fallback),
1545-
line=arg.line,
1546-
column=arg.column,
1547-
)
1548-
]
1549-
15501544
# If arg is an UnboundType that was *not* originally defined as
15511545
# a string, try expanding it in case it's a type alias or something.
15521546
if isinstance(arg, UnboundType):
@@ -2528,7 +2522,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> None:
25282522
self.process_types(list(t.items.values()))
25292523

25302524
def visit_raw_expression_type(self, t: RawExpressionType) -> None:
2531-
pass
2525+
if t.node is not None:
2526+
t.node.accept(self)
25322527

25332528
def visit_literal_type(self, t: LiteralType) -> None:
25342529
pass

0 commit comments

Comments
 (0)