@@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
202202 f" actual { alen } , expected { elen } " )
203203
204204
205+ def _deduplicate (params ):
206+ # Weed out strict duplicates, preserving the first of each occurrence.
207+ all_params = set (params )
208+ if len (all_params ) < len (params ):
209+ new_params = []
210+ for t in params :
211+ if t in all_params :
212+ new_params .append (t )
213+ all_params .remove (t )
214+ params = new_params
215+ assert not all_params , all_params
216+ return params
217+
218+
205219def _remove_dups_flatten (parameters ):
206220 """An internal helper for Union creation and substitution: flatten Unions
207221 among parameters, then remove duplicates.
@@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
215229 params .extend (p [1 :])
216230 else :
217231 params .append (p )
218- # Weed out strict duplicates, preserving the first of each occurrence.
219- all_params = set (params )
220- if len (all_params ) < len (params ):
221- new_params = []
222- for t in params :
223- if t in all_params :
224- new_params .append (t )
225- all_params .remove (t )
226- params = new_params
227- assert not all_params , all_params
232+
233+ return tuple (_deduplicate (params ))
234+
235+
236+ def _flatten_literal_params (parameters ):
237+ """An internal helper for Literal creation: flatten Literals among parameters"""
238+ params = []
239+ for p in parameters :
240+ if isinstance (p , _LiteralGenericAlias ):
241+ params .extend (p .__args__ )
242+ else :
243+ params .append (p )
228244 return tuple (params )
229245
230246
231247_cleanups = []
232248
233249
234- def _tp_cache (func ):
250+ def _tp_cache (func = None , / , * , typed = False ):
235251 """Internal wrapper caching __getitem__ of generic types with a fallback to
236252 original function for non-hashable arguments.
237253 """
238- cached = functools .lru_cache ()(func )
239- _cleanups .append (cached .cache_clear )
254+ def decorator (func ):
255+ cached = functools .lru_cache (typed = typed )(func )
256+ _cleanups .append (cached .cache_clear )
240257
241- @functools .wraps (func )
242- def inner (* args , ** kwds ):
243- try :
244- return cached (* args , ** kwds )
245- except TypeError :
246- pass # All real errors (not unhashable args) are raised below.
247- return func (* args , ** kwds )
248- return inner
258+ @functools .wraps (func )
259+ def inner (* args , ** kwds ):
260+ try :
261+ return cached (* args , ** kwds )
262+ except TypeError :
263+ pass # All real errors (not unhashable args) are raised below.
264+ return func (* args , ** kwds )
265+ return inner
249266
267+ if func is not None :
268+ return decorator (func )
269+
270+ return decorator
250271
251272def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
252273 """Evaluate all forward references in the given type t.
@@ -319,6 +340,13 @@ def __subclasscheck__(self, cls):
319340 def __getitem__ (self , parameters ):
320341 return self ._getitem (self , parameters )
321342
343+
344+ class _LiteralSpecialForm (_SpecialForm , _root = True ):
345+ @_tp_cache (typed = True )
346+ def __getitem__ (self , parameters ):
347+ return self ._getitem (self , parameters )
348+
349+
322350@_SpecialForm
323351def Any (self , parameters ):
324352 """Special type indicating an unconstrained type.
@@ -436,7 +464,7 @@ def Optional(self, parameters):
436464 arg = _type_check (parameters , f"{ self } requires a single type." )
437465 return Union [arg , type (None )]
438466
439- @_SpecialForm
467+ @_LiteralSpecialForm
440468def Literal (self , parameters ):
441469 """Special typing form to define literal types (a.k.a. value types).
442470
@@ -460,7 +488,17 @@ def open_helper(file: str, mode: MODE) -> str:
460488 """
461489 # There is no '_type_check' call because arguments to Literal[...] are
462490 # values, not types.
463- return _GenericAlias (self , parameters )
491+ if not isinstance (parameters , tuple ):
492+ parameters = (parameters ,)
493+
494+ parameters = _flatten_literal_params (parameters )
495+
496+ try :
497+ parameters = tuple (p for p , _ in _deduplicate (list (_value_and_type_iter (parameters ))))
498+ except TypeError : # unhashable parameters
499+ pass
500+
501+ return _LiteralGenericAlias (self , parameters )
464502
465503
466504@_SpecialForm
@@ -930,6 +968,21 @@ def __subclasscheck__(self, cls):
930968 return True
931969
932970
971+ def _value_and_type_iter (parameters ):
972+ return ((p , type (p )) for p in parameters )
973+
974+
975+ class _LiteralGenericAlias (_GenericAlias , _root = True ):
976+
977+ def __eq__ (self , other ):
978+ if not isinstance (other , _LiteralGenericAlias ):
979+ return NotImplemented
980+
981+ return set (_value_and_type_iter (self .__args__ )) == set (_value_and_type_iter (other .__args__ ))
982+
983+ def __hash__ (self ):
984+ return hash (tuple (_value_and_type_iter (self .__args__ )))
985+
933986
934987class Generic :
935988 """Abstract base class for generic types.
0 commit comments