back to Reference (Gold) summary
Reference (Gold): marshmallow
Pytest Summary for test tests
status | count |
---|---|
passed | 1229 |
total | 1229 |
collected | 1229 |
Failed pytests:
Patch diff
diff --git a/src/marshmallow/base.py b/src/marshmallow/base.py
index 5849d2e..e82848d 100644
--- a/src/marshmallow/base.py
+++ b/src/marshmallow/base.py
@@ -7,16 +7,59 @@ These are necessary to avoid circular imports between schema.py and fields.py.
This module is treated as private API.
Users should not need to use this module directly.
"""
+
from __future__ import annotations
+
from abc import ABC, abstractmethod
class FieldABC(ABC):
"""Abstract base class from which all Field classes inherit."""
+
parent = None
name = None
root = None
+ @abstractmethod
+ def serialize(self, attr, obj, accessor=None):
+ pass
+
+ @abstractmethod
+ def deserialize(self, value):
+ pass
+
+ @abstractmethod
+ def _serialize(self, value, attr, obj, **kwargs):
+ pass
+
+ @abstractmethod
+ def _deserialize(self, value, attr, data, **kwargs):
+ pass
+
class SchemaABC(ABC):
"""Abstract base class from which all Schemas inherit."""
+
+ @abstractmethod
+ def dump(self, obj, *, many: bool | None = None):
+ pass
+
+ @abstractmethod
+ def dumps(self, obj, *, many: bool | None = None):
+ pass
+
+ @abstractmethod
+ def load(self, data, *, many: bool | None = None, partial=None, unknown=None):
+ pass
+
+ @abstractmethod
+ def loads(
+ self,
+ json_data,
+ *,
+ many: bool | None = None,
+ partial=None,
+ unknown=None,
+ **kwargs,
+ ):
+ pass
diff --git a/src/marshmallow/class_registry.py b/src/marshmallow/class_registry.py
index 249b898..810d115 100644
--- a/src/marshmallow/class_registry.py
+++ b/src/marshmallow/class_registry.py
@@ -7,16 +7,26 @@ class:`fields.Nested <marshmallow.fields.Nested>`.
This module is treated as private API.
Users should not need to use this module directly.
"""
+
from __future__ import annotations
+
import typing
+
from marshmallow.exceptions import RegistryError
+
if typing.TYPE_CHECKING:
from marshmallow import Schema
+
SchemaType = typing.Type[Schema]
-_registry = {}
+# {
+# <class_name>: <list of class objects>
+# <module_path_to_class>: <list of class objects>
+# }
+_registry = {} # type: dict[str, list[SchemaType]]
-def register(classname: str, cls: SchemaType) ->None:
+
+def register(classname: str, cls: SchemaType) -> None:
"""Add a class to the registry of serializer classes. When a class is
registered, an entry for both its classname and its full, module-qualified
path are added to the registry.
@@ -35,14 +45,50 @@ def register(classname: str, cls: SchemaType) ->None:
# }
"""
- pass
+ # Module where the class is located
+ module = cls.__module__
+ # Full module path to the class
+ # e.g. user.schemas.UserSchema
+ fullpath = ".".join([module, classname])
+ # If the class is already registered; need to check if the entries are
+ # in the same module as cls to avoid having multiple instances of the same
+ # class in the registry
+ if classname in _registry and not any(
+ each.__module__ == module for each in _registry[classname]
+ ):
+ _registry[classname].append(cls)
+ elif classname not in _registry:
+ _registry[classname] = [cls]
+ # Also register the full path
+ if fullpath not in _registry:
+ _registry.setdefault(fullpath, []).append(cls)
+ else:
+ # If fullpath does exist, replace existing entry
+ _registry[fullpath] = [cls]
+ return None
-def get_class(classname: str, all: bool=False) ->(list[SchemaType] | SchemaType
- ):
+
+def get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:
"""Retrieve a class from the registry.
:raises: marshmallow.exceptions.RegistryError if the class cannot be found
or if there are multiple entries for the given class name.
"""
- pass
+ try:
+ classes = _registry[classname]
+ except KeyError as error:
+ raise RegistryError(
+ f"Class with name {classname!r} was not found. You may need "
+ "to import the class."
+ ) from error
+ if len(classes) > 1:
+ if all:
+ return _registry[classname]
+ raise RegistryError(
+ f"Multiple classes with name {classname!r} "
+ "were found. Please use the full, "
+ "module-qualified path."
+ )
+ else:
+ return _registry[classname][0]
diff --git a/src/marshmallow/decorators.py b/src/marshmallow/decorators.py
index d78f5be..dafca95 100644
--- a/src/marshmallow/decorators.py
+++ b/src/marshmallow/decorators.py
@@ -64,32 +64,38 @@ Example: ::
If you need to guarantee order of different processing steps, you should put
them in the same processing method.
"""
+
from __future__ import annotations
+
import functools
from typing import Any, Callable, cast
-PRE_DUMP = 'pre_dump'
-POST_DUMP = 'post_dump'
-PRE_LOAD = 'pre_load'
-POST_LOAD = 'post_load'
-VALIDATES = 'validates'
-VALIDATES_SCHEMA = 'validates_schema'
+
+PRE_DUMP = "pre_dump"
+POST_DUMP = "post_dump"
+PRE_LOAD = "pre_load"
+POST_LOAD = "post_load"
+VALIDATES = "validates"
+VALIDATES_SCHEMA = "validates_schema"
class MarshmallowHook:
__marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None
-def validates(field_name: str) ->Callable[..., Any]:
+def validates(field_name: str) -> Callable[..., Any]:
"""Register a field validator.
:param str field_name: Name of the field that the method validates.
"""
- pass
+ return set_hook(None, VALIDATES, field_name=field_name)
-def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
- False, pass_original: bool=False, skip_on_field_errors: bool=True
- ) ->Callable[..., Any]:
+def validates_schema(
+ fn: Callable[..., Any] | None = None,
+ pass_many: bool = False,
+ pass_original: bool = False,
+ skip_on_field_errors: bool = True,
+) -> Callable[..., Any]:
"""Register a schema-level validator.
By default it receives a single object at a time, transparently handling the ``many``
@@ -109,11 +115,17 @@ def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ return set_hook(
+ fn,
+ (VALIDATES_SCHEMA, pass_many),
+ pass_original=pass_original,
+ skip_on_field_errors=skip_on_field_errors,
+ )
-def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
- ) ->Callable[..., Any]:
+def pre_dump(
+ fn: Callable[..., Any] | None = None, pass_many: bool = False
+) -> Callable[..., Any]:
"""Register a method to invoke before serializing an object. The method
receives the object to be serialized and returns the processed object.
@@ -124,11 +136,14 @@ def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ return set_hook(fn, (PRE_DUMP, pass_many))
-def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
- pass_original: bool=False) ->Callable[..., Any]:
+def post_dump(
+ fn: Callable[..., Any] | None = None,
+ pass_many: bool = False,
+ pass_original: bool = False,
+) -> Callable[..., Any]:
"""Register a method to invoke after serializing an object. The method
receives the serialized object and returns the processed object.
@@ -142,11 +157,12 @@ def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)
-def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
- ) ->Callable[..., Any]:
+def pre_load(
+ fn: Callable[..., Any] | None = None, pass_many: bool = False
+) -> Callable[..., Any]:
"""Register a method to invoke before deserializing an object. The method
receives the data to be deserialized and returns the processed data.
@@ -158,11 +174,14 @@ def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ return set_hook(fn, (PRE_LOAD, pass_many))
-def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
- pass_original: bool=False) ->Callable[..., Any]:
+def post_load(
+ fn: Callable[..., Any] | None = None,
+ pass_many: bool = False,
+ pass_original: bool = False,
+) -> Callable[..., Any]:
"""Register a method to invoke after deserializing an object. The method
receives the deserialized data and returns the processed data.
@@ -177,11 +196,12 @@ def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)
-def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
- **kwargs: Any) ->Callable[..., Any]:
+def set_hook(
+ fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any
+) -> Callable[..., Any]:
"""Mark decorated function as a hook to be picked up later.
You should not need to use this method directly.
@@ -192,4 +212,20 @@ def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
:return: Decorated function if supplied, else this decorator with its args
bound.
"""
- pass
+ # Allow using this as either a decorator or a decorator factory.
+ if fn is None:
+ return functools.partial(set_hook, key=key, **kwargs)
+
+ # Set a __marshmallow_hook__ attribute instead of wrapping in some class,
+ # because I still want this to end up as a normal (unbound) method.
+ function = cast(MarshmallowHook, fn)
+ try:
+ hook_config = function.__marshmallow_hook__
+ except AttributeError:
+ function.__marshmallow_hook__ = hook_config = {}
+ # Also save the kwargs for the tagged function on
+ # __marshmallow_hook__, keyed by (<tag>, <pass_many>)
+ if hook_config is not None:
+ hook_config[key] = kwargs
+
+ return fn
diff --git a/src/marshmallow/error_store.py b/src/marshmallow/error_store.py
index a659aaf..72b7037 100644
--- a/src/marshmallow/error_store.py
+++ b/src/marshmallow/error_store.py
@@ -5,14 +5,25 @@
This module is treated as private API.
Users should not need to use this module directly.
"""
+
from marshmallow.exceptions import SCHEMA
class ErrorStore:
-
def __init__(self):
+ #: Dictionary of errors stored during serialization
self.errors = {}
+ def store_error(self, messages, field_name=SCHEMA, index=None):
+ # field error -> store/merge error messages under field name key
+ # schema error -> if string or list, store/merge under _schema key
+ # -> if dict, store/merge with other top-level keys
+ if field_name != SCHEMA or not isinstance(messages, dict):
+ messages = {field_name: messages}
+ if index is not None:
+ messages = {index: messages}
+ self.errors = merge_errors(self.errors, messages)
+
def merge_errors(errors1, errors2):
"""Deeply merge two error messages.
@@ -20,4 +31,30 @@ def merge_errors(errors1, errors2):
The format of ``errors1`` and ``errors2`` matches the ``message``
parameter of :exc:`marshmallow.exceptions.ValidationError`.
"""
- pass
+ if not errors1:
+ return errors2
+ if not errors2:
+ return errors1
+ if isinstance(errors1, list):
+ if isinstance(errors2, list):
+ return errors1 + errors2
+ if isinstance(errors2, dict):
+ return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})
+ return errors1 + [errors2]
+ if isinstance(errors1, dict):
+ if isinstance(errors2, list):
+ return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})
+ if isinstance(errors2, dict):
+ errors = dict(errors1)
+ for key, val in errors2.items():
+ if key in errors:
+ errors[key] = merge_errors(errors[key], val)
+ else:
+ errors[key] = val
+ return errors
+ return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})
+ if isinstance(errors2, list):
+ return [errors1] + errors2
+ if isinstance(errors2, dict):
+ return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})
+ return [errors1, errors2]
diff --git a/src/marshmallow/exceptions.py b/src/marshmallow/exceptions.py
index 52e36c1..096b6bd 100644
--- a/src/marshmallow/exceptions.py
+++ b/src/marshmallow/exceptions.py
@@ -1,7 +1,11 @@
"""Exception classes for marshmallow-related errors."""
+
from __future__ import annotations
+
import typing
-SCHEMA = '_schema'
+
+# Key used for schema-level validation errors
+SCHEMA = "_schema"
class MarshmallowError(Exception):
@@ -21,18 +25,37 @@ class ValidationError(MarshmallowError):
:param valid_data: Valid (de)serialized data.
"""
- def __init__(self, message: (str | list | dict), field_name: str=SCHEMA,
- data: (typing.Mapping[str, typing.Any] | typing.Iterable[typing.
- Mapping[str, typing.Any]] | None)=None, valid_data: (list[dict[str,
- typing.Any]] | dict[str, typing.Any] | None)=None, **kwargs):
- self.messages = [message] if isinstance(message, (str, bytes)
- ) else message
+ def __init__(
+ self,
+ message: str | list | dict,
+ field_name: str = SCHEMA,
+ data: typing.Mapping[str, typing.Any]
+ | typing.Iterable[typing.Mapping[str, typing.Any]]
+ | None = None,
+ valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,
+ **kwargs,
+ ):
+ self.messages = [message] if isinstance(message, (str, bytes)) else message
self.field_name = field_name
self.data = data
self.valid_data = valid_data
self.kwargs = kwargs
super().__init__(message)
+ def normalized_messages(self):
+ if self.field_name == SCHEMA and isinstance(self.messages, dict):
+ return self.messages
+ return {self.field_name: self.messages}
+
+ @property
+ def messages_dict(self) -> dict[str, typing.Any]:
+ if not isinstance(self.messages, dict):
+ raise TypeError(
+ "cannot access 'messages_dict' when 'messages' is of type "
+ + type(self.messages).__name__
+ )
+ return self.messages
+
class RegistryError(NameError):
"""Raised when an invalid operation is performed on the serializer
diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py
index 8656a56..ceb32aa 100644
--- a/src/marshmallow/fields.py
+++ b/src/marshmallow/fields.py
@@ -1,5 +1,7 @@
"""Field classes for various types of data."""
+
from __future__ import annotations
+
import collections
import copy
import datetime as dt
@@ -12,20 +14,66 @@ import uuid
import warnings
from collections.abc import Mapping as _Mapping
from enum import Enum as EnumType
+
from marshmallow import class_registry, types, utils, validate
from marshmallow.base import FieldABC, SchemaABC
-from marshmallow.exceptions import FieldInstanceResolutionError, StringNotCollectionError, ValidationError
-from marshmallow.utils import is_aware, is_collection, resolve_field_instance
-from marshmallow.utils import missing as missing_
+from marshmallow.exceptions import (
+ FieldInstanceResolutionError,
+ StringNotCollectionError,
+ ValidationError,
+)
+from marshmallow.utils import (
+ is_aware,
+ is_collection,
+ resolve_field_instance,
+)
+from marshmallow.utils import (
+ missing as missing_,
+)
from marshmallow.validate import And, Length
from marshmallow.warnings import RemovedInMarshmallow4Warning
-__all__ = ['Field', 'Raw', 'Nested', 'Mapping', 'Dict', 'List', 'Tuple',
- 'String', 'UUID', 'Number', 'Integer', 'Decimal', 'Boolean', 'Float',
- 'DateTime', 'NaiveDateTime', 'AwareDateTime', 'Time', 'Date',
- 'TimeDelta', 'Url', 'URL', 'Email', 'IP', 'IPv4', 'IPv6', 'IPInterface',
- 'IPv4Interface', 'IPv6Interface', 'Enum', 'Method', 'Function', 'Str',
- 'Bool', 'Int', 'Constant', 'Pluck']
-_T = typing.TypeVar('_T')
+
+__all__ = [
+ "Field",
+ "Raw",
+ "Nested",
+ "Mapping",
+ "Dict",
+ "List",
+ "Tuple",
+ "String",
+ "UUID",
+ "Number",
+ "Integer",
+ "Decimal",
+ "Boolean",
+ "Float",
+ "DateTime",
+ "NaiveDateTime",
+ "AwareDateTime",
+ "Time",
+ "Date",
+ "TimeDelta",
+ "Url",
+ "URL",
+ "Email",
+ "IP",
+ "IPv4",
+ "IPv6",
+ "IPInterface",
+ "IPv4Interface",
+ "IPv6Interface",
+ "Enum",
+ "Method",
+ "Function",
+ "Str",
+ "Bool",
+ "Int",
+ "Constant",
+ "Pluck",
+]
+
+_T = typing.TypeVar("_T")
class Field(FieldABC):
@@ -89,34 +137,65 @@ class Field(FieldABC):
Add ``data_key`` parameter for the specifying the key in the input and
output data. This parameter replaced both ``load_from`` and ``dump_to``.
"""
+
+ # Some fields, such as Method fields and Function fields, are not expected
+ # to exist as attributes on the objects to serialize. Set this to False
+ # for those fields
_CHECK_ATTRIBUTE = True
- default_error_messages = {'required':
- 'Missing data for required field.', 'null':
- 'Field may not be null.', 'validator_failed': 'Invalid value.'}
-
- def __init__(self, *, load_default: typing.Any=missing_, missing:
- typing.Any=missing_, dump_default: typing.Any=missing_, default:
- typing.Any=missing_, data_key: (str | None)=None, attribute: (str |
- None)=None, validate: (None | typing.Callable[[typing.Any], typing.
- Any] | typing.Iterable[typing.Callable[[typing.Any], typing.Any]])=
- None, required: bool=False, allow_none: (bool | None)=None,
- load_only: bool=False, dump_only: bool=False, error_messages: (dict
- [str, str] | None)=None, metadata: (typing.Mapping[str, typing.Any] |
- None)=None, **additional_metadata) ->None:
+
+ #: Default error messages for various kinds of errors. The keys in this dictionary
+ #: are passed to `Field.make_error`. The values are error messages passed to
+ #: :exc:`marshmallow.exceptions.ValidationError`.
+ default_error_messages = {
+ "required": "Missing data for required field.",
+ "null": "Field may not be null.",
+ "validator_failed": "Invalid value.",
+ }
+
+ def __init__(
+ self,
+ *,
+ load_default: typing.Any = missing_,
+ missing: typing.Any = missing_,
+ dump_default: typing.Any = missing_,
+ default: typing.Any = missing_,
+ data_key: str | None = None,
+ attribute: str | None = None,
+ validate: (
+ None
+ | typing.Callable[[typing.Any], typing.Any]
+ | typing.Iterable[typing.Callable[[typing.Any], typing.Any]]
+ ) = None,
+ required: bool = False,
+ allow_none: bool | None = None,
+ load_only: bool = False,
+ dump_only: bool = False,
+ error_messages: dict[str, str] | None = None,
+ metadata: typing.Mapping[str, typing.Any] | None = None,
+ **additional_metadata,
+ ) -> None:
+ # handle deprecated `default` and `missing` parameters
if default is not missing_:
warnings.warn(
- "The 'default' argument to fields is deprecated. Use 'dump_default' instead."
- , RemovedInMarshmallow4Warning, stacklevel=2)
+ "The 'default' argument to fields is deprecated. "
+ "Use 'dump_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
if dump_default is missing_:
dump_default = default
if missing is not missing_:
warnings.warn(
- "The 'missing' argument to fields is deprecated. Use 'load_default' instead."
- , RemovedInMarshmallow4Warning, stacklevel=2)
+ "The 'missing' argument to fields is deprecated. "
+ "Use 'load_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
if load_default is missing_:
load_default = missing
self.dump_default = dump_default
self.load_default = load_default
+
self.attribute = attribute
self.data_key = data_key
self.validate = validate
@@ -128,32 +207,46 @@ class Field(FieldABC):
self.validators = list(validate)
else:
raise ValueError(
- "The 'validate' parameter must be a callable or a collection of callables."
- )
- self.allow_none = (load_default is None if allow_none is None else
- allow_none)
+ "The 'validate' parameter must be a callable "
+ "or a collection of callables."
+ )
+
+ # If allow_none is None and load_default is None
+ # None should be considered valid by default
+ self.allow_none = load_default is None if allow_none is None else allow_none
self.load_only = load_only
self.dump_only = dump_only
if required is True and load_default is not missing_:
- raise ValueError(
- "'load_default' must not be set for required fields.")
+ raise ValueError("'load_default' must not be set for required fields.")
self.required = required
+
metadata = metadata or {}
self.metadata = {**metadata, **additional_metadata}
if additional_metadata:
warnings.warn(
- f'Passing field metadata as keyword arguments is deprecated. Use the explicit `metadata=...` argument instead. Additional metadata: {additional_metadata}'
- , RemovedInMarshmallow4Warning, stacklevel=2)
- messages = {}
+ "Passing field metadata as keyword arguments is deprecated. Use the "
+ "explicit `metadata=...` argument instead. "
+ f"Additional metadata: {additional_metadata}",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+
+ # Collect default error message from self and parent classes
+ messages = {} # type: dict[str, str]
for cls in reversed(self.__class__.__mro__):
- messages.update(getattr(cls, 'default_error_messages', {}))
+ messages.update(getattr(cls, "default_error_messages", {}))
messages.update(error_messages or {})
self.error_messages = messages
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return (
- f'<fields.{self.__class__.__name__}(dump_default={self.dump_default!r}, attribute={self.attribute!r}, validate={self.validate}, required={self.required}, load_only={self.load_only}, dump_only={self.dump_only}, load_default={self.load_default}, allow_none={self.allow_none}, error_messages={self.error_messages})>'
- )
+ f"<fields.{self.__class__.__name__}(dump_default={self.dump_default!r}, "
+ f"attribute={self.attribute!r}, "
+ f"validate={self.validate}, required={self.required}, "
+ f"load_only={self.load_only}, dump_only={self.dump_only}, "
+ f"load_default={self.load_default}, allow_none={self.allow_none}, "
+ f"error_messages={self.error_messages})>"
+ )
def __deepcopy__(self, memo):
return copy.copy(self)
@@ -166,19 +259,36 @@ class Field(FieldABC):
:param callable accessor: A callable used to retrieve the value of `attr` from
the object `obj`. Defaults to `marshmallow.utils.get_value`.
"""
- pass
+ accessor_func = accessor or utils.get_value
+ check_key = attr if self.attribute is None else self.attribute
+ return accessor_func(obj, check_key, default)
def _validate(self, value):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
- pass
+ self._validate_all(value)
- def make_error(self, key: str, **kwargs) ->ValidationError:
+ @property
+ def _validate_all(self):
+ return And(*self.validators, error=self.error_messages["validator_failed"])
+
+ def make_error(self, key: str, **kwargs) -> ValidationError:
"""Helper method to make a `ValidationError` with an error message
from ``self.error_messages``.
"""
- pass
+ try:
+ msg = self.error_messages[key]
+ except KeyError as error:
+ class_name = self.__class__.__name__
+ message = (
+ f"ValidationError raised by `{class_name}`, but error key `{key}` does "
+ "not exist in the `error_messages` dictionary."
+ )
+ raise AssertionError(message) from error
+ if isinstance(msg, (str, bytes)):
+ msg = msg.format(**kwargs)
+ return ValidationError(msg)
def fail(self, key: str, **kwargs):
"""Helper method that raises a `ValidationError` with an error message
@@ -187,17 +297,30 @@ class Field(FieldABC):
.. deprecated:: 3.0.0
Use `make_error <marshmallow.fields.Field.make_error>` instead.
"""
- pass
+ warnings.warn(
+ f'`Field.fail` is deprecated. Use `raise self.make_error("{key}", ...)` instead.',
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ raise self.make_error(key=key, **kwargs)
def _validate_missing(self, value):
"""Validate missing values. Raise a :exc:`ValidationError` if
`value` should be considered missing.
"""
- pass
-
- def serialize(self, attr: str, obj: typing.Any, accessor: (typing.
- Callable[[typing.Any, str, typing.Any], typing.Any] | None)=None,
- **kwargs):
+ if value is missing_ and self.required:
+ raise self.make_error("required")
+ if value is None and not self.allow_none:
+ raise self.make_error("null")
+
+ def serialize(
+ self,
+ attr: str,
+ obj: typing.Any,
+ accessor: typing.Callable[[typing.Any, str, typing.Any], typing.Any]
+ | None = None,
+ **kwargs,
+ ):
"""Pulls the value for the given key from the object, applies the
field's formatting and returns the result.
@@ -206,10 +329,24 @@ class Field(FieldABC):
:param accessor: Function used to access values from ``obj``.
:param kwargs: Field-specific keyword arguments.
"""
- pass
-
- def deserialize(self, value: typing.Any, attr: (str | None)=None, data:
- (typing.Mapping[str, typing.Any] | None)=None, **kwargs):
+ if self._CHECK_ATTRIBUTE:
+ value = self.get_value(obj, attr, accessor=accessor)
+ if value is missing_:
+ default = self.dump_default
+ value = default() if callable(default) else default
+ if value is missing_:
+ return value
+ else:
+ value = None
+ return self._serialize(value, attr, obj, **kwargs)
+
+ def deserialize(
+ self,
+ value: typing.Any,
+ attr: str | None = None,
+ data: typing.Mapping[str, typing.Any] | None = None,
+ **kwargs,
+ ):
"""Deserialize ``value``.
:param value: The value to deserialize.
@@ -219,7 +356,19 @@ class Field(FieldABC):
:raise ValidationError: If an invalid value is passed or if a required value
is missing.
"""
- pass
+ # Validate required fields, deserialize, then validate
+ # deserialized value
+ self._validate_missing(value)
+ if value is missing_:
+ _miss = self.load_default
+ return _miss() if callable(_miss) else _miss
+ if self.allow_none and value is None:
+ return None
+ output = self._deserialize(value, attr, data, **kwargs)
+ self._validate(output)
+ return output
+
+ # Methods for concrete classes to override.
def _bind_to_schema(self, field_name, schema):
"""Update field with values from its parent schema. Called by
@@ -228,10 +377,15 @@ class Field(FieldABC):
:param str field_name: Field name set in schema.
:param Schema|Field schema: Parent object.
"""
- pass
-
- def _serialize(self, value: typing.Any, attr: (str | None), obj: typing
- .Any, **kwargs):
+ self.parent = self.parent or schema
+ self.name = self.name or field_name
+ self.root = self.root or (
+ self.parent.root if isinstance(self.parent, FieldABC) else self.parent
+ )
+
+ def _serialize(
+ self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs
+ ):
"""Serializes ``value`` to a basic Python datatype. Noop by default.
Concrete :class:`Field` classes should implement this method.
@@ -249,10 +403,15 @@ class Field(FieldABC):
:param dict kwargs: Field-specific keyword arguments.
:return: The serialized value
"""
- pass
-
- def _deserialize(self, value: typing.Any, attr: (str | None), data: (
- typing.Mapping[str, typing.Any] | None), **kwargs):
+ return value
+
+ def _deserialize(
+ self,
+ value: typing.Any,
+ attr: str | None,
+ data: typing.Mapping[str, typing.Any] | None,
+ **kwargs,
+ ):
"""Deserialize value. Concrete :class:`Field` classes should implement this method.
:param value: The value to be deserialized.
@@ -268,12 +427,56 @@ class Field(FieldABC):
.. versionchanged:: 3.0.0
Added ``**kwargs`` to signature.
"""
- pass
+ return value
+
+ # Properties
@property
def context(self):
"""The context dictionary for the parent :class:`Schema`."""
- pass
+ return self.parent.context
+
+ # the default and missing properties are provided for compatibility and
+ # emit warnings when they are accessed and set
+ @property
+ def default(self):
+ warnings.warn(
+ "The 'default' attribute of fields is deprecated. "
+ "Use 'dump_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ return self.dump_default
+
+ @default.setter
+ def default(self, value):
+ warnings.warn(
+ "The 'default' attribute of fields is deprecated. "
+ "Use 'dump_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ self.dump_default = value
+
+ @property
+ def missing(self):
+ warnings.warn(
+ "The 'missing' attribute of fields is deprecated. "
+ "Use 'load_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ return self.load_default
+
+ @missing.setter
+ def missing(self, value):
+ warnings.warn(
+ "The 'missing' attribute of fields is deprecated. "
+ "Use 'load_default' instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ self.load_default = value
class Raw(Field):
@@ -325,30 +528,46 @@ class Nested(Field):
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- default_error_messages = {'type': 'Invalid type.'}
-
- def __init__(self, nested: (SchemaABC | type | str | dict[str, Field |
- type] | typing.Callable[[], SchemaABC | type | dict[str, Field |
- type]]), *, dump_default: typing.Any=missing_, default: typing.Any=
- missing_, only: (types.StrSequenceOrSet | None)=None, exclude:
- types.StrSequenceOrSet=(), many: bool=False, unknown: (str | None)=
- None, **kwargs):
+
+ #: Default error messages.
+ default_error_messages = {"type": "Invalid type."}
+
+ def __init__(
+ self,
+ nested: SchemaABC
+ | type
+ | str
+ | dict[str, Field | type]
+ | typing.Callable[[], SchemaABC | type | dict[str, Field | type]],
+ *,
+ dump_default: typing.Any = missing_,
+ default: typing.Any = missing_,
+ only: types.StrSequenceOrSet | None = None,
+ exclude: types.StrSequenceOrSet = (),
+ many: bool = False,
+ unknown: str | None = None,
+ **kwargs,
+ ):
+ # Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
- raise StringNotCollectionError(
- '"only" should be a collection of strings.')
+ raise StringNotCollectionError('"only" should be a collection of strings.')
if not is_collection(exclude):
raise StringNotCollectionError(
- '"exclude" should be a collection of strings.')
- if nested == 'self':
+ '"exclude" should be a collection of strings.'
+ )
+ if nested == "self":
warnings.warn(
- "Passing 'self' to `Nested` is deprecated. Use `Nested(lambda: MySchema(...))` instead."
- , RemovedInMarshmallow4Warning, stacklevel=2)
+ "Passing 'self' to `Nested` is deprecated. "
+ "Use `Nested(lambda: MySchema(...))` instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
self.nested = nested
self.only = only
self.exclude = exclude
self.many = many
self.unknown = unknown
- self._schema = None
+ self._schema = None # Cached Schema instance
super().__init__(default=default, dump_default=dump_default, **kwargs)
@property
@@ -358,7 +577,86 @@ class Nested(Field):
.. versionchanged:: 1.0.0
Renamed from `serializer` to `schema`.
"""
- pass
+ if not self._schema:
+ # Inherit context from parent.
+ context = getattr(self.parent, "context", {})
+ if callable(self.nested) and not isinstance(self.nested, type):
+ nested = self.nested()
+ else:
+ nested = self.nested
+ if isinstance(nested, dict):
+ # defer the import of `marshmallow.schema` to avoid circular imports
+ from marshmallow.schema import Schema
+
+ nested = Schema.from_dict(nested)
+
+ if isinstance(nested, SchemaABC):
+ self._schema = copy.copy(nested)
+ self._schema.context.update(context)
+ # Respect only and exclude passed from parent and re-initialize fields
+ set_class = self._schema.set_class
+ if self.only is not None:
+ if self._schema.only is not None:
+ original = self._schema.only
+ else: # only=None -> all fields
+ original = self._schema.fields.keys()
+ self._schema.only = set_class(self.only) & set_class(original)
+ if self.exclude:
+ original = self._schema.exclude
+ self._schema.exclude = set_class(self.exclude) | set_class(original)
+ self._schema._init_fields()
+ else:
+ if isinstance(nested, type) and issubclass(nested, SchemaABC):
+ schema_class = nested
+ elif not isinstance(nested, (str, bytes)):
+ raise ValueError(
+ "`Nested` fields must be passed a "
+ f"`Schema`, not {nested.__class__}."
+ )
+ elif nested == "self":
+ schema_class = self.root.__class__
+ else:
+ schema_class = class_registry.get_class(nested)
+ self._schema = schema_class(
+ many=self.many,
+ only=self.only,
+ exclude=self.exclude,
+ context=context,
+ load_only=self._nested_normalized_option("load_only"),
+ dump_only=self._nested_normalized_option("dump_only"),
+ )
+ return self._schema
+
+ def _nested_normalized_option(self, option_name: str) -> list[str]:
+ nested_field = f"{self.name}."
+ return [
+ field.split(nested_field, 1)[1]
+ for field in getattr(self.root, option_name, set())
+ if field.startswith(nested_field)
+ ]
+
+ def _serialize(self, nested_obj, attr, obj, **kwargs):
+ # Load up the schema first. This allows a RegistryError to be raised
+ # if an invalid schema name was passed
+ schema = self.schema
+ if nested_obj is None:
+ return None
+ many = schema.many or self.many
+ return schema.dump(nested_obj, many=many)
+
+ def _test_collection(self, value):
+ many = self.schema.many or self.many
+ if many and not utils.is_collection(value):
+ raise self.make_error("type", input=value, type=value.__class__.__name__)
+
+ def _load(self, value, data, partial=None):
+ try:
+ valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
+ except ValidationError as error:
+ raise ValidationError(
+ error.messages, valid_data=error.valid_data
+ ) from error
+ return valid_data
def _deserialize(self, value, attr, data, partial=None, **kwargs):
"""Same as :meth:`Field._deserialize` with additional ``partial`` argument.
@@ -369,7 +667,8 @@ class Nested(Field):
.. versionchanged:: 3.0.0
Add ``partial`` parameter.
"""
- pass
+ self._test_collection(value)
+ return self._load(value, data, partial=partial)
class Pluck(Nested):
@@ -399,11 +698,36 @@ class Pluck(Nested):
:param kwargs: The same keyword arguments that :class:`Nested` receives.
"""
- def __init__(self, nested: (SchemaABC | type | str | typing.Callable[[],
- SchemaABC]), field_name: str, **kwargs):
+ def __init__(
+ self,
+ nested: SchemaABC | type | str | typing.Callable[[], SchemaABC],
+ field_name: str,
+ **kwargs,
+ ):
super().__init__(nested, only=(field_name,), **kwargs)
self.field_name = field_name
+ @property
+ def _field_data_key(self):
+ only_field = self.schema.fields[self.field_name]
+ return only_field.data_key or self.field_name
+
+ def _serialize(self, nested_obj, attr, obj, **kwargs):
+ ret = super()._serialize(nested_obj, attr, obj, **kwargs)
+ if ret is None:
+ return None
+ if self.many:
+ return utils.pluck(ret, key=self._field_data_key)
+ return ret[self._field_data_key]
+
+ def _deserialize(self, value, attr, data, partial=None, **kwargs):
+ self._test_collection(value)
+ if self.many:
+ value = [{self._field_data_key: v} for v in value]
+ else:
+ value = {self._field_data_key: value}
+ return self._load(value, data, partial=partial)
+
class List(Field):
"""A list field, composed with another `Field` class or
@@ -423,20 +747,53 @@ class List(Field):
.. versionchanged:: 3.0.0rc9
Does not serialize scalar values to single-item lists.
"""
- default_error_messages = {'invalid': 'Not a valid list.'}
- def __init__(self, cls_or_instance: (Field | type), **kwargs):
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid list."}
+
+ def __init__(self, cls_or_instance: Field | type, **kwargs):
super().__init__(**kwargs)
try:
self.inner = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError as error:
raise ValueError(
- 'The list elements must be a subclass or instance of marshmallow.base.FieldABC.'
- ) from error
+ "The list elements must be a subclass or instance of "
+ "marshmallow.base.FieldABC."
+ ) from error
if isinstance(self.inner, Nested):
self.only = self.inner.only
self.exclude = self.inner.exclude
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ self.inner = copy.deepcopy(self.inner)
+ self.inner._bind_to_schema(field_name, self)
+ if isinstance(self.inner, Nested):
+ self.inner.only = self.only
+ self.inner.exclude = self.exclude
+
+ def _serialize(self, value, attr, obj, **kwargs) -> list[typing.Any] | None:
+ if value is None:
+ return None
+ return [self.inner._serialize(each, attr, obj, **kwargs) for each in value]
+
+ def _deserialize(self, value, attr, data, **kwargs) -> list[typing.Any]:
+ if not utils.is_collection(value):
+ raise self.make_error("invalid")
+
+ result = []
+ errors = {}
+ for idx, each in enumerate(value):
+ try:
+ result.append(self.inner.deserialize(each, **kwargs))
+ except ValidationError as error:
+ if error.valid_data is not None:
+ result.append(error.valid_data)
+ errors.update({idx: error.messages})
+ if errors:
+ raise ValidationError(errors, valid_data=result)
+ return result
+
class Tuple(Field):
"""A tuple field, composed of a fixed number of other `Field` classes or
@@ -457,40 +814,118 @@ class Tuple(Field):
.. versionadded:: 3.0.0rc4
"""
- default_error_messages = {'invalid': 'Not a valid tuple.'}
+
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid tuple."}
def __init__(self, tuple_fields, *args, **kwargs):
super().__init__(*args, **kwargs)
if not utils.is_collection(tuple_fields):
raise ValueError(
- 'tuple_fields must be an iterable of Field classes or instances.'
- )
+ "tuple_fields must be an iterable of Field classes or " "instances."
+ )
+
try:
- self.tuple_fields = [resolve_field_instance(cls_or_instance) for
- cls_or_instance in tuple_fields]
+ self.tuple_fields = [
+ resolve_field_instance(cls_or_instance)
+ for cls_or_instance in tuple_fields
+ ]
except FieldInstanceResolutionError as error:
raise ValueError(
- 'Elements of "tuple_fields" must be subclasses or instances of marshmallow.base.FieldABC.'
- ) from error
+ 'Elements of "tuple_fields" must be subclasses or '
+ "instances of marshmallow.base.FieldABC."
+ ) from error
+
self.validate_length = Length(equal=len(self.tuple_fields))
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ new_tuple_fields = []
+ for field in self.tuple_fields:
+ field = copy.deepcopy(field)
+ field._bind_to_schema(field_name, self)
+ new_tuple_fields.append(field)
+
+ self.tuple_fields = new_tuple_fields
+
+ def _serialize(self, value, attr, obj, **kwargs) -> tuple | None:
+ if value is None:
+ return None
+
+ return tuple(
+ field._serialize(each, attr, obj, **kwargs)
+ for field, each in zip(self.tuple_fields, value)
+ )
+
+ def _deserialize(self, value, attr, data, **kwargs) -> tuple:
+ if not utils.is_collection(value):
+ raise self.make_error("invalid")
+
+ self.validate_length(value)
+
+ result = []
+ errors = {}
+
+ for idx, (field, each) in enumerate(zip(self.tuple_fields, value)):
+ try:
+ result.append(field.deserialize(each, **kwargs))
+ except ValidationError as error:
+ if error.valid_data is not None:
+ result.append(error.valid_data)
+ errors.update({idx: error.messages})
+ if errors:
+ raise ValidationError(errors, valid_data=result)
+
+ return tuple(result)
+
class String(Field):
"""A string field.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- default_error_messages = {'invalid': 'Not a valid string.',
- 'invalid_utf8': 'Not a valid utf-8 string.'}
+
+ #: Default error messages.
+ default_error_messages = {
+ "invalid": "Not a valid string.",
+ "invalid_utf8": "Not a valid utf-8 string.",
+ }
+
+ def _serialize(self, value, attr, obj, **kwargs) -> str | None:
+ if value is None:
+ return None
+ return utils.ensure_text_type(value)
+
+ def _deserialize(self, value, attr, data, **kwargs) -> typing.Any:
+ if not isinstance(value, (str, bytes)):
+ raise self.make_error("invalid")
+ try:
+ return utils.ensure_text_type(value)
+ except UnicodeDecodeError as error:
+ raise self.make_error("invalid_utf8") from error
class UUID(String):
"""A UUID field."""
- default_error_messages = {'invalid_uuid': 'Not a valid UUID.'}
- def _validated(self, value) ->(uuid.UUID | None):
+ #: Default error messages.
+ default_error_messages = {"invalid_uuid": "Not a valid UUID."}
+
+ def _validated(self, value) -> uuid.UUID | None:
"""Format the value or raise a :exc:`ValidationError` if an error occurs."""
- pass
+ if value is None:
+ return None
+ if isinstance(value, uuid.UUID):
+ return value
+ try:
+ if isinstance(value, bytes) and len(value) == 16:
+ return uuid.UUID(bytes=value)
+ return uuid.UUID(value)
+ except (ValueError, AttributeError, TypeError) as error:
+ raise self.make_error("invalid_uuid") from error
+
+ def _deserialize(self, value, attr, data, **kwargs) -> uuid.UUID | None:
+ return self._validated(value)
class Number(Field):
@@ -499,25 +934,49 @@ class Number(Field):
:param bool as_string: If `True`, format the serialized value as a string.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- num_type = float
- default_error_messages = {'invalid': 'Not a valid number.', 'too_large':
- 'Number too large.'}
- def __init__(self, *, as_string: bool=False, **kwargs):
+ num_type = float # type: typing.Type
+
+ #: Default error messages.
+ default_error_messages = {
+ "invalid": "Not a valid number.",
+ "too_large": "Number too large.",
+ }
+
+ def __init__(self, *, as_string: bool = False, **kwargs):
self.as_string = as_string
super().__init__(**kwargs)
- def _format_num(self, value) ->typing.Any:
+ def _format_num(self, value) -> typing.Any:
"""Return the number value for value, given this field's `num_type`."""
- pass
+ return self.num_type(value)
- def _validated(self, value) ->(_T | None):
+ def _validated(self, value) -> _T | None:
"""Format the value or raise a :exc:`ValidationError` if an error occurs."""
- pass
+ if value is None:
+ return None
+ # (value is True or value is False) is ~5x faster than isinstance(value, bool)
+ if value is True or value is False:
+ raise self.make_error("invalid", input=value)
+ try:
+ return self._format_num(value)
+ except (TypeError, ValueError) as error:
+ raise self.make_error("invalid", input=value) from error
+ except OverflowError as error:
+ raise self.make_error("too_large", input=value) from error
- def _serialize(self, value, attr, obj, **kwargs) ->(str | _T | None):
+ def _to_string(self, value) -> str:
+ return str(value)
+
+ def _serialize(self, value, attr, obj, **kwargs) -> str | _T | None:
"""Return a string if `self.as_string=True`, otherwise return this field's `num_type`."""
- pass
+ if value is None:
+ return None
+ ret = self._format_num(value) # type: _T
+ return self._to_string(ret) if self.as_string else ret
+
+ def _deserialize(self, value, attr, data, **kwargs) -> _T | None:
+ return self._validated(value)
class Integer(Number):
@@ -527,13 +986,22 @@ class Integer(Number):
Otherwise, any value castable to `int` is valid.
:param kwargs: The same keyword arguments that :class:`Number` receives.
"""
+
num_type = int
- default_error_messages = {'invalid': 'Not a valid integer.'}
- def __init__(self, *, strict: bool=False, **kwargs):
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid integer."}
+
+ def __init__(self, *, strict: bool = False, **kwargs):
self.strict = strict
super().__init__(**kwargs)
+ # override Number
+ def _validated(self, value):
+ if self.strict and not isinstance(value, numbers.Integral):
+ raise self.make_error("invalid", input=value)
+ return super()._validated(value)
+
class Float(Number):
"""A double as an IEEE-754 double precision string.
@@ -543,15 +1011,25 @@ class Float(Number):
:param bool as_string: If `True`, format the value as a string.
:param kwargs: The same keyword arguments that :class:`Number` receives.
"""
+
num_type = float
- default_error_messages = {'special':
- 'Special numeric values (nan or infinity) are not permitted.'}
- def __init__(self, *, allow_nan: bool=False, as_string: bool=False, **
- kwargs):
+ #: Default error messages.
+ default_error_messages = {
+ "special": "Special numeric values (nan or infinity) are not permitted."
+ }
+
+ def __init__(self, *, allow_nan: bool = False, as_string: bool = False, **kwargs):
self.allow_nan = allow_nan
super().__init__(as_string=as_string, **kwargs)
+ def _validated(self, value):
+ num = super()._validated(value)
+ if self.allow_nan is False:
+ if math.isnan(num) or num == float("inf") or num == float("-inf"):
+ raise self.make_error("special")
+ return num
+
class Decimal(Number):
"""A field that (de)serializes to the Python ``decimal.Decimal`` type.
@@ -589,18 +1067,54 @@ class Decimal(Number):
.. versionadded:: 1.2.0
"""
+
num_type = decimal.Decimal
- default_error_messages = {'special':
- 'Special numeric values (nan or infinity) are not permitted.'}
- def __init__(self, places: (int | None)=None, rounding: (str | None)=
- None, *, allow_nan: bool=False, as_string: bool=False, **kwargs):
- self.places = decimal.Decimal((0, (1,), -places)
- ) if places is not None else None
+ #: Default error messages.
+ default_error_messages = {
+ "special": "Special numeric values (nan or infinity) are not permitted."
+ }
+
+ def __init__(
+ self,
+ places: int | None = None,
+ rounding: str | None = None,
+ *,
+ allow_nan: bool = False,
+ as_string: bool = False,
+ **kwargs,
+ ):
+ self.places = (
+ decimal.Decimal((0, (1,), -places)) if places is not None else None
+ )
self.rounding = rounding
self.allow_nan = allow_nan
super().__init__(as_string=as_string, **kwargs)
+ # override Number
+ def _format_num(self, value):
+ num = decimal.Decimal(str(value))
+ if self.allow_nan:
+ if num.is_nan():
+ return decimal.Decimal("NaN") # avoid sNaN, -sNaN and -NaN
+ if self.places is not None and num.is_finite():
+ num = num.quantize(self.places, rounding=self.rounding)
+ return num
+
+ # override Number
+ def _validated(self, value):
+ try:
+ num = super()._validated(value)
+ except decimal.InvalidOperation as error:
+ raise self.make_error("invalid") from error
+ if not self.allow_nan and (num.is_nan() or num.is_infinite()):
+ raise self.make_error("special")
+ return num
+
+ # override Number
+ def _to_string(self, value):
+ return format(value, "f")
+
class Boolean(Field):
"""A boolean field.
@@ -612,20 +1126,92 @@ class Boolean(Field):
`marshmallow.fields.Boolean.falsy` will be used.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- truthy = {'t', 'T', 'true', 'True', 'TRUE', 'on', 'On', 'ON', 'y', 'Y',
- 'yes', 'Yes', 'YES', '1', 1}
- falsy = {'f', 'F', 'false', 'False', 'FALSE', 'off', 'Off', 'OFF', 'n',
- 'N', 'no', 'No', 'NO', '0', 0}
- default_error_messages = {'invalid': 'Not a valid boolean.'}
-
- def __init__(self, *, truthy: (set | None)=None, falsy: (set | None)=
- None, **kwargs):
+
+ #: Default truthy values.
+ truthy = {
+ "t",
+ "T",
+ "true",
+ "True",
+ "TRUE",
+ "on",
+ "On",
+ "ON",
+ "y",
+ "Y",
+ "yes",
+ "Yes",
+ "YES",
+ "1",
+ 1,
+ # Equal to 1
+ # True,
+ }
+ #: Default falsy values.
+ falsy = {
+ "f",
+ "F",
+ "false",
+ "False",
+ "FALSE",
+ "off",
+ "Off",
+ "OFF",
+ "n",
+ "N",
+ "no",
+ "No",
+ "NO",
+ "0",
+ 0,
+ # Equal to 0
+ # 0.0,
+ # False,
+ }
+
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid boolean."}
+
+ def __init__(
+ self,
+ *,
+ truthy: set | None = None,
+ falsy: set | None = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
+
if truthy is not None:
self.truthy = set(truthy)
if falsy is not None:
self.falsy = set(falsy)
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+
+ try:
+ if value in self.truthy:
+ return True
+ if value in self.falsy:
+ return False
+ except TypeError:
+ pass
+
+ return bool(value)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if not self.truthy:
+ return bool(value)
+ try:
+ if value in self.truthy:
+ return True
+ if value in self.falsy:
+ return False
+ except TypeError as error:
+ raise self.make_error("invalid", input=value) from error
+ raise self.make_error("invalid", input=value)
+
class DateTime(Field):
"""A formatted datetime string.
@@ -642,24 +1228,78 @@ class DateTime(Field):
.. versionchanged:: 3.19
Add timestamp as a format.
"""
- SERIALIZATION_FUNCS = {'iso': utils.isoformat, 'iso8601': utils.
- isoformat, 'rfc': utils.rfcformat, 'rfc822': utils.rfcformat,
- 'timestamp': utils.timestamp, 'timestamp_ms': utils.timestamp_ms}
- DESERIALIZATION_FUNCS = {'iso': utils.from_iso_datetime, 'iso8601':
- utils.from_iso_datetime, 'rfc': utils.from_rfc, 'rfc822': utils.
- from_rfc, 'timestamp': utils.from_timestamp, 'timestamp_ms': utils.
- from_timestamp_ms}
- DEFAULT_FORMAT = 'iso'
- OBJ_TYPE = 'datetime'
- SCHEMA_OPTS_VAR_NAME = 'datetimeformat'
- default_error_messages = {'invalid': 'Not a valid {obj_type}.',
- 'invalid_awareness': 'Not a valid {awareness} {obj_type}.',
- 'format': '"{input}" cannot be formatted as a {obj_type}.'}
-
- def __init__(self, format: (str | None)=None, **kwargs) ->None:
+
+ SERIALIZATION_FUNCS = {
+ "iso": utils.isoformat,
+ "iso8601": utils.isoformat,
+ "rfc": utils.rfcformat,
+ "rfc822": utils.rfcformat,
+ "timestamp": utils.timestamp,
+ "timestamp_ms": utils.timestamp_ms,
+ } # type: typing.Dict[str, typing.Callable[[typing.Any], str | float]]
+
+ DESERIALIZATION_FUNCS = {
+ "iso": utils.from_iso_datetime,
+ "iso8601": utils.from_iso_datetime,
+ "rfc": utils.from_rfc,
+ "rfc822": utils.from_rfc,
+ "timestamp": utils.from_timestamp,
+ "timestamp_ms": utils.from_timestamp_ms,
+ } # type: typing.Dict[str, typing.Callable[[str], typing.Any]]
+
+ DEFAULT_FORMAT = "iso"
+
+ OBJ_TYPE = "datetime"
+
+ SCHEMA_OPTS_VAR_NAME = "datetimeformat"
+
+ #: Default error messages.
+ default_error_messages = {
+ "invalid": "Not a valid {obj_type}.",
+ "invalid_awareness": "Not a valid {awareness} {obj_type}.",
+ "format": '"{input}" cannot be formatted as a {obj_type}.',
+ }
+
+ def __init__(self, format: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
+ # Allow this to be None. It may be set later in the ``_serialize``
+ # or ``_deserialize`` methods. This allows a Schema to dynamically set the
+ # format, e.g. from a Meta option
self.format = format
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ self.format = (
+ self.format
+ or getattr(self.root.opts, self.SCHEMA_OPTS_VAR_NAME)
+ or self.DEFAULT_FORMAT
+ )
+
+ def _serialize(self, value, attr, obj, **kwargs) -> str | float | None:
+ if value is None:
+ return None
+ data_format = self.format or self.DEFAULT_FORMAT
+ format_func = self.SERIALIZATION_FUNCS.get(data_format)
+ if format_func:
+ return format_func(value)
+ return value.strftime(data_format)
+
+ def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
+ data_format = self.format or self.DEFAULT_FORMAT
+ func = self.DESERIALIZATION_FUNCS.get(data_format)
+ try:
+ if func:
+ return func(value)
+ return self._make_object_from_format(value, data_format)
+ except (TypeError, AttributeError, ValueError) as error:
+ raise self.make_error(
+ "invalid", input=value, obj_type=self.OBJ_TYPE
+ ) from error
+
+ @staticmethod
+ def _make_object_from_format(value, data_format) -> dt.datetime:
+ return dt.datetime.strptime(value, data_format)
+
class NaiveDateTime(DateTime):
"""A formatted naive datetime string.
@@ -673,13 +1313,31 @@ class NaiveDateTime(DateTime):
.. versionadded:: 3.0.0rc9
"""
- AWARENESS = 'naive'
- def __init__(self, format: (str | None)=None, *, timezone: (dt.timezone |
- None)=None, **kwargs) ->None:
+ AWARENESS = "naive"
+
+ def __init__(
+ self,
+ format: str | None = None,
+ *,
+ timezone: dt.timezone | None = None,
+ **kwargs,
+ ) -> None:
super().__init__(format=format, **kwargs)
self.timezone = timezone
+ def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
+ ret = super()._deserialize(value, attr, data, **kwargs)
+ if is_aware(ret):
+ if self.timezone is None:
+ raise self.make_error(
+ "invalid_awareness",
+ awareness=self.AWARENESS,
+ obj_type=self.OBJ_TYPE,
+ )
+ ret = ret.astimezone(self.timezone).replace(tzinfo=None)
+ return ret
+
class AwareDateTime(DateTime):
"""A formatted aware datetime string.
@@ -692,13 +1350,31 @@ class AwareDateTime(DateTime):
.. versionadded:: 3.0.0rc9
"""
- AWARENESS = 'aware'
- def __init__(self, format: (str | None)=None, *, default_timezone: (dt.
- tzinfo | None)=None, **kwargs) ->None:
+ AWARENESS = "aware"
+
+ def __init__(
+ self,
+ format: str | None = None,
+ *,
+ default_timezone: dt.tzinfo | None = None,
+ **kwargs,
+ ) -> None:
super().__init__(format=format, **kwargs)
self.default_timezone = default_timezone
+ def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
+ ret = super()._deserialize(value, attr, data, **kwargs)
+ if not is_aware(ret):
+ if self.default_timezone is None:
+ raise self.make_error(
+ "invalid_awareness",
+ awareness=self.AWARENESS,
+ obj_type=self.OBJ_TYPE,
+ )
+ ret = ret.replace(tzinfo=self.default_timezone)
+ return ret
+
class Time(DateTime):
"""A formatted time string.
@@ -709,13 +1385,20 @@ class Time(DateTime):
If `None`, defaults to "iso".
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- SERIALIZATION_FUNCS = {'iso': utils.to_iso_time, 'iso8601': utils.
- to_iso_time}
- DESERIALIZATION_FUNCS = {'iso': utils.from_iso_time, 'iso8601': utils.
- from_iso_time}
- DEFAULT_FORMAT = 'iso'
- OBJ_TYPE = 'time'
- SCHEMA_OPTS_VAR_NAME = 'timeformat'
+
+ SERIALIZATION_FUNCS = {"iso": utils.to_iso_time, "iso8601": utils.to_iso_time}
+
+ DESERIALIZATION_FUNCS = {"iso": utils.from_iso_time, "iso8601": utils.from_iso_time}
+
+ DEFAULT_FORMAT = "iso"
+
+ OBJ_TYPE = "time"
+
+ SCHEMA_OPTS_VAR_NAME = "timeformat"
+
+ @staticmethod
+ def _make_object_from_format(value, data_format):
+ return dt.datetime.strptime(value, data_format).time()
class Date(DateTime):
@@ -725,15 +1408,26 @@ class Date(DateTime):
If `None`, defaults to "iso".
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
- default_error_messages = {'invalid': 'Not a valid date.', 'format':
- '"{input}" cannot be formatted as a date.'}
- SERIALIZATION_FUNCS = {'iso': utils.to_iso_date, 'iso8601': utils.
- to_iso_date}
- DESERIALIZATION_FUNCS = {'iso': utils.from_iso_date, 'iso8601': utils.
- from_iso_date}
- DEFAULT_FORMAT = 'iso'
- OBJ_TYPE = 'date'
- SCHEMA_OPTS_VAR_NAME = 'dateformat'
+
+ #: Default error messages.
+ default_error_messages = {
+ "invalid": "Not a valid date.",
+ "format": '"{input}" cannot be formatted as a date.',
+ }
+
+ SERIALIZATION_FUNCS = {"iso": utils.to_iso_date, "iso8601": utils.to_iso_date}
+
+ DESERIALIZATION_FUNCS = {"iso": utils.from_iso_date, "iso8601": utils.from_iso_date}
+
+ DEFAULT_FORMAT = "iso"
+
+ OBJ_TYPE = "date"
+
+ SCHEMA_OPTS_VAR_NAME = "dateformat"
+
+ @staticmethod
+ def _make_object_from_format(value, data_format):
+ return dt.datetime.strptime(value, data_format).date()
class TimeDelta(Field):
@@ -768,32 +1462,77 @@ class TimeDelta(Field):
Allow (de)serialization to `float` through use of a new `serialization_type` parameter.
`int` is the default to retain previous behaviour.
"""
- DAYS = 'days'
- SECONDS = 'seconds'
- MICROSECONDS = 'microseconds'
- MILLISECONDS = 'milliseconds'
- MINUTES = 'minutes'
- HOURS = 'hours'
- WEEKS = 'weeks'
- default_error_messages = {'invalid': 'Not a valid period of time.',
- 'format': '{input!r} cannot be formatted as a timedelta.'}
-
- def __init__(self, precision: str=SECONDS, serialization_type: type[int |
- float]=int, **kwargs):
+
+ DAYS = "days"
+ SECONDS = "seconds"
+ MICROSECONDS = "microseconds"
+ MILLISECONDS = "milliseconds"
+ MINUTES = "minutes"
+ HOURS = "hours"
+ WEEKS = "weeks"
+
+ #: Default error messages.
+ default_error_messages = {
+ "invalid": "Not a valid period of time.",
+ "format": "{input!r} cannot be formatted as a timedelta.",
+ }
+
+ def __init__(
+ self,
+ precision: str = SECONDS,
+ serialization_type: type[int | float] = int,
+ **kwargs,
+ ):
precision = precision.lower()
- units = (self.DAYS, self.SECONDS, self.MICROSECONDS, self.
- MILLISECONDS, self.MINUTES, self.HOURS, self.WEEKS)
+ units = (
+ self.DAYS,
+ self.SECONDS,
+ self.MICROSECONDS,
+ self.MILLISECONDS,
+ self.MINUTES,
+ self.HOURS,
+ self.WEEKS,
+ )
+
if precision not in units:
- msg = 'The precision must be {} or "{}".'.format(', '.join([
- f'"{each}"' for each in units[:-1]]), units[-1])
+ msg = 'The precision must be {} or "{}".'.format(
+ ", ".join([f'"{each}"' for each in units[:-1]]), units[-1]
+ )
raise ValueError(msg)
+
if serialization_type not in (int, float):
- raise ValueError(
- 'The serialization type must be one of int or float')
+ raise ValueError("The serialization type must be one of int or float")
+
self.precision = precision
self.serialization_type = serialization_type
super().__init__(**kwargs)
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+
+ base_unit = dt.timedelta(**{self.precision: 1})
+
+ if self.serialization_type is int:
+ delta = utils.timedelta_to_microseconds(value)
+ unit = utils.timedelta_to_microseconds(base_unit)
+ return delta // unit
+ assert self.serialization_type is float
+ return value.total_seconds() / base_unit.total_seconds()
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ try:
+ value = self.serialization_type(value)
+ except (TypeError, ValueError) as error:
+ raise self.make_error("invalid") from error
+
+ kwargs = {self.precision: value}
+
+ try:
+ return dt.timedelta(**kwargs)
+ except OverflowError as error:
+ raise self.make_error("invalid") from error
+
class Mapping(Field):
"""An abstract class for objects with key-value pairs.
@@ -808,11 +1547,18 @@ class Mapping(Field):
.. versionadded:: 3.0.0rc4
"""
+
mapping_type = dict
- default_error_messages = {'invalid': 'Not a valid mapping type.'}
- def __init__(self, keys: (Field | type | None)=None, values: (Field |
- type | None)=None, **kwargs):
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid mapping type."}
+
+ def __init__(
+ self,
+ keys: Field | type | None = None,
+ values: Field | type | None = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
if keys is None:
self.key_field = None
@@ -821,8 +1567,10 @@ class Mapping(Field):
self.key_field = resolve_field_instance(keys)
except FieldInstanceResolutionError as error:
raise ValueError(
- '"keys" must be a subclass or instance of marshmallow.base.FieldABC.'
- ) from error
+ '"keys" must be a subclass or instance of '
+ "marshmallow.base.FieldABC."
+ ) from error
+
if values is None:
self.value_field = None
else:
@@ -830,12 +1578,94 @@ class Mapping(Field):
self.value_field = resolve_field_instance(values)
except FieldInstanceResolutionError as error:
raise ValueError(
- '"values" must be a subclass or instance of marshmallow.base.FieldABC.'
- ) from error
+ '"values" must be a subclass or instance of '
+ "marshmallow.base.FieldABC."
+ ) from error
if isinstance(self.value_field, Nested):
self.only = self.value_field.only
self.exclude = self.value_field.exclude
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ if self.value_field:
+ self.value_field = copy.deepcopy(self.value_field)
+ self.value_field._bind_to_schema(field_name, self)
+ if isinstance(self.value_field, Nested):
+ self.value_field.only = self.only
+ self.value_field.exclude = self.exclude
+ if self.key_field:
+ self.key_field = copy.deepcopy(self.key_field)
+ self.key_field._bind_to_schema(field_name, self)
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+ if not self.value_field and not self.key_field:
+ return self.mapping_type(value)
+
+ # Â Serialize keys
+ if self.key_field is None:
+ keys = {k: k for k in value.keys()}
+ else:
+ keys = {
+ k: self.key_field._serialize(k, None, None, **kwargs)
+ for k in value.keys()
+ }
+
+ # Â Serialize values
+ result = self.mapping_type()
+ if self.value_field is None:
+ for k, v in value.items():
+ if k in keys:
+ result[keys[k]] = v
+ else:
+ for k, v in value.items():
+ result[keys[k]] = self.value_field._serialize(v, None, None, **kwargs)
+
+ return result
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if not isinstance(value, _Mapping):
+ raise self.make_error("invalid")
+ if not self.value_field and not self.key_field:
+ return self.mapping_type(value)
+
+ errors = collections.defaultdict(dict)
+
+ # Â Deserialize keys
+ if self.key_field is None:
+ keys = {k: k for k in value.keys()}
+ else:
+ keys = {}
+ for key in value.keys():
+ try:
+ keys[key] = self.key_field.deserialize(key, **kwargs)
+ except ValidationError as error:
+ errors[key]["key"] = error.messages
+
+ # Â Deserialize values
+ result = self.mapping_type()
+ if self.value_field is None:
+ for k, v in value.items():
+ if k in keys:
+ result[keys[k]] = v
+ else:
+ for key, val in value.items():
+ try:
+ deser_val = self.value_field.deserialize(val, **kwargs)
+ except ValidationError as error:
+ errors[key]["value"] = error.messages
+ if error.valid_data is not None and key in keys:
+ result[keys[key]] = error.valid_data
+ else:
+ if key in keys:
+ result[keys[key]] = deser_val
+
+ if errors:
+ raise ValidationError(errors, valid_data=result)
+
+ return result
+
class Dict(Mapping):
"""A dict field. Supports dicts and dict-like objects. Extends
@@ -849,6 +1679,7 @@ class Dict(Mapping):
.. versionadded:: 2.1.0
"""
+
mapping_type = dict
@@ -862,18 +1693,32 @@ class Url(String):
``ftp``, and ``ftps`` are allowed.
:param kwargs: The same keyword arguments that :class:`String` receives.
"""
- default_error_messages = {'invalid': 'Not a valid URL.'}
- def __init__(self, *, relative: bool=False, absolute: bool=True,
- schemes: (types.StrSequenceOrSet | None)=None, require_tld: bool=
- True, **kwargs):
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid URL."}
+
+ def __init__(
+ self,
+ *,
+ relative: bool = False,
+ absolute: bool = True,
+ schemes: types.StrSequenceOrSet | None = None,
+ require_tld: bool = True,
+ **kwargs,
+ ):
super().__init__(**kwargs)
+
self.relative = relative
self.absolute = absolute
self.require_tld = require_tld
- validator = validate.URL(relative=self.relative, absolute=self.
- absolute, schemes=schemes, require_tld=self.require_tld, error=
- self.error_messages['invalid'])
+ # Insert validation into self.validators so that multiple errors can be stored.
+ validator = validate.URL(
+ relative=self.relative,
+ absolute=self.absolute,
+ schemes=schemes,
+ require_tld=self.require_tld,
+ error=self.error_messages["invalid"],
+ )
self.validators.insert(0, validator)
@@ -883,11 +1728,14 @@ class Email(String):
:param args: The same positional arguments that :class:`String` receives.
:param kwargs: The same keyword arguments that :class:`String` receives.
"""
- default_error_messages = {'invalid': 'Not a valid email address.'}
- def __init__(self, *args, **kwargs) ->None:
+ #: Default error messages.
+ default_error_messages = {"invalid": "Not a valid email address."}
+
+ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- validator = validate.Email(error=self.error_messages['invalid'])
+ # Insert validation into self.validators so that multiple errors can be stored.
+ validator = validate.Email(error=self.error_messages["invalid"])
self.validators.insert(0, validator)
@@ -899,20 +1747,43 @@ class IP(Field):
.. versionadded:: 3.8.0
"""
- default_error_messages = {'invalid_ip': 'Not a valid IP address.'}
- DESERIALIZATION_CLASS = None
+
+ default_error_messages = {"invalid_ip": "Not a valid IP address."}
+
+ DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type]
def __init__(self, *args, exploded=False, **kwargs):
super().__init__(*args, **kwargs)
self.exploded = exploded
+ def _serialize(self, value, attr, obj, **kwargs) -> str | None:
+ if value is None:
+ return None
+ if self.exploded:
+ return value.exploded
+ return value.compressed
+
+ def _deserialize(
+ self, value, attr, data, **kwargs
+ ) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
+ if value is None:
+ return None
+ try:
+ return (self.DESERIALIZATION_CLASS or ipaddress.ip_address)(
+ utils.ensure_text_type(value)
+ )
+ except (ValueError, TypeError) as error:
+ raise self.make_error("invalid_ip") from error
+
class IPv4(IP):
"""A IPv4 address field.
.. versionadded:: 3.8.0
"""
- default_error_messages = {'invalid_ip': 'Not a valid IPv4 address.'}
+
+ default_error_messages = {"invalid_ip": "Not a valid IPv4 address."}
+
DESERIALIZATION_CLASS = ipaddress.IPv4Address
@@ -921,7 +1792,9 @@ class IPv6(IP):
.. versionadded:: 3.8.0
"""
- default_error_messages = {'invalid_ip': 'Not a valid IPv6 address.'}
+
+ default_error_messages = {"invalid_ip": "Not a valid IPv6 address."}
+
DESERIALIZATION_CLASS = ipaddress.IPv6Address
@@ -938,26 +1811,48 @@ class IPInterface(Field):
:param bool exploded: If `True`, serialize ipv6 interface in long form, ie. with groups
consisting entirely of zeros included.
"""
- default_error_messages = {'invalid_ip_interface':
- 'Not a valid IP interface.'}
- DESERIALIZATION_CLASS = None
- def __init__(self, *args, exploded: bool=False, **kwargs):
+ default_error_messages = {"invalid_ip_interface": "Not a valid IP interface."}
+
+ DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type]
+
+ def __init__(self, *args, exploded: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.exploded = exploded
+ def _serialize(self, value, attr, obj, **kwargs) -> str | None:
+ if value is None:
+ return None
+ if self.exploded:
+ return value.exploded
+ return value.compressed
+
+ def _deserialize(self, value, attr, data, **kwargs) -> None | (
+ ipaddress.IPv4Interface | ipaddress.IPv6Interface
+ ):
+ if value is None:
+ return None
+ try:
+ return (self.DESERIALIZATION_CLASS or ipaddress.ip_interface)(
+ utils.ensure_text_type(value)
+ )
+ except (ValueError, TypeError) as error:
+ raise self.make_error("invalid_ip_interface") from error
+
class IPv4Interface(IPInterface):
"""A IPv4 Network Interface field."""
- default_error_messages = {'invalid_ip_interface':
- 'Not a valid IPv4 interface.'}
+
+ default_error_messages = {"invalid_ip_interface": "Not a valid IPv4 interface."}
+
DESERIALIZATION_CLASS = ipaddress.IPv4Interface
class IPv6Interface(IPInterface):
"""A IPv6 Network Interface field."""
- default_error_messages = {'invalid_ip_interface':
- 'Not a valid IPv6 interface.'}
+
+ default_error_messages = {"invalid_ip_interface": "Not a valid IPv6 interface."}
+
DESERIALIZATION_CLASS = ipaddress.IPv6Interface
@@ -974,17 +1869,29 @@ class Enum(Field):
.. versionadded:: 3.18.0
"""
- default_error_messages = {'unknown': 'Must be one of: {choices}.'}
- def __init__(self, enum: type[EnumType], *, by_value: (bool | Field |
- type)=False, **kwargs):
+ default_error_messages = {
+ "unknown": "Must be one of: {choices}.",
+ }
+
+ def __init__(
+ self,
+ enum: type[EnumType],
+ *,
+ by_value: bool | Field | type = False,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.enum = enum
self.by_value = by_value
+
+ # Serialization by name
if by_value is False:
self.field: Field = String()
- self.choices_text = ', '.join(str(self.field._serialize(m, None,
- None)) for m in enum.__members__)
+ self.choices_text = ", ".join(
+ str(self.field._serialize(m, None, None)) for m in enum.__members__
+ )
+ # Serialization by value
else:
if by_value is True:
self.field = Field()
@@ -993,10 +1900,33 @@ class Enum(Field):
self.field = resolve_field_instance(by_value)
except FieldInstanceResolutionError as error:
raise ValueError(
- '"by_value" must be either a bool or a subclass or instance of marshmallow.base.FieldABC.'
- ) from error
- self.choices_text = ', '.join(str(self.field._serialize(m.value,
- None, None)) for m in enum)
+ '"by_value" must be either a bool or a subclass or instance of '
+ "marshmallow.base.FieldABC."
+ ) from error
+ self.choices_text = ", ".join(
+ str(self.field._serialize(m.value, None, None)) for m in enum
+ )
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+ if self.by_value:
+ val = value.value
+ else:
+ val = value.name
+ return self.field._serialize(val, attr, obj, **kwargs)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ val = self.field._deserialize(value, attr, data, **kwargs)
+ if self.by_value:
+ try:
+ return self.enum(val)
+ except ValueError as error:
+ raise self.make_error("unknown", choices=self.choices_text) from error
+ try:
+ return getattr(self.enum, val)
+ except AttributeError as error:
+ raise self.make_error("unknown", choices=self.choices_text) from error
class Method(Field):
@@ -1019,18 +1949,47 @@ class Method(Field):
.. versionchanged:: 3.0.0
Removed ``method_name`` parameter.
"""
+
_CHECK_ATTRIBUTE = False
- def __init__(self, serialize: (str | None)=None, deserialize: (str |
- None)=None, **kwargs):
- kwargs['dump_only'] = bool(serialize) and not bool(deserialize)
- kwargs['load_only'] = bool(deserialize) and not bool(serialize)
+ def __init__(
+ self,
+ serialize: str | None = None,
+ deserialize: str | None = None,
+ **kwargs,
+ ):
+ # Set dump_only and load_only based on arguments
+ kwargs["dump_only"] = bool(serialize) and not bool(deserialize)
+ kwargs["load_only"] = bool(deserialize) and not bool(serialize)
super().__init__(**kwargs)
self.serialize_method_name = serialize
self.deserialize_method_name = deserialize
self._serialize_method = None
self._deserialize_method = None
+ def _bind_to_schema(self, field_name, schema):
+ if self.serialize_method_name:
+ self._serialize_method = utils.callable_or_raise(
+ getattr(schema, self.serialize_method_name)
+ )
+
+ if self.deserialize_method_name:
+ self._deserialize_method = utils.callable_or_raise(
+ getattr(schema, self.deserialize_method_name)
+ )
+
+ super()._bind_to_schema(field_name, schema)
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if self._serialize_method is not None:
+ return self._serialize_method(obj)
+ return missing_
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if self._deserialize_method is not None:
+ return self._deserialize_method(value)
+ return value
+
class Function(Field):
"""A field that takes the value returned by a function.
@@ -1054,18 +2013,45 @@ class Function(Field):
.. versionchanged:: 3.0.0a1
Removed ``func`` parameter.
"""
+
_CHECK_ATTRIBUTE = False
- def __init__(self, serialize: (None | typing.Callable[[typing.Any],
- typing.Any] | typing.Callable[[typing.Any, dict], typing.Any])=None,
- deserialize: (None | typing.Callable[[typing.Any], typing.Any] |
- typing.Callable[[typing.Any, dict], typing.Any])=None, **kwargs):
- kwargs['dump_only'] = bool(serialize) and not bool(deserialize)
- kwargs['load_only'] = bool(deserialize) and not bool(serialize)
+ def __init__(
+ self,
+ serialize: (
+ None
+ | typing.Callable[[typing.Any], typing.Any]
+ | typing.Callable[[typing.Any, dict], typing.Any]
+ ) = None,
+ deserialize: (
+ None
+ | typing.Callable[[typing.Any], typing.Any]
+ | typing.Callable[[typing.Any, dict], typing.Any]
+ ) = None,
+ **kwargs,
+ ):
+ # Set dump_only and load_only based on arguments
+ kwargs["dump_only"] = bool(serialize) and not bool(deserialize)
+ kwargs["load_only"] = bool(deserialize) and not bool(serialize)
super().__init__(**kwargs)
self.serialize_func = serialize and utils.callable_or_raise(serialize)
- self.deserialize_func = deserialize and utils.callable_or_raise(
- deserialize)
+ self.deserialize_func = deserialize and utils.callable_or_raise(deserialize)
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ return self._call_or_raise(self.serialize_func, obj, attr)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if self.deserialize_func:
+ return self._call_or_raise(self.deserialize_func, value, attr)
+ return value
+
+ def _call_or_raise(self, func, value, attr):
+ if len(utils.get_func_args(func)) > 1:
+ if self.parent.context is None:
+ msg = f"No context available for Function field {attr!r}"
+ raise ValidationError(msg)
+ return func(value, self.parent.context)
+ return func(value)
class Constant(Field):
@@ -1077,6 +2063,7 @@ class Constant(Field):
.. versionadded:: 2.0.0
"""
+
_CHECK_ATTRIBUTE = False
def __init__(self, constant: typing.Any, **kwargs):
@@ -1085,6 +2072,12 @@ class Constant(Field):
self.load_default = constant
self.dump_default = constant
+ def _serialize(self, value, *args, **kwargs):
+ return self.constant
+
+ def _deserialize(self, value, *args, **kwargs):
+ return self.constant
+
class Inferred(Field):
"""A field that infers how to serialize, based on the value type.
@@ -1097,9 +2090,24 @@ class Inferred(Field):
def __init__(self):
super().__init__()
+ # We memoize the fields to avoid creating and binding new fields
+ # every time on serialization.
self._field_cache = {}
+ def _serialize(self, value, attr, obj, **kwargs):
+ field_cls = self.root.TYPE_MAPPING.get(type(value))
+ if field_cls is None:
+ field = super()
+ else:
+ field = self._field_cache.get(field_cls)
+ if field is None:
+ field = field_cls()
+ field._bind_to_schema(self.name, self.parent)
+ self._field_cache[field_cls] = field
+ return field._serialize(value, attr, obj, **kwargs)
+
+# Aliases
URL = Url
Str = String
Bool = Boolean
diff --git a/src/marshmallow/orderedset.py b/src/marshmallow/orderedset.py
index 35553ec..7ce0723 100644
--- a/src/marshmallow/orderedset.py
+++ b/src/marshmallow/orderedset.py
@@ -1,12 +1,33 @@
+# OrderedSet
+# Copyright (c) 2009 Raymond Hettinger
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation files
+# (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge,
+# publish, distribute, sublicense, and/or sell copies of the Software,
+# and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
from collections.abc import MutableSet
class OrderedSet(MutableSet):
-
def __init__(self, iterable=None):
self.end = end = []
- end += [None, end, end]
- self.map = {}
+ end += [None, end, end] # sentinel node for doubly linked list
+ self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
@@ -16,6 +37,18 @@ class OrderedSet(MutableSet):
def __contains__(self, key):
return key in self.map
+ def add(self, key):
+ if key not in self.map:
+ end = self.end
+ curr = end[1]
+ curr[2] = end[1] = self.map[key] = [key, curr, end]
+
+ def discard(self, key):
+ if key in self.map:
+ key, prev, next = self.map.pop(key)
+ prev[2] = next
+ next[1] = prev
+
def __iter__(self):
end = self.end
curr = end[2]
@@ -30,10 +63,17 @@ class OrderedSet(MutableSet):
yield curr[0]
curr = curr[1]
+ def pop(self, last=True):
+ if not self:
+ raise KeyError("set is empty")
+ key = self.end[1][0] if last else self.end[2][0]
+ self.discard(key)
+ return key
+
def __repr__(self):
if not self:
- return f'{self.__class__.__name__}()'
- return f'{self.__class__.__name__}({list(self)!r})'
+ return f"{self.__class__.__name__}()"
+ return f"{self.__class__.__name__}({list(self)!r})"
def __eq__(self, other):
if isinstance(other, OrderedSet):
@@ -41,9 +81,9 @@ class OrderedSet(MutableSet):
return set(self) == set(other)
-if __name__ == '__main__':
- s = OrderedSet('abracadaba')
- t = OrderedSet('simsalabim')
+if __name__ == "__main__":
+ s = OrderedSet("abracadaba")
+ t = OrderedSet("simsalabim")
print(s | t)
print(s & t)
print(s - t)
diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py
index 1e6eabf..23b43c4 100644
--- a/src/marshmallow/schema.py
+++ b/src/marshmallow/schema.py
@@ -1,5 +1,7 @@
"""The :class:`Schema` class, including its metaclass and options (class Meta)."""
+
from __future__ import annotations
+
import copy
import datetime as dt
import decimal
@@ -11,15 +13,34 @@ import warnings
from abc import ABCMeta
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
+
from marshmallow import base, class_registry, types
from marshmallow import fields as ma_fields
-from marshmallow.decorators import POST_DUMP, POST_LOAD, PRE_DUMP, PRE_LOAD, VALIDATES, VALIDATES_SCHEMA
+from marshmallow.decorators import (
+ POST_DUMP,
+ POST_LOAD,
+ PRE_DUMP,
+ PRE_LOAD,
+ VALIDATES,
+ VALIDATES_SCHEMA,
+)
from marshmallow.error_store import ErrorStore
from marshmallow.exceptions import StringNotCollectionError, ValidationError
from marshmallow.orderedset import OrderedSet
-from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, get_value, is_collection, is_instance_or_subclass, missing, set_value, validate_unknown_parameter_value
+from marshmallow.utils import (
+ EXCLUDE,
+ INCLUDE,
+ RAISE,
+ get_value,
+ is_collection,
+ is_instance_or_subclass,
+ missing,
+ set_value,
+ validate_unknown_parameter_value,
+)
from marshmallow.warnings import RemovedInMarshmallow4Warning
-_T = typing.TypeVar('_T')
+
+_T = typing.TypeVar("_T")
def _get_fields(attrs):
@@ -27,9 +48,15 @@ def _get_fields(attrs):
:param attrs: Mapping of class attributes
"""
- pass
+ return [
+ (field_name, field_value)
+ for field_name, field_value in attrs.items()
+ if is_instance_or_subclass(field_value, base.FieldABC)
+ ]
+# This function allows Schemas to inherit from non-Schema classes and ensures
+# inheritance according to the MRO
def _get_fields_by_mro(klass):
"""Collect fields from a class, following its method resolution order. The
class itself is excluded from the search; only its parents are checked. Get
@@ -37,7 +64,17 @@ def _get_fields_by_mro(klass):
:param type klass: Class whose fields to retrieve
"""
- pass
+ mro = inspect.getmro(klass)
+ # Loop over mro in reverse to maintain correct order of fields
+ return sum(
+ (
+ _get_fields(
+ getattr(base, "_declared_fields", base.__dict__),
+ )
+ for base in mro[:0:-1]
+ ),
+ [],
+ )
class SchemaMeta(ABCMeta):
@@ -48,31 +85,51 @@ class SchemaMeta(ABCMeta):
"""
def __new__(mcs, name, bases, attrs):
- meta = attrs.get('Meta')
- ordered = getattr(meta, 'ordered', False)
+ meta = attrs.get("Meta")
+ ordered = getattr(meta, "ordered", False)
if not ordered:
+ # Inherit 'ordered' option
+ # Warning: We loop through bases instead of MRO because we don't
+ # yet have access to the class object
+ # (i.e. can't call super before we have fields)
for base_ in bases:
- if hasattr(base_, 'Meta') and hasattr(base_.Meta, 'ordered'):
+ if hasattr(base_, "Meta") and hasattr(base_.Meta, "ordered"):
ordered = base_.Meta.ordered
break
else:
ordered = False
cls_fields = _get_fields(attrs)
+ # Remove fields from list of class attributes to avoid shadowing
+ # Schema attributes/methods in case of name conflict
for field_name, _ in cls_fields:
del attrs[field_name]
klass = super().__new__(mcs, name, bases, attrs)
inherited_fields = _get_fields_by_mro(klass)
+
meta = klass.Meta
+ # Set klass.opts in __new__ rather than __init__ so that it is accessible in
+ # get_declared_fields
klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)
+ # Add fields specified in the `include` class Meta option
cls_fields += list(klass.opts.include.items())
- klass._declared_fields = mcs.get_declared_fields(klass=klass,
- cls_fields=cls_fields, inherited_fields=inherited_fields,
- dict_cls=dict)
+
+ # Assign _declared_fields on class
+ klass._declared_fields = mcs.get_declared_fields(
+ klass=klass,
+ cls_fields=cls_fields,
+ inherited_fields=inherited_fields,
+ dict_cls=dict,
+ )
return klass
@classmethod
- def get_declared_fields(mcs, klass: type, cls_fields: list,
- inherited_fields: list, dict_cls: type=dict):
+ def get_declared_fields(
+ mcs,
+ klass: type,
+ cls_fields: list,
+ inherited_fields: list,
+ dict_cls: type = dict,
+ ):
"""Returns a dictionary of field_name => `Field` pairs declared on the class.
This is exposed mainly so that plugins can add additional fields, e.g. fields
computed from class Meta options.
@@ -83,7 +140,7 @@ class SchemaMeta(ABCMeta):
:param inherited_fields: Inherited fields.
:param dict_cls: dict-like class to use for dict output Default to ``dict``.
"""
- pass
+ return dict_cls(inherited_fields + cls_fields)
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
@@ -91,51 +148,84 @@ class SchemaMeta(ABCMeta):
class_registry.register(name, cls)
cls._hooks = cls.resolve_hooks()
- def resolve_hooks(cls) ->dict[types.Tag, list[str]]:
+ def resolve_hooks(cls) -> dict[types.Tag, list[str]]:
"""Add in the decorated processors
By doing this after constructing the class, we let standard inheritance
do all the hard work.
"""
- pass
+ mro = inspect.getmro(cls)
+
+ hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]
+
+ for attr_name in dir(cls):
+ # Need to look up the actual descriptor, not whatever might be
+ # bound to the class. This needs to come from the __dict__ of the
+ # declaring class.
+ for parent in mro:
+ try:
+ attr = parent.__dict__[attr_name]
+ except KeyError:
+ continue
+ else:
+ break
+ else:
+ # In case we didn't find the attribute and didn't break above.
+ # We should never hit this - it's just here for completeness
+ # to exclude the possibility of attr being undefined.
+ continue
+
+ try:
+ hook_config = attr.__marshmallow_hook__
+ except AttributeError:
+ pass
+ else:
+ for key in hook_config.keys():
+ # Use name here so we can get the bound method later, in
+ # case the processor was a descriptor or something.
+ hooks[key].append(attr_name)
+
+ return hooks
class SchemaOpts:
"""class Meta options for the :class:`Schema`. Defines defaults."""
- def __init__(self, meta, ordered: bool=False):
- self.fields = getattr(meta, 'fields', ())
+ def __init__(self, meta, ordered: bool = False):
+ self.fields = getattr(meta, "fields", ())
if not isinstance(self.fields, (list, tuple)):
- raise ValueError('`fields` option must be a list or tuple.')
- self.additional = getattr(meta, 'additional', ())
+ raise ValueError("`fields` option must be a list or tuple.")
+ self.additional = getattr(meta, "additional", ())
if not isinstance(self.additional, (list, tuple)):
- raise ValueError('`additional` option must be a list or tuple.')
+ raise ValueError("`additional` option must be a list or tuple.")
if self.fields and self.additional:
raise ValueError(
- 'Cannot set both `fields` and `additional` options for the same Schema.'
- )
- self.exclude = getattr(meta, 'exclude', ())
+ "Cannot set both `fields` and `additional` options"
+ " for the same Schema."
+ )
+ self.exclude = getattr(meta, "exclude", ())
if not isinstance(self.exclude, (list, tuple)):
- raise ValueError('`exclude` must be a list or tuple.')
- self.dateformat = getattr(meta, 'dateformat', None)
- self.datetimeformat = getattr(meta, 'datetimeformat', None)
- self.timeformat = getattr(meta, 'timeformat', None)
- if hasattr(meta, 'json_module'):
+ raise ValueError("`exclude` must be a list or tuple.")
+ self.dateformat = getattr(meta, "dateformat", None)
+ self.datetimeformat = getattr(meta, "datetimeformat", None)
+ self.timeformat = getattr(meta, "timeformat", None)
+ if hasattr(meta, "json_module"):
warnings.warn(
- 'The json_module class Meta option is deprecated. Use render_module instead.'
- , RemovedInMarshmallow4Warning, stacklevel=2)
- render_module = getattr(meta, 'json_module', json)
+ "The json_module class Meta option is deprecated. Use render_module instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ render_module = getattr(meta, "json_module", json)
else:
render_module = json
- self.render_module = getattr(meta, 'render_module', render_module)
- self.ordered = getattr(meta, 'ordered', ordered)
- self.index_errors = getattr(meta, 'index_errors', True)
- self.include = getattr(meta, 'include', {})
- self.load_only = getattr(meta, 'load_only', ())
- self.dump_only = getattr(meta, 'dump_only', ())
- self.unknown = validate_unknown_parameter_value(getattr(meta,
- 'unknown', RAISE))
- self.register = getattr(meta, 'register', True)
+ self.render_module = getattr(meta, "render_module", render_module)
+ self.ordered = getattr(meta, "ordered", ordered)
+ self.index_errors = getattr(meta, "index_errors", True)
+ self.include = getattr(meta, "include", {})
+ self.load_only = getattr(meta, "load_only", ())
+ self.dump_only = getattr(meta, "dump_only", ())
+ self.unknown = validate_unknown_parameter_value(getattr(meta, "unknown", RAISE))
+ self.register = getattr(meta, "register", True)
class Schema(base.SchemaABC, metaclass=SchemaMeta):
@@ -197,21 +287,39 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
`__accessor__` and `__error_handler__` are deprecated. Implement the
`handle_error` and `get_attribute` methods instead.
"""
- TYPE_MAPPING = {str: ma_fields.String, bytes: ma_fields.String, dt.
- datetime: ma_fields.DateTime, float: ma_fields.Float, bool:
- ma_fields.Boolean, tuple: ma_fields.Raw, list: ma_fields.Raw, set:
- ma_fields.Raw, int: ma_fields.Integer, uuid.UUID: ma_fields.UUID,
- dt.time: ma_fields.Time, dt.date: ma_fields.Date, dt.timedelta:
- ma_fields.TimeDelta, decimal.Decimal: ma_fields.Decimal}
- error_messages = {}
- _default_error_messages = {'type': 'Invalid input type.', 'unknown':
- 'Unknown field.'}
- OPTIONS_CLASS = SchemaOpts
+
+ TYPE_MAPPING = {
+ str: ma_fields.String,
+ bytes: ma_fields.String,
+ dt.datetime: ma_fields.DateTime,
+ float: ma_fields.Float,
+ bool: ma_fields.Boolean,
+ tuple: ma_fields.Raw,
+ list: ma_fields.Raw,
+ set: ma_fields.Raw,
+ int: ma_fields.Integer,
+ uuid.UUID: ma_fields.UUID,
+ dt.time: ma_fields.Time,
+ dt.date: ma_fields.Date,
+ dt.timedelta: ma_fields.TimeDelta,
+ decimal.Decimal: ma_fields.Decimal,
+ } # type: typing.Dict[type, typing.Type[ma_fields.Field]]
+ #: Overrides for default schema-level error messages
+ error_messages = {} # type: typing.Dict[str, str]
+
+ _default_error_messages = {
+ "type": "Invalid input type.",
+ "unknown": "Unknown field.",
+ } # type: typing.Dict[str, str]
+
+ OPTIONS_CLASS = SchemaOpts # type: type
+
set_class = OrderedSet
- opts = None
- _declared_fields = {}
- _hooks = {}
+ # These get set by SchemaMeta
+ opts = None # type: SchemaOpts
+ _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]
+ _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]
class Meta:
"""Options object for a Schema.
@@ -252,47 +360,67 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
usage is critical. Defaults to `True`.
"""
- def __init__(self, *, only: (types.StrSequenceOrSet | None)=None,
- exclude: types.StrSequenceOrSet=(), many: bool=False, context: (
- dict | None)=None, load_only: types.StrSequenceOrSet=(), dump_only:
- types.StrSequenceOrSet=(), partial: (bool | types.StrSequenceOrSet |
- None)=None, unknown: (str | None)=None):
+ def __init__(
+ self,
+ *,
+ only: types.StrSequenceOrSet | None = None,
+ exclude: types.StrSequenceOrSet = (),
+ many: bool = False,
+ context: dict | None = None,
+ load_only: types.StrSequenceOrSet = (),
+ dump_only: types.StrSequenceOrSet = (),
+ partial: bool | types.StrSequenceOrSet | None = None,
+ unknown: str | None = None,
+ ):
+ # Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
- raise StringNotCollectionError('"only" should be a list of strings'
- )
+ raise StringNotCollectionError('"only" should be a list of strings')
if not is_collection(exclude):
- raise StringNotCollectionError(
- '"exclude" should be a list of strings')
+ raise StringNotCollectionError('"exclude" should be a list of strings')
+ # copy declared fields from metaclass
self.declared_fields = copy.deepcopy(self._declared_fields)
self.many = many
self.only = only
self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(
- self.opts.exclude) | set(exclude)
+ self.opts.exclude
+ ) | set(exclude)
self.ordered = self.opts.ordered
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
self.partial = partial
- self.unknown = (self.opts.unknown if unknown is None else
- validate_unknown_parameter_value(unknown))
+ self.unknown = (
+ self.opts.unknown
+ if unknown is None
+ else validate_unknown_parameter_value(unknown)
+ )
self.context = context or {}
self._normalize_nested_options()
- self.fields = {}
- self.load_fields = {}
- self.dump_fields = {}
+ #: Dictionary mapping field_names -> :class:`Field` objects
+ self.fields = {} # type: typing.Dict[str, ma_fields.Field]
+ self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]
+ self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]
self._init_fields()
messages = {}
messages.update(self._default_error_messages)
for cls in reversed(self.__class__.__mro__):
- messages.update(getattr(cls, 'error_messages', {}))
+ messages.update(getattr(cls, "error_messages", {}))
messages.update(self.error_messages or {})
self.error_messages = messages
- def __repr__(self) ->str:
- return f'<{self.__class__.__name__}(many={self.many})>'
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__}(many={self.many})>"
+
+ @property
+ def dict_class(self) -> type:
+ return OrderedDict if self.ordered else dict
@classmethod
- def from_dict(cls, fields: dict[str, ma_fields.Field | type], *, name:
- str='GeneratedSchema') ->type:
+ def from_dict(
+ cls,
+ fields: dict[str, ma_fields.Field | type],
+ *,
+ name: str = "GeneratedSchema",
+ ) -> type:
"""Generate a `Schema` class given a dictionary of fields.
.. code-block:: python
@@ -311,10 +439,18 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionadded:: 3.0.0
"""
- pass
-
- def handle_error(self, error: ValidationError, data: typing.Any, *,
- many: bool, **kwargs):
+ attrs = fields.copy()
+ attrs["Meta"] = type(
+ "GeneratedMeta", (getattr(cls, "Meta", object),), {"register": False}
+ )
+ schema_cls = type(name, (cls,), attrs)
+ return schema_cls
+
+ ##### Override-able methods #####
+
+ def handle_error(
+ self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs
+ ):
"""Custom error handler function for the schema.
:param error: The `ValidationError` raised during (de)serialization.
@@ -337,11 +473,12 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0a1
Changed position of ``obj`` and ``attr``.
"""
- pass
+ return get_value(obj, attr, default)
+
+ ##### Serialization/Deserialization API #####
@staticmethod
- def _call_and_store(getter_func, data, *, field_name, error_store,
- index=None):
+ def _call_and_store(getter_func, data, *, field_name, error_store, index=None):
"""Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.
:param callable getter_func: Function for getting the serialized/deserialized
@@ -351,9 +488,16 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
:param int index: Index of the item being validated, if validating a collection,
otherwise `None`.
"""
- pass
-
- def _serialize(self, obj: (_T | typing.Iterable[_T]), *, many: bool=False):
+ try:
+ value = getter_func(data)
+ except ValidationError as error:
+ error_store.store_error(error.messages, field_name, index=index)
+ # When a Nested field fails validation, the marshalled data is stored
+ # on the ValidationError's valid_data attribute
+ return error.valid_data or missing
+ return value
+
+ def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):
"""Serialize ``obj``.
:param obj: The object(s) to serialize.
@@ -363,9 +507,21 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 1.0.0
Renamed from ``marshal``.
"""
- pass
-
- def dump(self, obj: typing.Any, *, many: (bool | None)=None):
+ if many and obj is not None:
+ return [
+ self._serialize(d, many=False)
+ for d in typing.cast(typing.Iterable[_T], obj)
+ ]
+ ret = self.dict_class()
+ for attr_name, field_obj in self.dump_fields.items():
+ value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)
+ if value is missing:
+ continue
+ key = field_obj.data_key if field_obj.data_key is not None else attr_name
+ ret[key] = value
+ return ret
+
+ def dump(self, obj: typing.Any, *, many: bool | None = None):
"""Serialize an object to native Python data types according to this
Schema's fields.
@@ -382,10 +538,24 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0rc9
Validation no longer occurs upon serialization.
"""
- pass
+ many = self.many if many is None else bool(many)
+ if self._has_processors(PRE_DUMP):
+ processed_obj = self._invoke_dump_processors(
+ PRE_DUMP, obj, many=many, original_data=obj
+ )
+ else:
+ processed_obj = obj
- def dumps(self, obj: typing.Any, *args, many: (bool | None)=None, **kwargs
- ):
+ result = self._serialize(processed_obj, many=many)
+
+ if self._has_processors(POST_DUMP):
+ result = self._invoke_dump_processors(
+ POST_DUMP, result, many=many, original_data=obj
+ )
+
+ return result
+
+ def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):
"""Same as :meth:`dump`, except return a JSON-encoded string.
:param obj: The object to serialize.
@@ -399,12 +569,22 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
A :exc:`ValidationError <marshmallow.exceptions.ValidationError>` is raised
if ``obj`` is invalid.
"""
- pass
-
- def _deserialize(self, data: (typing.Mapping[str, typing.Any] | typing.
- Iterable[typing.Mapping[str, typing.Any]]), *, error_store:
- ErrorStore, many: bool=False, partial=None, unknown=RAISE, index=None
- ) ->(_T | list[_T]):
+ serialized = self.dump(obj, many=many)
+ return self.opts.render_module.dumps(serialized, *args, **kwargs)
+
+ def _deserialize(
+ self,
+ data: (
+ typing.Mapping[str, typing.Any]
+ | typing.Iterable[typing.Mapping[str, typing.Any]]
+ ),
+ *,
+ error_store: ErrorStore,
+ many: bool = False,
+ partial=None,
+ unknown=RAISE,
+ index=None,
+ ) -> _T | list[_T]:
"""Deserialize ``data``.
:param dict data: The data to deserialize.
@@ -420,12 +600,105 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
serializing a collection, otherwise `None`.
:return: A dictionary of the deserialized data.
"""
- pass
-
- def load(self, data: (typing.Mapping[str, typing.Any] | typing.Iterable
- [typing.Mapping[str, typing.Any]]), *, many: (bool | None)=None,
- partial: (bool | types.StrSequenceOrSet | None)=None, unknown: (str |
- None)=None):
+ index_errors = self.opts.index_errors
+ index = index if index_errors else None
+ if many:
+ if not is_collection(data):
+ error_store.store_error([self.error_messages["type"]], index=index)
+ ret_l = [] # type: typing.List[_T]
+ else:
+ ret_l = [
+ typing.cast(
+ _T,
+ self._deserialize(
+ typing.cast(typing.Mapping[str, typing.Any], d),
+ error_store=error_store,
+ many=False,
+ partial=partial,
+ unknown=unknown,
+ index=idx,
+ ),
+ )
+ for idx, d in enumerate(data)
+ ]
+ return ret_l
+ ret_d = self.dict_class()
+ # Check data is a dict
+ if not isinstance(data, Mapping):
+ error_store.store_error([self.error_messages["type"]], index=index)
+ else:
+ partial_is_collection = is_collection(partial)
+ for attr_name, field_obj in self.load_fields.items():
+ field_name = (
+ field_obj.data_key if field_obj.data_key is not None else attr_name
+ )
+ raw_value = data.get(field_name, missing)
+ if raw_value is missing:
+ # Ignore missing field if we're allowed to.
+ if partial is True or (
+ partial_is_collection and attr_name in partial
+ ):
+ continue
+ d_kwargs = {}
+ # Allow partial loading of nested schemas.
+ if partial_is_collection:
+ prefix = field_name + "."
+ len_prefix = len(prefix)
+ sub_partial = [
+ f[len_prefix:] for f in partial if f.startswith(prefix)
+ ]
+ d_kwargs["partial"] = sub_partial
+ elif partial is not None:
+ d_kwargs["partial"] = partial
+
+ def getter(
+ val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs
+ ):
+ return field_obj.deserialize(
+ val,
+ field_name,
+ data,
+ **d_kwargs,
+ )
+
+ value = self._call_and_store(
+ getter_func=getter,
+ data=raw_value,
+ field_name=field_name,
+ error_store=error_store,
+ index=index,
+ )
+ if value is not missing:
+ key = field_obj.attribute or attr_name
+ set_value(ret_d, key, value)
+ if unknown != EXCLUDE:
+ fields = {
+ field_obj.data_key if field_obj.data_key is not None else field_name
+ for field_name, field_obj in self.load_fields.items()
+ }
+ for key in set(data) - fields:
+ value = data[key]
+ if unknown == INCLUDE:
+ ret_d[key] = value
+ elif unknown == RAISE:
+ error_store.store_error(
+ [self.error_messages["unknown"]],
+ key,
+ (index if index_errors else None),
+ )
+ return ret_d
+
+ def load(
+ self,
+ data: (
+ typing.Mapping[str, typing.Any]
+ | typing.Iterable[typing.Mapping[str, typing.Any]]
+ ),
+ *,
+ many: bool | None = None,
+ partial: bool | types.StrSequenceOrSet | None = None,
+ unknown: str | None = None,
+ ):
"""Deserialize a data structure to an object defined by this Schema's fields.
:param data: The data to deserialize.
@@ -446,11 +719,19 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
A :exc:`ValidationError <marshmallow.exceptions.ValidationError>` is raised
if invalid data are passed.
"""
- pass
-
- def loads(self, json_data: str, *, many: (bool | None)=None, partial: (
- bool | types.StrSequenceOrSet | None)=None, unknown: (str | None)=
- None, **kwargs):
+ return self._do_load(
+ data, many=many, partial=partial, unknown=unknown, postprocess=True
+ )
+
+ def loads(
+ self,
+ json_data: str,
+ *,
+ many: bool | None = None,
+ partial: bool | types.StrSequenceOrSet | None = None,
+ unknown: str | None = None,
+ **kwargs,
+ ):
"""Same as :meth:`load`, except it takes a JSON string as input.
:param json_data: A JSON string of the data to deserialize.
@@ -471,12 +752,39 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
A :exc:`ValidationError <marshmallow.exceptions.ValidationError>` is raised
if invalid data are passed.
"""
- pass
-
- def validate(self, data: (typing.Mapping[str, typing.Any] | typing.
- Iterable[typing.Mapping[str, typing.Any]]), *, many: (bool | None)=
- None, partial: (bool | types.StrSequenceOrSet | None)=None) ->dict[
- str, list[str]]:
+ data = self.opts.render_module.loads(json_data, **kwargs)
+ return self.load(data, many=many, partial=partial, unknown=unknown)
+
+ def _run_validator(
+ self,
+ validator_func,
+ output,
+ *,
+ original_data,
+ error_store,
+ many,
+ partial,
+ pass_original,
+ index=None,
+ ):
+ try:
+ if pass_original: # Pass original, raw data (before unmarshalling)
+ validator_func(output, original_data, partial=partial, many=many)
+ else:
+ validator_func(output, partial=partial, many=many)
+ except ValidationError as err:
+ error_store.store_error(err.messages, err.field_name, index=index)
+
+ def validate(
+ self,
+ data: (
+ typing.Mapping[str, typing.Any]
+ | typing.Iterable[typing.Mapping[str, typing.Any]]
+ ),
+ *,
+ many: bool | None = None,
+ partial: bool | types.StrSequenceOrSet | None = None,
+ ) -> dict[str, list[str]]:
"""Validate `data` against the schema, returning a dictionary of
validation errors.
@@ -491,12 +799,26 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionadded:: 1.1.0
"""
- pass
-
- def _do_load(self, data: (typing.Mapping[str, typing.Any] | typing.
- Iterable[typing.Mapping[str, typing.Any]]), *, many: (bool | None)=
- None, partial: (bool | types.StrSequenceOrSet | None)=None, unknown:
- (str | None)=None, postprocess: bool=True):
+ try:
+ self._do_load(data, many=many, partial=partial, postprocess=False)
+ except ValidationError as exc:
+ return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)
+ return {}
+
+ ##### Private Helpers #####
+
+ def _do_load(
+ self,
+ data: (
+ typing.Mapping[str, typing.Any]
+ | typing.Iterable[typing.Mapping[str, typing.Any]]
+ ),
+ *,
+ many: bool | None = None,
+ partial: bool | types.StrSequenceOrSet | None = None,
+ unknown: str | None = None,
+ postprocess: bool = True,
+ ):
"""Deserialize `data`, returning the deserialized result.
This method is private API.
@@ -513,41 +835,394 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
:param postprocess: Whether to run post_load methods..
:return: Deserialized data
"""
- pass
-
- def _normalize_nested_options(self) ->None:
+ error_store = ErrorStore()
+ errors = {} # type: dict[str, list[str]]
+ many = self.many if many is None else bool(many)
+ unknown = (
+ self.unknown
+ if unknown is None
+ else validate_unknown_parameter_value(unknown)
+ )
+ if partial is None:
+ partial = self.partial
+ # Run preprocessors
+ if self._has_processors(PRE_LOAD):
+ try:
+ processed_data = self._invoke_load_processors(
+ PRE_LOAD, data, many=many, original_data=data, partial=partial
+ )
+ except ValidationError as err:
+ errors = err.normalized_messages()
+ result = None # type: list | dict | None
+ else:
+ processed_data = data
+ if not errors:
+ # Deserialize data
+ result = self._deserialize(
+ processed_data,
+ error_store=error_store,
+ many=many,
+ partial=partial,
+ unknown=unknown,
+ )
+ # Run field-level validation
+ self._invoke_field_validators(
+ error_store=error_store, data=result, many=many
+ )
+ # Run schema-level validation
+ if self._has_processors(VALIDATES_SCHEMA):
+ field_errors = bool(error_store.errors)
+ self._invoke_schema_validators(
+ error_store=error_store,
+ pass_many=True,
+ data=result,
+ original_data=data,
+ many=many,
+ partial=partial,
+ field_errors=field_errors,
+ )
+ self._invoke_schema_validators(
+ error_store=error_store,
+ pass_many=False,
+ data=result,
+ original_data=data,
+ many=many,
+ partial=partial,
+ field_errors=field_errors,
+ )
+ errors = error_store.errors
+ # Run post processors
+ if not errors and postprocess and self._has_processors(POST_LOAD):
+ try:
+ result = self._invoke_load_processors(
+ POST_LOAD,
+ result,
+ many=many,
+ original_data=data,
+ partial=partial,
+ )
+ except ValidationError as err:
+ errors = err.normalized_messages()
+ if errors:
+ exc = ValidationError(errors, data=data, valid_data=result)
+ self.handle_error(exc, data, many=many, partial=partial)
+ raise exc
+
+ return result
+
+ def _normalize_nested_options(self) -> None:
"""Apply then flatten nested schema options.
This method is private API.
"""
- pass
-
- def __apply_nested_option(self, option_name, field_names, set_operation
- ) ->None:
+ if self.only is not None:
+ # Apply the only option to nested fields.
+ self.__apply_nested_option("only", self.only, "intersection")
+ # Remove the child field names from the only option.
+ self.only = self.set_class([field.split(".", 1)[0] for field in self.only])
+ if self.exclude:
+ # Apply the exclude option to nested fields.
+ self.__apply_nested_option("exclude", self.exclude, "union")
+ # Remove the parent field names from the exclude option.
+ self.exclude = self.set_class(
+ [field for field in self.exclude if "." not in field]
+ )
+
+ def __apply_nested_option(self, option_name, field_names, set_operation) -> None:
"""Apply nested options to nested fields"""
- pass
-
- def _init_fields(self) ->None:
+ # Split nested field names on the first dot.
+ nested_fields = [name.split(".", 1) for name in field_names if "." in name]
+ # Partition the nested field names by parent field.
+ nested_options = defaultdict(list) # type: defaultdict
+ for parent, nested_names in nested_fields:
+ nested_options[parent].append(nested_names)
+ # Apply the nested field options.
+ for key, options in iter(nested_options.items()):
+ new_options = self.set_class(options)
+ original_options = getattr(self.declared_fields[key], option_name, ())
+ if original_options:
+ if set_operation == "union":
+ new_options |= self.set_class(original_options)
+ if set_operation == "intersection":
+ new_options &= self.set_class(original_options)
+ setattr(self.declared_fields[key], option_name, new_options)
+
+ def _init_fields(self) -> None:
"""Update self.fields, self.load_fields, and self.dump_fields based on schema options.
This method is private API.
"""
- pass
+ if self.opts.fields:
+ available_field_names = self.set_class(self.opts.fields)
+ else:
+ available_field_names = self.set_class(self.declared_fields.keys())
+ if self.opts.additional:
+ available_field_names |= self.set_class(self.opts.additional)
+
+ invalid_fields = self.set_class()
+
+ if self.only is not None:
+ # Return only fields specified in only option
+ field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)
+
+ invalid_fields |= field_names - available_field_names
+ else:
+ field_names = available_field_names
+
+ # If "exclude" option or param is specified, remove those fields.
+ if self.exclude:
+ # Note that this isn't available_field_names, since we want to
+ # apply "only" for the actual calculation.
+ field_names = field_names - self.exclude
+ invalid_fields |= self.exclude - available_field_names
+
+ if invalid_fields:
+ message = f"Invalid fields for {self}: {invalid_fields}."
+ raise ValueError(message)
+
+ fields_dict = self.dict_class()
+ for field_name in field_names:
+ field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())
+ self._bind_field(field_name, field_obj)
+ fields_dict[field_name] = field_obj
+
+ load_fields, dump_fields = self.dict_class(), self.dict_class()
+ for field_name, field_obj in fields_dict.items():
+ if not field_obj.dump_only:
+ load_fields[field_name] = field_obj
+ if not field_obj.load_only:
+ dump_fields[field_name] = field_obj
+
+ dump_data_keys = [
+ field_obj.data_key if field_obj.data_key is not None else name
+ for name, field_obj in dump_fields.items()
+ ]
+ if len(dump_data_keys) != len(set(dump_data_keys)):
+ data_keys_duplicates = {
+ x for x in dump_data_keys if dump_data_keys.count(x) > 1
+ }
+ raise ValueError(
+ "The data_key argument for one or more fields collides "
+ "with another field's name or data_key argument. "
+ "Check the following field names and "
+ f"data_key arguments: {list(data_keys_duplicates)}"
+ )
+ load_attributes = [obj.attribute or name for name, obj in load_fields.items()]
+ if len(load_attributes) != len(set(load_attributes)):
+ attributes_duplicates = {
+ x for x in load_attributes if load_attributes.count(x) > 1
+ }
+ raise ValueError(
+ "The attribute argument for one or more fields collides "
+ "with another field's name or attribute argument. "
+ "Check the following field names and "
+ f"attribute arguments: {list(attributes_duplicates)}"
+ )
- def on_bind_field(self, field_name: str, field_obj: ma_fields.Field
- ) ->None:
+ self.fields = fields_dict
+ self.dump_fields = dump_fields
+ self.load_fields = load_fields
+
+ def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
"""Hook to modify a field when it is bound to the `Schema`.
No-op by default.
"""
- pass
+ return None
- def _bind_field(self, field_name: str, field_obj: ma_fields.Field) ->None:
+ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
"""Bind field to the schema, setting any necessary attributes on the
field (e.g. parent and name).
Also set field load_only and dump_only values if field_name was
specified in ``class Meta``.
"""
- pass
+ if field_name in self.load_only:
+ field_obj.load_only = True
+ if field_name in self.dump_only:
+ field_obj.dump_only = True
+ try:
+ field_obj._bind_to_schema(field_name, self)
+ except TypeError as error:
+ # Field declared as a class, not an instance. Ignore type checking because
+ # we handle unsupported arg types, i.e. this is dead code from
+ # the type checker's perspective.
+ if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):
+ msg = (
+ f'Field for "{field_name}" must be declared as a '
+ "Field instance, not a class. "
+ f'Did you mean "fields.{field_obj.__name__}()"?' # type: ignore
+ )
+ raise TypeError(msg) from error
+ raise error
+ self.on_bind_field(field_name, field_obj)
+
+ def _has_processors(self, tag) -> bool:
+ return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])
+
+ def _invoke_dump_processors(
+ self, tag: str, data, *, many: bool, original_data=None
+ ):
+ # The pass_many post-dump processors may do things like add an envelope, so
+ # invoke those after invoking the non-pass_many processors which will expect
+ # to get a list of items.
+ data = self._invoke_processors(
+ tag, pass_many=False, data=data, many=many, original_data=original_data
+ )
+ data = self._invoke_processors(
+ tag, pass_many=True, data=data, many=many, original_data=original_data
+ )
+ return data
+
+ def _invoke_load_processors(
+ self,
+ tag: str,
+ data,
+ *,
+ many: bool,
+ original_data,
+ partial: bool | types.StrSequenceOrSet | None,
+ ):
+ # This has to invert the order of the dump processors, so run the pass_many
+ # processors first.
+ data = self._invoke_processors(
+ tag,
+ pass_many=True,
+ data=data,
+ many=many,
+ original_data=original_data,
+ partial=partial,
+ )
+ data = self._invoke_processors(
+ tag,
+ pass_many=False,
+ data=data,
+ many=many,
+ original_data=original_data,
+ partial=partial,
+ )
+ return data
+
+ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):
+ for attr_name in self._hooks[VALIDATES]:
+ validator = getattr(self, attr_name)
+ validator_kwargs = validator.__marshmallow_hook__[VALIDATES]
+ field_name = validator_kwargs["field_name"]
+
+ try:
+ field_obj = self.fields[field_name]
+ except KeyError as error:
+ if field_name in self.declared_fields:
+ continue
+ raise ValueError(f'"{field_name}" field does not exist.') from error
+
+ data_key = (
+ field_obj.data_key if field_obj.data_key is not None else field_name
+ )
+ if many:
+ for idx, item in enumerate(data):
+ try:
+ value = item[field_obj.attribute or field_name]
+ except KeyError:
+ pass
+ else:
+ validated_value = self._call_and_store(
+ getter_func=validator,
+ data=value,
+ field_name=data_key,
+ error_store=error_store,
+ index=(idx if self.opts.index_errors else None),
+ )
+ if validated_value is missing:
+ data[idx].pop(field_name, None)
+ else:
+ try:
+ value = data[field_obj.attribute or field_name]
+ except KeyError:
+ pass
+ else:
+ validated_value = self._call_and_store(
+ getter_func=validator,
+ data=value,
+ field_name=data_key,
+ error_store=error_store,
+ )
+ if validated_value is missing:
+ data.pop(field_name, None)
+
+ def _invoke_schema_validators(
+ self,
+ *,
+ error_store: ErrorStore,
+ pass_many: bool,
+ data,
+ original_data,
+ many: bool,
+ partial: bool | types.StrSequenceOrSet | None,
+ field_errors: bool = False,
+ ):
+ for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:
+ validator = getattr(self, attr_name)
+ validator_kwargs = validator.__marshmallow_hook__[
+ (VALIDATES_SCHEMA, pass_many)
+ ]
+ if field_errors and validator_kwargs["skip_on_field_errors"]:
+ continue
+ pass_original = validator_kwargs.get("pass_original", False)
+
+ if many and not pass_many:
+ for idx, (item, orig) in enumerate(zip(data, original_data)):
+ self._run_validator(
+ validator,
+ item,
+ original_data=orig,
+ error_store=error_store,
+ many=many,
+ partial=partial,
+ index=idx,
+ pass_original=pass_original,
+ )
+ else:
+ self._run_validator(
+ validator,
+ data,
+ original_data=original_data,
+ error_store=error_store,
+ many=many,
+ pass_original=pass_original,
+ partial=partial,
+ )
+
+ def _invoke_processors(
+ self,
+ tag: str,
+ *,
+ pass_many: bool,
+ data,
+ many: bool,
+ original_data=None,
+ **kwargs,
+ ):
+ key = (tag, pass_many)
+ for attr_name in self._hooks[key]:
+ # This will be a bound method.
+ processor = getattr(self, attr_name)
+
+ processor_kwargs = processor.__marshmallow_hook__[key]
+ pass_original = processor_kwargs.get("pass_original", False)
+
+ if many and not pass_many:
+ if pass_original:
+ data = [
+ processor(item, original, many=many, **kwargs)
+ for item, original in zip(data, original_data)
+ ]
+ else:
+ data = [processor(item, many=many, **kwargs) for item in data]
+ else:
+ if pass_original:
+ data = processor(data, original_data, many=many, **kwargs)
+ else:
+ data = processor(data, many=many, **kwargs)
+ return data
-BaseSchema = Schema
+BaseSchema = Schema # for backwards compatibility
diff --git a/src/marshmallow/types.py b/src/marshmallow/types.py
index 43103ae..ce31c05 100644
--- a/src/marshmallow/types.py
+++ b/src/marshmallow/types.py
@@ -4,7 +4,9 @@
This module is provisional. Types may be modified, added, and removed between minor releases.
"""
+
import typing
+
StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]
diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py
index 1c71b57..a5fe726 100644
--- a/src/marshmallow/utils.py
+++ b/src/marshmallow/utils.py
@@ -1,5 +1,7 @@
"""Utility methods for marshmallow."""
+
from __future__ import annotations
+
import collections
import datetime as dt
import functools
@@ -11,17 +13,18 @@ import warnings
from collections.abc import Mapping
from email.utils import format_datetime, parsedate_to_datetime
from pprint import pprint as py_pprint
+
from marshmallow.base import FieldABC
from marshmallow.exceptions import FieldInstanceResolutionError
from marshmallow.warnings import RemovedInMarshmallow4Warning
-EXCLUDE = 'exclude'
-INCLUDE = 'include'
-RAISE = 'raise'
+
+EXCLUDE = "exclude"
+INCLUDE = "include"
+RAISE = "raise"
_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}
class _Missing:
-
def __bool__(self):
return False
@@ -32,40 +35,46 @@ class _Missing:
return self
def __repr__(self):
- return '<marshmallow.missing>'
+ return "<marshmallow.missing>"
+# Singleton value that indicates that a field's value is missing from input
+# dict passed to :meth:`Schema.load`. If the field's value is not required,
+# it's ``default`` value is used.
missing = _Missing()
-def is_generator(obj) ->bool:
+def is_generator(obj) -> bool:
"""Return True if ``obj`` is a generator"""
- pass
+ return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
-def is_iterable_but_not_string(obj) ->bool:
+def is_iterable_but_not_string(obj) -> bool:
"""Return True if ``obj`` is an iterable object that isn't a string."""
- pass
+ return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)
-def is_collection(obj) ->bool:
+def is_collection(obj) -> bool:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
- pass
+ return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)
-def is_instance_or_subclass(val, class_) ->bool:
+def is_instance_or_subclass(val, class_) -> bool:
"""Return True if ``val`` is either a subclass or instance of ``class_``."""
- pass
+ try:
+ return issubclass(val, class_)
+ except TypeError:
+ return isinstance(val, class_)
-def is_keyed_tuple(obj) ->bool:
+def is_keyed_tuple(obj) -> bool:
"""Return True if ``obj`` has keyed tuple behavior, such as
namedtuples or SQLAlchemy's KeyedTuples.
"""
- pass
+ return isinstance(obj, tuple) and hasattr(obj, "_fields")
-def pprint(obj, *args, **kwargs) ->None:
+def pprint(obj, *args, **kwargs) -> None:
"""Pretty-printing function that can pretty-print OrderedDicts
like regular dictionaries. Useful for printing the output of
:meth:`marshmallow.Schema.dump`.
@@ -73,38 +82,65 @@ def pprint(obj, *args, **kwargs) ->None:
.. deprecated:: 3.7.0
marshmallow.pprint will be removed in marshmallow 4.
"""
- pass
+ warnings.warn(
+ "marshmallow's pprint function is deprecated and will be removed in marshmallow 4.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ if isinstance(obj, collections.OrderedDict):
+ print(json.dumps(obj, *args, **kwargs))
+ else:
+ py_pprint(obj, *args, **kwargs)
+
+# https://stackoverflow.com/a/27596917
+def is_aware(datetime: dt.datetime) -> bool:
+ return (
+ datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None
+ )
-def from_rfc(datestring: str) ->dt.datetime:
+
+def from_rfc(datestring: str) -> dt.datetime:
"""Parse a RFC822-formatted datetime string and return a datetime object.
https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950
"""
- pass
+ return parsedate_to_datetime(datestring)
-def rfcformat(datetime: dt.datetime) ->str:
+def rfcformat(datetime: dt.datetime) -> str:
"""Return the RFC822-formatted representation of a datetime object.
:param datetime datetime: The datetime.
"""
- pass
+ return format_datetime(datetime)
+
+# Hat tip to Django for ISO8601 deserialization functions
_iso8601_datetime_re = re.compile(
- '(?P<year>\\d{4})-(?P<month>\\d{1,2})-(?P<day>\\d{1,2})[T ](?P<hour>\\d{1,2}):(?P<minute>\\d{1,2})(?::(?P<second>\\d{1,2})(?:\\.(?P<microsecond>\\d{1,6})\\d{0,6})?)?(?P<tzinfo>Z|[+-]\\d{2}(?::?\\d{2})?)?$'
- )
-_iso8601_date_re = re.compile(
- '(?P<year>\\d{4})-(?P<month>\\d{1,2})-(?P<day>\\d{1,2})$')
+ r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})"
+ r"[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})"
+ r"(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?"
+ r"(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$"
+)
+
+_iso8601_date_re = re.compile(r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$")
+
_iso8601_time_re = re.compile(
- '(?P<hour>\\d{1,2}):(?P<minute>\\d{1,2})(?::(?P<second>\\d{1,2})(?:\\.(?P<microsecond>\\d{1,6})\\d{0,6})?)?'
- )
+ r"(?P<hour>\d{1,2}):(?P<minute>\d{1,2})"
+ r"(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?"
+)
-def get_fixed_timezone(offset: (int | float | dt.timedelta)) ->dt.timezone:
+def get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:
"""Return a tzinfo instance with a fixed offset from UTC."""
- pass
+ if isinstance(offset, dt.timedelta):
+ offset = offset.total_seconds() // 60
+ sign = "-" if offset < 0 else "+"
+ hhmm = "%02d%02d" % divmod(abs(offset), 60)
+ name = sign + hhmm
+ return dt.timezone(dt.timedelta(minutes=offset), name)
def from_iso_datetime(value):
@@ -113,7 +149,23 @@ def from_iso_datetime(value):
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
"""
- pass
+ match = _iso8601_datetime_re.match(value)
+ if not match:
+ raise ValueError("Not a valid ISO8601-formatted datetime string")
+ kw = match.groupdict()
+ kw["microsecond"] = kw["microsecond"] and kw["microsecond"].ljust(6, "0")
+ tzinfo = kw.pop("tzinfo")
+ if tzinfo == "Z":
+ tzinfo = dt.timezone.utc
+ elif tzinfo is not None:
+ offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0
+ offset = 60 * int(tzinfo[1:3]) + offset_mins
+ if tzinfo[0] == "-":
+ offset = -offset
+ tzinfo = get_fixed_timezone(offset)
+ kw = {k: int(v) for k, v in kw.items() if v is not None}
+ kw["tzinfo"] = tzinfo
+ return dt.datetime(**kw)
def from_iso_time(value):
@@ -121,20 +173,79 @@ def from_iso_time(value):
This function doesn't support time zone offsets.
"""
- pass
+ match = _iso8601_time_re.match(value)
+ if not match:
+ raise ValueError("Not a valid ISO8601-formatted time string")
+ kw = match.groupdict()
+ kw["microsecond"] = kw["microsecond"] and kw["microsecond"].ljust(6, "0")
+ kw = {k: int(v) for k, v in kw.items() if v is not None}
+ return dt.time(**kw)
def from_iso_date(value):
"""Parse a string and return a datetime.date."""
- pass
+ match = _iso8601_date_re.match(value)
+ if not match:
+ raise ValueError("Not a valid ISO8601-formatted date string")
+ kw = {k: int(v) for k, v in match.groupdict().items()}
+ return dt.date(**kw)
+
+
+def from_timestamp(value: typing.Any) -> dt.datetime:
+ if value is True or value is False:
+ raise ValueError("Not a valid POSIX timestamp")
+ value = float(value)
+ if value < 0:
+ raise ValueError("Not a valid POSIX timestamp")
+
+ # Load a timestamp with utc as timezone to prevent using system timezone.
+ # Then set timezone to None, to let the Field handle adding timezone info.
+ try:
+ return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)
+ except OverflowError as exc:
+ raise ValueError("Timestamp is too large") from exc
+ except OSError as exc:
+ raise ValueError("Error converting value to datetime") from exc
+
+
+def from_timestamp_ms(value: typing.Any) -> dt.datetime:
+ value = float(value)
+ return from_timestamp(value / 1000)
+
+def timestamp(
+ value: dt.datetime,
+) -> float:
+ if not is_aware(value):
+ # When a date is naive, use UTC as zone info to prevent using system timezone.
+ value = value.replace(tzinfo=dt.timezone.utc)
+ return value.timestamp()
-def isoformat(datetime: dt.datetime) ->str:
+
+def timestamp_ms(value: dt.datetime) -> float:
+ return timestamp(value) * 1000
+
+
+def isoformat(datetime: dt.datetime) -> str:
"""Return the ISO8601-formatted representation of a datetime object.
:param datetime datetime: The datetime.
"""
- pass
+ return datetime.isoformat()
+
+
+def to_iso_time(time: dt.time) -> str:
+ return dt.time.isoformat(time)
+
+
+def to_iso_date(date: dt.date) -> str:
+ return dt.date.isoformat(date)
+
+
+def ensure_text_type(val: str | bytes) -> str:
+ if isinstance(val, bytes):
+ val = val.decode("utf-8")
+ return str(val)
def pluck(dictlist: list[dict[str, typing.Any]], key: str):
@@ -145,10 +256,13 @@ def pluck(dictlist: list[dict[str, typing.Any]], key: str):
>>> pluck(dlist, 'id')
[1, 2]
"""
- pass
+ return [d[key] for d in dictlist]
+
+# Various utilities for pulling keyed values from objects
-def get_value(obj, key: (int | str), default=missing):
+
+def get_value(obj, key: int | str, default=missing):
"""Helper for pulling a keyed value off various types of objects. Fields use
this method by default to access attributes of the source object. For object `x`
and attribute `i`, this method first tries to access `x[i]`, and then falls back to
@@ -159,7 +273,29 @@ def get_value(obj, key: (int | str), default=missing):
`get_value` will never check the value `x.i`. Consider overriding
`marshmallow.fields.Field.get_value` in this case.
"""
- pass
+ if not isinstance(key, int) and "." in key:
+ return _get_value_for_keys(obj, key.split("."), default)
+ else:
+ return _get_value_for_key(obj, key, default)
+
+
+def _get_value_for_keys(obj, keys, default):
+ if len(keys) == 1:
+ return _get_value_for_key(obj, keys[0], default)
+ else:
+ return _get_value_for_keys(
+ _get_value_for_key(obj, keys[0], default), keys[1:], default
+ )
+
+
+def _get_value_for_key(obj, key, default):
+ if not hasattr(obj, "__getitem__"):
+ return getattr(obj, key, default)
+
+ try:
+ return obj[key]
+ except (KeyError, IndexError, TypeError, AttributeError):
+ return getattr(obj, key, default)
def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
@@ -173,22 +309,42 @@ def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
>>> d
{'foo': {'bar': 42}}
"""
- pass
+ if "." in key:
+ head, rest = key.split(".", 1)
+ target = dct.setdefault(head, {})
+ if not isinstance(target, dict):
+ raise ValueError(
+ f"Cannot set {key} in {head} " f"due to existing value: {target}"
+ )
+ set_value(target, rest, value)
+ else:
+ dct[key] = value
def callable_or_raise(obj):
"""Check that an object is callable, else raise a :exc:`TypeError`."""
- pass
+ if not callable(obj):
+ raise TypeError(f"Object {obj!r} is not callable.")
+ return obj
+
+
+def _signature(func: typing.Callable) -> list[str]:
+ return list(inspect.signature(func).parameters.keys())
-def get_func_args(func: typing.Callable) ->list[str]:
+def get_func_args(func: typing.Callable) -> list[str]:
"""Given a callable, return a list of argument names. Handles
`functools.partial` objects and class-based callables.
.. versionchanged:: 3.0.0a1
Do not return bound arguments, eg. ``self``.
"""
- pass
+ if inspect.isfunction(func) or inspect.ismethod(func):
+ return _signature(func)
+ if isinstance(func, functools.partial):
+ return _signature(func.func)
+ # Callable class
+ return _signature(func)
def resolve_field_instance(cls_or_instance):
@@ -196,12 +352,27 @@ def resolve_field_instance(cls_or_instance):
:param type|Schema cls_or_instance: Marshmallow Schema class or instance.
"""
- pass
+ if isinstance(cls_or_instance, type):
+ if not issubclass(cls_or_instance, FieldABC):
+ raise FieldInstanceResolutionError
+ return cls_or_instance()
+ else:
+ if not isinstance(cls_or_instance, FieldABC):
+ raise FieldInstanceResolutionError
+ return cls_or_instance
-def timedelta_to_microseconds(value: dt.timedelta) ->int:
+def timedelta_to_microseconds(value: dt.timedelta) -> int:
"""Compute the total microseconds of a timedelta
https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950
"""
- pass
+ return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds
+
+
+def validate_unknown_parameter_value(obj: typing.Any) -> str:
+ if obj not in _UNKNOWN_VALUES:
+ raise ValueError(
+ f"Object {obj!r} is not a valid value for the 'unknown' parameter"
+ )
+ return obj
diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py
index 3cc3b97..e4536d8 100644
--- a/src/marshmallow/validate.py
+++ b/src/marshmallow/validate.py
@@ -1,13 +1,17 @@
"""Validation classes for various types of data."""
+
from __future__ import annotations
+
import re
import typing
from abc import ABC, abstractmethod
from itertools import zip_longest
from operator import attrgetter
+
from marshmallow import types
from marshmallow.exceptions import ValidationError
-_T = typing.TypeVar('_T')
+
+_T = typing.TypeVar("_T")
class Validator(ABC):
@@ -17,22 +21,23 @@ class Validator(ABC):
This class does not provide any validation behavior. It is only used to
add a useful `__repr__` implementation for validators.
"""
- error = None
- def __repr__(self) ->str:
+ error = None # type: str | None
+
+ def __repr__(self) -> str:
args = self._repr_args()
- args = f'{args}, ' if args else ''
- return f'<{self.__class__.__name__}({args}error={self.error!r})>'
+ args = f"{args}, " if args else ""
- def _repr_args(self) ->str:
+ return f"<{self.__class__.__name__}({args}error={self.error!r})>"
+
+ def _repr_args(self) -> str:
"""A string representation of the args passed to this validator. Used by
`__repr__`.
"""
- pass
+ return ""
@abstractmethod
- def __call__(self, value: typing.Any) ->typing.Any:
- ...
+ def __call__(self, value: typing.Any) -> typing.Any: ...
class And(Validator):
@@ -55,13 +60,17 @@ class And(Validator):
:param validators: Validators to combine.
:param error: Error message to use when a validator returns ``False``.
"""
- default_error_message = 'Invalid value.'
- def __init__(self, *validators: types.Validator, error: (str | None)=None):
+ default_error_message = "Invalid value."
+
+ def __init__(self, *validators: types.Validator, error: str | None = None):
self.validators = tuple(validators)
- self.error = error or self.default_error_message
+ self.error = error or self.default_error_message # type: str
- def __call__(self, value: typing.Any) ->typing.Any:
+ def _repr_args(self) -> str:
+ return f"validators={self.validators!r}"
+
+ def __call__(self, value: typing.Any) -> typing.Any:
errors = []
kwargs = {}
for validator in self.validators:
@@ -74,6 +83,7 @@ class And(Validator):
if isinstance(err.messages, dict):
errors.append(err.messages)
else:
+ # FIXME : Get rid of cast
errors.extend(typing.cast(list, err.messages))
if errors:
raise ValidationError(errors, **kwargs)
@@ -92,47 +102,121 @@ class URL(Validator):
:param require_tld: Whether to reject non-FQDN hostnames.
"""
-
class RegexMemoizer:
-
def __init__(self):
self._memoized = {}
- def __call__(self, relative: bool, absolute: bool, require_tld: bool
- ) ->typing.Pattern:
- key = relative, absolute, require_tld
+ def _regex_generator(
+ self, relative: bool, absolute: bool, require_tld: bool
+ ) -> typing.Pattern:
+ hostname_variants = [
+ # a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle
+ # note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars
+ (
+ r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+"
+ r"(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)"
+ ),
+ # or the special string 'localhost'
+ r"localhost",
+ # or IPv4
+ r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}",
+ # or IPv6
+ r"\[[A-F0-9]*:[A-F0-9:]+\]",
+ ]
+ if not require_tld:
+ # allow dotless hostnames
+ hostname_variants.append(r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.?)")
+
+ absolute_part = "".join(
+ (
+ # scheme (e.g. 'https://', 'ftp://', etc)
+ # this is validated separately against allowed schemes, so in the regex
+ # we simply want to capture its existence
+ r"(?:[a-z0-9\.\-\+]*)://",
+ # userinfo, for URLs encoding authentication
+ # e.g. 'ftp://foo:bar@ftp.example.org/'
+ r"(?:(?:[a-z0-9\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?",
+ # netloc, the hostname/domain part of the URL plus the optional port
+ r"(?:",
+ "|".join(hostname_variants),
+ r")",
+ r"(?::\d+)?",
+ )
+ )
+ relative_part = r"(?:/?|[/?]\S+)\Z"
+
+ if relative:
+ if absolute:
+ parts: tuple[str, ...] = (
+ r"^(",
+ absolute_part,
+ r")?",
+ relative_part,
+ )
+ else:
+ parts = (r"^", relative_part)
+ else:
+ parts = (r"^", absolute_part, relative_part)
+
+ return re.compile("".join(parts), re.IGNORECASE)
+
+ def __call__(
+ self, relative: bool, absolute: bool, require_tld: bool
+ ) -> typing.Pattern:
+ key = (relative, absolute, require_tld)
if key not in self._memoized:
- self._memoized[key] = self._regex_generator(relative,
- absolute, require_tld)
+ self._memoized[key] = self._regex_generator(
+ relative, absolute, require_tld
+ )
+
return self._memoized[key]
+
_regex = RegexMemoizer()
- default_message = 'Not a valid URL.'
- default_schemes = {'http', 'https', 'ftp', 'ftps'}
- def __init__(self, *, relative: bool=False, absolute: bool=True,
- schemes: (types.StrSequenceOrSet | None)=None, require_tld: bool=
- True, error: (str | None)=None):
+ default_message = "Not a valid URL."
+ default_schemes = {"http", "https", "ftp", "ftps"}
+
+ def __init__(
+ self,
+ *,
+ relative: bool = False,
+ absolute: bool = True,
+ schemes: types.StrSequenceOrSet | None = None,
+ require_tld: bool = True,
+ error: str | None = None,
+ ):
if not relative and not absolute:
raise ValueError(
- 'URL validation cannot set both relative and absolute to False.'
- )
+ "URL validation cannot set both relative and absolute to False."
+ )
self.relative = relative
self.absolute = absolute
- self.error = error or self.default_message
+ self.error = error or self.default_message # type: str
self.schemes = schemes or self.default_schemes
self.require_tld = require_tld
- def __call__(self, value: str) ->str:
+ def _repr_args(self) -> str:
+ return f"relative={self.relative!r}, absolute={self.absolute!r}"
+
+ def _format_error(self, value) -> str:
+ return self.error.format(input=value)
+
+ def __call__(self, value: str) -> str:
message = self._format_error(value)
if not value:
raise ValidationError(message)
- if '://' in value:
- scheme = value.split('://')[0].lower()
+
+ # Check first if the scheme is valid
+ if "://" in value:
+ scheme = value.split("://")[0].lower()
if scheme not in self.schemes:
raise ValidationError(message)
+
regex = self._regex(self.relative, self.absolute, self.require_tld)
+
if not regex.search(value):
raise ValidationError(message)
+
return value
@@ -142,35 +226,57 @@ class Email(Validator):
:param error: Error message to raise in case of a validation error. Can be
interpolated with `{input}`.
"""
+
USER_REGEX = re.compile(
- '(^[-!#$%&\'*+/=?^`{}|~\\w]+(\\.[-!#$%&\'*+/=?^`{}|~\\w]+)*\\Z|^"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]|\\\\[\\001-\\011\\013\\014\\016-\\177])*"\\Z)'
- , re.IGNORECASE | re.UNICODE)
+ r"(^[-!#$%&'*+/=?^`{}|~\w]+(\.[-!#$%&'*+/=?^`{}|~\w]+)*\Z" # dot-atom
+ # quoted-string
+ r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]'
+ r'|\\[\001-\011\013\014\016-\177])*"\Z)',
+ re.IGNORECASE | re.UNICODE,
+ )
+
DOMAIN_REGEX = re.compile(
- '(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\\Z|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\]\\Z'
- , re.IGNORECASE | re.UNICODE)
- DOMAIN_WHITELIST = 'localhost',
- default_message = 'Not a valid email address.'
+ # domain
+ r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+"
+ r"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\Z"
+ # literal form, ipv4 address (SMTP 4.1.3)
+ r"|^\[(25[0-5]|2[0-4]\d|[0-1]?\d?\d)"
+ r"(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}\]\Z",
+ re.IGNORECASE | re.UNICODE,
+ )
+
+ DOMAIN_WHITELIST = ("localhost",)
- def __init__(self, *, error: (str | None)=None):
- self.error = error or self.default_message
+ default_message = "Not a valid email address."
- def __call__(self, value: str) ->str:
+ def __init__(self, *, error: str | None = None):
+ self.error = error or self.default_message # type: str
+
+ def _format_error(self, value: str) -> str:
+ return self.error.format(input=value)
+
+ def __call__(self, value: str) -> str:
message = self._format_error(value)
- if not value or '@' not in value:
+
+ if not value or "@" not in value:
raise ValidationError(message)
- user_part, domain_part = value.rsplit('@', 1)
+
+ user_part, domain_part = value.rsplit("@", 1)
+
if not self.USER_REGEX.match(user_part):
raise ValidationError(message)
+
if domain_part not in self.DOMAIN_WHITELIST:
if not self.DOMAIN_REGEX.match(domain_part):
try:
- domain_part = domain_part.encode('idna').decode('ascii')
+ domain_part = domain_part.encode("idna").decode("ascii")
except UnicodeError:
pass
else:
if self.DOMAIN_REGEX.match(domain_part):
return value
raise ValidationError(message)
+
return value
@@ -192,40 +298,62 @@ class Range(Validator):
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}`, `{min}` and `{max}`.
"""
- message_min = 'Must be {min_op} {{min}}.'
- message_max = 'Must be {max_op} {{max}}.'
- message_all = 'Must be {min_op} {{min}} and {max_op} {{max}}.'
- message_gte = 'greater than or equal to'
- message_gt = 'greater than'
- message_lte = 'less than or equal to'
- message_lt = 'less than'
-
- def __init__(self, min=None, max=None, *, min_inclusive: bool=True,
- max_inclusive: bool=True, error: (str | None)=None):
+
+ message_min = "Must be {min_op} {{min}}."
+ message_max = "Must be {max_op} {{max}}."
+ message_all = "Must be {min_op} {{min}} and {max_op} {{max}}."
+
+ message_gte = "greater than or equal to"
+ message_gt = "greater than"
+ message_lte = "less than or equal to"
+ message_lt = "less than"
+
+ def __init__(
+ self,
+ min=None,
+ max=None,
+ *,
+ min_inclusive: bool = True,
+ max_inclusive: bool = True,
+ error: str | None = None,
+ ):
self.min = min
self.max = max
self.error = error
self.min_inclusive = min_inclusive
self.max_inclusive = max_inclusive
- self.message_min = self.message_min.format(min_op=self.message_gte if
- self.min_inclusive else self.message_gt)
- self.message_max = self.message_max.format(max_op=self.message_lte if
- self.max_inclusive else self.message_lt)
- self.message_all = self.message_all.format(min_op=self.message_gte if
- self.min_inclusive else self.message_gt, max_op=self.
- message_lte if self.max_inclusive else self.message_lt)
-
- def __call__(self, value: _T) ->_T:
- if self.min is not None and (value < self.min if self.min_inclusive
- else value <= self.min):
- message = (self.message_min if self.max is None else self.
- message_all)
+
+ # interpolate messages based on bound inclusivity
+ self.message_min = self.message_min.format(
+ min_op=self.message_gte if self.min_inclusive else self.message_gt
+ )
+ self.message_max = self.message_max.format(
+ max_op=self.message_lte if self.max_inclusive else self.message_lt
+ )
+ self.message_all = self.message_all.format(
+ min_op=self.message_gte if self.min_inclusive else self.message_gt,
+ max_op=self.message_lte if self.max_inclusive else self.message_lt,
+ )
+
+ def _repr_args(self) -> str:
+ return f"min={self.min!r}, max={self.max!r}, min_inclusive={self.min_inclusive!r}, max_inclusive={self.max_inclusive!r}"
+
+ def _format_error(self, value: _T, message: str) -> str:
+ return (self.error or message).format(input=value, min=self.min, max=self.max)
+
+ def __call__(self, value: _T) -> _T:
+ if self.min is not None and (
+ value < self.min if self.min_inclusive else value <= self.min
+ ):
+ message = self.message_min if self.max is None else self.message_all
raise ValidationError(self._format_error(value, message))
- if self.max is not None and (value > self.max if self.max_inclusive
- else value >= self.max):
- message = (self.message_max if self.min is None else self.
- message_all)
+
+ if self.max is not None and (
+ value > self.max if self.max_inclusive else value >= self.max
+ ):
+ message = self.message_max if self.min is None else self.message_all
raise ValidationError(self._format_error(value, message))
+
return value
@@ -243,37 +371,55 @@ class Length(Validator):
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}`, `{min}` and `{max}`.
"""
- message_min = 'Shorter than minimum length {min}.'
- message_max = 'Longer than maximum length {max}.'
- message_all = 'Length must be between {min} and {max}.'
- message_equal = 'Length must be {equal}.'
- def __init__(self, min: (int | None)=None, max: (int | None)=None, *,
- equal: (int | None)=None, error: (str | None)=None):
+ message_min = "Shorter than minimum length {min}."
+ message_max = "Longer than maximum length {max}."
+ message_all = "Length must be between {min} and {max}."
+ message_equal = "Length must be {equal}."
+
+ def __init__(
+ self,
+ min: int | None = None,
+ max: int | None = None,
+ *,
+ equal: int | None = None,
+ error: str | None = None,
+ ):
if equal is not None and any([min, max]):
raise ValueError(
- 'The `equal` parameter was provided, maximum or minimum parameter must not be provided.'
- )
+ "The `equal` parameter was provided, maximum or "
+ "minimum parameter must not be provided."
+ )
+
self.min = min
self.max = max
self.error = error
self.equal = equal
- def __call__(self, value: typing.Sized) ->typing.Sized:
+ def _repr_args(self) -> str:
+ return f"min={self.min!r}, max={self.max!r}, equal={self.equal!r}"
+
+ def _format_error(self, value: typing.Sized, message: str) -> str:
+ return (self.error or message).format(
+ input=value, min=self.min, max=self.max, equal=self.equal
+ )
+
+ def __call__(self, value: typing.Sized) -> typing.Sized:
length = len(value)
+
if self.equal is not None:
if length != self.equal:
- raise ValidationError(self._format_error(value, self.
- message_equal))
+ raise ValidationError(self._format_error(value, self.message_equal))
return value
+
if self.min is not None and length < self.min:
- message = (self.message_min if self.max is None else self.
- message_all)
+ message = self.message_min if self.max is None else self.message_all
raise ValidationError(self._format_error(value, message))
+
if self.max is not None and length > self.max:
- message = (self.message_max if self.min is None else self.
- message_all)
+ message = self.message_max if self.min is None else self.message_all
raise ValidationError(self._format_error(value, message))
+
return value
@@ -285,13 +431,20 @@ class Equal(Validator):
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}` and `{other}`.
"""
- default_message = 'Must be equal to {other}.'
- def __init__(self, comparable, *, error: (str | None)=None):
+ default_message = "Must be equal to {other}."
+
+ def __init__(self, comparable, *, error: str | None = None):
self.comparable = comparable
- self.error = error or self.default_message
+ self.error = error or self.default_message # type: str
+
+ def _repr_args(self) -> str:
+ return f"comparable={self.comparable!r}"
- def __call__(self, value: _T) ->_T:
+ def _format_error(self, value: _T) -> str:
+ return self.error.format(input=value, other=self.comparable)
+
+ def __call__(self, value: _T) -> _T:
if value != self.comparable:
raise ValidationError(self._format_error(value))
return value
@@ -311,25 +464,37 @@ class Regexp(Validator):
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}` and `{regex}`.
"""
- default_message = 'String does not match expected pattern.'
- def __init__(self, regex: (str | bytes | typing.Pattern), flags: int=0,
- *, error: (str | None)=None):
- self.regex = re.compile(regex, flags) if isinstance(regex, (str, bytes)
- ) else regex
- self.error = error or self.default_message
+ default_message = "String does not match expected pattern."
+
+ def __init__(
+ self,
+ regex: str | bytes | typing.Pattern,
+ flags: int = 0,
+ *,
+ error: str | None = None,
+ ):
+ self.regex = (
+ re.compile(regex, flags) if isinstance(regex, (str, bytes)) else regex
+ )
+ self.error = error or self.default_message # type: str
+
+ def _repr_args(self) -> str:
+ return f"regex={self.regex!r}"
+
+ def _format_error(self, value: str | bytes) -> str:
+ return self.error.format(input=value, regex=self.regex.pattern)
@typing.overload
- def __call__(self, value: str) ->str:
- ...
+ def __call__(self, value: str) -> str: ...
@typing.overload
- def __call__(self, value: bytes) ->bytes:
- ...
+ def __call__(self, value: bytes) -> bytes: ...
def __call__(self, value):
if self.regex.match(value) is None:
raise ValidationError(self._format_error(value))
+
return value
@@ -344,17 +509,26 @@ class Predicate(Validator):
Can be interpolated with `{input}` and `{method}`.
:param kwargs: Additional keyword arguments to pass to the method.
"""
- default_message = 'Invalid input.'
- def __init__(self, method: str, *, error: (str | None)=None, **kwargs):
+ default_message = "Invalid input."
+
+ def __init__(self, method: str, *, error: str | None = None, **kwargs):
self.method = method
- self.error = error or self.default_message
+ self.error = error or self.default_message # type: str
self.kwargs = kwargs
- def __call__(self, value: typing.Any) ->typing.Any:
+ def _repr_args(self) -> str:
+ return f"method={self.method!r}, kwargs={self.kwargs!r}"
+
+ def _format_error(self, value: typing.Any) -> str:
+ return self.error.format(input=value, method=self.method)
+
+ def __call__(self, value: typing.Any) -> typing.Any:
method = getattr(value, self.method)
+
if not method(**self.kwargs):
raise ValidationError(self._format_error(value))
+
return value
@@ -365,19 +539,27 @@ class NoneOf(Validator):
:param error: Error message to raise in case of a validation error. Can be
interpolated using `{input}` and `{values}`.
"""
- default_message = 'Invalid input.'
- def __init__(self, iterable: typing.Iterable, *, error: (str | None)=None):
+ default_message = "Invalid input."
+
+ def __init__(self, iterable: typing.Iterable, *, error: str | None = None):
self.iterable = iterable
- self.values_text = ', '.join(str(each) for each in self.iterable)
- self.error = error or self.default_message
+ self.values_text = ", ".join(str(each) for each in self.iterable)
+ self.error = error or self.default_message # type: str
- def __call__(self, value: typing.Any) ->typing.Any:
+ def _repr_args(self) -> str:
+ return f"iterable={self.iterable!r}"
+
+ def _format_error(self, value) -> str:
+ return self.error.format(input=value, values=self.values_text)
+
+ def __call__(self, value: typing.Any) -> typing.Any:
try:
if value in self.iterable:
raise ValidationError(self._format_error(value))
except TypeError:
pass
+
return value
@@ -389,26 +571,43 @@ class OneOf(Validator):
:param error: Error message to raise in case of a validation error. Can be
interpolated with `{input}`, `{choices}` and `{labels}`.
"""
- default_message = 'Must be one of: {choices}.'
- def __init__(self, choices: typing.Iterable, labels: (typing.Iterable[
- str] | None)=None, *, error: (str | None)=None):
+ default_message = "Must be one of: {choices}."
+
+ def __init__(
+ self,
+ choices: typing.Iterable,
+ labels: typing.Iterable[str] | None = None,
+ *,
+ error: str | None = None,
+ ):
self.choices = choices
- self.choices_text = ', '.join(str(choice) for choice in self.choices)
+ self.choices_text = ", ".join(str(choice) for choice in self.choices)
self.labels = labels if labels is not None else []
- self.labels_text = ', '.join(str(label) for label in self.labels)
- self.error = error or self.default_message
+ self.labels_text = ", ".join(str(label) for label in self.labels)
+ self.error = error or self.default_message # type: str
- def __call__(self, value: typing.Any) ->typing.Any:
+ def _repr_args(self) -> str:
+ return f"choices={self.choices!r}, labels={self.labels!r}"
+
+ def _format_error(self, value) -> str:
+ return self.error.format(
+ input=value, choices=self.choices_text, labels=self.labels_text
+ )
+
+ def __call__(self, value: typing.Any) -> typing.Any:
try:
if value not in self.choices:
raise ValidationError(self._format_error(value))
except TypeError as error:
raise ValidationError(self._format_error(value)) from error
+
return value
- def options(self, valuegetter: (str | typing.Callable[[typing.Any],
- typing.Any])=str) ->typing.Iterable[tuple[typing.Any, str]]:
+ def options(
+ self,
+ valuegetter: str | typing.Callable[[typing.Any], typing.Any] = str,
+ ) -> typing.Iterable[tuple[typing.Any, str]]:
"""Return a generator over the (value, label) pairs, where value
is a string associated with each choice. This convenience method
is useful to populate, for instance, a form select field.
@@ -419,7 +618,10 @@ class OneOf(Validator):
of an attribute of the choice objects. Defaults to `str()`
or `str()`.
"""
- pass
+ valuegetter = valuegetter if callable(valuegetter) else attrgetter(valuegetter)
+ pairs = zip_longest(self.choices, self.labels, fillvalue="")
+
+ return ((valuegetter(choice), label) for choice, label in pairs)
class ContainsOnly(OneOf):
@@ -437,10 +639,15 @@ class ContainsOnly(OneOf):
Empty input is considered valid. Use `validate.Length(min=1) <marshmallow.validate.Length>`
to validate against empty inputs.
"""
- default_message = (
- 'One or more of the choices you made was not in: {choices}.')
- def __call__(self, value: typing.Sequence[_T]) ->typing.Sequence[_T]:
+ default_message = "One or more of the choices you made was not in: {choices}."
+
+ def _format_error(self, value) -> str:
+ value_text = ", ".join(str(val) for val in value)
+ return super()._format_error(value_text)
+
+ def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
+ # We can't use set.issubset because does not handle unhashable types
for val in value:
if val not in self.choices:
raise ValidationError(self._format_error(value))
@@ -457,9 +664,14 @@ class ContainsNoneOf(NoneOf):
.. versionadded:: 3.6.0
"""
- default_message = 'One or more of the choices you made was in: {values}.'
- def __call__(self, value: typing.Sequence[_T]) ->typing.Sequence[_T]:
+ default_message = "One or more of the choices you made was in: {values}."
+
+ def _format_error(self, value) -> str:
+ value_text = ", ".join(str(val) for val in value)
+ return super()._format_error(value_text)
+
+ def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
for val in value:
if val in self.iterable:
raise ValidationError(self._format_error(value))