Skip to content

Commit d93482e

Browse files
Fix pydantic 7715 (#1002)
Co-authored-by: David Montague <[email protected]>
1 parent 4622ed7 commit d93482e

File tree

4 files changed

+231
-31
lines changed

4 files changed

+231
-31
lines changed

‎src/validators/dataclass.rs‎

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,31 @@ impl Validator for DataclassArgsValidator {
232232
}
233233
// found neither, check if there is a default value, otherwise error
234234
(None, None) => {
235-
if let Some(value) =
236-
field
237-
.validator
238-
.default_value(py, Some(field.name.as_str()), state)?
239-
{
240-
set_item!(field, value);
241-
} else {
242-
errors.push(field.lookup_key.error(
243-
ErrorTypeDefaults::Missing,
244-
input,
245-
self.loc_by_alias,
246-
&field.name,
247-
));
235+
match field.validator.default_value(py, Some(field.name.as_str()), state) {
236+
Ok(Some(value)) => {
237+
// Default value exists, and passed validation if required
238+
set_item!(field, value);
239+
},
240+
Ok(None) => {
241+
// This means there was no default value
242+
errors.push(field.lookup_key.error(
243+
ErrorTypeDefaults::Missing,
244+
input,
245+
self.loc_by_alias,
246+
&field.name
247+
));
248+
},
249+
Err(ValError::Omit) => continue,
250+
Err(ValError::LineErrors(line_errors)) => {
251+
for err in line_errors {
252+
// Note: this will always use the field name even if there is an alias
253+
// However, we don't mind so much because this error can only happen if the
254+
// default value fails validation, which is arguably a developer error.
255+
// We could try to "fix" this in the future if desired.
256+
errors.push(err);
257+
}
258+
}
259+
Err(err) => return Err(err),
248260
}
249261
}
250262
}

‎src/validators/model_fields.rs‎

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,33 @@ impl Validator for ModelFieldsValidator {
211211
Err(err) => return ControlFlow::Break(err.into_owned(py)),
212212
}
213213
continue;
214-
} else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? {
215-
control_flow!(model_dict.set_item(&field.name_py, value))?;
216-
} else {
217-
errors.push(field.lookup_key.error(
218-
ErrorTypeDefaults::Missing,
219-
input,
220-
self.loc_by_alias,
221-
&field.name
222-
));
214+
}
215+
216+
match field.validator.default_value(py, Some(field.name.as_str()), state) {
217+
Ok(Some(value)) => {
218+
// Default value exists, and passed validation if required
219+
control_flow!(model_dict.set_item(&field.name_py, value))?;
220+
},
221+
Ok(None) => {
222+
// This means there was no default value
223+
errors.push(field.lookup_key.error(
224+
ErrorTypeDefaults::Missing,
225+
input,
226+
self.loc_by_alias,
227+
&field.name
228+
));
229+
},
230+
Err(ValError::Omit) => continue,
231+
Err(ValError::LineErrors(line_errors)) => {
232+
for err in line_errors {
233+
// Note: this will always use the field name even if there is an alias
234+
// However, we don't mind so much because this error can only happen if the
235+
// default value fails validation, which is arguably a developer error.
236+
// We could try to "fix" this in the future if desired.
237+
errors.push(err);
238+
}
239+
}
240+
Err(err) => return ControlFlow::Break(err),
223241
}
224242
}
225243
ControlFlow::Continue(())

‎src/validators/typed_dict.rs‎

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,35 @@ impl Validator for TypedDictValidator {
212212
Err(err) => return ControlFlow::Break(err.into_owned(py)),
213213
}
214214
continue;
215-
} else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? {
216-
control_flow!(output_dict.set_item(&field.name_py, value))?;
217-
} else if field.required {
218-
errors.push(field.lookup_key.error(
219-
ErrorTypeDefaults::Missing,
220-
input,
221-
self.loc_by_alias,
222-
&field.name
223-
));
215+
}
216+
217+
match field.validator.default_value(py, Some(field.name.as_str()), state) {
218+
Ok(Some(value)) => {
219+
// Default value exists, and passed validation if required
220+
control_flow!(output_dict.set_item(&field.name_py, value))?;
221+
},
222+
Ok(None) => {
223+
// This means there was no default value
224+
if (field.required) {
225+
errors.push(field.lookup_key.error(
226+
ErrorTypeDefaults::Missing,
227+
input,
228+
self.loc_by_alias,
229+
&field.name
230+
));
231+
}
232+
},
233+
Err(ValError::Omit) => continue,
234+
Err(ValError::LineErrors(line_errors)) => {
235+
for err in line_errors {
236+
// Note: this will always use the field name even if there is an alias
237+
// However, we don't mind so much because this error can only happen if the
238+
// default value fails validation, which is arguably a developer error.
239+
// We could try to "fix" this in the future if desired.
240+
errors.push(err);
241+
}
242+
}
243+
Err(err) => return ControlFlow::Break(err),
224244
}
225245
}
226246
ControlFlow::Continue(())

‎tests/validators/test_with_default.py‎

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,153 @@ def _validator(cls, v, info):
654654
gc.collect()
655655

656656
assert ref() is None
657+
658+
659+
validate_default_raises_examples = [
660+
(
661+
{},
662+
[
663+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
664+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
665+
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {}},
666+
],
667+
),
668+
(
669+
{'z': 'some str'},
670+
[
671+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
672+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
673+
],
674+
),
675+
(
676+
{'x': None},
677+
[
678+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
679+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
680+
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None}},
681+
],
682+
),
683+
(
684+
{'x': None, 'z': 'some str'},
685+
[
686+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
687+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
688+
],
689+
),
690+
(
691+
{'y': None},
692+
[
693+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
694+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
695+
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'y': None}},
696+
],
697+
),
698+
(
699+
{'y': None, 'z': 'some str'},
700+
[
701+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
702+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
703+
],
704+
),
705+
(
706+
{'x': None, 'y': None},
707+
[
708+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
709+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
710+
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None, 'y': None}},
711+
],
712+
),
713+
(
714+
{'x': None, 'y': None, 'z': 'some str'},
715+
[
716+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
717+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
718+
],
719+
),
720+
(
721+
{'x': 1, 'y': None, 'z': 'some str'},
722+
[
723+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1},
724+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
725+
],
726+
),
727+
(
728+
{'x': None, 'y': 1, 'z': 'some str'},
729+
[
730+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
731+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1},
732+
],
733+
),
734+
(
735+
{'x': 1, 'y': 1, 'z': 'some str'},
736+
[
737+
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1},
738+
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1},
739+
],
740+
),
741+
]
742+
743+
744+
@pytest.mark.parametrize(
745+
'core_schema_constructor,field_constructor',
746+
[
747+
(core_schema.model_fields_schema, core_schema.model_field),
748+
(core_schema.typed_dict_schema, core_schema.typed_dict_field),
749+
],
750+
)
751+
@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples)
752+
def test_validate_default_raises(
753+
core_schema_constructor: Union[core_schema.ModelFieldsSchema, core_schema.TypedDictSchema],
754+
field_constructor: Union[core_schema.model_field, core_schema.typed_dict_field],
755+
input_value: dict,
756+
expected: Any,
757+
) -> None:
758+
def _raise(ex: Exception) -> None:
759+
raise ex()
760+
761+
inner_schema = core_schema.no_info_after_validator_function(
762+
lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema())
763+
)
764+
765+
v = SchemaValidator(
766+
core_schema_constructor(
767+
{
768+
'x': field_constructor(
769+
core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
770+
),
771+
'y': field_constructor(
772+
core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
773+
),
774+
'z': field_constructor(core_schema.str_schema()),
775+
}
776+
)
777+
)
778+
779+
with pytest.raises(ValidationError) as exc_info:
780+
v.validate_python(input_value)
781+
assert exc_info.value.errors(include_url=False, include_context=False) == expected
782+
783+
784+
@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples)
785+
def test_validate_default_raises_dataclass(input_value: dict, expected: Any) -> None:
786+
def _raise(ex: Exception) -> None:
787+
raise ex()
788+
789+
inner_schema = core_schema.no_info_after_validator_function(
790+
lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema())
791+
)
792+
793+
x = core_schema.dataclass_field(
794+
name='x', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
795+
)
796+
y = core_schema.dataclass_field(
797+
name='y', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
798+
)
799+
z = core_schema.dataclass_field(name='z', schema=core_schema.str_schema())
800+
801+
v = SchemaValidator(core_schema.dataclass_args_schema('XYZ', [x, y, z]))
802+
803+
with pytest.raises(ValidationError) as exc_info:
804+
v.validate_python(input_value)
805+
806+
assert exc_info.value.errors(include_url=False, include_context=False) == expected

0 commit comments

Comments
 (0)