@@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
200200 f" actual { alen } , expected { elen } " )
201201
202202
203+ def _deduplicate (params ):
204+ # Weed out strict duplicates, preserving the first of each occurrence.
205+ all_params = set (params )
206+ if len (all_params ) < len (params ):
207+ new_params = []
208+ for t in params :
209+ if t in all_params :
210+ new_params .append (t )
211+ all_params .remove (t )
212+ params = new_params
213+ assert not all_params , all_params
214+ return params
215+
216+
203217def _remove_dups_flatten (parameters ):
204218 """An internal helper for Union creation and substitution: flatten Unions
205219 among parameters, then remove duplicates.
@@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
213227 params .extend (p [1 :])
214228 else :
215229 params .append (p )
216- # Weed out strict duplicates, preserving the first of each occurrence.
217- all_params = set (params )
218- if len (all_params ) < len (params ):
219- new_params = []
220- for t in params :
221- if t in all_params :
222- new_params .append (t )
223- all_params .remove (t )
224- params = new_params
225- assert not all_params , all_params
230+
231+ return tuple (_deduplicate (params ))
232+
233+
234+ def _flatten_literal_params (parameters ):
235+ """An internal helper for Literal creation: flatten Literals among parameters"""
236+ params = []
237+ for p in parameters :
238+ if isinstance (p , _LiteralGenericAlias ):
239+ params .extend (p .__args__ )
240+ else :
241+ params .append (p )
226242 return tuple (params )
227243
228244
229245_cleanups = []
230246
231247
232- def _tp_cache (func ):
248+ def _tp_cache (func = None , / , * , typed = False ):
233249 """Internal wrapper caching __getitem__ of generic types with a fallback to
234250 original function for non-hashable arguments.
235251 """
236- cached = functools .lru_cache ()(func )
237- _cleanups .append (cached .cache_clear )
252+ def decorator (func ):
253+ cached = functools .lru_cache (typed = typed )(func )
254+ _cleanups .append (cached .cache_clear )
238255
239- @functools .wraps (func )
240- def inner (* args , ** kwds ):
241- try :
242- return cached (* args , ** kwds )
243- except TypeError :
244- pass # All real errors (not unhashable args) are raised below.
245- return func (* args , ** kwds )
246- return inner
256+ @functools .wraps (func )
257+ def inner (* args , ** kwds ):
258+ try :
259+ return cached (* args , ** kwds )
260+ except TypeError :
261+ pass # All real errors (not unhashable args) are raised below.
262+ return func (* args , ** kwds )
263+ return inner
264+
265+ if func is not None :
266+ return decorator (func )
247267
268+ return decorator
248269
249270def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
250271 """Evaluate all forward references in the given type t.
@@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
317338 def __getitem__ (self , parameters ):
318339 return self ._getitem (self , parameters )
319340
341+
342+ class _LiteralSpecialForm (_SpecialForm , _root = True ):
343+ @_tp_cache (typed = True )
344+ def __getitem__ (self , parameters ):
345+ return self ._getitem (self , parameters )
346+
347+
320348@_SpecialForm
321349def Any (self , parameters ):
322350 """Special type indicating an unconstrained type.
@@ -434,7 +462,7 @@ def Optional(self, parameters):
434462 arg = _type_check (parameters , f"{ self } requires a single type." )
435463 return Union [arg , type (None )]
436464
437- @_SpecialForm
465+ @_LiteralSpecialForm
438466def Literal (self , parameters ):
439467 """Special typing form to define literal types (a.k.a. value types).
440468
@@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
458486 """
459487 # There is no '_type_check' call because arguments to Literal[...] are
460488 # values, not types.
461- return _GenericAlias (self , parameters )
489+ if not isinstance (parameters , tuple ):
490+ parameters = (parameters ,)
491+
492+ parameters = _flatten_literal_params (parameters )
493+
494+ try :
495+ parameters = tuple (p for p , _ in _deduplicate (list (_value_and_type_iter (parameters ))))
496+ except TypeError : # unhashable parameters
497+ pass
498+
499+ return _LiteralGenericAlias (self , parameters )
462500
463501
464502class ForwardRef (_Final , _root = True ):
@@ -881,6 +919,22 @@ def __repr__(self):
881919 return super ().__repr__ ()
882920
883921
922+ def _value_and_type_iter (parameters ):
923+ return ((p , type (p )) for p in parameters )
924+
925+
926+ class _LiteralGenericAlias (_GenericAlias , _root = True ):
927+
928+ def __eq__ (self , other ):
929+ if not isinstance (other , _LiteralGenericAlias ):
930+ return NotImplemented
931+
932+ return set (_value_and_type_iter (self .__args__ )) == set (_value_and_type_iter (other .__args__ ))
933+
934+ def __hash__ (self ):
935+ return hash (tuple (_value_and_type_iter (self .__args__ )))
936+
937+
884938class Generic :
885939 """Abstract base class for generic types.
886940
0 commit comments