Skip to content

Commit d73668f

Browse files
authored
Rust enums validator (#1235)
1 parent 36a1b6c commit d73668f

File tree

18 files changed

+993
-35
lines changed

18 files changed

+993
-35
lines changed

‎python/pydantic_core/core_schema.py‎

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,69 @@ def literal_schema(
11651165
return _dict_not_none(type='literal', expected=expected, ref=ref, metadata=metadata, serialization=serialization)
11661166

11671167

1168+
class EnumSchema(TypedDict, total=False):
1169+
type: Required[Literal['enum']]
1170+
cls: Required[Any]
1171+
members: Required[List[Any]]
1172+
sub_type: Literal['str', 'int', 'float']
1173+
missing: Callable[[Any], Any]
1174+
strict: bool
1175+
ref: str
1176+
metadata: Any
1177+
serialization: SerSchema
1178+
1179+
1180+
def enum_schema(
1181+
cls: Any,
1182+
members: list[Any],
1183+
*,
1184+
sub_type: Literal['str', 'int', 'float'] | None = None,
1185+
missing: Callable[[Any], Any] | None = None,
1186+
strict: bool | None = None,
1187+
ref: str | None = None,
1188+
metadata: Any = None,
1189+
serialization: SerSchema | None = None,
1190+
) -> EnumSchema:
1191+
"""
1192+
Returns a schema that matches an enum value, e.g.:
1193+
1194+
```py
1195+
from enum import Enum
1196+
from pydantic_core import SchemaValidator, core_schema
1197+
1198+
class Color(Enum):
1199+
RED = 1
1200+
GREEN = 2
1201+
BLUE = 3
1202+
1203+
schema = core_schema.enum_schema(Color, list(Color.__members__.values()))
1204+
v = SchemaValidator(schema)
1205+
assert v.validate_python(2) is Color.GREEN
1206+
```
1207+
1208+
Args:
1209+
cls: The enum class
1210+
members: The members of the enum, generally `list(MyEnum.__members__.values())`
1211+
sub_type: The type of the enum, either 'str' or 'int' or None for plain enums
1212+
missing: A function to use when the value is not found in the enum, from `_missing_`
1213+
strict: Whether to use strict mode, defaults to False
1214+
ref: optional unique identifier of the schema, used to reference the schema in other places
1215+
metadata: Any other information you want to include with the schema, not used by pydantic-core
1216+
serialization: Custom serialization schema
1217+
"""
1218+
return _dict_not_none(
1219+
type='enum',
1220+
cls=cls,
1221+
members=members,
1222+
sub_type=sub_type,
1223+
missing=missing,
1224+
strict=strict,
1225+
ref=ref,
1226+
metadata=metadata,
1227+
serialization=serialization,
1228+
)
1229+
1230+
11681231
# must match input/parse_json.rs::JsonType::try_from
11691232
JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict']
11701233

@@ -3670,6 +3733,7 @@ def definition_reference_schema(
36703733
DatetimeSchema,
36713734
TimedeltaSchema,
36723735
LiteralSchema,
3736+
EnumSchema,
36733737
IsInstanceSchema,
36743738
IsSubclassSchema,
36753739
CallableSchema,
@@ -3724,6 +3788,7 @@ def definition_reference_schema(
37243788
'datetime',
37253789
'timedelta',
37263790
'literal',
3791+
'enum',
37273792
'is-instance',
37283793
'is-subclass',
37293794
'callable',

‎src/input/input_abstract.rs‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
6464
None
6565
}
6666

67+
fn input_is_exact_instance(&self, _class: &Bound<'py, PyType>) -> bool {
68+
false
69+
}
70+
6771
fn is_python(&self) -> bool {
6872
false
6973
}

‎src/input/input_python.rs‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
138138
}
139139
}
140140

141+
fn input_is_exact_instance(&self, class: &Bound<'py, PyType>) -> bool {
142+
self.is_exact_instance(class)
143+
}
144+
141145
fn is_python(&self) -> bool {
142146
true
143147
}

‎src/serializers/shared.rs‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ combined_serializer! {
139139
JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer;
140140
Union: super::type_serializers::union::UnionSerializer;
141141
Literal: super::type_serializers::literal::LiteralSerializer;
142+
Enum: super::type_serializers::enum_::EnumSerializer;
142143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
143144
Tuple: super::type_serializers::tuple::TupleSerializer;
144145
}
@@ -246,6 +247,7 @@ impl PyGcTraverse for CombinedSerializer {
246247
CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit),
247248
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
248249
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
250+
CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit),
249251
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
250252
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
251253
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::borrow::Cow;
2+
3+
use crate::build_tools::py_schema_err;
4+
use pyo3::intern;
5+
use pyo3::prelude::*;
6+
use pyo3::types::{PyDict, PyType};
7+
8+
use crate::definitions::DefinitionsBuilder;
9+
use crate::serializers::errors::py_err_se_err;
10+
use crate::serializers::infer::{infer_json_key, infer_serialize, infer_to_python};
11+
use crate::tools::SchemaDict;
12+
13+
use super::float::FloatSerializer;
14+
use super::simple::IntSerializer;
15+
use super::string::StrSerializer;
16+
use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer};
17+
18+
#[derive(Debug, Clone)]
19+
pub struct EnumSerializer {
20+
class: Py<PyType>,
21+
serializer: Option<Box<CombinedSerializer>>,
22+
}
23+
24+
impl BuildSerializer for EnumSerializer {
25+
const EXPECTED_TYPE: &'static str = "enum";
26+
27+
fn build(
28+
schema: &Bound<'_, PyDict>,
29+
config: Option<&Bound<'_, PyDict>>,
30+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
31+
) -> PyResult<CombinedSerializer> {
32+
let sub_type: Option<String> = schema.get_as(intern!(schema.py(), "sub_type"))?;
33+
34+
let serializer = match sub_type.as_deref() {
35+
Some("int") => Some(Box::new(IntSerializer::new().into())),
36+
Some("str") => Some(Box::new(StrSerializer::new().into())),
37+
Some("float") => Some(Box::new(FloatSerializer::new(schema.py(), config)?.into())),
38+
Some(_) => return py_schema_err!("`sub_type` must be one of: 'int', 'str', 'float' or None"),
39+
None => None,
40+
};
41+
Ok(Self {
42+
class: schema.get_as_req(intern!(schema.py(), "cls"))?,
43+
serializer,
44+
}
45+
.into())
46+
}
47+
}
48+
49+
impl_py_gc_traverse!(EnumSerializer { serializer });
50+
51+
impl TypeSerializer for EnumSerializer {
52+
fn to_python(
53+
&self,
54+
value: &Bound<'_, PyAny>,
55+
include: Option<&Bound<'_, PyAny>>,
56+
exclude: Option<&Bound<'_, PyAny>>,
57+
extra: &Extra,
58+
) -> PyResult<PyObject> {
59+
let py = value.py();
60+
if value.is_exact_instance(self.class.bind(py)) {
61+
// if we're in JSON mode, we need to get the value attribute and serialize that
62+
if extra.mode.is_json() {
63+
let dot_value = value.getattr(intern!(py, "value"))?;
64+
match self.serializer {
65+
Some(ref s) => s.to_python(&dot_value, include, exclude, extra),
66+
None => infer_to_python(&dot_value, include, exclude, extra),
67+
}
68+
} else {
69+
// if we're not in JSON mode, we assume the value is safe to return directly
70+
Ok(value.into_py(py))
71+
}
72+
} else {
73+
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
74+
infer_to_python(value, include, exclude, extra)
75+
}
76+
}
77+
78+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
79+
let py = key.py();
80+
if key.is_exact_instance(self.class.bind(py)) {
81+
let dot_value = key.getattr(intern!(py, "value"))?;
82+
let k = match self.serializer {
83+
Some(ref s) => s.json_key(&dot_value, extra),
84+
None => infer_json_key(&dot_value, extra),
85+
}?;
86+
// since dot_value is a local reference, we need to allocate it and returned an
87+
// owned variant of cow.
88+
Ok(Cow::Owned(k.into_owned()))
89+
} else {
90+
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
91+
infer_json_key(key, extra)
92+
}
93+
}
94+
95+
fn serde_serialize<S: serde::ser::Serializer>(
96+
&self,
97+
value: &Bound<'_, PyAny>,
98+
serializer: S,
99+
include: Option<&Bound<'_, PyAny>>,
100+
exclude: Option<&Bound<'_, PyAny>>,
101+
extra: &Extra,
102+
) -> Result<S::Ok, S::Error> {
103+
if value.is_exact_instance(self.class.bind(value.py())) {
104+
let dot_value = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
105+
match self.serializer {
106+
Some(ref s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra),
107+
None => infer_serialize(&dot_value, serializer, include, exclude, extra),
108+
}
109+
} else {
110+
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
111+
infer_serialize(value, serializer, include, exclude, extra)
112+
}
113+
}
114+
115+
fn get_name(&self) -> &str {
116+
Self::EXPECTED_TYPE
117+
}
118+
119+
fn retry_with_lax_check(&self) -> bool {
120+
match self.serializer {
121+
Some(ref s) => s.retry_with_lax_check(),
122+
None => false,
123+
}
124+
}
125+
}

‎src/serializers/type_serializers/float.rs‎

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ pub struct FloatSerializer {
2020
inf_nan_mode: InfNanMode,
2121
}
2222

23+
impl FloatSerializer {
24+
pub fn new(py: Python, config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
25+
let inf_nan_mode = config
26+
.and_then(|c| c.get_as(intern!(py, "ser_json_inf_nan")).transpose())
27+
.transpose()?
28+
.unwrap_or_default();
29+
Ok(Self { inf_nan_mode })
30+
}
31+
}
32+
2333
impl BuildSerializer for FloatSerializer {
2434
const EXPECTED_TYPE: &'static str = "float";
2535

@@ -28,11 +38,7 @@ impl BuildSerializer for FloatSerializer {
2838
config: Option<&Bound<'_, PyDict>>,
2939
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
3040
) -> PyResult<CombinedSerializer> {
31-
let inf_nan_mode = config
32-
.and_then(|c| c.get_as(intern!(schema.py(), "ser_json_inf_nan")).transpose())
33-
.transpose()?
34-
.unwrap_or_default();
35-
Ok(Self { inf_nan_mode }.into())
41+
Self::new(schema.py(), config).map(Into::into)
3642
}
3743
}
3844

‎src/serializers/type_serializers/format.rs‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl TypeSerializer for FormatSerializer {
122122
}
123123
}
124124

125-
fn json_key<'py>(&self, key: &Bound<'py, PyAny>, _extra: &Extra) -> PyResult<Cow<'py, str>> {
125+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, _extra: &Extra) -> PyResult<Cow<'a, str>> {
126126
if self.when_used.should_use_json(key) {
127127
let py_str = self
128128
.call(key)
@@ -198,7 +198,7 @@ impl TypeSerializer for ToStringSerializer {
198198
}
199199
}
200200

201-
fn json_key<'py>(&self, key: &Bound<'py, PyAny>, _extra: &Extra) -> PyResult<Cow<'py, str>> {
201+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, _extra: &Extra) -> PyResult<Cow<'a, str>> {
202202
if self.when_used.should_use_json(key) {
203203
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
204204
} else {

‎src/serializers/type_serializers/mod.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod datetime_etc;
55
pub mod decimal;
66
pub mod definitions;
77
pub mod dict;
8+
pub mod enum_;
89
pub mod float;
910
pub mod format;
1011
pub mod function;

‎src/serializers/type_serializers/simple.rs‎

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ macro_rules! build_simple_serializer {
9090
#[derive(Debug, Clone)]
9191
pub struct $struct_name;
9292

93+
impl $struct_name {
94+
pub fn new() -> Self {
95+
Self {}
96+
}
97+
}
98+
9399
impl BuildSerializer for $struct_name {
94100
const EXPECTED_TYPE: &'static str = $expected_type;
95101

@@ -98,7 +104,7 @@ macro_rules! build_simple_serializer {
98104
_config: Option<&Bound<'_, PyDict>>,
99105
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
100106
) -> PyResult<CombinedSerializer> {
101-
Ok(Self {}.into())
107+
Ok(Self::new().into())
102108
}
103109
}
104110

@@ -172,13 +178,13 @@ macro_rules! build_simple_serializer {
172178
};
173179
}
174180

175-
pub(crate) fn to_str_json_key<'py>(key: &Bound<'py, PyAny>) -> PyResult<Cow<'py, str>> {
181+
pub(crate) fn to_str_json_key<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<Cow<'a, str>> {
176182
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
177183
}
178184

179185
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true);
180186

181-
pub(crate) fn bool_json_key<'py>(key: &Bound<'py, PyAny>) -> PyResult<Cow<'py, str>> {
187+
pub(crate) fn bool_json_key<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<Cow<'a, str>> {
182188
let v = if key.is_truthy().unwrap_or(false) {
183189
"true"
184190
} else {

‎src/serializers/type_serializers/string.rs‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ use super::{
1313
#[derive(Debug, Clone)]
1414
pub struct StrSerializer;
1515

16+
impl StrSerializer {
17+
pub fn new() -> Self {
18+
Self {}
19+
}
20+
}
21+
1622
impl BuildSerializer for StrSerializer {
1723
const EXPECTED_TYPE: &'static str = "str";
1824

@@ -21,7 +27,7 @@ impl BuildSerializer for StrSerializer {
2127
_config: Option<&Bound<'_, PyDict>>,
2228
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
2329
) -> PyResult<CombinedSerializer> {
24-
Ok(Self {}.into())
30+
Ok(Self::new().into())
2531
}
2632
}
2733

0 commit comments

Comments
 (0)