Skip to content

Commit 16370b4

Browse files
committed
Implement retry_with_lax_check in build_simple_serializer
1 parent 0303d7f commit 16370b4

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

‎src/serializers/type_serializers/simple.rs‎

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl TypeSerializer for NoneSerializer {
8686
}
8787

8888
macro_rules! build_simple_serializer {
89-
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => {
89+
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => {
9090
#[derive(Debug, Clone)]
9191
pub struct $struct_name;
9292

@@ -164,6 +164,10 @@ macro_rules! build_simple_serializer {
164164
fn get_name(&self) -> &str {
165165
Self::EXPECTED_TYPE
166166
}
167+
168+
fn retry_with_lax_check(&self) -> bool {
169+
$subtypes_allowed
170+
}
167171
}
168172
};
169173
}
@@ -172,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult<Cow<str>> {
172176
Ok(key.str()?.to_string_lossy())
173177
}
174178

175-
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key);
179+
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true);
176180

177181
pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
178182
let v = if key.is_true().unwrap_or(false) {
@@ -183,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
183187
Ok(Cow::Borrowed(v))
184188
}
185189

186-
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key);
190+
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false);

‎tests/serializers/test_union.py‎

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,14 +517,19 @@ class Item(BaseModel):
517517
EXAMPLE_UUID = uuid.uuid4()
518518

519519

520-
@pytest.mark.parametrize('order', ['direct', 'inverse'])
520+
class IntSubclass(int):
521+
pass
522+
523+
524+
@pytest.mark.parametrize('reverse', [False, True])
521525
@pytest.mark.parametrize(
522526
'core_schema_left,core_schema_right,input_value,expected_value',
523527
[
524528
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
525529
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
526530
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
527531
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
532+
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
528533
(
529534
core_schema.decimal_schema(),
530535
core_schema.int_schema(),
@@ -538,6 +543,18 @@ class Item(BaseModel):
538543
Decimal('1.'),
539544
Decimal('1.'),
540545
),
546+
(
547+
core_schema.decimal_schema(),
548+
core_schema.str_schema(),
549+
Decimal('_1'),
550+
Decimal('_1'),
551+
),
552+
(
553+
core_schema.decimal_schema(),
554+
core_schema.str_schema(),
555+
'_1',
556+
'_1',
557+
),
541558
(
542559
core_schema.uuid_schema(),
543560
core_schema.str_schema(),
@@ -553,24 +570,25 @@ class Item(BaseModel):
553570
],
554571
)
555572
def test_union_serializer_picks_exact_type_over_subclass(
556-
core_schema_left, core_schema_right, input_value, expected_value, order
573+
core_schema_left, core_schema_right, input_value, expected_value, reverse
557574
):
558575
s = SchemaSerializer(
559576
core_schema.union_schema(
560-
[core_schema_left, core_schema_right] if order == 'direct' else [core_schema_right, core_schema_left]
577+
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
561578
)
562579
)
563580
assert s.to_python(input_value) == expected_value
564581

565582

566-
@pytest.mark.parametrize('order', ['direct', 'inverse'])
583+
@pytest.mark.parametrize('reverse', [False, True])
567584
@pytest.mark.parametrize(
568585
'core_schema_left,core_schema_right,input_value,expected_value',
569586
[
570587
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
571588
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
572589
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
573590
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
591+
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
574592
(
575593
core_schema.decimal_schema(),
576594
core_schema.int_schema(),
@@ -584,14 +602,26 @@ def test_union_serializer_picks_exact_type_over_subclass(
584602
Decimal('1.'),
585603
'1',
586604
),
605+
(
606+
core_schema.decimal_schema(),
607+
core_schema.str_schema(),
608+
Decimal('_1'),
609+
'1',
610+
),
611+
(
612+
core_schema.decimal_schema(),
613+
core_schema.str_schema(),
614+
'_1',
615+
'_1',
616+
),
587617
],
588618
)
589619
def test_union_serializer_picks_exact_type_over_subclass_json(
590-
core_schema_left, core_schema_right, input_value, expected_value, order
620+
core_schema_left, core_schema_right, input_value, expected_value, reverse
591621
):
592622
s = SchemaSerializer(
593623
core_schema.union_schema(
594-
[core_schema_left, core_schema_right] if order == 'direct' else [core_schema_right, core_schema_left]
624+
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
595625
)
596626
)
597627
assert s.to_python(input_value, mode='json') == expected_value

0 commit comments

Comments
 (0)