back to Reference (Gold) summary
Reference (Gold): voluptuous
Pytest Summary for test tests
status | count |
---|---|
passed | 149 |
total | 149 |
collected | 149 |
Failed pytests:
Patch diff
diff --git a/voluptuous/error.py b/voluptuous/error.py
index f72fbe7..9dab943 100644
--- a/voluptuous/error.py
+++ b/voluptuous/error.py
@@ -1,5 +1,8 @@
+# fmt: off
import typing
+# fmt: on
+
class Error(Exception):
"""Base validation exception."""
@@ -19,35 +22,70 @@ class Invalid(Error):
"""
- def __init__(self, message: str, path: typing.Optional[typing.List[
- typing.Hashable]]=None, error_message: typing.Optional[str]=None,
- error_type: typing.Optional[str]=None) ->None:
+ def __init__(
+ self,
+ message: str,
+ path: typing.Optional[typing.List[typing.Hashable]] = None,
+ error_message: typing.Optional[str] = None,
+ error_type: typing.Optional[str] = None,
+ ) -> None:
Error.__init__(self, message)
self._path = path or []
self._error_message = error_message or message
self.error_type = error_type
- def __str__(self) ->str:
- path = ' @ data[%s]' % ']['.join(map(repr, self.path)
- ) if self.path else ''
+ @property
+ def msg(self) -> str:
+ return self.args[0]
+
+ @property
+ def path(self) -> typing.List[typing.Hashable]:
+ return self._path
+
+ @property
+ def error_message(self) -> str:
+ return self._error_message
+
+ def __str__(self) -> str:
+ path = ' @ data[%s]' % ']['.join(map(repr, self.path)) if self.path else ''
output = Exception.__str__(self)
if self.error_type:
output += ' for ' + self.error_type
return output + path
+ def prepend(self, path: typing.List[typing.Hashable]) -> None:
+ self._path = path + self.path
-class MultipleInvalid(Invalid):
- def __init__(self, errors: typing.Optional[typing.List[Invalid]]=None
- ) ->None:
+class MultipleInvalid(Invalid):
+ def __init__(self, errors: typing.Optional[typing.List[Invalid]] = None) -> None:
self.errors = errors[:] if errors else []
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return 'MultipleInvalid(%r)' % self.errors
- def __str__(self) ->str:
+ @property
+ def msg(self) -> str:
+ return self.errors[0].msg
+
+ @property
+ def path(self) -> typing.List[typing.Hashable]:
+ return self.errors[0].path
+
+ @property
+ def error_message(self) -> str:
+ return self.errors[0].error_message
+
+ def add(self, error: Invalid) -> None:
+ self.errors.append(error)
+
+ def __str__(self) -> str:
return str(self.errors[0])
+ def prepend(self, path: typing.List[typing.Hashable]) -> None:
+ for error in self.errors:
+ error.prepend(path)
+
class RequiredFieldInvalid(Invalid):
"""Required field was missing."""
@@ -171,9 +209,11 @@ class ExactSequenceInvalid(Invalid):
class NotEnoughValid(Invalid):
"""The value did not pass enough validations."""
+
pass
class TooManyValid(Invalid):
"""The value passed more than expected validations."""
+
pass
diff --git a/voluptuous/humanize.py b/voluptuous/humanize.py
index 2902871..eabfd02 100644
--- a/voluptuous/humanize.py
+++ b/voluptuous/humanize.py
@@ -1,14 +1,57 @@
+# fmt: off
import typing
+
from voluptuous import Invalid, MultipleInvalid
from voluptuous.error import Error
from voluptuous.schema_builder import Schema
+
+# fmt: on
+
MAX_VALIDATION_ERROR_ITEM_LENGTH = 500
-def humanize_error(data, validation_error: Invalid, max_sub_error_length:
- int=MAX_VALIDATION_ERROR_ITEM_LENGTH) ->str:
+def _nested_getitem(
+ data: typing.Any, path: typing.List[typing.Hashable]
+) -> typing.Optional[typing.Any]:
+ for item_index in path:
+ try:
+ data = data[item_index]
+ except (KeyError, IndexError, TypeError):
+ # The index is not present in the dictionary, list or other
+ # indexable or data is not subscriptable
+ return None
+ return data
+
+
+def humanize_error(
+ data,
+ validation_error: Invalid,
+ max_sub_error_length: int = MAX_VALIDATION_ERROR_ITEM_LENGTH,
+) -> str:
"""Provide a more helpful + complete validation error message than that provided automatically
Invalid and MultipleInvalid do not include the offending value in error messages,
and MultipleInvalid.__str__ only provides the first error.
"""
- pass
+ if isinstance(validation_error, MultipleInvalid):
+ return '\n'.join(
+ sorted(
+ humanize_error(data, sub_error, max_sub_error_length)
+ for sub_error in validation_error.errors
+ )
+ )
+ else:
+ offending_item_summary = repr(_nested_getitem(data, validation_error.path))
+ if len(offending_item_summary) > max_sub_error_length:
+ offending_item_summary = (
+ offending_item_summary[: max_sub_error_length - 3] + '...'
+ )
+ return '%s. Got %s' % (validation_error, offending_item_summary)
+
+
+def validate_with_humanized_errors(
+ data, schema: Schema, max_sub_error_length: int = MAX_VALIDATION_ERROR_ITEM_LENGTH
+) -> typing.Any:
+ try:
+ return schema(data)
+ except (Invalid, MultipleInvalid) as e:
+ raise Error(humanize_error(data, e, max_sub_error_length))
diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py
index de2b53c..cdeb514 100644
--- a/voluptuous/schema_builder.py
+++ b/voluptuous/schema_builder.py
@@ -1,4 +1,6 @@
+# fmt: off
from __future__ import annotations
+
import collections
import inspect
import itertools
@@ -8,15 +10,23 @@ import typing
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache, wraps
+
from voluptuous import error as er
from voluptuous.error import Error
-PREVENT_EXTRA = 0
-ALLOW_EXTRA = 1
-REMOVE_EXTRA = 2
+# fmt: on
-class Undefined(object):
+# options for extra keys
+PREVENT_EXTRA = 0 # any extra key not in schema will raise an error
+ALLOW_EXTRA = 1 # extra keys not in schema will be included in output
+REMOVE_EXTRA = 2 # extra keys not in schema will be excluded from output
+
+
+def _isnamedtuple(obj):
+ return isinstance(obj, tuple) and hasattr(obj, '_fields')
+
+class Undefined(object):
def __nonzero__(self):
return False
@@ -25,19 +35,56 @@ class Undefined(object):
UNDEFINED = Undefined()
+
+
+def Self() -> None:
+ raise er.SchemaError('"Self" should never be called')
+
+
DefaultFactory = typing.Union[Undefined, typing.Callable[[], typing.Any]]
-def Extra(_) ->None:
+def default_factory(value) -> DefaultFactory:
+ if value is UNDEFINED or callable(value):
+ return value
+ return lambda: value
+
+
+@contextmanager
+def raises(
+ exc, msg: typing.Optional[str] = None, regex: typing.Optional[re.Pattern] = None
+) -> Generator[None, None, None]:
+ try:
+ yield
+ except exc as e:
+ if msg is not None:
+ assert str(e) == msg, '%r != %r' % (str(e), msg)
+ if regex is not None:
+ assert re.search(regex, str(e)), '%r does not match %r' % (str(e), regex)
+ else:
+ raise AssertionError(f"Did not raise exception {exc.__name__}")
+
+
+def Extra(_) -> None:
"""Allow keys in the data that are not present in the schema."""
- pass
+ raise er.SchemaError('"Extra" should never be called')
+# As extra() is never called there's no way to catch references to the
+# deprecated object, so we just leave an alias here instead.
extra = Extra
-primitive_types = bool, bytes, int, str, float, complex
-Schemable = typing.Union['Schema', 'Object', collections.abc.Mapping, list,
- tuple, frozenset, set, bool, bytes, int, str, float, complex, type,
- object, dict, None, typing.Callable]
+
+primitive_types = (bool, bytes, int, str, float, complex)
+
+# fmt: off
+Schemable = typing.Union[
+ 'Schema', 'Object',
+ collections.abc.Mapping,
+ list, tuple, frozenset, set,
+ bool, bytes, int, str, float, complex,
+ type, object, dict, None, typing.Callable
+]
+# fmt: on
class Schema(object):
@@ -61,11 +108,16 @@ class Schema(object):
>>> assert v != v2
"""
- _extra_to_name = {REMOVE_EXTRA: 'REMOVE_EXTRA', ALLOW_EXTRA:
- 'ALLOW_EXTRA', PREVENT_EXTRA: 'PREVENT_EXTRA'}
- def __init__(self, schema: Schemable, required: bool=False, extra: int=
- PREVENT_EXTRA) ->None:
+ _extra_to_name = {
+ REMOVE_EXTRA: 'REMOVE_EXTRA',
+ ALLOW_EXTRA: 'ALLOW_EXTRA',
+ PREVENT_EXTRA: 'PREVENT_EXTRA',
+ }
+
+ def __init__(
+ self, schema: Schemable, required: bool = False, extra: int = PREVENT_EXTRA
+ ) -> None:
"""Create a new Schema.
:param schema: Validation schema. See :module:`voluptuous` for details.
@@ -82,11 +134,11 @@ class Schema(object):
"""
self.schema: typing.Any = schema
self.required = required
- self.extra = int(extra)
+ self.extra = int(extra) # ensure the value is an integer
self._compiled = self._compile(schema)
@classmethod
- def infer(cls, data, **kwargs) ->Schema:
+ def infer(cls, data, **kwargs) -> Schema:
"""Create a Schema from concrete data (e.g. an API response).
For example, this will take a dict like:
@@ -113,7 +165,20 @@ class Schema(object):
Note: only very basic inference is supported.
"""
- pass
+
+ def value_to_schema_type(value):
+ if isinstance(value, dict):
+ if len(value) == 0:
+ return dict
+ return {k: value_to_schema_type(v) for k, v in value.items()}
+ if isinstance(value, list):
+ if len(value) == 0:
+ return list
+ else:
+ return [value_to_schema_type(v) for v in value]
+ return type(value)
+
+ return cls(value_to_schema_type(data), **kwargs)
def __eq__(self, other):
if not isinstance(other, Schema):
@@ -121,15 +186,18 @@ class Schema(object):
return other.schema == self.schema
def __ne__(self, other):
- return not self == other
+ return not (self == other)
def __str__(self):
return str(self.schema)
def __repr__(self):
- return '<Schema(%s, extra=%s, required=%s) object at 0x%x>' % (self
- .schema, self._extra_to_name.get(self.extra, '??'), self.
- required, id(self))
+ return "<Schema(%s, extra=%s, required=%s) object at 0x%x>" % (
+ self.schema,
+ self._extra_to_name.get(self.extra, '??'),
+ self.required,
+ id(self),
+ )
def __call__(self, data):
"""Validate data against this schema."""
@@ -139,10 +207,183 @@ class Schema(object):
raise
except er.Invalid as e:
raise er.MultipleInvalid([e])
+ # return self.validate([], self.schema, data)
+
+ def _compile(self, schema):
+ if schema is Extra:
+ return lambda _, v: v
+ if schema is Self:
+ return lambda p, v: self._compiled(p, v)
+ elif hasattr(schema, "__voluptuous_compile__"):
+ return schema.__voluptuous_compile__(self)
+ if isinstance(schema, Object):
+ return self._compile_object(schema)
+ if isinstance(schema, collections.abc.Mapping):
+ return self._compile_dict(schema)
+ elif isinstance(schema, list):
+ return self._compile_list(schema)
+ elif isinstance(schema, tuple):
+ return self._compile_tuple(schema)
+ elif isinstance(schema, (frozenset, set)):
+ return self._compile_set(schema)
+ type_ = type(schema)
+ if inspect.isclass(schema):
+ type_ = schema
+ if type_ in (*primitive_types, object, type(None)) or callable(schema):
+ return _compile_scalar(schema)
+ raise er.SchemaError('unsupported schema data type %r' % type(schema).__name__)
def _compile_mapping(self, schema, invalid_msg=None):
"""Create validator for given mapping."""
- pass
+ invalid_msg = invalid_msg or 'mapping value'
+
+ # Keys that may be required
+ all_required_keys = set(
+ key
+ for key in schema
+ if key is not Extra
+ and (
+ (self.required and not isinstance(key, (Optional, Remove)))
+ or isinstance(key, Required)
+ )
+ )
+
+ # Keys that may have defaults
+ all_default_keys = set(
+ key
+ for key in schema
+ if isinstance(key, Required) or isinstance(key, Optional)
+ )
+
+ _compiled_schema = {}
+ for skey, svalue in schema.items():
+ new_key = self._compile(skey)
+ new_value = self._compile(svalue)
+ _compiled_schema[skey] = (new_key, new_value)
+
+ candidates = list(_iterate_mapping_candidates(_compiled_schema))
+
+ # After we have the list of candidates in the correct order, we want to apply some optimization so that each
+ # key in the data being validated will be matched against the relevant schema keys only.
+ # No point in matching against different keys
+ additional_candidates = []
+ candidates_by_key = {}
+ for skey, (ckey, cvalue) in candidates:
+ if type(skey) in primitive_types:
+ candidates_by_key.setdefault(skey, []).append((skey, (ckey, cvalue)))
+ elif isinstance(skey, Marker) and type(skey.schema) in primitive_types:
+ candidates_by_key.setdefault(skey.schema, []).append(
+ (skey, (ckey, cvalue))
+ )
+ else:
+ # These are wildcards such as 'int', 'str', 'Remove' and others which should be applied to all keys
+ additional_candidates.append((skey, (ckey, cvalue)))
+
+ def validate_mapping(path, iterable, out):
+ required_keys = all_required_keys.copy()
+
+ # Build a map of all provided key-value pairs.
+ # The type(out) is used to retain ordering in case a ordered
+ # map type is provided as input.
+ key_value_map = type(out)()
+ for key, value in iterable:
+ key_value_map[key] = value
+
+ # Insert default values for non-existing keys.
+ for key in all_default_keys:
+ if (
+ not isinstance(key.default, Undefined)
+ and key.schema not in key_value_map
+ ):
+ # A default value has been specified for this missing
+ # key, insert it.
+ key_value_map[key.schema] = key.default()
+
+ errors = []
+ for key, value in key_value_map.items():
+ key_path = path + [key]
+ remove_key = False
+
+ # Optimization. Validate against the matching key first, then fallback to the rest
+ relevant_candidates = itertools.chain(
+ candidates_by_key.get(key, []), additional_candidates
+ )
+
+ # compare each given key/value against all compiled key/values
+ # schema key, (compiled key, compiled value)
+ error = None
+ for skey, (ckey, cvalue) in relevant_candidates:
+ try:
+ new_key = ckey(key_path, key)
+ except er.Invalid as e:
+ if len(e.path) > len(key_path):
+ raise
+ if not error or len(e.path) > len(error.path):
+ error = e
+ continue
+ # Backtracking is not performed once a key is selected, so if
+ # the value is invalid we immediately throw an exception.
+ exception_errors = []
+ # check if the key is marked for removal
+ is_remove = new_key is Remove
+ try:
+ cval = cvalue(key_path, value)
+ # include if it's not marked for removal
+ if not is_remove:
+ out[new_key] = cval
+ else:
+ remove_key = True
+ continue
+ except er.MultipleInvalid as e:
+ exception_errors.extend(e.errors)
+ except er.Invalid as e:
+ exception_errors.append(e)
+
+ if exception_errors:
+ if is_remove or remove_key:
+ continue
+ for err in exception_errors:
+ if len(err.path) <= len(key_path):
+ err.error_type = invalid_msg
+ errors.append(err)
+ # If there is a validation error for a required
+ # key, this means that the key was provided.
+ # Discard the required key so it does not
+ # create an additional, noisy exception.
+ required_keys.discard(skey)
+ break
+
+ # Key and value okay, mark as found in case it was
+ # a Required() field.
+ required_keys.discard(skey)
+
+ break
+ else:
+ if remove_key:
+ # remove key
+ continue
+ elif self.extra == ALLOW_EXTRA:
+ out[key] = value
+ elif error:
+ errors.append(error)
+ elif self.extra != REMOVE_EXTRA:
+ errors.append(er.Invalid('extra keys not allowed', key_path))
+ # else REMOVE_EXTRA: ignore the key so it's removed from output
+
+ # for any required keys left that weren't found and don't have defaults:
+ for key in required_keys:
+ msg = (
+ key.msg
+ if hasattr(key, 'msg') and key.msg
+ else 'required key not provided'
+ )
+ errors.append(er.RequiredFieldInvalid(msg, path + [key]))
+ if errors:
+ raise er.MultipleInvalid(errors)
+
+ return out
+
+ return validate_mapping
def _compile_object(self, schema):
"""Validate an object.
@@ -162,7 +403,17 @@ class Schema(object):
... validate(Structure(one='three'))
"""
- pass
+ base_validate = self._compile_mapping(schema, invalid_msg='object value')
+
+ def validate_object(path, data):
+ if schema.cls is not UNDEFINED and not isinstance(data, schema.cls):
+ raise er.ObjectInvalid('expected a {0!r}'.format(schema.cls), path)
+ iterable = _iterate_object(data)
+ iterable = filter(lambda item: item[1] is not None, iterable)
+ out = base_validate(path, iterable, {})
+ return type(data)(**out)
+
+ return validate_object
def _compile_dict(self, schema):
"""Validate a dictionary.
@@ -240,7 +491,64 @@ class Schema(object):
"expected str for dictionary value @ data['adict']['strfield']"]
"""
- pass
+ base_validate = self._compile_mapping(schema, invalid_msg='dictionary value')
+
+ groups_of_exclusion = {}
+ groups_of_inclusion = {}
+ for node in schema:
+ if isinstance(node, Exclusive):
+ g = groups_of_exclusion.setdefault(node.group_of_exclusion, [])
+ g.append(node)
+ elif isinstance(node, Inclusive):
+ g = groups_of_inclusion.setdefault(node.group_of_inclusion, [])
+ g.append(node)
+
+ def validate_dict(path, data):
+ if not isinstance(data, dict):
+ raise er.DictInvalid('expected a dictionary', path)
+
+ errors = []
+ for label, group in groups_of_exclusion.items():
+ exists = False
+ for exclusive in group:
+ if exclusive.schema in data:
+ if exists:
+ msg = (
+ exclusive.msg
+ if hasattr(exclusive, 'msg') and exclusive.msg
+ else "two or more values in the same group of exclusion '%s'"
+ % label
+ )
+ next_path = path + [VirtualPathComponent(label)]
+ errors.append(er.ExclusiveInvalid(msg, next_path))
+ break
+ exists = True
+
+ if errors:
+ raise er.MultipleInvalid(errors)
+
+ for label, group in groups_of_inclusion.items():
+ included = [node.schema in data for node in group]
+ if any(included) and not all(included):
+ msg = (
+ "some but not all values in the same group of inclusion '%s'"
+ % label
+ )
+ for g in group:
+ if hasattr(g, 'msg') and g.msg:
+ msg = g.msg
+ break
+ next_path = path + [VirtualPathComponent(label)]
+ errors.append(er.InclusiveInvalid(msg, next_path))
+ break
+
+ if errors:
+ raise er.MultipleInvalid(errors)
+
+ out = data.__class__()
+ return base_validate(path, data.items(), out)
+
+ return validate_dict
def _compile_sequence(self, schema, seq_type):
"""Validate a sequence type.
@@ -255,7 +563,49 @@ class Schema(object):
>>> validator([1])
[1]
"""
- pass
+ _compiled = [self._compile(s) for s in schema]
+ seq_type_name = seq_type.__name__
+
+ def validate_sequence(path, data):
+ if not isinstance(data, seq_type):
+ raise er.SequenceTypeInvalid('expected a %s' % seq_type_name, path)
+
+ # Empty seq schema, reject any data.
+ if not schema:
+ if data:
+ raise er.MultipleInvalid(
+ [er.ValueInvalid('not a valid value', path if path else data)]
+ )
+ return data
+
+ out = []
+ invalid = None
+ errors = []
+ index_path = UNDEFINED
+ for i, value in enumerate(data):
+ index_path = path + [i]
+ invalid = None
+ for validate in _compiled:
+ try:
+ cval = validate(index_path, value)
+ if cval is not Remove: # do not include Remove values
+ out.append(cval)
+ break
+ except er.Invalid as e:
+ if len(e.path) > len(index_path):
+ raise
+ invalid = e
+ else:
+ errors.append(invalid)
+ if errors:
+ raise er.MultipleInvalid(errors)
+
+ if _isnamedtuple(data):
+ return type(data)(*out)
+ else:
+ return type(data)(out)
+
+ return validate_sequence
def _compile_tuple(self, schema):
"""Validate a tuple.
@@ -270,7 +620,7 @@ class Schema(object):
>>> validator((1,))
(1,)
"""
- pass
+ return self._compile_sequence(schema, tuple)
def _compile_list(self, schema):
"""Validate a list.
@@ -285,7 +635,7 @@ class Schema(object):
>>> validator([1])
[1]
"""
- pass
+ return self._compile_sequence(schema, list)
def _compile_set(self, schema):
"""Validate a set.
@@ -300,10 +650,39 @@ class Schema(object):
>>> with raises(er.MultipleInvalid, 'invalid value in set'):
... validator(set(['a']))
"""
- pass
-
- def extend(self, schema: Schemable, required: typing.Optional[bool]=
- None, extra: typing.Optional[int]=None) ->Schema:
+ type_ = type(schema)
+ type_name = type_.__name__
+
+ def validate_set(path, data):
+ if not isinstance(data, type_):
+ raise er.Invalid('expected a %s' % type_name, path)
+
+ _compiled = [self._compile(s) for s in schema]
+ errors = []
+ for value in data:
+ for validate in _compiled:
+ try:
+ validate(path, value)
+ break
+ except er.Invalid:
+ pass
+ else:
+ invalid = er.Invalid('invalid value in %s' % type_name, path)
+ errors.append(invalid)
+
+ if errors:
+ raise er.MultipleInvalid(errors)
+
+ return data
+
+ return validate_set
+
+ def extend(
+ self,
+ schema: Schemable,
+ required: typing.Optional[bool] = None,
+ extra: typing.Optional[int] = None,
+ ) -> Schema:
"""Create a new `Schema` by merging this and the provided `schema`.
Neither this `Schema` nor the provided `schema` are modified. The
@@ -316,7 +695,51 @@ class Schema(object):
:param required: if set, overrides `required` of this `Schema`
:param extra: if set, overrides `extra` of this `Schema`
"""
- pass
+
+ assert isinstance(self.schema, dict) and isinstance(
+ schema, dict
+ ), 'Both schemas must be dictionary-based'
+
+ result = self.schema.copy()
+
+ # returns the key that may have been passed as an argument to Marker constructor
+ def key_literal(key):
+ return key.schema if isinstance(key, Marker) else key
+
+ # build a map that takes the key literals to the needed objects
+ # literal -> Required|Optional|literal
+ result_key_map = dict((key_literal(key), key) for key in result)
+
+ # for each item in the extension schema, replace duplicates
+ # or add new keys
+ for key, value in schema.items():
+ # if the key is already in the dictionary, we need to replace it
+ # transform key to literal before checking presence
+ if key_literal(key) in result_key_map:
+ result_key = result_key_map[key_literal(key)]
+ result_value = result[result_key]
+
+ # if both are dictionaries, we need to extend recursively
+ # create the new extended sub schema, then remove the old key and add the new one
+ if isinstance(result_value, dict) and isinstance(value, dict):
+ new_value = Schema(result_value).extend(value).schema
+ del result[result_key]
+ result[key] = new_value
+ # one or the other or both are not sub-schemas, simple replacement is fine
+ # remove old key and add new one
+ else:
+ del result[result_key]
+ result[key] = value
+
+ # key is new and can simply be added
+ else:
+ result[key] = value
+
+ # recompile and send old object
+ result_cls = type(self)
+ result_required = required if required is not None else self.required
+ result_extra = extra if extra is not None else self.extra
+ return result_cls(result, required=result_required, extra=result_extra)
def _compile_scalar(schema):
@@ -338,12 +761,78 @@ def _compile_scalar(schema):
>>> with raises(er.Invalid, 'not a valid value'):
... _compile_scalar(lambda v: float(v))([], 'a')
"""
- pass
+ if inspect.isclass(schema):
+
+ def validate_instance(path, data):
+ if isinstance(data, schema):
+ return data
+ else:
+ msg = 'expected %s' % schema.__name__
+ raise er.TypeInvalid(msg, path)
+
+ return validate_instance
+
+ if callable(schema):
+
+ def validate_callable(path, data):
+ try:
+ return schema(data)
+ except ValueError:
+ raise er.ValueInvalid('not a valid value', path)
+ except er.Invalid as e:
+ e.prepend(path)
+ raise
+
+ return validate_callable
+
+ def validate_value(path, data):
+ if data != schema:
+ raise er.ScalarInvalid('not a valid value', path)
+ return data
+
+ return validate_value
def _compile_itemsort():
- """return sort function of mappings"""
- pass
+ '''return sort function of mappings'''
+
+ def is_extra(key_):
+ return key_ is Extra
+
+ def is_remove(key_):
+ return isinstance(key_, Remove)
+
+ def is_marker(key_):
+ return isinstance(key_, Marker)
+
+ def is_type(key_):
+ return inspect.isclass(key_)
+
+ def is_callable(key_):
+ return callable(key_)
+
+ # priority list for map sorting (in order of checking)
+ # We want Extra to match last, because it's a catch-all. On the other hand,
+ # Remove markers should match first (since invalid values will not
+ # raise an Error, instead the validator will check if other schemas match
+ # the same value).
+ priority = [
+ (1, is_remove), # Remove highest priority after values
+ (2, is_marker), # then other Markers
+ (4, is_type), # types/classes lowest before Extra
+ (3, is_callable), # callables after markers
+ (5, is_extra), # Extra lowest priority
+ ]
+
+ def item_priority(item_):
+ key_ = item_[0]
+ for i, check_ in priority:
+ if check_(key_):
+ return i
+ # values have highest priorities
+ return 0
+
+ return item_priority
_sort_item = _compile_itemsort()
@@ -351,7 +840,10 @@ _sort_item = _compile_itemsort()
def _iterate_mapping_candidates(schema):
"""Iterate over schema in a meaningful order."""
- pass
+ # Without this, Extra might appear first in the iterator, and fail to
+ # validate a key even though it's a Required that has its own validation,
+ # generating a false positive.
+ return sorted(schema.items(), key=_sort_item)
def _iterate_object(obj):
@@ -359,7 +851,23 @@ def _iterate_object(obj):
defined __slots__.
"""
- pass
+ d = {}
+ try:
+ d = vars(obj)
+ except TypeError:
+ # maybe we have named tuple here?
+ if hasattr(obj, '_asdict'):
+ d = obj._asdict()
+ for item in d.items():
+ yield item
+ try:
+ slots = obj.__slots__
+ except AttributeError:
+ pass
+ else:
+ for key in slots:
+ if key != '__dict__':
+ yield (key, getattr(obj, key))
class Msg(object):
@@ -391,11 +899,16 @@ class Msg(object):
... assert isinstance(e.errors[0], er.RangeInvalid)
"""
- def __init__(self, schema: Schemable, msg: str, cls: typing.Optional[
- typing.Type[Error]]=None) ->None:
+ def __init__(
+ self,
+ schema: Schemable,
+ msg: str,
+ cls: typing.Optional[typing.Type[Error]] = None,
+ ) -> None:
if cls and not issubclass(cls, er.Invalid):
raise er.SchemaError(
- 'Msg can only use subclases of Invalid as custom class')
+ "Msg can only use subclases of Invalid as custom class"
+ )
self._schema = schema
self.schema = Schema(schema)
self.msg = msg
@@ -417,13 +930,12 @@ class Msg(object):
class Object(dict):
"""Indicate that we should work with attributes, not keys."""
- def __init__(self, schema: typing.Any, cls: object=UNDEFINED) ->None:
+ def __init__(self, schema: typing.Any, cls: object = UNDEFINED) -> None:
self.cls = cls
super(Object, self).__init__(schema)
class VirtualPathComponent(str):
-
def __str__(self):
return '<' + self + '>'
@@ -437,15 +949,20 @@ class Marker(object):
`description` is an optional field, unused by Voluptuous itself, but can be
introspected by any external tool, for example to generate schema documentation.
"""
- __slots__ = 'schema', '_schema', 'msg', 'description', '__hash__'
- def __init__(self, schema_: Schemable, msg: typing.Optional[str]=None,
- description: (typing.Any | None)=None) ->None:
+ __slots__ = ('schema', '_schema', 'msg', 'description', '__hash__')
+
+ def __init__(
+ self,
+ schema_: Schemable,
+ msg: typing.Optional[str] = None,
+ description: typing.Any | None = None,
+ ) -> None:
self.schema: typing.Any = schema_
self._schema = Schema(schema_)
self.msg = msg
self.description = description
- self.__hash__ = cache(lambda : hash(schema_))
+ self.__hash__ = cache(lambda: hash(schema_)) # type: ignore[method-assign]
def __call__(self, v):
try:
@@ -470,7 +987,7 @@ class Marker(object):
return self.schema == other
def __ne__(self, other):
- return not self.schema == other
+ return not (self.schema == other)
class Optional(Marker):
@@ -496,11 +1013,14 @@ class Optional(Marker):
{'key2': 'value'}
"""
- def __init__(self, schema: Schemable, msg: typing.Optional[str]=None,
- default: typing.Any=UNDEFINED, description: (typing.Any | None)=None
- ) ->None:
- super(Optional, self).__init__(schema, msg=msg, description=description
- )
+ def __init__(
+ self,
+ schema: Schemable,
+ msg: typing.Optional[str] = None,
+ default: typing.Any = UNDEFINED,
+ description: typing.Any | None = None,
+ ) -> None:
+ super(Optional, self).__init__(schema, msg=msg, description=description)
self.default = default_factory(default)
@@ -540,11 +1060,14 @@ class Exclusive(Optional):
... 'social': {'social_network': 'barfoo', 'token': 'tEMp'}})
"""
- def __init__(self, schema: Schemable, group_of_exclusion: str, msg:
- typing.Optional[str]=None, description: (typing.Any | None)=None
- ) ->None:
- super(Exclusive, self).__init__(schema, msg=msg, description=
- description)
+ def __init__(
+ self,
+ schema: Schemable,
+ group_of_exclusion: str,
+ msg: typing.Optional[str] = None,
+ description: typing.Any | None = None,
+ ) -> None:
+ super(Exclusive, self).__init__(schema, msg=msg, description=description)
self.group_of_exclusion = group_of_exclusion
@@ -590,11 +1113,17 @@ class Inclusive(Optional):
True
"""
- def __init__(self, schema: Schemable, group_of_inclusion: str, msg:
- typing.Optional[str]=None, description: (typing.Any | None)=None,
- default: typing.Any=UNDEFINED) ->None:
- super(Inclusive, self).__init__(schema, msg=msg, default=default,
- description=description)
+ def __init__(
+ self,
+ schema: Schemable,
+ group_of_inclusion: str,
+ msg: typing.Optional[str] = None,
+ description: typing.Any | None = None,
+ default: typing.Any = UNDEFINED,
+ ) -> None:
+ super(Inclusive, self).__init__(
+ schema, msg=msg, default=default, description=description
+ )
self.group_of_inclusion = group_of_inclusion
@@ -613,11 +1142,14 @@ class Required(Marker):
{'key': []}
"""
- def __init__(self, schema: Schemable, msg: typing.Optional[str]=None,
- default: typing.Any=UNDEFINED, description: (typing.Any | None)=None
- ) ->None:
- super(Required, self).__init__(schema, msg=msg, description=description
- )
+ def __init__(
+ self,
+ schema: Schemable,
+ msg: typing.Optional[str] = None,
+ default: typing.Any = UNDEFINED,
+ description: typing.Any | None = None,
+ ) -> None:
+ super(Required, self).__init__(schema, msg=msg, description=description)
self.default = default_factory(default)
@@ -636,21 +1168,27 @@ class Remove(Marker):
[1, 2, 3, 5, '7']
"""
- def __init__(self, schema_: Schemable, msg: typing.Optional[str]=None,
- description: (typing.Any | None)=None) ->None:
+ def __init__(
+ self,
+ schema_: Schemable,
+ msg: typing.Optional[str] = None,
+ description: typing.Any | None = None,
+ ) -> None:
super().__init__(schema_, msg, description)
- self.__hash__ = cache(lambda : object.__hash__(self))
+ self.__hash__ = cache(lambda: object.__hash__(self)) # type: ignore[method-assign]
def __call__(self, schema: Schemable):
super(Remove, self).__call__(schema)
return self.__class__
def __repr__(self):
- return 'Remove(%r)' % (self.schema,)
+ return "Remove(%r)" % (self.schema,)
-def message(default: typing.Optional[str]=None, cls: typing.Optional[typing
- .Type[Error]]=None) ->typing.Callable:
+def message(
+ default: typing.Optional[str] = None,
+ cls: typing.Optional[typing.Type[Error]] = None,
+) -> typing.Callable:
"""Convenience decorator to allow functions to provide a message.
Set a default message:
@@ -678,20 +1216,56 @@ def message(default: typing.Optional[str]=None, cls: typing.Optional[typing
... except er.MultipleInvalid as e:
... assert isinstance(e.errors[0], IntegerInvalid)
"""
- pass
+ if cls and not issubclass(cls, er.Invalid):
+ raise er.SchemaError(
+ "message can only use subclases of Invalid as custom class"
+ )
+
+ def decorator(f):
+ @wraps(f)
+ def check(msg=None, clsoverride=None):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ try:
+ return f(*args, **kwargs)
+ except ValueError:
+ raise (clsoverride or cls or er.ValueInvalid)(
+ msg or default or 'invalid value'
+ )
+
+ return wrapper
+
+ return check
+
+ return decorator
def _args_to_dict(func, args):
"""Returns argument names as values as key-value pairs."""
- pass
+ if sys.version_info >= (3, 0):
+ arg_count = func.__code__.co_argcount
+ arg_names = func.__code__.co_varnames[:arg_count]
+ else:
+ arg_count = func.func_code.co_argcount
+ arg_names = func.func_code.co_varnames[:arg_count]
+
+ arg_value_list = list(args)
+ arguments = dict(
+ (arg_name, arg_value_list[i])
+ for i, arg_name in enumerate(arg_names)
+ if i < len(arg_value_list)
+ )
+ return arguments
def _merge_args_with_kwargs(args_dict, kwargs_dict):
"""Merge args with kwargs."""
- pass
+ ret = args_dict.copy()
+ ret.update(kwargs_dict)
+ return ret
-def validate(*a, **kw) ->typing.Callable:
+def validate(*a, **kw) -> typing.Callable:
"""Decorator for validating arguments of a function against a given schema.
Set restrictions for arguments:
@@ -707,4 +1281,35 @@ def validate(*a, **kw) ->typing.Callable:
... return arg1 * 2
"""
- pass
+ RETURNS_KEY = '__return__'
+
+ def validate_schema_decorator(func):
+ returns_defined = False
+ returns = None
+
+ schema_args_dict = _args_to_dict(func, a)
+ schema_arguments = _merge_args_with_kwargs(schema_args_dict, kw)
+
+ if RETURNS_KEY in schema_arguments:
+ returns_defined = True
+ returns = schema_arguments[RETURNS_KEY]
+ del schema_arguments[RETURNS_KEY]
+
+ input_schema = (
+ Schema(schema_arguments, extra=ALLOW_EXTRA)
+ if len(schema_arguments) != 0
+ else lambda x: x
+ )
+ output_schema = Schema(returns) if returns_defined else lambda x: x
+
+ @wraps(func)
+ def func_wrapper(*args, **kwargs):
+ args_dict = _args_to_dict(func, args)
+ arguments = _merge_args_with_kwargs(args_dict, kwargs)
+ validated_arguments = input_schema(arguments)
+ output = func(**validated_arguments)
+ return output_schema(output)
+
+ return func_wrapper
+
+ return validate_schema_decorator
diff --git a/voluptuous/util.py b/voluptuous/util.py
index fe15b1a..0bf9302 100644
--- a/voluptuous/util.py
+++ b/voluptuous/util.py
@@ -1,59 +1,65 @@
+# F401: "imported but unused"
+# fmt: off
import typing
-from voluptuous import validators
-from voluptuous.error import Invalid, LiteralInvalid, TypeInvalid
-from voluptuous.schema_builder import DefaultFactory
-from voluptuous.schema_builder import Schema, default_factory, raises
+
+from voluptuous import validators # noqa: F401
+from voluptuous.error import Invalid, LiteralInvalid, TypeInvalid # noqa: F401
+from voluptuous.schema_builder import DefaultFactory # noqa: F401
+from voluptuous.schema_builder import Schema, default_factory, raises # noqa: F401
+
+# fmt: on
+
__author__ = 'tusharmakkar08'
-def Lower(v: str) ->str:
+def Lower(v: str) -> str:
"""Transform a string to lower case.
>>> s = Schema(Lower)
>>> s('HI')
'hi'
"""
- pass
+ return str(v).lower()
-def Upper(v: str) ->str:
+def Upper(v: str) -> str:
"""Transform a string to upper case.
>>> s = Schema(Upper)
>>> s('hi')
'HI'
"""
- pass
+ return str(v).upper()
-def Capitalize(v: str) ->str:
+def Capitalize(v: str) -> str:
"""Capitalise a string.
>>> s = Schema(Capitalize)
>>> s('hello world')
'Hello world'
"""
- pass
+ return str(v).capitalize()
-def Title(v: str) ->str:
+def Title(v: str) -> str:
"""Title case a string.
>>> s = Schema(Title)
>>> s('hello world')
'Hello World'
"""
- pass
+ return str(v).title()
-def Strip(v: str) ->str:
+def Strip(v: str) -> str:
"""Strip whitespace from a string.
>>> s = Schema(Strip)
>>> s(' hello world ')
'hello world'
"""
- pass
+ return str(v).strip()
class DefaultTo(object):
@@ -67,7 +73,7 @@ class DefaultTo(object):
[]
"""
- def __init__(self, default_value, msg: typing.Optional[str]=None) ->None:
+ def __init__(self, default_value, msg: typing.Optional[str] = None) -> None:
self.default_value = default_factory(default_value)
self.msg = msg
@@ -90,7 +96,7 @@ class SetTo(object):
42
"""
- def __init__(self, value) ->None:
+ def __init__(self, value) -> None:
self.value = default_factory(value)
def __call__(self, v):
@@ -112,15 +118,14 @@ class Set(object):
... s([set([1, 2]), set([3, 4])])
"""
- def __init__(self, msg: typing.Optional[str]=None) ->None:
+ def __init__(self, msg: typing.Optional[str] = None) -> None:
self.msg = msg
def __call__(self, v):
try:
set_v = set(v)
except Exception as e:
- raise TypeInvalid(self.msg or 'cannot be presented as set: {0}'
- .format(e))
+ raise TypeInvalid(self.msg or 'cannot be presented as set: {0}'.format(e))
return set_v
def __repr__(self):
@@ -128,14 +133,12 @@ class Set(object):
class Literal(object):
-
- def __init__(self, lit) ->None:
+ def __init__(self, lit) -> None:
self.lit = lit
- def __call__(self, value, msg: typing.Optional[str]=None):
+ def __call__(self, value, msg: typing.Optional[str] = None):
if self.lit != value:
- raise LiteralInvalid(msg or '%s not match for %s' % (value,
- self.lit))
+ raise LiteralInvalid(msg or '%s not match for %s' % (value, self.lit))
else:
return self.lit
diff --git a/voluptuous/validators.py b/voluptuous/validators.py
index 88b50f6..d385260 100644
--- a/voluptuous/validators.py
+++ b/voluptuous/validators.py
@@ -1,4 +1,6 @@
+# fmt: off
from __future__ import annotations
+
import datetime
import os
import re
@@ -6,30 +8,73 @@ import sys
import typing
from decimal import Decimal, InvalidOperation
from functools import wraps
-from voluptuous.error import AllInvalid, AnyInvalid, BooleanInvalid, CoerceInvalid, ContainsInvalid, DateInvalid, DatetimeInvalid, DirInvalid, EmailInvalid, ExactSequenceInvalid, FalseInvalid, FileInvalid, InInvalid, Invalid, LengthInvalid, MatchInvalid, MultipleInvalid, NotEnoughValid, NotInInvalid, PathInvalid, RangeInvalid, TooManyValid, TrueInvalid, TypeInvalid, UrlInvalid
-from voluptuous.schema_builder import Schema, Schemable, message, raises
+
+from voluptuous.error import (
+ AllInvalid, AnyInvalid, BooleanInvalid, CoerceInvalid, ContainsInvalid, DateInvalid,
+ DatetimeInvalid, DirInvalid, EmailInvalid, ExactSequenceInvalid, FalseInvalid,
+ FileInvalid, InInvalid, Invalid, LengthInvalid, MatchInvalid, MultipleInvalid,
+ NotEnoughValid, NotInInvalid, PathInvalid, RangeInvalid, TooManyValid, TrueInvalid,
+ TypeInvalid, UrlInvalid,
+)
+
+# F401: flake8 complains about 'raises' not being used, but it is used in doctests
+from voluptuous.schema_builder import Schema, Schemable, message, raises # noqa: F401
+
if typing.TYPE_CHECKING:
from _typeshed import SupportsAllComparisons
+
+# fmt: on
+
+
Enum: typing.Union[type, None]
try:
from enum import Enum
except ImportError:
Enum = None
+
+
if sys.version_info >= (3,):
import urllib.parse as urlparse
+
basestring = str
else:
import urlparse
+
+# Taken from https://github.com/kvesteri/validators/blob/master/validators/email.py
+# fmt: off
USER_REGEX = re.compile(
- '(?:(^[-!#$%&\'*+/=?^_`{}|~0-9A-Z]+(\\.[-!#$%&\'*+/=?^_`{}|~0-9A-Z]+)*$|^"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]|\\\\[\\001-\\011\\013\\014\\016-\\177])*"$))\\Z'
- , re.IGNORECASE)
+ # start anchor, because fullmatch is not available in python 2.7
+ "(?:"
+ # dot-atom
+ r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+"
+ r"(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*$"
+ # quoted-string
+ r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|'
+ r"""\\[\001-\011\013\014\016-\177])*"$)"""
+ # end anchor, because fullmatch is not available in python 2.7
+ r")\Z",
+ re.IGNORECASE,
+)
DOMAIN_REGEX = re.compile(
- '(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+(?:[A-Z]{2,6}\\.?|[A-Z0-9-]{2,}\\.?$)|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\]$)\\Z'
- , re.IGNORECASE)
+ # start anchor, because fullmatch is not available in python 2.7
+ "(?:"
+ # domain
+ r'(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+'
+ # tld
+ r'(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?$)'
+ # literal form, ipv4 address (SMTP 4.1.3)
+ r'|^\[(25[0-5]|2[0-4]\d|[0-1]?\d?\d)'
+ r'(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}\]$'
+ # end anchor, because fullmatch is not available in python 2.7
+ r")\Z",
+ re.IGNORECASE,
+)
+# fmt: on
+
__author__ = 'tusharmakkar08'
-def truth(f: typing.Callable) ->typing.Callable:
+def truth(f: typing.Callable) -> typing.Callable:
"""Convenience decorator to convert truth functions into validators.
>>> @truth
@@ -41,7 +86,15 @@ def truth(f: typing.Callable) ->typing.Callable:
>>> with raises(MultipleInvalid, 'not a valid value'):
... validate('/notavaliddir')
"""
- pass
+
+ @wraps(f)
+ def check(v):
+ t = f(v)
+ if not t:
+ raise ValueError
+ return v
+
+ return check
class Coerce(object):
@@ -65,8 +118,11 @@ class Coerce(object):
... validate('foo')
"""
- def __init__(self, type: typing.Union[type, typing.Callable], msg:
- typing.Optional[str]=None) ->None:
+ def __init__(
+ self,
+ type: typing.Union[type, typing.Callable],
+ msg: typing.Optional[str] = None,
+ ) -> None:
self.type = type
self.msg = msg
self.type_name = type.__name__
@@ -75,10 +131,9 @@ class Coerce(object):
try:
return self.type(v)
except (ValueError, TypeError, InvalidOperation):
- msg = self.msg or 'expected %s' % self.type_name
+ msg = self.msg or ('expected %s' % self.type_name)
if not self.msg and Enum and issubclass(self.type, Enum):
- msg += ' or one of %s' % str([e.value for e in self.type])[1:-1
- ]
+ msg += " or one of %s" % str([e.value for e in self.type])[1:-1]
raise CoerceInvalid(msg)
def __repr__(self):
@@ -109,7 +164,7 @@ def IsTrue(v):
... except MultipleInvalid as e:
... assert isinstance(e.errors[0], TrueInvalid)
"""
- pass
+ return v
@message('value was not false', cls=FalseInvalid)
@@ -129,7 +184,9 @@ def IsFalse(v):
... except MultipleInvalid as e:
... assert isinstance(e.errors[0], FalseInvalid)
"""
- pass
+ if v:
+ raise ValueError
+ return v
@message('expected boolean', cls=BooleanInvalid)
@@ -153,7 +210,14 @@ def Boolean(v):
... except MultipleInvalid as e:
... assert isinstance(e.errors[0], BooleanInvalid)
"""
- pass
+ if isinstance(v, basestring):
+ v = v.lower()
+ if v in ('1', 'true', 'yes', 'on', 'enable'):
+ return True
+ if v in ('0', 'false', 'no', 'off', 'disable'):
+ return False
+ raise ValueError
+ return bool(v)
class _WithSubValidators(object):
@@ -164,14 +228,15 @@ class _WithSubValidators(object):
sub-validators are compiled by the parent `Schema`.
"""
- def __init__(self, *validators, msg=None, required=False, discriminant=
- None, **kwargs) ->None:
+ def __init__(
+ self, *validators, msg=None, required=False, discriminant=None, **kwargs
+ ) -> None:
self.validators = validators
self.msg = msg
self.required = required
self.discriminant = discriminant
- def __voluptuous_compile__(self, schema: Schema) ->typing.Callable:
+ def __voluptuous_compile__(self, schema: Schema) -> typing.Callable:
self._compiled = []
old_required = schema.required
self.schema = schema
@@ -181,12 +246,32 @@ class _WithSubValidators(object):
schema.required = old_required
return self._run
+ def _run(self, path: typing.List[typing.Hashable], value):
+ if self.discriminant is not None:
+ self._compiled = [
+ self.schema._compile(v)
+ for v in self.discriminant(value, self.validators)
+ ]
+
+ return self._exec(self._compiled, value, path)
+
def __call__(self, v):
return self._exec((Schema(val) for val in self.validators), v)
def __repr__(self):
- return '%s(%s, msg=%r)' % (self.__class__.__name__, ', '.join(repr(
- v) for v in self.validators), self.msg)
+ return '%s(%s, msg=%r)' % (
+ self.__class__.__name__,
+ ", ".join(repr(v) for v in self.validators),
+ self.msg,
+ )
+
+ def _exec(
+ self,
+ funcs: typing.Iterable,
+ v,
+ path: typing.Optional[typing.List[typing.Hashable]] = None,
+ ):
+ raise NotImplementedError()
class Any(_WithSubValidators):
@@ -214,7 +299,24 @@ class Any(_WithSubValidators):
... validate(4)
"""
+ def _exec(self, funcs, v, path=None):
+ error = None
+ for func in funcs:
+ try:
+ if path is None:
+ return func(v)
+ else:
+ return func(path, v)
+ except Invalid as e:
+ if error is None or len(e.path) > len(error.path):
+ error = e
+ else:
+ if error:
+ raise error if self.msg is None else AnyInvalid(self.msg, path=path)
+ raise AnyInvalid(self.msg or 'no valid value found', path=path)
+
+# Convenience alias
Or = Any
@@ -239,7 +341,24 @@ class Union(_WithSubValidators):
Without the discriminant, the exception would be "extra keys not allowed @ data['b_val']"
"""
+ def _exec(self, funcs, v, path=None):
+ error = None
+ for func in funcs:
+ try:
+ if path is None:
+ return func(v)
+ else:
+ return func(path, v)
+ except Invalid as e:
+ if error is None or len(e.path) > len(error.path):
+ error = e
+ else:
+ if error:
+ raise error if self.msg is None else AnyInvalid(self.msg, path=path)
+ raise AnyInvalid(self.msg or 'no valid value found', path=path)
+
+# Convenience alias
Switch = Union
@@ -256,7 +375,19 @@ class All(_WithSubValidators):
10
"""
+ def _exec(self, funcs, v, path=None):
+ try:
+ for func in funcs:
+ if path is None:
+ v = func(v)
+ else:
+ v = func(path, v)
+ except Invalid as e:
+ raise e if self.msg is None else AllInvalid(self.msg, path=path)
+ return v
+
+# Convenience alias
And = All
@@ -279,8 +410,9 @@ class Match(object):
'0x123ef4'
"""
- def __init__(self, pattern: typing.Union[re.Pattern, str], msg: typing.
- Optional[str]=None) ->None:
+ def __init__(
+ self, pattern: typing.Union[re.Pattern, str], msg: typing.Optional[str] = None
+ ) -> None:
if isinstance(pattern, basestring):
pattern = re.compile(pattern)
self.pattern = pattern
@@ -290,11 +422,12 @@ class Match(object):
try:
match = self.pattern.match(v)
except TypeError:
- raise MatchInvalid('expected string or buffer')
+ raise MatchInvalid("expected string or buffer")
if not match:
- raise MatchInvalid(self.msg or
- 'does not match regular expression {}'.format(self.pattern.
- pattern))
+ raise MatchInvalid(
+ self.msg
+ or 'does not match regular expression {}'.format(self.pattern.pattern)
+ )
return v
def __repr__(self):
@@ -310,8 +443,12 @@ class Replace(object):
'I say goodbye'
"""
- def __init__(self, pattern: typing.Union[re.Pattern, str], substitution:
- str, msg: typing.Optional[str]=None) ->None:
+ def __init__(
+ self,
+ pattern: typing.Union[re.Pattern, str],
+ substitution: str,
+ msg: typing.Optional[str] = None,
+ ) -> None:
if isinstance(pattern, basestring):
pattern = re.compile(pattern)
self.pattern = pattern
@@ -322,8 +459,18 @@ class Replace(object):
return self.pattern.sub(self.substitution, v)
def __repr__(self):
- return 'Replace(%r, %r, msg=%r)' % (self.pattern.pattern, self.
- substitution, self.msg)
+ return 'Replace(%r, %r, msg=%r)' % (
+ self.pattern.pattern,
+ self.substitution,
+ self.msg,
+ )
+
+
+def _url_validation(v: str) -> urlparse.ParseResult:
+ parsed = urlparse.urlparse(v)
+ if not parsed.scheme or not parsed.netloc:
+ raise UrlInvalid("must have a URL scheme and host")
+ return parsed
@message('expected an email address', cls=EmailInvalid)
@@ -340,7 +487,16 @@ def Email(v):
>>> s('t@x.com')
't@x.com'
"""
- pass
+ try:
+ if not v or "@" not in v:
+ raise EmailInvalid("Invalid email address")
+ user_part, domain_part = v.rsplit('@', 1)
+
+ if not (USER_REGEX.match(user_part) and DOMAIN_REGEX.match(domain_part)):
+ raise EmailInvalid("Invalid email address")
+ return v
+ except: # noqa: E722
+ raise ValueError
@message('expected a fully qualified domain name URL', cls=UrlInvalid)
@@ -353,7 +509,13 @@ def FqdnUrl(v):
>>> s('http://w3.org')
'http://w3.org'
"""
- pass
+ try:
+ parsed_url = _url_validation(v)
+ if "." not in parsed_url.netloc:
+ raise UrlInvalid("must have a domain name in URL")
+ return v
+ except: # noqa: E722
+ raise ValueError
@message('expected a URL', cls=UrlInvalid)
@@ -366,7 +528,11 @@ def Url(v):
>>> s('http://w3.org')
'http://w3.org'
"""
- pass
+ try:
+ _url_validation(v)
+ return v
+ except: # noqa: E722
+ raise ValueError
@message('Not a file', cls=FileInvalid)
@@ -381,7 +547,14 @@ def IsFile(v):
>>> with raises(FileInvalid, 'Not a file'):
... IsFile()(None)
"""
- pass
+ try:
+ if v:
+ v = str(v)
+ return os.path.isfile(v)
+ else:
+ raise FileInvalid('Not a file')
+ except TypeError:
+ raise FileInvalid('Not a file')
@message('Not a directory', cls=DirInvalid)
@@ -394,7 +567,14 @@ def IsDir(v):
>>> with raises(DirInvalid, 'Not a directory'):
... IsDir()(None)
"""
- pass
+ try:
+ if v:
+ v = str(v)
+ return os.path.isdir(v)
+ else:
+ raise DirInvalid("Not a directory")
+ except TypeError:
+ raise DirInvalid("Not a directory")
@message('path does not exist', cls=PathInvalid)
@@ -409,10 +589,17 @@ def PathExists(v):
>>> with raises(PathInvalid, 'Not a Path'):
... PathExists()(None)
"""
- pass
+ try:
+ if v:
+ v = str(v)
+ return os.path.exists(v)
+ else:
+ raise PathInvalid("Not a Path")
+ except TypeError:
+ raise PathInvalid("Not a Path")
-def Maybe(validator: Schemable, msg: typing.Optional[str]=None):
+def Maybe(validator: Schemable, msg: typing.Optional[str] = None):
"""Validate that the object matches given validator or is None.
:raises Invalid: If the value does not match the given validator and is not
@@ -425,7 +612,7 @@ def Maybe(validator: Schemable, msg: typing.Optional[str]=None):
... s("string")
"""
- pass
+ return Any(None, validator, msg=msg)
class Range(object):
@@ -449,9 +636,14 @@ class Range(object):
... Schema(Range(max=10, max_included=False))(20)
"""
- def __init__(self, min: (SupportsAllComparisons | None)=None, max: (
- SupportsAllComparisons | None)=None, min_included: bool=True,
- max_included: bool=True, msg: typing.Optional[str]=None) ->None:
+ def __init__(
+ self,
+ min: SupportsAllComparisons | None = None,
+ max: SupportsAllComparisons | None = None,
+ min_included: bool = True,
+ max_included: bool = True,
+ msg: typing.Optional[str] = None,
+ ) -> None:
self.min = min
self.max = max
self.min_included = min_included
@@ -462,28 +654,41 @@ class Range(object):
try:
if self.min_included:
if self.min is not None and not v >= self.min:
- raise RangeInvalid(self.msg or
- 'value must be at least %s' % self.min)
- elif self.min is not None and not v > self.min:
- raise RangeInvalid(self.msg or
- 'value must be higher than %s' % self.min)
+ raise RangeInvalid(
+ self.msg or 'value must be at least %s' % self.min
+ )
+ else:
+ if self.min is not None and not v > self.min:
+ raise RangeInvalid(
+ self.msg or 'value must be higher than %s' % self.min
+ )
if self.max_included:
if self.max is not None and not v <= self.max:
- raise RangeInvalid(self.msg or
- 'value must be at most %s' % self.max)
- elif self.max is not None and not v < self.max:
- raise RangeInvalid(self.msg or
- 'value must be lower than %s' % self.max)
+ raise RangeInvalid(
+ self.msg or 'value must be at most %s' % self.max
+ )
+ else:
+ if self.max is not None and not v < self.max:
+ raise RangeInvalid(
+ self.msg or 'value must be lower than %s' % self.max
+ )
+
return v
+
+ # Objects that lack a partial ordering, e.g. None or strings will raise TypeError
except TypeError:
- raise RangeInvalid(self.msg or
- 'invalid value or type (must have a partial ordering)')
+ raise RangeInvalid(
+ self.msg or 'invalid value or type (must have a partial ordering)'
+ )
def __repr__(self):
- return (
- 'Range(min=%r, max=%r, min_included=%r, max_included=%r, msg=%r)' %
- (self.min, self.max, self.min_included, self.max_included, self
- .msg))
+ return 'Range(min=%r, max=%r, min_included=%r, max_included=%r, msg=%r)' % (
+ self.min,
+ self.max,
+ self.min_included,
+ self.max_included,
+ self.msg,
+ )
class Clamp(object):
@@ -500,9 +705,12 @@ class Clamp(object):
0
"""
- def __init__(self, min: (SupportsAllComparisons | None)=None, max: (
- SupportsAllComparisons | None)=None, msg: typing.Optional[str]=None
- ) ->None:
+ def __init__(
+ self,
+ min: SupportsAllComparisons | None = None,
+ max: SupportsAllComparisons | None = None,
+ msg: typing.Optional[str] = None,
+ ) -> None:
self.min = min
self.max = max
self.msg = msg
@@ -514,9 +722,12 @@ class Clamp(object):
if self.max is not None and v > self.max:
v = self.max
return v
+
+ # Objects that lack a partial ordering, e.g. None or strings will raise TypeError
except TypeError:
- raise RangeInvalid(self.msg or
- 'invalid value or type (must have a partial ordering)')
+ raise RangeInvalid(
+ self.msg or 'invalid value or type (must have a partial ordering)'
+ )
def __repr__(self):
return 'Clamp(min=%s, max=%s)' % (self.min, self.max)
@@ -525,9 +736,12 @@ class Clamp(object):
class Length(object):
"""The length of a value must be in a certain range."""
- def __init__(self, min: (SupportsAllComparisons | None)=None, max: (
- SupportsAllComparisons | None)=None, msg: typing.Optional[str]=None
- ) ->None:
+ def __init__(
+ self,
+ min: SupportsAllComparisons | None = None,
+ max: SupportsAllComparisons | None = None,
+ msg: typing.Optional[str] = None,
+ ) -> None:
self.min = min
self.max = max
self.msg = msg
@@ -535,12 +749,16 @@ class Length(object):
def __call__(self, v):
try:
if self.min is not None and len(v) < self.min:
- raise LengthInvalid(self.msg or
- 'length of value must be at least %s' % self.min)
+ raise LengthInvalid(
+ self.msg or 'length of value must be at least %s' % self.min
+ )
if self.max is not None and len(v) > self.max:
- raise LengthInvalid(self.msg or
- 'length of value must be at most %s' % self.max)
+ raise LengthInvalid(
+ self.msg or 'length of value must be at most %s' % self.max
+ )
return v
+
+ # Objects that have no length e.g. None or strings will raise TypeError
except TypeError:
raise RangeInvalid(self.msg or 'invalid value or type')
@@ -550,10 +768,12 @@ class Length(object):
class Datetime(object):
"""Validate that the value matches the datetime format."""
+
DEFAULT_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
- def __init__(self, format: typing.Optional[str]=None, msg: typing.
- Optional[str]=None) ->None:
+ def __init__(
+ self, format: typing.Optional[str] = None, msg: typing.Optional[str] = None
+ ) -> None:
self.format = format or self.DEFAULT_FORMAT
self.msg = msg
@@ -561,8 +781,9 @@ class Datetime(object):
try:
datetime.datetime.strptime(v, self.format)
except (TypeError, ValueError):
- raise DatetimeInvalid(self.msg or
- 'value does not match expected format %s' % self.format)
+ raise DatetimeInvalid(
+ self.msg or 'value does not match expected format %s' % self.format
+ )
return v
def __repr__(self):
@@ -571,14 +792,16 @@ class Datetime(object):
class Date(Datetime):
"""Validate that the value matches the date format."""
+
DEFAULT_FORMAT = '%Y-%m-%d'
def __call__(self, v):
try:
datetime.datetime.strptime(v, self.format)
except (TypeError, ValueError):
- raise DateInvalid(self.msg or
- 'value does not match expected format %s' % self.format)
+ raise DateInvalid(
+ self.msg or 'value does not match expected format %s' % self.format
+ )
return v
def __repr__(self):
@@ -588,8 +811,9 @@ class Date(Datetime):
class In(object):
"""Validate that a value is in a collection."""
- def __init__(self, container: typing.Container, msg: typing.Optional[
- str]=None) ->None:
+ def __init__(
+ self, container: typing.Container, msg: typing.Optional[str] = None
+ ) -> None:
self.container = container
self.msg = msg
@@ -600,11 +824,14 @@ class In(object):
check = True
if check:
try:
- raise InInvalid(self.msg or
- f'value must be one of {sorted(self.container)}')
+ raise InInvalid(
+ self.msg or f'value must be one of {sorted(self.container)}'
+ )
except TypeError:
- raise InInvalid(self.msg or
- f'value must be one of {sorted(self.container, key=str)}')
+ raise InInvalid(
+ self.msg
+ or f'value must be one of {sorted(self.container, key=str)}'
+ )
return v
def __repr__(self):
@@ -614,8 +841,9 @@ class In(object):
class NotIn(object):
"""Validate that a value is not in a collection."""
- def __init__(self, container: typing.Iterable, msg: typing.Optional[str
- ]=None) ->None:
+ def __init__(
+ self, container: typing.Iterable, msg: typing.Optional[str] = None
+ ) -> None:
self.container = container
self.msg = msg
@@ -626,12 +854,14 @@ class NotIn(object):
check = True
if check:
try:
- raise NotInInvalid(self.msg or
- f'value must not be one of {sorted(self.container)}')
+ raise NotInInvalid(
+ self.msg or f'value must not be one of {sorted(self.container)}'
+ )
except TypeError:
- raise NotInInvalid(self.msg or
- f'value must not be one of {sorted(self.container, key=str)}'
- )
+ raise NotInInvalid(
+ self.msg
+ or f'value must not be one of {sorted(self.container, key=str)}'
+ )
return v
def __repr__(self):
@@ -648,7 +878,7 @@ class Contains(object):
... s([3, 2])
"""
- def __init__(self, item, msg: typing.Optional[str]=None) ->None:
+ def __init__(self, item, msg: typing.Optional[str] = None) -> None:
self.item = item
self.msg = msg
@@ -681,8 +911,12 @@ class ExactSequence(object):
('hourly_report', 10, [], [])
"""
- def __init__(self, validators: typing.Iterable[Schemable], msg: typing.
- Optional[str]=None, **kwargs) ->None:
+ def __init__(
+ self,
+ validators: typing.Iterable[Schemable],
+ msg: typing.Optional[str] = None,
+ **kwargs,
+ ) -> None:
self.validators = validators
self.msg = msg
self._schemas = [Schema(val, **kwargs) for val in validators]
@@ -693,12 +927,11 @@ class ExactSequence(object):
try:
v = type(v)(schema(x) for x, schema in zip(v, self._schemas))
except Invalid as e:
- raise (e if self.msg is None else ExactSequenceInvalid(self.msg))
+ raise e if self.msg is None else ExactSequenceInvalid(self.msg)
return v
def __repr__(self):
- return 'ExactSequence([%s])' % ', '.join(repr(v) for v in self.
- validators)
+ return 'ExactSequence([%s])' % ", ".join(repr(v) for v in self.validators)
class Unique(object):
@@ -727,20 +960,18 @@ class Unique(object):
... s('aabbc')
"""
- def __init__(self, msg: typing.Optional[str]=None) ->None:
+ def __init__(self, msg: typing.Optional[str] = None) -> None:
self.msg = msg
def __call__(self, v):
try:
set_v = set(v)
except TypeError as e:
- raise TypeInvalid(self.msg or
- 'contains unhashable elements: {0}'.format(e))
+ raise TypeInvalid(self.msg or 'contains unhashable elements: {0}'.format(e))
if len(set_v) != len(v):
seen = set()
dupes = list(set(x for x in v if x in seen or seen.add(x)))
- raise Invalid(self.msg or 'contains duplicate items: {0}'.
- format(dupes))
+ raise Invalid(self.msg or 'contains duplicate items: {0}'.format(dupes))
return v
def __repr__(self):
@@ -763,15 +994,16 @@ class Equal(object):
... s('foo')
"""
- def __init__(self, target, msg: typing.Optional[str]=None) ->None:
+ def __init__(self, target, msg: typing.Optional[str] = None) -> None:
self.target = target
self.msg = msg
def __call__(self, v):
if v != self.target:
- raise Invalid(self.msg or
- 'Values are not equal: value:{} != target:{}'.format(v,
- self.target))
+ raise Invalid(
+ self.msg
+ or 'Values are not equal: value:{} != target:{}'.format(v, self.target)
+ )
return v
def __repr__(self):
@@ -793,8 +1025,12 @@ class Unordered(object):
[1, 'foo']
"""
- def __init__(self, validators: typing.Iterable[Schemable], msg: typing.
- Optional[str]=None, **kwargs) ->None:
+ def __init__(
+ self,
+ validators: typing.Iterable[Schemable],
+ msg: typing.Optional[str] = None,
+ **kwargs,
+ ) -> None:
self.validators = validators
self.msg = msg
self._schemas = [Schema(val, **kwargs) for val in validators]
@@ -802,10 +1038,15 @@ class Unordered(object):
def __call__(self, v):
if not isinstance(v, (list, tuple)):
raise Invalid(self.msg or 'Value {} is not sequence!'.format(v))
+
if len(v) != len(self._schemas):
- raise Invalid(self.msg or
- 'List lengths differ, value:{} != target:{}'.format(len(v),
- len(self._schemas)))
+ raise Invalid(
+ self.msg
+ or 'List lengths differ, value:{} != target:{}'.format(
+ len(v), len(self._schemas)
+ )
+ )
+
consumed = set()
missing = []
for index, value in enumerate(v):
@@ -823,20 +1064,31 @@ class Unordered(object):
break
if not found:
missing.append((index, value))
+
if len(missing) == 1:
el = missing[0]
- raise Invalid(self.msg or
- 'Element #{} ({}) is not valid against any validator'.
- format(el[0], el[1]))
+ raise Invalid(
+ self.msg
+ or 'Element #{} ({}) is not valid against any validator'.format(
+ el[0], el[1]
+ )
+ )
elif missing:
- raise MultipleInvalid([Invalid(self.msg or
- 'Element #{} ({}) is not valid against any validator'.
- format(el[0], el[1])) for el in missing])
+ raise MultipleInvalid(
+ [
+ Invalid(
+ self.msg
+ or 'Element #{} ({}) is not valid against any validator'.format(
+ el[0], el[1]
+ )
+ )
+ for el in missing
+ ]
+ )
return v
def __repr__(self):
- return 'Unordered([{}])'.format(', '.join(repr(v) for v in self.
- validators))
+ return 'Unordered([{}])'.format(", ".join(repr(v) for v in self.validators))
class Number(object):
@@ -854,9 +1106,13 @@ class Number(object):
Decimal('1234.01')
"""
- def __init__(self, precision: typing.Optional[int]=None, scale: typing.
- Optional[int]=None, msg: typing.Optional[str]=None, yield_decimal:
- bool=False) ->None:
+ def __init__(
+ self,
+ precision: typing.Optional[int] = None,
+ scale: typing.Optional[int] = None,
+ msg: typing.Optional[str] = None,
+ yield_decimal: bool = False,
+ ) -> None:
self.precision = precision
self.scale = scale
self.msg = msg
@@ -868,33 +1124,56 @@ class Number(object):
:return: Decimal number
"""
precision, scale, decimal_num = self._get_precision_scale(v)
- if (self.precision is not None and self.scale is not None and
- precision != self.precision and scale != self.scale):
- raise Invalid(self.msg or
- 'Precision must be equal to %s, and Scale must be equal to %s'
- % (self.precision, self.scale))
+
+ if (
+ self.precision is not None
+ and self.scale is not None
+ and precision != self.precision
+ and scale != self.scale
+ ):
+ raise Invalid(
+ self.msg
+ or "Precision must be equal to %s, and Scale must be equal to %s"
+ % (self.precision, self.scale)
+ )
else:
if self.precision is not None and precision != self.precision:
- raise Invalid(self.msg or 'Precision must be equal to %s' %
- self.precision)
+ raise Invalid(
+ self.msg or "Precision must be equal to %s" % self.precision
+ )
+
if self.scale is not None and scale != self.scale:
- raise Invalid(self.msg or 'Scale must be equal to %s' %
- self.scale)
+ raise Invalid(self.msg or "Scale must be equal to %s" % self.scale)
+
if self.yield_decimal:
return decimal_num
else:
return v
def __repr__(self):
- return 'Number(precision=%s, scale=%s, msg=%s)' % (self.precision,
- self.scale, self.msg)
+ return 'Number(precision=%s, scale=%s, msg=%s)' % (
+ self.precision,
+ self.scale,
+ self.msg,
+ )
- def _get_precision_scale(self, number) ->typing.Tuple[int, int, Decimal]:
+ def _get_precision_scale(self, number) -> typing.Tuple[int, int, Decimal]:
"""
:param number:
:return: tuple(precision, scale, decimal_number)
"""
- pass
+ try:
+ decimal_num = Decimal(number)
+ except InvalidOperation:
+ raise Invalid(self.msg or 'Value must be a number enclosed with string')
+
+ exp = decimal_num.as_tuple().exponent
+ if isinstance(exp, int):
+ return (len(decimal_num.as_tuple().digits), -exp, decimal_num)
+ else:
+ # TODO: handle infinity and NaN
+ # raise Invalid(self.msg or 'Value has no precision')
+ raise TypeError("infinity and NaN have no precision")
class SomeOf(_WithSubValidators):
@@ -921,17 +1200,49 @@ class SomeOf(_WithSubValidators):
... validate(6.2)
"""
- def __init__(self, validators: typing.List[Schemable], min_valid:
- typing.Optional[int]=None, max_valid: typing.Optional[int]=None, **
- kwargs) ->None:
- assert min_valid is not None or max_valid is not None, 'when using "%s" you should specify at least one of min_valid and max_valid' % (
- type(self).__name__,)
+ def __init__(
+ self,
+ validators: typing.List[Schemable],
+ min_valid: typing.Optional[int] = None,
+ max_valid: typing.Optional[int] = None,
+ **kwargs,
+ ) -> None:
+ assert min_valid is not None or max_valid is not None, (
+ 'when using "%s" you should specify at least one of min_valid and max_valid'
+ % (type(self).__name__,)
+ )
self.min_valid = min_valid or 0
self.max_valid = max_valid or len(validators)
super(SomeOf, self).__init__(*validators, **kwargs)
+ def _exec(self, funcs, v, path=None):
+ errors = []
+ funcs = list(funcs)
+ for func in funcs:
+ try:
+ if path is None:
+ v = func(v)
+ else:
+ v = func(path, v)
+ except Invalid as e:
+ errors.append(e)
+
+ passed_count = len(funcs) - len(errors)
+ if self.min_valid <= passed_count <= self.max_valid:
+ return v
+
+ msg = self.msg
+ if not msg:
+ msg = ', '.join(map(str, errors))
+
+ if passed_count > self.max_valid:
+ raise TooManyValid(msg)
+ raise NotEnoughValid(msg)
+
def __repr__(self):
- return (
- 'SomeOf(min_valid=%s, validators=[%s], max_valid=%s, msg=%r)' %
- (self.min_valid, ', '.join(repr(v) for v in self.validators),
- self.max_valid, self.msg))
+ return 'SomeOf(min_valid=%s, validators=[%s], max_valid=%s, msg=%r)' % (
+ self.min_valid,
+ ", ".join(repr(v) for v in self.validators),
+ self.max_valid,
+ self.msg,
+ )