|
| 1 | +use std::borrow::Cow; |
| 2 | + |
| 3 | +use crate::build_tools::py_schema_err; |
| 4 | +use pyo3::intern; |
| 5 | +use pyo3::prelude::*; |
| 6 | +use pyo3::types::{PyDict, PyType}; |
| 7 | + |
| 8 | +use crate::definitions::DefinitionsBuilder; |
| 9 | +use crate::serializers::errors::py_err_se_err; |
| 10 | +use crate::serializers::infer::{infer_json_key, infer_serialize, infer_to_python}; |
| 11 | +use crate::tools::SchemaDict; |
| 12 | + |
| 13 | +use super::float::FloatSerializer; |
| 14 | +use super::simple::IntSerializer; |
| 15 | +use super::string::StrSerializer; |
| 16 | +use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer}; |
| 17 | + |
| 18 | +#[derive(Debug, Clone)] |
| 19 | +pub struct EnumSerializer { |
| 20 | + class: Py<PyType>, |
| 21 | + serializer: Option<Box<CombinedSerializer>>, |
| 22 | +} |
| 23 | + |
| 24 | +impl BuildSerializer for EnumSerializer { |
| 25 | + const EXPECTED_TYPE: &'static str = "enum"; |
| 26 | + |
| 27 | + fn build( |
| 28 | + schema: &Bound<'_, PyDict>, |
| 29 | + config: Option<&Bound<'_, PyDict>>, |
| 30 | + _definitions: &mut DefinitionsBuilder<CombinedSerializer>, |
| 31 | + ) -> PyResult<CombinedSerializer> { |
| 32 | + let sub_type: Option<String> = schema.get_as(intern!(schema.py(), "sub_type"))?; |
| 33 | + |
| 34 | + let serializer = match sub_type.as_deref() { |
| 35 | + Some("int") => Some(Box::new(IntSerializer::new().into())), |
| 36 | + Some("str") => Some(Box::new(StrSerializer::new().into())), |
| 37 | + Some("float") => Some(Box::new(FloatSerializer::new(schema.py(), config)?.into())), |
| 38 | + Some(_) => return py_schema_err!("`sub_type` must be one of: 'int', 'str', 'float' or None"), |
| 39 | + None => None, |
| 40 | + }; |
| 41 | + Ok(Self { |
| 42 | + class: schema.get_as_req(intern!(schema.py(), "cls"))?, |
| 43 | + serializer, |
| 44 | + } |
| 45 | + .into()) |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +impl_py_gc_traverse!(EnumSerializer { serializer }); |
| 50 | + |
| 51 | +impl TypeSerializer for EnumSerializer { |
| 52 | + fn to_python( |
| 53 | + &self, |
| 54 | + value: &Bound<'_, PyAny>, |
| 55 | + include: Option<&Bound<'_, PyAny>>, |
| 56 | + exclude: Option<&Bound<'_, PyAny>>, |
| 57 | + extra: &Extra, |
| 58 | + ) -> PyResult<PyObject> { |
| 59 | + let py = value.py(); |
| 60 | + if value.is_exact_instance(self.class.bind(py)) { |
| 61 | + // if we're in JSON mode, we need to get the value attribute and serialize that |
| 62 | + if extra.mode.is_json() { |
| 63 | + let dot_value = value.getattr(intern!(py, "value"))?; |
| 64 | + match self.serializer { |
| 65 | + Some(ref s) => s.to_python(&dot_value, include, exclude, extra), |
| 66 | + None => infer_to_python(&dot_value, include, exclude, extra), |
| 67 | + } |
| 68 | + } else { |
| 69 | + // if we're not in JSON mode, we assume the value is safe to return directly |
| 70 | + Ok(value.into_py(py)) |
| 71 | + } |
| 72 | + } else { |
| 73 | + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; |
| 74 | + infer_to_python(value, include, exclude, extra) |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { |
| 79 | + let py = key.py(); |
| 80 | + if key.is_exact_instance(self.class.bind(py)) { |
| 81 | + let dot_value = key.getattr(intern!(py, "value"))?; |
| 82 | + let k = match self.serializer { |
| 83 | + Some(ref s) => s.json_key(&dot_value, extra), |
| 84 | + None => infer_json_key(&dot_value, extra), |
| 85 | + }?; |
| 86 | + // since dot_value is a local reference, we need to allocate it and returned an |
| 87 | + // owned variant of cow. |
| 88 | + Ok(Cow::Owned(k.into_owned())) |
| 89 | + } else { |
| 90 | + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; |
| 91 | + infer_json_key(key, extra) |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + fn serde_serialize<S: serde::ser::Serializer>( |
| 96 | + &self, |
| 97 | + value: &Bound<'_, PyAny>, |
| 98 | + serializer: S, |
| 99 | + include: Option<&Bound<'_, PyAny>>, |
| 100 | + exclude: Option<&Bound<'_, PyAny>>, |
| 101 | + extra: &Extra, |
| 102 | + ) -> Result<S::Ok, S::Error> { |
| 103 | + if value.is_exact_instance(self.class.bind(value.py())) { |
| 104 | + let dot_value = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?; |
| 105 | + match self.serializer { |
| 106 | + Some(ref s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra), |
| 107 | + None => infer_serialize(&dot_value, serializer, include, exclude, extra), |
| 108 | + } |
| 109 | + } else { |
| 110 | + extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?; |
| 111 | + infer_serialize(value, serializer, include, exclude, extra) |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + fn get_name(&self) -> &str { |
| 116 | + Self::EXPECTED_TYPE |
| 117 | + } |
| 118 | + |
| 119 | + fn retry_with_lax_check(&self) -> bool { |
| 120 | + match self.serializer { |
| 121 | + Some(ref s) => s.retry_with_lax_check(), |
| 122 | + None => false, |
| 123 | + } |
| 124 | + } |
| 125 | +} |
0 commit comments