Skip to content

Commit

Permalink
Add validate.And (#1777)
Browse files Browse the repository at this point in the history
* Add validate.And

* Rename _validate_callable to _validate_all

* Test composing And

* Add and_; fix passing a generator to And

* Remove and_
  • Loading branch information
sloria authored and bonastreyair committed Apr 19, 2021
1 parent b58773a commit a7054db
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

Features:

- Add ``validate.And`` (:issue:`1768`).
Thanks :user:`rugleb` for the suggestion.
- Let ``Field``s be accessed by name as ``Schema`` attributes (:pr:`1631`).
- Add a `NoDuplicates` validator in ``marshmallow.validate`` (:pr:`1793`).

Expand Down
22 changes: 6 additions & 16 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
StringNotCollectionError,
FieldInstanceResolutionError,
)
from marshmallow.validate import Validator, Length
from marshmallow.validate import And, Length
from marshmallow.warnings import RemovedInMarshmallow4Warning

__all__ = [
Expand Down Expand Up @@ -242,21 +242,11 @@ def _validate(self, value):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
errors = []
kwargs = {}
for validator in self.validators:
try:
r = validator(value)
if not isinstance(validator, Validator) and r is False:
raise self.make_error("validator_failed")
except ValidationError as err:
kwargs.update(err.kwargs)
if isinstance(err.messages, dict):
errors.append(err.messages)
else:
errors.extend(err.messages)
if errors:
raise ValidationError(errors, **kwargs)
self._validate_all(value)

@property
def _validate_all(self):
return And(*self.validators, error=self.error_messages["validator_failed"])

def make_error(self, key: str, **kwargs) -> ValidationError:
"""Helper method to make a `ValidationError` with an error message
Expand Down
1 change: 1 addition & 0 deletions src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@

StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.Set[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]
50 changes: 50 additions & 0 deletions src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,56 @@ def __call__(self, value: typing.Any) -> typing.Any:
...


class And(Validator):
"""Compose multiple validators and combine their error messages.
Example: ::
from marshmallow import validate, ValidationError
def is_even(value):
if value % 2 != 0:
raise ValidationError("Not an even value.")
validator = validate.And(validate.Range(min=0), is_even)
validator(-1)
# ValidationError: ['Must be greater than or equal to 0.', 'Not an even value.']
:param validators: Validators to combine.
:param error: Error message to use when a validator returns ``False``.
"""

default_error_message = "Invalid value."

def __init__(
self, *validators: types.Validator, error: typing.Optional[str] = None
):
self.validators = tuple(validators)
self.error = error or self.default_error_message # type: str

def _repr_args(self) -> str:
return "validators={!r}".format(self.validators)

def __call__(self, value: typing.Any) -> typing.Any:
errors = []
kwargs = {}
for validator in self.validators:
try:
r = validator(value)
if not isinstance(validator, Validator) and r is False:
raise ValidationError(self.error)
except ValidationError as err:
kwargs.update(err.kwargs)
if isinstance(err.messages, dict):
errors.append(err.messages)
else:
# FIXME : Get rid of cast
errors.extend(typing.cast(list, err.messages))
if errors:
raise ValidationError(errors, **kwargs)
return value


class URL(Validator):
"""Validate a URL.
Expand Down
5 changes: 5 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,11 @@ def test_email_field_deserialization(self):
field.deserialize("invalidemail")
assert excinfo.value.args[0][0] == "Not a valid email address."

field = fields.Email(validate=[validate.Length(min=12)])
with pytest.raises(ValidationError) as excinfo:
field.deserialize("[email protected]")
assert excinfo.value.args[0][0] == "Shorter than minimum length 12."

# regression test for https://github.com/marshmallow-code/marshmallow/issues/1400
def test_email_field_non_list_validators(self):
field = fields.Email(validate=(validate.Length(min=9),))
Expand Down
22 changes: 22 additions & 0 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,28 @@ def test_containsnoneof_mixing_types():
validate.ContainsNoneOf([1, 2, 3])((1,))


def is_even(value):
if value % 2 != 0:
raise ValidationError("Not an even value.")


def test_and():
validator = validate.And(validate.Range(min=0), is_even)
assert validator(2)
with pytest.raises(ValidationError) as excinfo:
validator(-1)
errors = excinfo.value.messages
assert errors == ["Must be greater than or equal to 0.", "Not an even value."]

validator_with_composition = validate.And(validator, validate.Range(max=6))
assert validator_with_composition(4)
with pytest.raises(ValidationError) as excinfo:
validator_with_composition(7)

errors = excinfo.value.messages
assert errors == ["Not an even value.", "Must be less than or equal to 6."]


def test_noduplicates():
class Mock:
def __init__(self, name):
Expand Down

0 comments on commit a7054db

Please sign in to comment.