Skip to content

Commit f03d318

Browse files
authored
bpo-42345: Fix three issues with typing.Literal parameters (GH-23294)
Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments.
1 parent b0aba1f commit f03d318

File tree

4 files changed

+104
-23
lines changed

4 files changed

+104
-23
lines changed

Lib/test/test_typing.py

+25
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def test_repr(self):
528528
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
529529
self.assertEqual(repr(Literal), "typing.Literal")
530530
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
531+
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
531532

532533
def test_cannot_init(self):
533534
with self.assertRaises(TypeError):
@@ -559,6 +560,30 @@ def test_no_multiple_subscripts(self):
559560
with self.assertRaises(TypeError):
560561
Literal[1][1]
561562

563+
def test_equal(self):
564+
self.assertNotEqual(Literal[0], Literal[False])
565+
self.assertNotEqual(Literal[True], Literal[1])
566+
self.assertNotEqual(Literal[1], Literal[2])
567+
self.assertNotEqual(Literal[1, True], Literal[1])
568+
self.assertEqual(Literal[1], Literal[1])
569+
self.assertEqual(Literal[1, 2], Literal[2, 1])
570+
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
571+
572+
def test_args(self):
573+
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
574+
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
575+
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
576+
# Mutable arguments will not be deduplicated
577+
self.assertEqual(Literal[[], []].__args__, ([], []))
578+
579+
def test_flatten(self):
580+
l1 = Literal[Literal[1], Literal[2], Literal[3]]
581+
l2 = Literal[Literal[1, 2], 3]
582+
l3 = Literal[Literal[1, 2, 3]]
583+
for l in l1, l2, l3:
584+
self.assertEqual(l, Literal[1, 2, 3])
585+
self.assertEqual(l.__args__, (1, 2, 3))
586+
562587

563588
XK = TypeVar('XK', str, bytes)
564589
XV = TypeVar('XV')

Lib/typing.py

+76-23
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
202202
f" actual {alen}, expected {elen}")
203203

204204

205+
def _deduplicate(params):
206+
# Weed out strict duplicates, preserving the first of each occurrence.
207+
all_params = set(params)
208+
if len(all_params) < len(params):
209+
new_params = []
210+
for t in params:
211+
if t in all_params:
212+
new_params.append(t)
213+
all_params.remove(t)
214+
params = new_params
215+
assert not all_params, all_params
216+
return params
217+
218+
205219
def _remove_dups_flatten(parameters):
206220
"""An internal helper for Union creation and substitution: flatten Unions
207221
among parameters, then remove duplicates.
@@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
215229
params.extend(p[1:])
216230
else:
217231
params.append(p)
218-
# Weed out strict duplicates, preserving the first of each occurrence.
219-
all_params = set(params)
220-
if len(all_params) < len(params):
221-
new_params = []
222-
for t in params:
223-
if t in all_params:
224-
new_params.append(t)
225-
all_params.remove(t)
226-
params = new_params
227-
assert not all_params, all_params
232+
233+
return tuple(_deduplicate(params))
234+
235+
236+
def _flatten_literal_params(parameters):
237+
"""An internal helper for Literal creation: flatten Literals among parameters"""
238+
params = []
239+
for p in parameters:
240+
if isinstance(p, _LiteralGenericAlias):
241+
params.extend(p.__args__)
242+
else:
243+
params.append(p)
228244
return tuple(params)
229245

230246

231247
_cleanups = []
232248

233249

234-
def _tp_cache(func):
250+
def _tp_cache(func=None, /, *, typed=False):
235251
"""Internal wrapper caching __getitem__ of generic types with a fallback to
236252
original function for non-hashable arguments.
237253
"""
238-
cached = functools.lru_cache()(func)
239-
_cleanups.append(cached.cache_clear)
254+
def decorator(func):
255+
cached = functools.lru_cache(typed=typed)(func)
256+
_cleanups.append(cached.cache_clear)
240257

241-
@functools.wraps(func)
242-
def inner(*args, **kwds):
243-
try:
244-
return cached(*args, **kwds)
245-
except TypeError:
246-
pass # All real errors (not unhashable args) are raised below.
247-
return func(*args, **kwds)
248-
return inner
258+
@functools.wraps(func)
259+
def inner(*args, **kwds):
260+
try:
261+
return cached(*args, **kwds)
262+
except TypeError:
263+
pass # All real errors (not unhashable args) are raised below.
264+
return func(*args, **kwds)
265+
return inner
249266

267+
if func is not None:
268+
return decorator(func)
269+
270+
return decorator
250271

251272
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
252273
"""Evaluate all forward references in the given type t.
@@ -319,6 +340,13 @@ def __subclasscheck__(self, cls):
319340
def __getitem__(self, parameters):
320341
return self._getitem(self, parameters)
321342

343+
344+
class _LiteralSpecialForm(_SpecialForm, _root=True):
345+
@_tp_cache(typed=True)
346+
def __getitem__(self, parameters):
347+
return self._getitem(self, parameters)
348+
349+
322350
@_SpecialForm
323351
def Any(self, parameters):
324352
"""Special type indicating an unconstrained type.
@@ -436,7 +464,7 @@ def Optional(self, parameters):
436464
arg = _type_check(parameters, f"{self} requires a single type.")
437465
return Union[arg, type(None)]
438466

439-
@_SpecialForm
467+
@_LiteralSpecialForm
440468
def Literal(self, parameters):
441469
"""Special typing form to define literal types (a.k.a. value types).
442470
@@ -460,7 +488,17 @@ def open_helper(file: str, mode: MODE) -> str:
460488
"""
461489
# There is no '_type_check' call because arguments to Literal[...] are
462490
# values, not types.
463-
return _GenericAlias(self, parameters)
491+
if not isinstance(parameters, tuple):
492+
parameters = (parameters,)
493+
494+
parameters = _flatten_literal_params(parameters)
495+
496+
try:
497+
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
498+
except TypeError: # unhashable parameters
499+
pass
500+
501+
return _LiteralGenericAlias(self, parameters)
464502

465503

466504
@_SpecialForm
@@ -930,6 +968,21 @@ def __subclasscheck__(self, cls):
930968
return True
931969

932970

971+
def _value_and_type_iter(parameters):
972+
return ((p, type(p)) for p in parameters)
973+
974+
975+
class _LiteralGenericAlias(_GenericAlias, _root=True):
976+
977+
def __eq__(self, other):
978+
if not isinstance(other, _LiteralGenericAlias):
979+
return NotImplemented
980+
981+
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
982+
983+
def __hash__(self):
984+
return hash(tuple(_value_and_type_iter(self.__args__)))
985+
933986

934987
class Generic:
935988
"""Abstract base class for generic types.

Misc/ACKS

+1
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ Jan Kanis
861861
Rafe Kaplan
862862
Jacob Kaplan-Moss
863863
Allison Kaptur
864+
Yurii Karabas
864865
Janne Karila
865866
Per Øyvind Karlsen
866867
Anton Kasyanov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix various issues with ``typing.Literal`` parameter handling (flatten,
2+
deduplicate, use type to cache key). Patch provided by Yurii Karabas.

0 commit comments

Comments
 (0)