Skip to content

Commit 6740e73

Browse files
committed
implement suggestions
1 parent 0797e59 commit 6740e73

File tree

4 files changed

+52
-140
lines changed

4 files changed

+52
-140
lines changed

‎src/serializers/fields.rs‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use super::extra::Extra;
1515
use super::filter::SchemaFilter;
1616
use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer};
1717
use super::shared::PydanticSerializer;
18-
use super::shared::{CombinedSerializer, DictIterator, TypeSerializer};
18+
use super::shared::{CombinedSerializer, TypeSerializer};
1919

2020
/// representation of a field for serialization
2121
#[derive(Debug, Clone)]
@@ -321,7 +321,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
321321
return infer_to_python(value, include, exclude, &td_extra);
322322
};
323323

324-
let output_dict = self.main_to_python(py, DictIterator::new(main_dict), include, exclude, td_extra)?;
324+
let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;
325325

326326
// this is used to include `__pydantic_extra__` in serialization on models
327327
if let Some(extra_dict) = extra_dict {
@@ -373,7 +373,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
373373
// NOTE! As above, we maintain the order of the input dict assuming that's right
374374
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
375375
let mut map = self.main_serde_serialize(
376-
DictIterator::new(main_dict),
376+
main_dict.iter().map(Ok),
377377
expected_len,
378378
serializer,
379379
include,

‎src/serializers/infer.rs‎

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
1919
use super::extra::{Extra, SerMode};
2020
use super::filter::{AnyFilter, SchemaFilter};
2121
use super::ob_type::ObType;
22-
use super::shared::{AnyDataclassIterator, DictIterator, PydanticSerializer, TypeSerializer};
22+
use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer};
2323
use super::SchemaSerializer;
2424

2525
pub(crate) fn infer_to_python(
@@ -151,7 +151,10 @@ pub(crate) fn infer_to_python_known(
151151
PyList::new(py, elements).into_py(py)
152152
}
153153
ObType::Dict => {
154-
serialize_pairs_python_mode_json(py, DictIterator::new(value.downcast()?), include, exclude, extra)?
154+
let dict: &PyDict = value.downcast()?;
155+
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, |k| {
156+
Ok(PyString::new(py, &infer_json_key(k, extra)?))
157+
})?
155158
}
156159
ObType::Datetime => {
157160
let py_dt: &PyDateTime = value.downcast()?;
@@ -190,7 +193,9 @@ pub(crate) fn infer_to_python_known(
190193
}
191194
ObType::PydanticSerializable => serialize_with_serializer()?,
192195
ObType::Dataclass => {
193-
serialize_pairs_python_mode_json(py, AnyDataclassIterator::new(value)?, include, exclude, extra)?
196+
serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, |k| {
197+
Ok(PyString::new(py, &infer_json_key(k, extra)?))
198+
})?
194199
}
195200
ObType::Enum => {
196201
let v = value.getattr(intern!(py, "value"))?;
@@ -241,11 +246,12 @@ pub(crate) fn infer_to_python_known(
241246
let elements = serialize_seq!(PyFrozenSet);
242247
PyFrozenSet::new(py, &elements)?.into_py(py)
243248
}
244-
ObType::Dict => serialize_pairs_python(py, DictIterator::new(value.downcast()?), include, exclude, extra)?,
245-
ObType::PydanticSerializable => serialize_with_serializer()?,
246-
ObType::Dataclass => {
247-
serialize_pairs_python(py, AnyDataclassIterator::new(value)?, include, exclude, extra)?
249+
ObType::Dict => {
250+
let dict: &PyDict = value.downcast()?;
251+
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, Ok)?
248252
}
253+
ObType::PydanticSerializable => serialize_with_serializer()?,
254+
ObType::Dataclass => serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, Ok)?,
249255
ObType::Generator => {
250256
let iter = super::type_serializers::generator::SerializationIterator::new(
251257
value.downcast()?,
@@ -404,7 +410,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
404410
}
405411
ObType::Dict => {
406412
let dict = value.downcast::<PyDict>().map_err(py_err_se_err)?;
407-
serialize_pairs_json(DictIterator::new(dict), serializer, include, exclude, extra)
413+
serialize_pairs_json(dict.iter().map(Ok), dict.len(), serializer, include, exclude, extra)
408414
}
409415
ObType::List => serialize_seq_filter!(PyList),
410416
ObType::Tuple => serialize_seq_filter!(PyTuple),
@@ -463,13 +469,10 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
463469
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
464470
pydantic_serializer.serialize(serializer)
465471
}
466-
ObType::Dataclass => serialize_pairs_json(
467-
AnyDataclassIterator::new(value).map_err(py_err_se_err)?,
468-
serializer,
469-
include,
470-
exclude,
471-
extra,
472-
),
472+
ObType::Dataclass => {
473+
let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?;
474+
serialize_pairs_json(pairs_iter, fields_dict.len(), serializer, include, exclude, extra)
475+
}
473476
ObType::Uuid => {
474477
let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?;
475478
let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?;
@@ -645,6 +648,7 @@ fn serialize_pairs_python<'py>(
645648
include: Option<&PyAny>,
646649
exclude: Option<&PyAny>,
647650
extra: &Extra,
651+
key_transform: impl Fn(&'py PyAny) -> PyResult<&'py PyAny>,
648652
) -> PyResult<PyObject> {
649653
let new_dict = PyDict::new(py);
650654
let filter = AnyFilter::new();
@@ -653,29 +657,7 @@ fn serialize_pairs_python<'py>(
653657
let (k, v) = result?;
654658
let op_next = filter.key_filter(k, include, exclude)?;
655659
if let Some((next_include, next_exclude)) = op_next {
656-
let v = infer_to_python(v, next_include, next_exclude, extra)?;
657-
new_dict.set_item(k, v)?;
658-
}
659-
}
660-
Ok(new_dict.into_py(py))
661-
}
662-
663-
fn serialize_pairs_python_mode_json<'py>(
664-
py: Python,
665-
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
666-
include: Option<&PyAny>,
667-
exclude: Option<&PyAny>,
668-
extra: &Extra,
669-
) -> PyResult<PyObject> {
670-
let new_dict = PyDict::new(py);
671-
let filter = AnyFilter::new();
672-
673-
for result in pairs_iter {
674-
let (k, v) = result?;
675-
let op_next = filter.key_filter(k, include, exclude)?;
676-
if let Some((next_include, next_exclude)) = op_next {
677-
let k_str = infer_json_key(k, extra)?;
678-
let k = PyString::new(py, &k_str);
660+
let k = key_transform(k)?;
679661
let v = infer_to_python(v, next_include, next_exclude, extra)?;
680662
new_dict.set_item(k, v)?;
681663
}
@@ -685,13 +667,13 @@ fn serialize_pairs_python_mode_json<'py>(
685667

686668
fn serialize_pairs_json<'py, S: Serializer>(
687669
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
670+
iter_size: usize,
688671
serializer: S,
689672
include: Option<&PyAny>,
690673
exclude: Option<&PyAny>,
691674
extra: &Extra,
692675
) -> Result<S::Ok, S::Error> {
693-
let (_, expected) = pairs_iter.size_hint();
694-
let mut map = serializer.serialize_map(expected)?;
676+
let mut map = serializer.serialize_map(Some(iter_size))?;
695677
let filter = AnyFilter::new();
696678

697679
for result in pairs_iter {

‎src/serializers/shared.rs‎

Lines changed: 14 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::fmt::Debug;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::once_cell::GILOnceCell;
66
use pyo3::prelude::*;
7-
use pyo3::types::iter::PyDictIterator;
87
use pyo3::types::{PyDict, PyString};
98
use pyo3::{intern, PyTraverseError, PyVisit};
109

@@ -365,75 +364,25 @@ pub(crate) fn to_json_bytes(
365364
Ok(bytes)
366365
}
367366

368-
pub(super) struct DictIterator<'py> {
369-
dict_iter: PyDictIterator<'py>,
370-
}
371-
372-
impl<'py> DictIterator<'py> {
373-
pub fn new(dict: &'py PyDict) -> Self {
374-
Self { dict_iter: dict.iter() }
375-
}
376-
}
377-
378-
impl<'py> Iterator for DictIterator<'py> {
379-
type Item = PyResult<(&'py PyAny, &'py PyAny)>;
380-
381-
fn next(&mut self) -> Option<Self::Item> {
382-
self.dict_iter.next().map(Ok)
383-
}
384-
385-
fn size_hint(&self) -> (usize, Option<usize>) {
386-
self.dict_iter.size_hint()
387-
}
388-
}
389-
390-
pub(super) struct AnyDataclassIterator<'py> {
367+
pub(super) fn any_dataclass_iter<'py>(
391368
dataclass: &'py PyAny,
392-
fields_iter: PyDictIterator<'py>,
393-
field_type_marker: &'py PyAny,
394-
}
395-
396-
impl<'py> AnyDataclassIterator<'py> {
397-
pub fn new(dc: &'py PyAny) -> PyResult<Self> {
398-
let py = dc.py();
399-
let fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
400-
Ok(Self {
401-
dataclass: dc,
402-
fields_iter: fields.iter(),
403-
field_type_marker: get_field_marker(py)?,
404-
})
405-
}
406-
407-
fn _next(&mut self) -> PyResult<Option<(&'py PyAny, &'py PyAny)>> {
408-
if let Some((field_name, field)) = self.fields_iter.next() {
409-
let field_type = field.getattr(intern!(self.dataclass.py(), "_field_type"))?;
410-
if field_type.is(self.field_type_marker) {
411-
let field_name: &PyString = field_name.downcast()?;
412-
let value = self.dataclass.getattr(field_name)?;
413-
Ok(Some((field_name, value)))
414-
} else {
415-
self._next()
416-
}
369+
) -> PyResult<(impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'py, &PyDict)> {
370+
let py = dataclass.py();
371+
let fields: &PyDict = dataclass.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
372+
let field_type_marker = get_field_marker(py)?;
373+
374+
let next = move |(field_name, field): (&'py PyAny, &'py PyAny)| -> PyResult<Option<(&'py PyAny, &'py PyAny)>> {
375+
let field_type = field.getattr(intern!(py, "_field_type"))?;
376+
if field_type.is(field_type_marker) {
377+
let field_name: &PyString = field_name.downcast()?;
378+
let value = dataclass.getattr(field_name)?;
379+
Ok(Some((field_name, value)))
417380
} else {
418381
Ok(None)
419382
}
420-
}
421-
}
422-
423-
impl<'py> Iterator for AnyDataclassIterator<'py> {
424-
type Item = PyResult<(&'py PyAny, &'py PyAny)>;
425-
426-
fn next(&mut self) -> Option<Self::Item> {
427-
match self._next() {
428-
Ok(Some(v)) => Some(Ok(v)),
429-
Ok(None) => None,
430-
Err(e) => Some(Err(e)),
431-
}
432-
}
383+
};
433384

434-
fn size_hint(&self) -> (usize, Option<usize>) {
435-
(0, None)
436-
}
385+
Ok((fields.iter().filter_map(move |field| next(field).transpose()), fields))
437386
}
438387

439388
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();

‎src/serializers/type_serializers/dataclass.rs‎

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl TypeSerializer for DataclassSerializer {
141141
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
142142
let output_dict = fields_serializer.main_to_python(
143143
py,
144-
KnownDataclassIterator::new(&self.fields, value),
144+
known_dataclass_iter(&self.fields, value),
145145
include,
146146
exclude,
147147
dc_extra,
@@ -182,7 +182,7 @@ impl TypeSerializer for DataclassSerializer {
182182
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
183183
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
184184
let mut map = fields_serializer.main_serde_serialize(
185-
KnownDataclassIterator::new(&self.fields, value),
185+
known_dataclass_iter(&self.fields, value),
186186
expected_len,
187187
serializer,
188188
include,
@@ -211,36 +211,17 @@ impl TypeSerializer for DataclassSerializer {
211211
}
212212
}
213213

214-
pub struct KnownDataclassIterator<'a, 'py> {
215-
index: usize,
214+
fn known_dataclass_iter<'a, 'py>(
216215
fields: &'a [Py<PyString>],
217216
dataclass: &'py PyAny,
218-
}
219-
220-
impl<'a, 'py> KnownDataclassIterator<'a, 'py> {
221-
pub fn new(fields: &'a [Py<PyString>], dataclass: &'py PyAny) -> Self {
222-
Self {
223-
index: 0,
224-
fields,
225-
dataclass,
226-
}
227-
}
228-
}
229-
230-
impl<'a, 'py> Iterator for KnownDataclassIterator<'a, 'py> {
231-
type Item = PyResult<(&'py PyAny, &'py PyAny)>;
232-
233-
fn next(&mut self) -> Option<Self::Item> {
234-
if let Some(field) = self.fields.get(self.index) {
235-
self.index += 1;
236-
let py = self.dataclass.py();
237-
let field_ref = field.clone_ref(py).into_ref(py);
238-
match self.dataclass.getattr(field_ref) {
239-
Ok(value) => Some(Ok((field_ref, value))),
240-
Err(e) => Some(Err(e)),
241-
}
242-
} else {
243-
None
244-
}
245-
}
217+
) -> impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'a
218+
where
219+
'py: 'a,
220+
{
221+
let py = dataclass.py();
222+
fields.iter().map(move |field| {
223+
let field_ref = field.clone_ref(py).into_ref(py);
224+
let value = dataclass.getattr(field_ref)?;
225+
Ok((field_ref as &PyAny, value))
226+
})
246227
}

0 commit comments

Comments
 (0)