Skip to content

Commit b51105a

Browse files
authored
Add SchemaSerializer.__reduce__ method to enable pickle serialization (#1006)
Signed-off-by: Edward Oakes <[email protected]>
1 parent 8e66bd9 commit b51105a

File tree

7 files changed

+158
-28
lines changed

7 files changed

+158
-28
lines changed

‎src/serializers/mod.rs‎

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ mod ob_type;
2626
mod shared;
2727
mod type_serializers;
2828

29-
#[pyclass(module = "pydantic_core._pydantic_core")]
29+
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
3030
#[derive(Debug)]
3131
pub struct SchemaSerializer {
3232
serializer: CombinedSerializer,
3333
definitions: Definitions<CombinedSerializer>,
3434
expected_json_size: AtomicUsize,
3535
config: SerializationConfig,
36+
// References to the Python schema and config objects are saved to enable
37+
// reconstructing the object for pickle support (see `__reduce__`).
38+
py_schema: Py<PyDict>,
39+
py_config: Option<Py<PyDict>>,
3640
}
3741

3842
impl SchemaSerializer {
@@ -71,15 +75,19 @@ impl SchemaSerializer {
7175
#[pymethods]
7276
impl SchemaSerializer {
7377
#[new]
74-
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
78+
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
7579
let mut definitions_builder = DefinitionsBuilder::new();
76-
7780
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
7881
Ok(Self {
7982
serializer,
8083
definitions: definitions_builder.finish()?,
8184
expected_json_size: AtomicUsize::new(1024),
8285
config: SerializationConfig::from_config(config)?,
86+
py_schema: schema.into_py(py),
87+
py_config: match config {
88+
Some(c) if !c.is_empty() => Some(c.into_py(py)),
89+
_ => None,
90+
},
8391
})
8492
}
8593

@@ -174,6 +182,14 @@ impl SchemaSerializer {
174182
Ok(py_bytes.into())
175183
}
176184

185+
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
186+
// Enables support for `pickle` serialization.
187+
let py = slf.py();
188+
let cls = slf.get_type().into();
189+
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
190+
Ok((cls, init_args))
191+
}
192+
177193
pub fn __repr__(&self) -> String {
178194
format!(
179195
"SchemaSerializer(serializer={:#?}, definitions={:#?})",
@@ -182,6 +198,10 @@ impl SchemaSerializer {
182198
}
183199

184200
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
201+
visit.call(&self.py_schema)?;
202+
if let Some(ref py_config) = self.py_config {
203+
visit.call(py_config)?;
204+
}
185205
self.serializer.py_gc_traverse(&visit)?;
186206
self.definitions.py_gc_traverse(&visit)?;
187207
Ok(())

‎src/validators/mod.rs‎

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,15 @@ impl PySome {
9797
}
9898
}
9999

100-
#[pyclass(module = "pydantic_core._pydantic_core")]
100+
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
101101
#[derive(Debug)]
102102
pub struct SchemaValidator {
103103
validator: CombinedValidator,
104104
definitions: Definitions<CombinedValidator>,
105-
schema: PyObject,
105+
// References to the Python schema and config objects are saved to enable
106+
// reconstructing the object for cloudpickle support (see `__reduce__`).
107+
py_schema: Py<PyAny>,
108+
py_config: Option<Py<PyDict>>,
106109
#[pyo3(get)]
107110
title: PyObject,
108111
hide_input_in_errors: bool,
@@ -121,6 +124,11 @@ impl SchemaValidator {
121124
for val in definitions.values() {
122125
val.get().unwrap().complete()?;
123126
}
127+
let py_schema = schema.into_py(py);
128+
let py_config = match config {
129+
Some(c) if !c.is_empty() => Some(c.into_py(py)),
130+
_ => None,
131+
};
124132
let config_title = match config {
125133
Some(c) => c.get_item("title"),
126134
None => None,
@@ -134,18 +142,20 @@ impl SchemaValidator {
134142
Ok(Self {
135143
validator,
136144
definitions,
137-
schema: schema.into_py(py),
145+
py_schema,
146+
py_config,
138147
title,
139148
hide_input_in_errors,
140149
validation_error_cause,
141150
})
142151
}
143152

144-
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<PyObject> {
153+
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
154+
// Enables support for `pickle` serialization.
145155
let py = slf.py();
146-
let args = (slf.try_borrow()?.schema.to_object(py),);
147-
let cls = slf.getattr("__class__")?;
148-
Ok((cls, args).into_py(py))
156+
let cls = slf.get_type().into();
157+
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
158+
Ok((cls, init_args))
149159
}
150160

151161
#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))]
@@ -307,7 +317,10 @@ impl SchemaValidator {
307317

308318
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
309319
self.validator.py_gc_traverse(&visit)?;
310-
visit.call(&self.schema)?;
320+
visit.call(&self.py_schema)?;
321+
if let Some(ref py_config) = self.py_config {
322+
visit.call(py_config)?;
323+
}
311324
Ok(())
312325
}
313326
}
@@ -396,7 +409,8 @@ impl<'py> SelfValidator<'py> {
396409
Ok(SchemaValidator {
397410
validator,
398411
definitions,
399-
schema: py.None(),
412+
py_schema: py.None(),
413+
py_config: None,
400414
title: "Self Schema".into_py(py),
401415
hide_input_in_errors: false,
402416
validation_error_cause: false,

‎tests/serializers/test_pickling.py‎

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
import pickle
3+
from datetime import timedelta
4+
5+
import pytest
6+
7+
from pydantic_core import core_schema
8+
from pydantic_core._pydantic_core import SchemaSerializer
9+
10+
11+
def repr_function(value, _info):
12+
return repr(value)
13+
14+
15+
def test_basic_schema_serializer():
16+
s = SchemaSerializer(core_schema.dict_schema())
17+
s = pickle.loads(pickle.dumps(s))
18+
assert s.to_python({'a': 1, b'b': 2, 33: 3}) == {'a': 1, b'b': 2, 33: 3}
19+
assert s.to_python({'a': 1, b'b': 2, 33: 3, True: 4}, mode='json') == {'a': 1, 'b': 2, '33': 3, 'true': 4}
20+
assert s.to_json({'a': 1, b'b': 2, 33: 3, True: 4}) == b'{"a":1,"b":2,"33":3,"true":4}'
21+
22+
assert s.to_python({(1, 2): 3}) == {(1, 2): 3}
23+
assert s.to_python({(1, 2): 3}, mode='json') == {'1,2': 3}
24+
assert s.to_json({(1, 2): 3}) == b'{"1,2":3}'
25+
26+
27+
@pytest.mark.parametrize(
28+
'value,expected_python,expected_json',
29+
[(None, 'None', b'"None"'), (1, '1', b'"1"'), ([1, 2, 3], '[1, 2, 3]', b'"[1, 2, 3]"')],
30+
)
31+
def test_schema_serializer_capturing_function(value, expected_python, expected_json):
32+
# Test a SchemaSerializer that captures a function.
33+
s = SchemaSerializer(
34+
core_schema.any_schema(
35+
serialization=core_schema.plain_serializer_function_ser_schema(repr_function, info_arg=True)
36+
)
37+
)
38+
s = pickle.loads(pickle.dumps(s))
39+
assert s.to_python(value) == expected_python
40+
assert s.to_json(value) == expected_json
41+
assert s.to_python(value, mode='json') == json.loads(expected_json)
42+
43+
44+
def test_schema_serializer_containing_config():
45+
s = SchemaSerializer(core_schema.timedelta_schema(), config={'ser_json_timedelta': 'float'})
46+
s = pickle.loads(pickle.dumps(s))
47+
48+
assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000)
49+
assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5
50+
assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5'

‎tests/test.rs‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ mod tests {
4646
]
4747
}"#;
4848
let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap();
49-
SchemaSerializer::py_new(schema, None).unwrap();
49+
SchemaSerializer::py_new(py, schema, None).unwrap();
5050
});
5151
}
5252

@@ -77,7 +77,7 @@ a = A()
7777
py.run(code, None, Some(locals)).unwrap();
7878
let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap();
7979
let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap();
80-
let serialized: Vec<u8> = SchemaSerializer::py_new(schema, None)
80+
let serialized: Vec<u8> = SchemaSerializer::py_new(py, schema, None)
8181
.unwrap()
8282
.to_json(py, a, None, None, None, true, false, false, false, false, true, None)
8383
.unwrap()

‎tests/test_garbage_collection.py‎

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ class BaseModel:
2727
__schema__: SchemaSerializer
2828

2929
def __init_subclass__(cls) -> None:
30-
cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER))
30+
cls.__schema__ = SchemaSerializer(
31+
core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), config={'ser_json_timedelta': 'float'}
32+
)
3133

3234
cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()
3335

@@ -56,7 +58,10 @@ class BaseModel:
5658
__validator__: SchemaValidator
5759

5860
def __init_subclass__(cls) -> None:
59-
cls.__validator__ = SchemaValidator(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER))
61+
cls.__validator__ = SchemaValidator(
62+
core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER),
63+
config=core_schema.CoreConfig(extra_fields_behavior='allow'),
64+
)
6065

6166
cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()
6267

‎tests/validators/test_datetime.py‎

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import json
3-
import pickle
43
import platform
54
import re
65
from datetime import date, datetime, time, timedelta, timezone, tzinfo
@@ -480,17 +479,6 @@ def test_tz_constraint_wrong():
480479
validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong'))
481480

482481

483-
def test_tz_pickle() -> None:
484-
"""
485-
https://github.com/pydantic/pydantic-core/issues/589
486-
"""
487-
v = SchemaValidator(core_schema.datetime_schema())
488-
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
489-
validated = v.validate_python('2022-06-08T12:13:14-12:15')
490-
assert validated == original
491-
assert pickle.loads(pickle.dumps(validated)) == validated == original
492-
493-
494482
def test_tz_hash() -> None:
495483
v = SchemaValidator(core_schema.datetime_schema())
496484
lookup: Dict[datetime, str] = {}

‎tests/validators/test_pickling.py‎

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pickle
2+
import re
3+
from datetime import datetime, timedelta, timezone
4+
5+
import pytest
6+
7+
from pydantic_core import core_schema, validate_core_schema
8+
from pydantic_core._pydantic_core import SchemaValidator, ValidationError
9+
10+
11+
def test_basic_schema_validator():
12+
v = SchemaValidator(
13+
validate_core_schema(
14+
{'type': 'dict', 'strict': True, 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}}
15+
)
16+
)
17+
v = pickle.loads(pickle.dumps(v))
18+
assert v.validate_python({'1': 2, '3': 4}) == {1: 2, 3: 4}
19+
assert v.validate_python({}) == {}
20+
with pytest.raises(ValidationError, match=re.escape('[type=dict_type, input_value=[], input_type=list]')):
21+
v.validate_python([])
22+
23+
24+
def test_schema_validator_containing_config():
25+
"""
26+
Verify that the config object is not lost during (de)serialization.
27+
"""
28+
v = SchemaValidator(
29+
core_schema.model_fields_schema({'f': core_schema.model_field(core_schema.str_schema())}),
30+
config=core_schema.CoreConfig(extra_fields_behavior='allow'),
31+
)
32+
v = pickle.loads(pickle.dumps(v))
33+
34+
m, model_extra, fields_set = v.validate_python({'f': 'x', 'extra_field': '123'})
35+
assert m == {'f': 'x'}
36+
# If the config was lost during (de)serialization, the below checks would fail as
37+
# the default behavior is to ignore extra fields.
38+
assert model_extra == {'extra_field': '123'}
39+
assert fields_set == {'f', 'extra_field'}
40+
41+
v.validate_assignment(m, 'f', 'y')
42+
assert m == {'f': 'y'}
43+
44+
45+
def test_schema_validator_tz_pickle() -> None:
46+
"""
47+
https://github.com/pydantic/pydantic-core/issues/589
48+
"""
49+
v = SchemaValidator(core_schema.datetime_schema())
50+
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
51+
validated = v.validate_python('2022-06-08T12:13:14-12:15')
52+
assert validated == original
53+
assert pickle.loads(pickle.dumps(validated)) == validated == original

0 commit comments

Comments
 (0)