Skip to content

Commit 1b4552c

Browse files
bpo-41428: Implementation for PEP 604 (GH-21515)
See https://www.python.org/dev/peps/pep-0604/ for more information. Co-authored-by: Pablo Galindo <[email protected]>
1 parent fa8c9e7 commit 1b4552c

File tree

13 files changed

+693
-17
lines changed

13 files changed

+693
-17
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef Py_INTERNAL_UNIONOBJECT_H
2+
#define Py_INTERNAL_UNIONOBJECT_H
3+
#ifdef __cplusplus
4+
extern "C" {
5+
#endif
6+
7+
#ifndef Py_BUILD_CORE
8+
# error "this header requires Py_BUILD_CORE define"
9+
#endif
10+
11+
PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args);
12+
PyAPI_DATA(PyTypeObject) _Py_UnionType;
13+
14+
#ifdef __cplusplus
15+
}
16+
#endif
17+
#endif /* !Py_INTERNAL_UNIONOBJECT_H */

‎Lib/test/test_isinstance.py‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import unittest
66
import sys
7+
import typing
78

89

910

@@ -208,6 +209,25 @@ def test_isinstance_abstract(self):
208209
self.assertEqual(False, isinstance(AbstractChild(), Super))
209210
self.assertEqual(False, isinstance(AbstractChild(), Child))
210211

212+
def test_isinstance_with_or_union(self):
213+
self.assertTrue(isinstance(Super(), Super | int))
214+
self.assertFalse(isinstance(None, str | int))
215+
self.assertTrue(isinstance(3, str | int))
216+
self.assertTrue(isinstance("", str | int))
217+
self.assertTrue(isinstance([], typing.List | typing.Tuple))
218+
self.assertTrue(isinstance(2, typing.List | int))
219+
self.assertFalse(isinstance(2, typing.List | typing.Tuple))
220+
self.assertTrue(isinstance(None, int | None))
221+
self.assertFalse(isinstance(3.14, int | str))
222+
with self.assertRaises(TypeError):
223+
isinstance(2, list[int])
224+
with self.assertRaises(TypeError):
225+
isinstance(2, list[int] | int)
226+
with self.assertRaises(TypeError):
227+
isinstance(2, int | str | list[int] | float)
228+
229+
230+
211231
def test_subclass_normal(self):
212232
# normal classes
213233
self.assertEqual(True, issubclass(Super, Super))
@@ -217,6 +237,8 @@ def test_subclass_normal(self):
217237
self.assertEqual(True, issubclass(Child, Child))
218238
self.assertEqual(True, issubclass(Child, Super))
219239
self.assertEqual(False, issubclass(Child, AbstractSuper))
240+
self.assertTrue(issubclass(typing.List, typing.List|typing.Tuple))
241+
self.assertFalse(issubclass(int, typing.List|typing.Tuple))
220242

221243
def test_subclass_abstract(self):
222244
# abstract classes
@@ -251,6 +273,16 @@ def test_isinstance_recursion_limit(self):
251273
# blown
252274
self.assertRaises(RecursionError, blowstack, isinstance, '', str)
253275

276+
def test_subclass_with_union(self):
277+
self.assertTrue(issubclass(int, int | float | int))
278+
self.assertTrue(issubclass(str, str | Child | str))
279+
self.assertFalse(issubclass(dict, float|str))
280+
self.assertFalse(issubclass(object, float|str))
281+
with self.assertRaises(TypeError):
282+
issubclass(2, Child | Super)
283+
with self.assertRaises(TypeError):
284+
issubclass(int, list[int] | Child)
285+
254286
def test_issubclass_refcount_handling(self):
255287
# bpo-39382: abstract_issubclass() didn't hold item reference while
256288
# peeking in the bases tuple, in the single inheritance case.

‎Lib/test/test_types.py‎

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22

33
from test.support import run_with_locale
44
import collections.abc
5+
from collections import namedtuple
56
import inspect
67
import pickle
78
import locale
89
import sys
910
import types
1011
import unittest.mock
1112
import weakref
13+
import typing
14+
15+
class Example:
16+
pass
17+
18+
class Forward: ...
1219

1320
class TypesTests(unittest.TestCase):
1421

@@ -598,6 +605,113 @@ def test_method_descriptor_types(self):
598605
self.assertIsInstance(int.from_bytes, types.BuiltinMethodType)
599606
self.assertIsInstance(int.__new__, types.BuiltinMethodType)
600607

608+
def test_or_types_operator(self):
609+
self.assertEqual(int | str, typing.Union[int, str])
610+
self.assertNotEqual(int | list, typing.Union[int, str])
611+
self.assertEqual(str | int, typing.Union[int, str])
612+
self.assertEqual(int | None, typing.Union[int, None])
613+
self.assertEqual(None | int, typing.Union[int, None])
614+
self.assertEqual(int | str | list, typing.Union[int, str, list])
615+
self.assertEqual(int | (str | list), typing.Union[int, str, list])
616+
self.assertEqual(str | (int | list), typing.Union[int, str, list])
617+
self.assertEqual(typing.List | typing.Tuple, typing.Union[typing.List, typing.Tuple])
618+
self.assertEqual(typing.List[int] | typing.Tuple[int], typing.Union[typing.List[int], typing.Tuple[int]])
619+
self.assertEqual(typing.List[int] | None, typing.Union[typing.List[int], None])
620+
self.assertEqual(None | typing.List[int], typing.Union[None, typing.List[int]])
621+
self.assertEqual(str | float | int | complex | int, (int | str) | (float | complex))
622+
self.assertEqual(typing.Union[str, int, typing.List[int]], str | int | typing.List[int])
623+
self.assertEqual(int | int, int)
624+
self.assertEqual(
625+
BaseException |
626+
bool |
627+
bytes |
628+
complex |
629+
float |
630+
int |
631+
list |
632+
map |
633+
set,
634+
typing.Union[
635+
BaseException,
636+
bool,
637+
bytes,
638+
complex,
639+
float,
640+
int,
641+
list,
642+
map,
643+
set,
644+
])
645+
with self.assertRaises(TypeError):
646+
int | 3
647+
with self.assertRaises(TypeError):
648+
3 | int
649+
with self.assertRaises(TypeError):
650+
Example() | int
651+
with self.assertRaises(TypeError):
652+
(int | str) < typing.Union[str, int]
653+
with self.assertRaises(TypeError):
654+
(int | str) < (int | bool)
655+
with self.assertRaises(TypeError):
656+
(int | str) <= (int | str)
657+
with self.assertRaises(TypeError):
658+
# Check that we don't crash if typing.Union does not have a tuple in __args__
659+
x = typing.Union[str, int]
660+
x.__args__ = [str, int]
661+
(int | str ) == x
662+
663+
def test_or_type_operator_with_TypeVar(self):
664+
TV = typing.TypeVar('T')
665+
assert TV | str == typing.Union[TV, str]
666+
assert str | TV == typing.Union[str, TV]
667+
668+
def test_or_type_operator_with_forward(self):
669+
T = typing.TypeVar('T')
670+
ForwardAfter = T | 'Forward'
671+
ForwardBefore = 'Forward' | T
672+
def forward_after(x: ForwardAfter[int]) -> None: ...
673+
def forward_before(x: ForwardBefore[int]) -> None: ...
674+
assert typing.get_args(typing.get_type_hints(forward_after)['x']) == (int, Forward)
675+
assert typing.get_args(typing.get_type_hints(forward_before)['x']) == (int, Forward)
676+
677+
def test_or_type_operator_with_Protocol(self):
678+
class Proto(typing.Protocol):
679+
def meth(self) -> int:
680+
...
681+
assert Proto | str == typing.Union[Proto, str]
682+
683+
def test_or_type_operator_with_Alias(self):
684+
assert list | str == typing.Union[list, str]
685+
assert typing.List | str == typing.Union[typing.List, str]
686+
687+
def test_or_type_operator_with_NamedTuple(self):
688+
NT=namedtuple('A', ['B', 'C', 'D'])
689+
assert NT | str == typing.Union[NT,str]
690+
691+
def test_or_type_operator_with_TypedDict(self):
692+
class Point2D(typing.TypedDict):
693+
x: int
694+
y: int
695+
label: str
696+
assert Point2D | str == typing.Union[Point2D, str]
697+
698+
def test_or_type_operator_with_NewType(self):
699+
UserId = typing.NewType('UserId', int)
700+
assert UserId | str == typing.Union[UserId, str]
701+
702+
def test_or_type_operator_with_IO(self):
703+
assert typing.IO | str == typing.Union[typing.IO, str]
704+
705+
def test_or_type_operator_with_SpecialForm(self):
706+
assert typing.Any | str == typing.Union[typing.Any, str]
707+
assert typing.NoReturn | str == typing.Union[typing.NoReturn, str]
708+
assert typing.Optional[int] | str == typing.Union[typing.Optional[int], str]
709+
assert typing.Optional[int] | str == typing.Union[int, str, None]
710+
assert typing.Union[int, bool] | str == typing.Union[int, bool, str]
711+
712+
def test_or_type_repr(self):
713+
assert repr(int | None) == "int | None"
714+
assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]"
601715

602716
class MappingProxyTests(unittest.TestCase):
603717
mappingproxy = types.MappingProxyType

‎Lib/test/test_typing.py‎

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,6 @@ def test_subclass_error(self):
244244
issubclass(int, Union)
245245
with self.assertRaises(TypeError):
246246
issubclass(Union, int)
247-
with self.assertRaises(TypeError):
248-
issubclass(int, Union[int, str])
249247
with self.assertRaises(TypeError):
250248
issubclass(Union[int, str], int)
251249

@@ -347,10 +345,6 @@ def test_empty(self):
347345
with self.assertRaises(TypeError):
348346
Union[()]
349347

350-
def test_union_instance_type_error(self):
351-
with self.assertRaises(TypeError):
352-
isinstance(42, Union[int, str])
353-
354348
def test_no_eval_union(self):
355349
u = Union[int, str]
356350
def f(x: u): ...

‎Lib/types.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def wrapped(*args, **kwargs):
294294

295295

296296
GenericAlias = type(list[int])
297+
Union = type(int | str)
297298

298299

299300
__all__ = [n for n in globals() if n[:1] != '_']

‎Lib/typing.py‎

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@
117117
# namespace, but excluded from __all__ because they might stomp on
118118
# legitimate imports of those modules.
119119

120-
121120
def _type_check(arg, msg, is_argument=True):
122121
"""Check that the argument is a type, and return it (internal helper).
123122
@@ -145,7 +144,7 @@ def _type_check(arg, msg, is_argument=True):
145144
return arg
146145
if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
147146
raise TypeError(f"Plain {arg} is not valid as type argument")
148-
if isinstance(arg, (type, TypeVar, ForwardRef)):
147+
if isinstance(arg, (type, TypeVar, ForwardRef, types.Union)):
149148
return arg
150149
if not callable(arg):
151150
raise TypeError(f"{msg} Got {arg!r:.100}.")
@@ -205,7 +204,7 @@ def _remove_dups_flatten(parameters):
205204
# Flatten out Union[Union[...], ...].
206205
params = []
207206
for p in parameters:
208-
if isinstance(p, _UnionGenericAlias):
207+
if isinstance(p, (_UnionGenericAlias, types.Union)):
209208
params.extend(p.__args__)
210209
elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union:
211210
params.extend(p[1:])
@@ -586,6 +585,12 @@ def __init__(self, name, *constraints, bound=None,
586585
if def_mod != 'typing':
587586
self.__module__ = def_mod
588587

588+
def __or__(self, right):
589+
return Union[self, right]
590+
591+
def __ror__(self, right):
592+
return Union[self, right]
593+
589594
def __repr__(self):
590595
if self.__covariant__:
591596
prefix = '+'
@@ -693,6 +698,12 @@ def __eq__(self, other):
693698
def __hash__(self):
694699
return hash((self.__origin__, self.__args__))
695700

701+
def __or__(self, right):
702+
return Union[self, right]
703+
704+
def __ror__(self, right):
705+
return Union[self, right]
706+
696707
@_tp_cache
697708
def __getitem__(self, params):
698709
if self.__origin__ in (Generic, Protocol):
@@ -792,6 +803,11 @@ def __subclasscheck__(self, cls):
792803
def __reduce__(self):
793804
return self._name
794805

806+
def __or__(self, right):
807+
return Union[self, right]
808+
809+
def __ror__(self, right):
810+
return Union[self, right]
795811

796812
class _CallableGenericAlias(_GenericAlias, _root=True):
797813
def __repr__(self):
@@ -878,6 +894,15 @@ def __repr__(self):
878894
return f'typing.Optional[{_type_repr(args[0])}]'
879895
return super().__repr__()
880896

897+
def __instancecheck__(self, obj):
898+
return self.__subclasscheck__(type(obj))
899+
900+
def __subclasscheck__(self, cls):
901+
for arg in self.__args__:
902+
if issubclass(cls, arg):
903+
return True
904+
905+
881906

882907
class Generic:
883908
"""Abstract base class for generic types.

‎Makefile.pre.in‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ OBJECT_OBJS= \
432432
Objects/typeobject.o \
433433
Objects/unicodeobject.o \
434434
Objects/unicodectype.o \
435+
Objects/unionobject.o \
435436
Objects/weakrefobject.o
436437

437438
##########################################################################
@@ -1128,6 +1129,7 @@ PYTHON_HEADERS= \
11281129
$(srcdir)/Include/internal/pycore_sysmodule.h \
11291130
$(srcdir)/Include/internal/pycore_traceback.h \
11301131
$(srcdir)/Include/internal/pycore_tuple.h \
1132+
$(srcdir)/Include/internal/pycore_unionobject.h \
11311133
$(srcdir)/Include/internal/pycore_warnings.h \
11321134
$(DTRACE_HEADERS)
11331135

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement PEP 604. This supports (int | str) etc. in place of Union[str, int].

‎Objects/abstract.c‎

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* Abstract Object Interface (many thanks to Jim Fulton) */
22

33
#include "Python.h"
4+
#include "pycore_unionobject.h" // _Py_UnionType && _Py_Union()
45
#include "pycore_abstract.h" // _PyIndex_Check()
56
#include "pycore_ceval.h" // _Py_EnterRecursiveCall()
67
#include "pycore_pyerrors.h" // _PyErr_Occurred()
@@ -839,7 +840,6 @@ binary_op(PyObject *v, PyObject *w, const int op_slot, const char *op_name)
839840
Py_TYPE(w)->tp_name);
840841
return NULL;
841842
}
842-
843843
return binop_type_error(v, w, op_name);
844844
}
845845
return result;
@@ -2412,7 +2412,6 @@ object_isinstance(PyObject *inst, PyObject *cls)
24122412
PyObject *icls;
24132413
int retval;
24142414
_Py_IDENTIFIER(__class__);
2415-
24162415
if (PyType_Check(cls)) {
24172416
retval = PyObject_TypeCheck(inst, (PyTypeObject *)cls);
24182417
if (retval == 0) {
@@ -2432,7 +2431,7 @@ object_isinstance(PyObject *inst, PyObject *cls)
24322431
}
24332432
else {
24342433
if (!check_class(cls,
2435-
"isinstance() arg 2 must be a type or tuple of types"))
2434+
"isinstance() arg 2 must be a type, a tuple of types or a union"))
24362435
return -1;
24372436
retval = _PyObject_LookupAttrId(inst, &PyId___class__, &icls);
24382437
if (icls != NULL) {
@@ -2525,10 +2524,14 @@ recursive_issubclass(PyObject *derived, PyObject *cls)
25252524
if (!check_class(derived,
25262525
"issubclass() arg 1 must be a class"))
25272526
return -1;
2528-
if (!check_class(cls,
2529-
"issubclass() arg 2 must be a class"
2530-
" or tuple of classes"))
2527+
2528+
PyTypeObject *type = Py_TYPE(cls);
2529+
int is_union = (PyType_Check(type) && type == &_Py_UnionType);
2530+
if (!is_union && !check_class(cls,
2531+
"issubclass() arg 2 must be a class,"
2532+
" a tuple of classes, or a union.")) {
25312533
return -1;
2534+
}
25322535

25332536
return abstract_issubclass(derived, cls);
25342537
}

0 commit comments

Comments
 (0)