back to Claude Sonnet 3.5 - Base summary
Claude Sonnet 3.5 - Base: marshmallow
Failed to run pytests for test tests
ImportError while loading conftest '/testbed/tests/conftest.py'.
tests/conftest.py:5: in <module>
from tests.base import Blog, User, UserSchema
tests/base.py:11: in <module>
from marshmallow import Schema, fields, missing, post_load, validate
src/marshmallow/__init__.py:17: in <module>
from marshmallow.schema import Schema, SchemaOpts
src/marshmallow/schema.py:15: in <module>
from marshmallow import fields as ma_fields
src/marshmallow/fields.py:18: in <module>
from marshmallow.utils import is_aware, is_collection, resolve_field_instance
E ImportError: cannot import name 'is_aware' from 'marshmallow.utils' (/testbed/src/marshmallow/utils.py)
Patch diff
diff --git a/src/marshmallow/class_registry.py b/src/marshmallow/class_registry.py
index 249b898..a7b38e8 100644
--- a/src/marshmallow/class_registry.py
+++ b/src/marshmallow/class_registry.py
@@ -35,14 +35,29 @@ def register(classname: str, cls: SchemaType) ->None:
# }
"""
- pass
+ global _registry
+ _registry[classname] = [cls]
+ _registry[f"{cls.__module__}.{cls.__name__}"] = [cls]
-def get_class(classname: str, all: bool=False) ->(list[SchemaType] | SchemaType
- ):
+def get_class(classname: str, all: bool=False) ->(list[SchemaType] | SchemaType):
"""Retrieve a class from the registry.
:raises: marshmallow.exceptions.RegistryError if the class cannot be found
or if there are multiple entries for the given class name.
"""
- pass
+ try:
+ classes = _registry[classname]
+ except KeyError:
+ raise RegistryError(f"Class with name {classname!r} was not found.")
+
+ if all:
+ return classes
+
+ if len(classes) > 1:
+ raise RegistryError(
+ f"Multiple classes with name {classname!r} were found. "
+ "Please use the full, module-qualified path."
+ )
+
+ return classes[0]
diff --git a/src/marshmallow/decorators.py b/src/marshmallow/decorators.py
index d78f5be..e2944df 100644
--- a/src/marshmallow/decorators.py
+++ b/src/marshmallow/decorators.py
@@ -84,7 +84,13 @@ def validates(field_name: str) ->Callable[..., Any]:
:param str field_name: Name of the field that the method validates.
"""
- pass
+ def decorator(fn):
+ @functools.wraps(fn)
+ def wrapper(self, value, **kwargs):
+ return fn(self, value, **kwargs)
+ wrapper.__marshmallow_hook__ = {VALIDATES: field_name}
+ return wrapper
+ return decorator
def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
@@ -109,7 +115,25 @@ def validates_schema(fn: (Callable[..., Any] | None)=None, pass_many: bool=
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ if fn is None:
+ return functools.partial(
+ validates_schema,
+ pass_many=pass_many,
+ pass_original=pass_original,
+ skip_on_field_errors=skip_on_field_errors,
+ )
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper.__marshmallow_hook__ = {
+ (VALIDATES_SCHEMA, pass_many): {
+ 'pass_original': pass_original,
+ 'skip_on_field_errors': skip_on_field_errors,
+ }
+ }
+ return wrapper
def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
@@ -124,7 +148,7 @@ def pre_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ return set_hook(fn, (PRE_DUMP, pass_many))
def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
@@ -142,7 +166,7 @@ def post_dump(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
- pass
+ return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)
def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
@@ -158,7 +182,7 @@ def pre_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ return set_hook(fn, (PRE_LOAD, pass_many))
def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
@@ -177,7 +201,7 @@ def post_load(fn: (Callable[..., Any] | None)=None, pass_many: bool=False,
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
- pass
+ return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)
def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
@@ -192,4 +216,11 @@ def set_hook(fn: (Callable[..., Any] | None), key: (tuple[str, bool] | str),
:return: Decorated function if supplied, else this decorator with its args
bound.
"""
- pass
+ def decorator(func):
+ func.__marshmallow_hook__ = {key: kwargs or True}
+ return func
+
+ if fn is None:
+ return decorator
+ else:
+ return decorator(fn)
diff --git a/src/marshmallow/error_store.py b/src/marshmallow/error_store.py
index a659aaf..d5a60d5 100644
--- a/src/marshmallow/error_store.py
+++ b/src/marshmallow/error_store.py
@@ -13,6 +13,16 @@ class ErrorStore:
def __init__(self):
self.errors = {}
+ def store_error(self, field_name, error):
+ """Store an error message for a field."""
+ if field_name not in self.errors:
+ self.errors[field_name] = []
+ self.errors[field_name].append(error)
+
+ def get_errors(self):
+ """Return all stored errors."""
+ return self.errors
+
def merge_errors(errors1, errors2):
"""Deeply merge two error messages.
@@ -20,4 +30,15 @@ def merge_errors(errors1, errors2):
The format of ``errors1`` and ``errors2`` matches the ``message``
parameter of :exc:`marshmallow.exceptions.ValidationError`.
"""
- pass
+ merged = errors1.copy()
+ for key, value in errors2.items():
+ if key in merged:
+ if isinstance(merged[key], dict) and isinstance(value, dict):
+ merged[key] = merge_errors(merged[key], value)
+ elif isinstance(merged[key], list) and isinstance(value, list):
+ merged[key].extend(value)
+ else:
+ merged[key] = [merged[key], value]
+ else:
+ merged[key] = value
+ return merged
diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py
index 8656a56..c3a5d4a 100644
--- a/src/marshmallow/fields.py
+++ b/src/marshmallow/fields.py
@@ -166,19 +166,37 @@ class Field(FieldABC):
:param callable accessor: A callable used to retrieve the value of `attr` from
the object `obj`. Defaults to `marshmallow.utils.get_value`.
"""
- pass
+ accessor_func = accessor or utils.get_value
+ return accessor_func(obj, attr, default)
def _validate(self, value):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
- pass
+ errors = []
+ for validator in self.validators:
+ try:
+ if validator(value) is False:
+ self.fail('validator_failed')
+ except ValidationError as error:
+ errors.extend(error.messages)
+ if errors:
+ raise ValidationError(errors)
def make_error(self, key: str, **kwargs) ->ValidationError:
"""Helper method to make a `ValidationError` with an error message
from ``self.error_messages``.
"""
- pass
+ try:
+ msg = self.error_messages[key]
+ except KeyError as error:
+ class_name = self.__class__.__name__
+ message = (f'Error key "{key}" does not exist for field "{class_name}".'
+ f' Available keys are {", ".join(self.error_messages.keys())}.')
+ raise KeyError(message) from error
+ if isinstance(msg, str):
+ msg = msg.format(**kwargs)
+ return ValidationError(msg)
def fail(self, key: str, **kwargs):
"""Helper method that raises a `ValidationError` with an error message
@@ -187,13 +205,23 @@ class Field(FieldABC):
.. deprecated:: 3.0.0
Use `make_error <marshmallow.fields.Field.make_error>` instead.
"""
- pass
+ warnings.warn(
+ "Field.fail is deprecated. Use Field.make_error instead.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2
+ )
+ raise self.make_error(key, **kwargs)
def _validate_missing(self, value):
"""Validate missing values. Raise a :exc:`ValidationError` if
`value` should be considered missing.
"""
- pass
+ if value is missing_:
+ if self.required:
+ raise self.make_error('required')
+ elif value is None:
+ if self.allow_none is False:
+ raise self.make_error('null')
def serialize(self, attr: str, obj: typing.Any, accessor: (typing.
Callable[[typing.Any, str, typing.Any], typing.Any] | None)=None,
@@ -206,7 +234,14 @@ class Field(FieldABC):
:param accessor: Function used to access values from ``obj``.
:param kwargs: Field-specific keyword arguments.
"""
- pass
+ if self.dump_only:
+ return self.dump_default
+
+ value = self.get_value(obj, attr, accessor=accessor)
+ if value is missing_:
+ return self.dump_default
+
+ return self._serialize(value, attr, obj, **kwargs)
def deserialize(self, value: typing.Any, attr: (str | None)=None, data:
(typing.Mapping[str, typing.Any] | None)=None, **kwargs):
@@ -219,7 +254,17 @@ class Field(FieldABC):
:raise ValidationError: If an invalid value is passed or if a required value
is missing.
"""
- pass
+ if self.load_only:
+ return self.load_default
+
+ self._validate_missing(value)
+ if value is missing_:
+ return self.load_default
+
+ value = self._deserialize(value, attr, data, **kwargs)
+ self._validate(value)
+
+ return value
def _bind_to_schema(self, field_name, schema):
"""Update field with values from its parent schema. Called by
@@ -228,7 +273,13 @@ class Field(FieldABC):
:param str field_name: Field name set in schema.
:param Schema|Field schema: Parent object.
"""
- pass
+ self.parent = self.schema = schema
+ self.name = field_name
+ self.root = schema.root
+ # Allow fields to override their data key
+ if self.data_key is None:
+ self.data_key = field_name
+ self.metadata.setdefault('name', field_name)
def _serialize(self, value: typing.Any, attr: (str | None), obj: typing
.Any, **kwargs):
@@ -249,7 +300,7 @@ class Field(FieldABC):
:param dict kwargs: Field-specific keyword arguments.
:return: The serialized value
"""
- pass
+ return value
def _deserialize(self, value: typing.Any, attr: (str | None), data: (
typing.Mapping[str, typing.Any] | None), **kwargs):
@@ -268,12 +319,12 @@ class Field(FieldABC):
.. versionchanged:: 3.0.0
Added ``**kwargs`` to signature.
"""
- pass
+ return value
@property
def context(self):
"""The context dictionary for the parent :class:`Schema`."""
- pass
+ return self.parent.context if self.parent else {}
class Raw(Field):
diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py
index 1e6eabf..38ad968 100644
--- a/src/marshmallow/schema.py
+++ b/src/marshmallow/schema.py
@@ -311,7 +311,10 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionadded:: 3.0.0
"""
- pass
+ attrs = fields.copy()
+ attrs['Meta'] = type('Meta', (), {'register': False})
+ schema_cls = type(name, (Schema,), attrs)
+ return schema_cls
def handle_error(self, error: ValidationError, data: typing.Any, *,
many: bool, **kwargs):
@@ -327,7 +330,7 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0rc9
Receives `many` and `partial` (on deserialization) as keyword arguments.
"""
- pass
+ pass # Default implementation does nothing
def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):
"""Defines how to pull values from an object to serialize.
@@ -337,7 +340,7 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
.. versionchanged:: 3.0.0a1
Changed position of ``obj`` and ``attr``.
"""
- pass
+ return get_value(obj, attr, default)
@staticmethod
def _call_and_store(getter_func, data, *, field_name, error_store,
@@ -351,7 +354,12 @@ class Schema(base.SchemaABC, metaclass=SchemaMeta):
:param int index: Index of the item being validated, if validating a collection,
otherwise `None`.
"""
- pass
+ try:
+ value = getter_func(data)
+ except ValidationError as error:
+ error_store.store_error(error.messages, field_name, index=index)
+ return missing
+ return value
def _serialize(self, obj: (_T | typing.Iterable[_T]), *, many: bool=False):
"""Serialize ``obj``.
diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py
index 1c71b57..e9ec424 100644
--- a/src/marshmallow/utils.py
+++ b/src/marshmallow/utils.py
@@ -40,29 +40,35 @@ missing = _Missing()
def is_generator(obj) ->bool:
"""Return True if ``obj`` is a generator"""
- pass
+ return inspect.isgenerator(obj)
def is_iterable_but_not_string(obj) ->bool:
"""Return True if ``obj`` is an iterable object that isn't a string."""
- pass
+ return (
+ isinstance(obj, collections.abc.Iterable) and
+ not isinstance(obj, (str, bytes))
+ )
def is_collection(obj) ->bool:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
- pass
+ return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)
def is_instance_or_subclass(val, class_) ->bool:
"""Return True if ``val`` is either a subclass or instance of ``class_``."""
- pass
+ try:
+ return issubclass(val, class_)
+ except TypeError:
+ return isinstance(val, class_)
def is_keyed_tuple(obj) ->bool:
"""Return True if ``obj`` has keyed tuple behavior, such as
namedtuples or SQLAlchemy's KeyedTuples.
"""
- pass
+ return isinstance(obj, tuple) and hasattr(obj, '_fields')
def pprint(obj, *args, **kwargs) ->None:
@@ -73,7 +79,15 @@ def pprint(obj, *args, **kwargs) ->None:
.. deprecated:: 3.7.0
marshmallow.pprint will be removed in marshmallow 4.
"""
- pass
+ warnings.warn(
+ "marshmallow.pprint is deprecated and will be removed in marshmallow 4.",
+ RemovedInMarshmallow4Warning,
+ stacklevel=2,
+ )
+ if isinstance(obj, collections.OrderedDict):
+ py_pprint(dict(obj), *args, **kwargs)
+ else:
+ py_pprint(obj, *args, **kwargs)
def from_rfc(datestring: str) ->dt.datetime:
@@ -81,7 +95,7 @@ def from_rfc(datestring: str) ->dt.datetime:
https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950
"""
- pass
+ return parsedate_to_datetime(datestring)
def rfcformat(datetime: dt.datetime) ->str:
@@ -89,7 +103,7 @@ def rfcformat(datetime: dt.datetime) ->str:
:param datetime datetime: The datetime.
"""
- pass
+ return format_datetime(datetime)
_iso8601_datetime_re = re.compile(
@@ -104,7 +118,9 @@ _iso8601_time_re = re.compile(
def get_fixed_timezone(offset: (int | float | dt.timedelta)) ->dt.timezone:
"""Return a tzinfo instance with a fixed offset from UTC."""
- pass
+ if isinstance(offset, dt.timedelta):
+ offset = offset.total_seconds()
+ return dt.timezone(dt.timedelta(seconds=int(offset)))
def from_iso_datetime(value):
@@ -113,7 +129,32 @@ def from_iso_datetime(value):
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
"""
- pass
+ match = _iso8601_datetime_re.match(value)
+ if not match:
+ raise ValueError('Not a valid ISO8601-formatted datetime string')
+
+ groups = match.groupdict()
+
+ groups['year'] = int(groups['year'])
+ groups['month'] = int(groups['month'])
+ groups['day'] = int(groups['day'])
+ groups['hour'] = int(groups['hour'])
+ groups['minute'] = int(groups['minute'])
+ groups['second'] = int(groups['second'] or 0)
+ groups['microsecond'] = int(groups['microsecond'] or 0)
+
+ if groups['tzinfo'] == 'Z':
+ tzinfo = dt.timezone.utc
+ elif groups['tzinfo']:
+ offset_mins = int(groups['tzinfo'][-2:]) if len(groups['tzinfo']) > 3 else 0
+ offset = 60 * int(groups['tzinfo'][1:3]) + offset_mins
+ if groups['tzinfo'][0] == '-':
+ offset = -offset
+ tzinfo = get_fixed_timezone(offset * 60)
+ else:
+ tzinfo = None
+
+ return dt.datetime(tzinfo=tzinfo, **groups)
def from_iso_time(value):
@@ -121,12 +162,33 @@ def from_iso_time(value):
This function doesn't support time zone offsets.
"""
- pass
+ match = _iso8601_time_re.match(value)
+ if not match:
+ raise ValueError('Not a valid ISO8601-formatted time string')
+
+ groups = match.groupdict()
+
+ groups['hour'] = int(groups['hour'])
+ groups['minute'] = int(groups['minute'])
+ groups['second'] = int(groups['second'] or 0)
+ groups['microsecond'] = int(groups['microsecond'] or 0)
+
+ return dt.time(**groups)
def from_iso_date(value):
"""Parse a string and return a datetime.date."""
- pass
+ match = _iso8601_date_re.match(value)
+ if not match:
+ raise ValueError('Not a valid ISO8601-formatted date string')
+
+ groups = match.groupdict()
+
+ return dt.date(
+ int(groups['year']),
+ int(groups['month']),
+ int(groups['day'])
+ )
def isoformat(datetime: dt.datetime) ->str:
@@ -134,7 +196,7 @@ def isoformat(datetime: dt.datetime) ->str:
:param datetime datetime: The datetime.
"""
- pass
+ return datetime.isoformat()
def pluck(dictlist: list[dict[str, typing.Any]], key: str):
@@ -145,7 +207,7 @@ def pluck(dictlist: list[dict[str, typing.Any]], key: str):
>>> pluck(dlist, 'id')
[1, 2]
"""
- pass
+ return [d.get(key) for d in dictlist]
def get_value(obj, key: (int | str), default=missing):
@@ -159,7 +221,19 @@ def get_value(obj, key: (int | str), default=missing):
`get_value` will never check the value `x.i`. Consider overriding
`marshmallow.fields.Field.get_value` in this case.
"""
- pass
+ if isinstance(key, int):
+ return obj[key] if isinstance(obj, collections.abc.Sequence) else default
+
+ try:
+ return obj[key]
+ except (KeyError, AttributeError, IndexError, TypeError):
+ pass
+
+ # Fall back to getattr for dotted paths
+ if '.' in key:
+ return functools.reduce(lambda o, k: get_value(o, k, default), key.split('.'), obj)
+
+ return getattr(obj, key, default)
def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
@@ -173,12 +247,26 @@ def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
>>> d
{'foo': {'bar': 42}}
"""
- pass
+ if '.' in key:
+ head, rest = key.split('.', 1)
+ target = dct.setdefault(head, {})
+ if not isinstance(target, dict):
+ raise ValueError(
+ "Cannot set '{rest}' in '{head}' ({target}) "
+ "which is not a dict".format(
+ rest=rest, head=head, target=target
+ )
+ )
+ set_value(target, rest, value)
+ else:
+ dct[key] = value
def callable_or_raise(obj):
"""Check that an object is callable, else raise a :exc:`TypeError`."""
- pass
+ if not callable(obj):
+ raise TypeError('Object {!r} is not callable.'.format(obj))
+ return obj
def get_func_args(func: typing.Callable) ->list[str]:
@@ -188,7 +276,12 @@ def get_func_args(func: typing.Callable) ->list[str]:
.. versionchanged:: 3.0.0a1
Do not return bound arguments, eg. ``self``.
"""
- pass
+ if isinstance(func, functools.partial):
+ return get_func_args(func.func)
+ if inspect.isfunction(func) or inspect.ismethod(func):
+ return list(inspect.signature(func).parameters.keys())
+ # Callable class
+ return list(inspect.signature(func.__call__).parameters.keys())[1:]
def resolve_field_instance(cls_or_instance):
@@ -196,7 +289,19 @@ def resolve_field_instance(cls_or_instance):
:param type|Schema cls_or_instance: Marshmallow Schema class or instance.
"""
- pass
+ if isinstance(cls_or_instance, type):
+ if not issubclass(cls_or_instance, FieldABC):
+ raise FieldInstanceResolutionError(
+ 'The class "{}" is not a subclass of '
+ 'marshmallow.base.FieldABC'.format(cls_or_instance.__name__)
+ )
+ return cls_or_instance()
+ if not isinstance(cls_or_instance, FieldABC):
+ raise FieldInstanceResolutionError(
+ 'The object "{}" is not an instance of '
+ 'marshmallow.base.FieldABC'.format(cls_or_instance)
+ )
+ return cls_or_instance
def timedelta_to_microseconds(value: dt.timedelta) ->int:
@@ -204,4 +309,4 @@ def timedelta_to_microseconds(value: dt.timedelta) ->int:
https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950
"""
- pass
+ return (value.days * 86400 + value.seconds) * 1000000 + value.microseconds
diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py
index 3cc3b97..928d96e 100644
--- a/src/marshmallow/validate.py
+++ b/src/marshmallow/validate.py
@@ -105,6 +105,20 @@ class URL(Validator):
self._memoized[key] = self._regex_generator(relative,
absolute, require_tld)
return self._memoized[key]
+
+ def _regex_generator(self, relative: bool, absolute: bool, require_tld: bool
+ ) ->typing.Pattern:
+ return re.compile(
+ r"".join([
+ r"^",
+ r"(" if relative else r"",
+ r"(?:[a-z0-9\.\-\+]*)://" if absolute else r"",
+ r"(?:[^/:]+)" if not require_tld else r"(?:[^/:]+\.)+[^/:]{2,}",
+ r"(?::\d+)?(?:/?|[/?]\S+)$",
+ r")?" if relative else r"",
+ ]),
+ re.IGNORECASE
+ )
_regex = RegexMemoizer()
default_message = 'Not a valid URL.'
default_schemes = {'http', 'https', 'ftp', 'ftps'}
@@ -419,7 +433,15 @@ class OneOf(Validator):
of an attribute of the choice objects. Defaults to `str()`
or `str()`.
"""
- pass
+ if callable(valuegetter):
+ getter = valuegetter
+ elif isinstance(valuegetter, str):
+ getter = attrgetter(valuegetter)
+ else:
+ getter = str
+
+ for choice, label in zip_longest(self.choices, self.labels):
+ yield getter(choice), label or str(choice)
class ContainsOnly(OneOf):