back to Claude Sonnet 3.5 - Fill-in summary
Claude Sonnet 3.5 - Fill-in: marshmallow
Failed to run pytests for test tests
ImportError while loading conftest '/testbed/tests/conftest.py'.
tests/conftest.py:5: in <module>
from tests.base import Blog, User, UserSchema
tests/base.py:11: in <module>
from marshmallow import Schema, fields, missing, post_load, validate
src/marshmallow/__init__.py:17: in <module>
from marshmallow.schema import Schema, SchemaOpts
src/marshmallow/schema.py:15: in <module>
from marshmallow import fields as ma_fields
src/marshmallow/fields.py:18: in <module>
from marshmallow.utils import is_aware, is_collection, resolve_field_instance
E ImportError: cannot import name 'is_aware' from 'marshmallow.utils' (/testbed/src/marshmallow/utils.py)
Patch diff
diff --git a/src/marshmallow/decorators.py b/src/marshmallow/decorators.py
index d78f5be..885181e 100644
--- a/src/marshmallow/decorators.py
+++ b/src/marshmallow/decorators.py
@@ -84,7 +84,13 @@ def validates(field_name: str) ->Callable[..., Any]:
:param str field_name: Name of the field that the method validates.
"""
- pass
+ def decorator(fn):
+ @functools.wraps(fn)
+ def wrapper(self, value, **kwargs):
+ return fn(self, value, **kwargs)
+ wrapper.__marshmallow_hook__ = {VALIDATES: field_name}
+ return wrapper
+ return decorator
def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
@@ -109,7 +115,25 @@ def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(
+ validates_schema,
+ pass_many=pass_many,
+ pass_original=pass_original,
+ skip_on_field_errors=skip_on_field_errors,
+ )
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {
+ (VALIDATES_SCHEMA, pass_many): {
+ 'pass_original': pass_original,
+ 'skip_on_field_errors': skip_on_field_errors,
+ }
+ }
+ return wrapper
def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
@@ -124,7 +148,15 @@ def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(pre_dump, pass_many=pass_many)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {(PRE_DUMP, pass_many): {}}
+ return wrapper
def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
@@ -142,7 +174,15 @@ def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(post_dump, pass_many=pass_many, pass_original=pass_original)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {(POST_DUMP, pass_many): {'pass_original': pass_original}}
+ return wrapper
def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
@@ -158,7 +198,15 @@ def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(pre_load, pass_many=pass_many)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {(PRE_LOAD, pass_many): {}}
+ return wrapper
def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
@@ -177,7 +225,15 @@ def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(post_load, pass_many=pass_many, pass_original=pass_original)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {(POST_LOAD, pass_many): {'pass_original': pass_original}}
+ return wrapper
def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
@@ -192,4 +248,12 @@ def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
:return: Decorated function if supplied, else this decorator with its args
bound.
"""
- pass
+ if fn is None:
+ return functools.partial(set_hook, key=key, **kwargs)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kw):
+ return fn(*args, **kw)
+
+ wrapper.__marshmallow_hook__ = {key: kwargs}
+ return wrapper
diff --git a/src/marshmallow/error_store.py b/src/marshmallow/error_store.py
index a659aaf..25acd66 100644
--- a/src/marshmallow/error_store.py
+++ b/src/marshmallow/error_store.py
@@ -13,6 +13,16 @@ class ErrorStore:
def __init__(self):
self.errors = {}
+ def store_error(self, field_name, error):
+ """Store an error message."""
+ if field_name not in self.errors:
+ self.errors[field_name] = []
+ self.errors[field_name].append(error)
+
+ def get_errors(self):
+ """Return all stored errors."""
+ return self.errors
+
def merge_errors(errors1, errors2):
"""Deeply merge two error messages.
@@ -20,4 +30,17 @@ def merge_errors(errors1, errors2):
The format of ``errors1`` and ``errors2`` matches the ``message``
parameter of :exc:`marshmallow.exceptions.ValidationError`.
"""
- pass
+ if isinstance(errors1, dict) and isinstance(errors2, dict):
+ merged = errors1.copy()
+ for key, value in errors2.items():
+ if key in merged:
+ merged[key] = merge_errors(merged[key], value)
+ else:
+ merged[key] = value
+ return merged
+ elif isinstance(errors1, list) and isinstance(errors2, list):
+ return errors1 + errors2
+ elif isinstance(errors1, (str, int, float)) and isinstance(errors2, (str, int, float)):
+ return [errors1, errors2]
+ else:
+ return errors2
diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py
index 8656a56..0c30bcf 100644
--- a/src/marshmallow/fields.py
+++ b/src/marshmallow/fields.py
@@ -166,19 +166,37 @@ class Field(FieldABC):
:param callable accessor: A callable used to retrieve the value of `attr` from
the object `obj`. Defaults to `marshmallow.utils.get_value`.
"""
- pass
+ if accessor is None:
+ accessor = utils.get_value
+ return accessor(obj, attr, default)
def _validate(self, value):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
- pass
+ errors = []
+ for validator in self.validators:
+ try:
+ if validator(value) is False:
+ self.fail('validator_failed')
+ except ValidationError as error:
+ errors.extend(error.messages)
+ if errors:
+ raise ValidationError(errors)
def make_error(self, key: str, **kwargs) ->ValidationError:
"""Helper method to make a `ValidationError` with an error message
from ``self.error_messages``.
"""
- pass
+ try:
+ msg = self.error_messages[key]
+ except KeyError as error:
+ class_name = self.__class__.__name__
+ msg = f'"{key}" is not a valid error key for {class_name}'
+ raise ValueError(msg) from error
+ if isinstance(msg, str):
+ msg = msg.format(**kwargs)
+ return ValidationError(msg)
def fail(self, key: str, **kwargs):
"""Helper method that raises a `ValidationError` with an error message
@@ -187,13 +205,23 @@ class Field(FieldABC):
.. deprecated:: 3.0.0
Use `make_error <marshmallow.fields.Field.make_error>` instead.
"""
- pass
+ warnings.warn(
+ "Field.fail is deprecated. Use Field.make_error instead.",
+ DeprecationWarning,
+ stacklevel=2
+ )
+ raise self.make_error(key, **kwargs)
def _validate_missing(self, value):
"""Validate missing values. Raise a :exc:`ValidationError` if
`value` should be considered missing.
"""
- pass
+ if value is missing_:
+ if self.required:
+ raise self.make_error('required')
+ elif value is None:
+ if self.allow_none is False:
+ raise self.make_error('null')
def serialize(self, attr: str, obj: typing.Any, accessor: (typing.
Callable[[typing.Any, str, typing.Any], typing.Any] | None)=None,
@@ -206,7 +234,14 @@ class Field(FieldABC):
:param accessor: Function used to access values from ``obj``.
:param kwargs: Field-specific keyword arguments.
"""
- pass
+ if self.dump_only:
+ return self.dump_default
+
+ value = self.get_value(obj, attr, accessor=accessor)
+ if value is missing_:
+ return self.dump_default
+
+ return self._serialize(value, attr, obj, **kwargs)
def deserialize(self, value: typing.Any, attr: (str | None)=None, data:
(typing.Mapping[str, typing.Any] | None)=None, **kwargs):
@@ -219,7 +254,17 @@ class Field(FieldABC):
:raise ValidationError: If an invalid value is passed or if a required value
is missing.
"""
- pass
+ if self.load_only:
+ return self.load_default
+
+ self._validate_missing(value)
+ if value is missing_:
+ return self.load_default
+
+ value = self._deserialize(value, attr, data, **kwargs)
+ self._validate(value)
+
+ return value
def _bind_to_schema(self, field_name, schema):
"""Update field with values from its parent schema. Called by
diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py
index 1e6eabf..7851fa0 100644
--- a/src/marshmallow/schema.py
+++ b/src/marshmallow/schema.py
@@ -27,7 +27,12 @@ def _get_fields(attrs):
:param attrs: Mapping of class attributes
"""
- pass
+ fields = [
+ (field_name, field_obj)
+ for field_name, field_obj in attrs.items()
+ if isinstance(field_obj, ma_fields.Field)
+ ]
+ return fields
def _get_fields_by_mro(klass):
@@ -37,7 +42,13 @@ def _get_fields_by_mro(klass):
:param type klass: Class whose fields to retrieve
"""
- pass
+ fields = []
+ for base_class in klass.__mro__[1:]: # skip the class itself
+ if hasattr(base_class, '_declared_fields'):
+ fields += list(base_class._declared_fields.items())
+ else:
+ fields += _get_fields(base_class.__dict__)
+ return fields
class SchemaMeta(ABCMeta):
@@ -83,7 +94,16 @@ class SchemaMeta(ABCMeta):
:param inherited_fields: Inherited fields.
:param dict_cls: dict-like class to use for dict output Default to ``dict``.
"""
- pass
+ declared_fields = dict_cls()
+ for field_name, field_obj in inherited_fields + cls_fields:
+ if field_name in klass.opts.exclude:
+ continue
+ if field_name in declared_fields:
+ prev_obj = declared_fields[field_name]
+ if hasattr(prev_obj, 'resolve_field_instance'):
+ field_obj = prev_obj.resolve_field_instance(field_obj)
+ declared_fields[field_name] = field_obj
+ return declared_fields
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
@@ -97,7 +117,13 @@ class SchemaMeta(ABCMeta):
By doing this after constructing the class, we let standard inheritance
do all the hard work.
"""
- pass
+ hooks = defaultdict(list)
+ for attr_name in dir(cls):
+ attr = getattr(cls, attr_name)
+ if hasattr(attr, '__marshmallow_hook__'):
+ hook = getattr(attr, '__marshmallow_hook__')
+ hooks[hook.tag].append(attr_name)
+ return dict(hooks)
class SchemaOpts:
@@ -311,7 +337,10 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionadded:: 3.0.0
"""
- pass
+ attrs = fields.copy()
+ attrs['Meta'] = type('Meta', (), {'register': False})
+ schema_cls = type(name, (cls,), attrs)
+ return schema_cls
def handle_error(self, error: ValidationError, data: typing.Any, *,
many: bool, **kwargs):
@@ -327,7 +356,7 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0rc9
Receives `many` and `partial` (on deserialization) as keyword arguments.
"""
- pass
+ pass # Default implementation does nothing
def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):
"""Defines how to pull values from an object to serialize.
@@ -337,7 +366,7 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0a1
Changed position of ``obj`` and ``attr``.
"""
- pass
+ return get_value(obj, attr, default)
@staticmethod
def _call_and_store(getter_func, data, *, field_name, error_store,
diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py
index 1c71b57..0b04589 100644
--- a/src/marshmallow/utils.py
+++ b/src/marshmallow/utils.py
@@ -40,29 +40,34 @@ missing = _Missing()
def is_generator(obj) ->bool:
"""Return True if ``obj`` is a generator"""
- pass
+ return inspect.isgenerator(obj)
def is_iterable_but_not_string(obj) ->bool:
"""Return True if ``obj`` is an iterable object that isn't a string."""
- pass
+ return (
+ isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes))
+ )
def is_collection(obj) ->bool:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
- pass
+ return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)
def is_instance_or_subclass(val, class_) ->bool:
"""Return True if ``val`` is either a subclass or instance of ``class_``."""
- pass
+ try:
+ return issubclass(val, class_)
+ except TypeError:
+ return isinstance(val, class_)
def is_keyed_tuple(obj) ->bool:
"""Return True if ``obj`` has keyed tuple behavior, such as
namedtuples or SQLAlchemy's KeyedTuples.
"""
- pass
+ return isinstance(obj, tuple) and hasattr(obj, '_fields')
def pprint(obj, *args, **kwargs) ->None:
@@ -73,7 +78,15 @@ def pprint(obj, *args, **kwargs) ->None:
.. deprecated:: 3.7.0
marshmallow.pprint will be removed in marshmallow 4.
"""
- pass
+ warnings.warn(
+ "marshmallow.pprint is deprecated and will be removed in marshmallow 4.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ if isinstance(obj, collections.OrderedDict):
+ print(json.dumps(obj, indent=2))
+ else:
+ py_pprint(obj, *args, **kwargs)
def from_rfc(datestring: str) ->dt.datetime:
@@ -81,7 +94,7 @@ def from_rfc(datestring: str) ->dt.datetime:
https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950
"""
- pass
+ return parsedate_to_datetime(datestring)
def rfcformat(datetime: dt.datetime) ->str:
@@ -89,7 +102,7 @@ def rfcformat(datetime: dt.datetime) ->str:
:param datetime datetime: The datetime.
"""
- pass
+ return format_datetime(datetime)
_iso8601_datetime_re = re.compile(
@@ -104,7 +117,11 @@ _iso8601_time_re = re.compile(
def get_fixed_timezone(offset: (int | float | dt.timedelta)) ->dt.timezone:
"""Return a tzinfo instance with a fixed offset from UTC."""
- pass
+ if isinstance(offset, dt.timedelta):
+ offset = offset.total_seconds() // 60
+ sign = '-' if offset < 0 else '+'
+ h, m = divmod(abs(int(offset)), 60)
+ return dt.timezone(dt.timedelta(hours=h, minutes=m), f"{sign}{h:02d}:{m:02d}")
def from_iso_datetime(value):
@@ -113,7 +130,32 @@ def from_iso_datetime(value):
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
"""
- pass
+ match = _iso8601_datetime_re.match(value)
+ if not match:
+ raise ValueError(f"Not a valid ISO8601-formatted datetime string: {value}")
+
+ groups = match.groupdict()
+
+ groups['year'] = int(groups['year'])
+ groups['month'] = int(groups['month'])
+ groups['day'] = int(groups['day'])
+ groups['hour'] = int(groups['hour'])
+ groups['minute'] = int(groups['minute'])
+ groups['second'] = int(groups['second'] or 0)
+ groups['microsecond'] = int(groups['microsecond'] or 0)
+
+ tzinfo = None
+ if groups['tzinfo']:
+ if groups['tzinfo'] == 'Z':
+ tzinfo = dt.timezone.utc
+ else:
+ offset_mins = int(groups['tzinfo'][-2:]) if len(groups['tzinfo']) > 3 else 0
+ offset = 60 * int(groups['tzinfo'][1:3]) + offset_mins
+ if groups['tzinfo'][0] == '-':
+ offset = -offset
+ tzinfo = get_fixed_timezone(offset)
+
+ return dt.datetime(tzinfo=tzinfo, **groups)
def from_iso_time(value):
@@ -121,12 +163,33 @@ def from_iso_time(value):
This function doesn't support time zone offsets.
"""
- pass
+ match = _iso8601_time_re.match(value)
+ if not match:
+ raise ValueError(f"Not a valid ISO8601-formatted time string: {value}")
+
+ groups = match.groupdict()
+
+ groups['hour'] = int(groups['hour'])
+ groups['minute'] = int(groups['minute'])
+ groups['second'] = int(groups['second'] or 0)
+ groups['microsecond'] = int(groups['microsecond'] or 0)
+
+ return dt.time(**groups)
def from_iso_date(value):
"""Parse a string and return a datetime.date."""
- pass
+ match = _iso8601_date_re.match(value)
+ if not match:
+ raise ValueError(f"Not a valid ISO8601-formatted date string: {value}")
+
+ groups = match.groupdict()
+
+ return dt.date(
+ int(groups['year']),
+ int(groups['month']),
+ int(groups['day'])
+ )
def isoformat(datetime: dt.datetime) ->str:
@@ -134,7 +197,7 @@ def isoformat(datetime: dt.datetime) ->str:
:param datetime datetime: The datetime.
"""
- pass
+ return datetime.isoformat()
def pluck(dictlist: list[dict[str, typing.Any]], key: str):
@@ -145,7 +208,7 @@ def pluck(dictlist: list[dict[str, typing.Any]], key: str):
>>> pluck(dlist, 'id')
[1, 2]
"""
- pass
+ return [d.get(key) for d in dictlist]
def get_value(obj, key: (int | str), default=missing):
@@ -159,7 +222,26 @@ def get_value(obj, key: (int | str), default=missing):
`get_value` will never check the value `x.i`. Consider overriding
`marshmallow.fields.Field.get_value` in this case.
"""
- pass
+ if isinstance(key, int):
+ return _get_value_for_key(obj, key, default)
+
+ return _get_value_for_keys(obj, key.split('.'), default)
+
+def _get_value_for_keys(obj, keys, default):
+ if len(keys) == 1:
+ return _get_value_for_key(obj, keys[0], default)
+ return _get_value_for_keys(
+ _get_value_for_key(obj, keys[0], default), keys[1:], default
+ )
+
+def _get_value_for_key(obj, key, default):
+ try:
+ return obj[key]
+ except (KeyError, IndexError, TypeError, AttributeError):
+ try:
+ return getattr(obj, key)
+ except AttributeError:
+ return default
def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
@@ -173,12 +255,17 @@ def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
>>> d
{'foo': {'bar': 42}}
"""
- pass
+ keys = key.split('.')
+ for key in keys[:-1]:
+ dct = dct.setdefault(key, {})
+ dct[keys[-1]] = value
def callable_or_raise(obj):
"""Check that an object is callable, else raise a :exc:`TypeError`."""
- pass
+ if not callable(obj):
+ raise TypeError(f"Object {obj!r} is not callable.")
+ return obj
def get_func_args(func: typing.Callable) ->list[str]:
@@ -188,7 +275,19 @@ def get_func_args(func: typing.Callable) ->list[str]:
.. versionchanged:: 3.0.0a1
Do not return bound arguments, eg. ``self``.
"""
- pass
+ if isinstance(func, functools.partial):
+ return get_func_args(func.func)
+
+ if inspect.isfunction(func) or inspect.ismethod(func):
+ return list(inspect.signature(func).parameters.keys())
+
+ if inspect.isclass(func):
+ return get_func_args(func.__init__)
+
+ if callable(func):
+ return get_func_args(func.__call__)
+
+ raise TypeError(f"{func!r} is not a callable.")
def resolve_field_instance(cls_or_instance):
@@ -196,7 +295,19 @@ def resolve_field_instance(cls_or_instance):
:param type|Schema cls_or_instance: Marshmallow Schema class or instance.
"""
- pass
+ if isinstance(cls_or_instance, type):
+ if not issubclass(cls_or_instance, FieldABC):
+ raise FieldInstanceResolutionError(
+ f"The class {cls_or_instance} is not a subclass of "
+ "marshmallow.base.FieldABC"
+ )
+ return cls_or_instance()
+ if isinstance(cls_or_instance, FieldABC):
+ return cls_or_instance
+ raise FieldInstanceResolutionError(
+ f"{cls_or_instance!r} is not a subclass or instance of "
+ "marshmallow.base.FieldABC"
+ )
def timedelta_to_microseconds(value: dt.timedelta) ->int:
@@ -204,4 +315,4 @@ def timedelta_to_microseconds(value: dt.timedelta) ->int:
https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950
"""
- pass
+ return (value.days * 86400 + value.seconds) * 1000000 + value.microseconds
diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py
index 3cc3b97..c06287b 100644
--- a/src/marshmallow/validate.py
+++ b/src/marshmallow/validate.py
@@ -28,7 +28,7 @@ class Validator(ABC):
"""A string representation of the args passed to this validator. Used by
`__repr__`.
"""
- pass
+ return ''
@abstractmethod
def __call__(self, value: typing.Any) ->typing.Any:
@@ -419,7 +419,13 @@ class OneOf(Validator):
of an attribute of the choice objects. Defaults to `str()`
or `str()`.
"""
- pass
+ if callable(valuegetter):
+ getter = valuegetter
+ else:
+ getter = lambda x: getattr(x, valuegetter)
+
+ for choice, label in zip_longest(self.choices, self.labels):
+ yield getter(choice), label or str(choice)
class ContainsOnly(OneOf):
diff --git a/src/marshmallow/warnings.py b/src/marshmallow/warnings.py
index 0da3c50..23b6477 100644
--- a/src/marshmallow/warnings.py
+++ b/src/marshmallow/warnings.py
@@ -1,2 +1,10 @@
class RemovedInMarshmallow4Warning(DeprecationWarning):
- pass
+ """
+ Warning class to indicate functionality that will be removed in Marshmallow 4.
+
+ This warning is a subclass of DeprecationWarning and is used to notify users
+ about features or behaviors that are deprecated and will be removed in the
+ next major version (Marshmallow 4) of the library.
+ """
+ def __init__(self, message):
+ super().__init__(message)