@@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
202
202
f" actual { alen } , expected { elen } " )
203
203
204
204
205
+ def _deduplicate (params ):
206
+ # Weed out strict duplicates, preserving the first of each occurrence.
207
+ all_params = set (params )
208
+ if len (all_params ) < len (params ):
209
+ new_params = []
210
+ for t in params :
211
+ if t in all_params :
212
+ new_params .append (t )
213
+ all_params .remove (t )
214
+ params = new_params
215
+ assert not all_params , all_params
216
+ return params
217
+
218
+
205
219
def _remove_dups_flatten (parameters ):
206
220
"""An internal helper for Union creation and substitution: flatten Unions
207
221
among parameters, then remove duplicates.
@@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
215
229
params .extend (p [1 :])
216
230
else :
217
231
params .append (p )
218
- # Weed out strict duplicates, preserving the first of each occurrence.
219
- all_params = set (params )
220
- if len (all_params ) < len (params ):
221
- new_params = []
222
- for t in params :
223
- if t in all_params :
224
- new_params .append (t )
225
- all_params .remove (t )
226
- params = new_params
227
- assert not all_params , all_params
232
+
233
+ return tuple (_deduplicate (params ))
234
+
235
+
236
+ def _flatten_literal_params (parameters ):
237
+ """An internal helper for Literal creation: flatten Literals among parameters"""
238
+ params = []
239
+ for p in parameters :
240
+ if isinstance (p , _LiteralGenericAlias ):
241
+ params .extend (p .__args__ )
242
+ else :
243
+ params .append (p )
228
244
return tuple (params )
229
245
230
246
231
247
_cleanups = []
232
248
233
249
234
- def _tp_cache (func ):
250
+ def _tp_cache (func = None , / , * , typed = False ):
235
251
"""Internal wrapper caching __getitem__ of generic types with a fallback to
236
252
original function for non-hashable arguments.
237
253
"""
238
- cached = functools .lru_cache ()(func )
239
- _cleanups .append (cached .cache_clear )
254
+ def decorator (func ):
255
+ cached = functools .lru_cache (typed = typed )(func )
256
+ _cleanups .append (cached .cache_clear )
240
257
241
- @functools .wraps (func )
242
- def inner (* args , ** kwds ):
243
- try :
244
- return cached (* args , ** kwds )
245
- except TypeError :
246
- pass # All real errors (not unhashable args) are raised below.
247
- return func (* args , ** kwds )
248
- return inner
258
+ @functools .wraps (func )
259
+ def inner (* args , ** kwds ):
260
+ try :
261
+ return cached (* args , ** kwds )
262
+ except TypeError :
263
+ pass # All real errors (not unhashable args) are raised below.
264
+ return func (* args , ** kwds )
265
+ return inner
249
266
267
+ if func is not None :
268
+ return decorator (func )
269
+
270
+ return decorator
250
271
251
272
def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
252
273
"""Evaluate all forward references in the given type t.
@@ -319,6 +340,13 @@ def __subclasscheck__(self, cls):
319
340
def __getitem__ (self , parameters ):
320
341
return self ._getitem (self , parameters )
321
342
343
+
344
+ class _LiteralSpecialForm (_SpecialForm , _root = True ):
345
+ @_tp_cache (typed = True )
346
+ def __getitem__ (self , parameters ):
347
+ return self ._getitem (self , parameters )
348
+
349
+
322
350
@_SpecialForm
323
351
def Any (self , parameters ):
324
352
"""Special type indicating an unconstrained type.
@@ -436,7 +464,7 @@ def Optional(self, parameters):
436
464
arg = _type_check (parameters , f"{ self } requires a single type." )
437
465
return Union [arg , type (None )]
438
466
439
- @_SpecialForm
467
+ @_LiteralSpecialForm
440
468
def Literal (self , parameters ):
441
469
"""Special typing form to define literal types (a.k.a. value types).
442
470
@@ -460,7 +488,17 @@ def open_helper(file: str, mode: MODE) -> str:
460
488
"""
461
489
# There is no '_type_check' call because arguments to Literal[...] are
462
490
# values, not types.
463
- return _GenericAlias (self , parameters )
491
+ if not isinstance (parameters , tuple ):
492
+ parameters = (parameters ,)
493
+
494
+ parameters = _flatten_literal_params (parameters )
495
+
496
+ try :
497
+ parameters = tuple (p for p , _ in _deduplicate (list (_value_and_type_iter (parameters ))))
498
+ except TypeError : # unhashable parameters
499
+ pass
500
+
501
+ return _LiteralGenericAlias (self , parameters )
464
502
465
503
466
504
@_SpecialForm
@@ -930,6 +968,21 @@ def __subclasscheck__(self, cls):
930
968
return True
931
969
932
970
971
+ def _value_and_type_iter (parameters ):
972
+ return ((p , type (p )) for p in parameters )
973
+
974
+
975
+ class _LiteralGenericAlias (_GenericAlias , _root = True ):
976
+
977
+ def __eq__ (self , other ):
978
+ if not isinstance (other , _LiteralGenericAlias ):
979
+ return NotImplemented
980
+
981
+ return set (_value_and_type_iter (self .__args__ )) == set (_value_and_type_iter (other .__args__ ))
982
+
983
+ def __hash__ (self ):
984
+ return hash (tuple (_value_and_type_iter (self .__args__ )))
985
+
933
986
934
987
class Generic :
935
988
"""Abstract base class for generic types.
0 commit comments