back to Reference (Gold) summary
Reference (Gold): graphene
Pytest Summary for test graphene
status | count |
---|---|
passed | 440 |
failed | 7 |
total | 447 |
collected | 447 |
Failed pytests:
test_custom_global_id.py::TestUUIDGlobalID::test_str_schema_correct
test_custom_global_id.py::TestUUIDGlobalID::test_str_schema_correct
self =def test_str_schema_correct(self): """ Check that the schema has the expected and custom node interface and user type and that they both use UUIDs """ > parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) E AttributeError: 'TestUUIDGlobalID' object has no attribute 'schema' graphene/relay/tests/test_custom_global_id.py:45: AttributeError
test_custom_global_id.py::TestUUIDGlobalID::test_get_by_id
test_custom_global_id.py::TestUUIDGlobalID::test_get_by_id
self =def test_get_by_id(self): query = """query userById($id: UUID!) { user(id: $id) { id name } }""" # UUID need to be converted to string for serialization result = graphql_sync( > self.graphql_schema, query, variable_values={"id": str(self.user_list[0]["id"])}, ) E AttributeError: 'TestUUIDGlobalID' object has no attribute 'graphql_schema' graphene/relay/tests/test_custom_global_id.py:70: AttributeError
test_custom_global_id.py::TestSimpleGlobalID::test_str_schema_correct
test_custom_global_id.py::TestSimpleGlobalID::test_str_schema_correct
self =def test_str_schema_correct(self): """ Check that the schema has the expected and custom node interface and user type and that they both use UUIDs """ > parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) E AttributeError: 'TestSimpleGlobalID' object has no attribute 'schema' graphene/relay/tests/test_custom_global_id.py:113: AttributeError
test_custom_global_id.py::TestSimpleGlobalID::test_get_by_id
test_custom_global_id.py::TestSimpleGlobalID::test_get_by_id
self =def test_get_by_id(self): query = """query { user(id: "my global primary key in clear 3") { id name } }""" > result = graphql_sync(self.graphql_schema, query) E AttributeError: 'TestSimpleGlobalID' object has no attribute 'graphql_schema' graphene/relay/tests/test_custom_global_id.py:136: AttributeError
test_custom_global_id.py::TestCustomGlobalID::test_str_schema_correct
test_custom_global_id.py::TestCustomGlobalID::test_str_schema_correct
self =def test_str_schema_correct(self): """ Check that the schema has the expected and custom node interface and user type and that they both use UUIDs """ > parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) E AttributeError: 'TestCustomGlobalID' object has no attribute 'schema' graphene/relay/tests/test_custom_global_id.py:192: AttributeError
test_custom_global_id.py::TestCustomGlobalID::test_get_by_id
test_custom_global_id.py::TestCustomGlobalID::test_get_by_id
self =def test_get_by_id(self): query = """query { user(id: 2) { id name } }""" > result = graphql_sync(self.graphql_schema, query) E AttributeError: 'TestCustomGlobalID' object has no attribute 'graphql_schema' graphene/relay/tests/test_custom_global_id.py:215: AttributeError
test_custom_global_id.py::TestIncompleteCustomGlobalID::test_must_define_to_global_id
test_custom_global_id.py::TestIncompleteCustomGlobalID::test_must_define_to_global_id
self =def test_must_define_to_global_id(self): """ Test that if the `to_global_id` method is not defined, we can query the object, but we can't request its ID. """ class CustomGlobalIDType(BaseGlobalIDType): graphene_type = Int @classmethod def resolve_global_id(cls, info, global_id): _type = info.return_type.graphene_type._meta.name return _type, global_id class CustomNode(Node): class Meta: global_id_type = CustomGlobalIDType class User(ObjectType): class Meta: interfaces = [CustomNode] name = String() @classmethod def get_node(cls, _type, _id): return self.users[_id] class RootQuery(ObjectType): user = CustomNode.Field(User) self.schema = Schema(query=RootQuery, types=[User]) self.graphql_schema = self.schema.graphql_schema query = """query { user(id: 2) { name } }""" result = graphql_sync(self.graphql_schema, query) > assert not result.errors E assert not [GraphQLError("'TestIncompleteCustomGlobalID' object has no attribute 'users'", locations=[SourceLocation(line=2, column=13)], path=['user'])] E + where [GraphQLError("'TestIncompleteCustomGlobalID' object has no attribute 'users'", locations=[SourceLocation(line=2, column=13)], path=['user'])] = ExecutionResult(data={'user': None}, errors=[GraphQLError("'TestIncompleteCustomGlobalID' object has no attribute 'users'", locations=[SourceLocation(line=2, column=13)], path=['user'])]).errors graphene/relay/tests/test_custom_global_id.py:270: AssertionError
Patch diff
diff --git a/graphene/pyutils/dataclasses.py b/graphene/pyutils/dataclasses.py
index f1ec952..1a47452 100644
--- a/graphene/pyutils/dataclasses.py
+++ b/graphene/pyutils/dataclasses.py
@@ -1,37 +1,185 @@
+# This is a polyfill for dataclasses
+# https://docs.python.org/3/library/dataclasses.html
+# Original PEP proposal: PEP 557
+# https://www.python.org/dev/peps/pep-0557/
import re
import sys
import copy
import types
import inspect
import keyword
-__all__ = ['dataclass', 'field', 'Field', 'FrozenInstanceError', 'InitVar',
- 'MISSING', 'fields', 'asdict', 'astuple', 'make_dataclass', 'replace',
- 'is_dataclass']
-
+__all__ = [
+ "dataclass",
+ "field",
+ "Field",
+ "FrozenInstanceError",
+ "InitVar",
+ "MISSING",
+ # Helper functions.
+ "fields",
+ "asdict",
+ "astuple",
+ "make_dataclass",
+ "replace",
+ "is_dataclass",
+]
+
+# Conditions for adding methods. The boxes indicate what action the
+# dataclass decorator takes. For all of these tables, when I talk
+# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm
+# referring to the arguments to the @dataclass decorator. When
+# checking if a dunder method already exists, I mean check for an
+# entry in the class's __dict__. I never check to see if an attribute
+# is defined in a base class.
+
+# Key:
+# +=========+=========================================+
+# + Value | Meaning |
+# +=========+=========================================+
+# | <blank> | No action: no method is added. |
+# +---------+-----------------------------------------+
+# | add | Generated method is added. |
+# +---------+-----------------------------------------+
+# | raise | TypeError is raised. |
+# +---------+-----------------------------------------+
+# | None | Attribute is set to None. |
+# +=========+=========================================+
+
+# __init__
+#
+# +--- init= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __init__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __repr__
+#
+# +--- repr= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __repr__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+
+# __setattr__
+# __delattr__
+#
+# +--- frozen= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because not adding these methods would break the "frozen-ness"
+# of the class.
+
+# __eq__
+#
+# +--- eq= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __eq__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __lt__
+# __le__
+# __gt__
+# __ge__
+#
+# +--- order= parameter
+# |
+# v | | |
+# | no | yes | <--- class has any comparison method in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because to allow this case would interfere with using
+# functools.total_ordering.
+
+# __hash__
+
+# +------------------- unsafe_hash= parameter
+# | +----------- eq= parameter
+# | | +--- frozen= parameter
+# | | |
+# v v v | | |
+# | no | yes | <--- class has explicitly defined __hash__
+# +=======+=======+=======+========+========+
+# | False | False | False | | | No __eq__, use the base class __hash__
+# +-------+-------+-------+--------+--------+
+# | False | False | True | | | No __eq__, use the base class __hash__
+# +-------+-------+-------+--------+--------+
+# | False | True | False | None | | <-- the default, not hashable
+# +-------+-------+-------+--------+--------+
+# | False | True | True | add | | Frozen, so hashable, allows override
+# +-------+-------+-------+--------+--------+
+# | True | False | False | add | raise | Has no __eq__, but hashable
+# +-------+-------+-------+--------+--------+
+# | True | False | True | add | raise | Has no __eq__, but hashable
+# +-------+-------+-------+--------+--------+
+# | True | True | False | add | raise | Not frozen, but hashable
+# +-------+-------+-------+--------+--------+
+# | True | True | True | add | raise | Frozen, so hashable
+# +=======+=======+=======+========+========+
+# For boxes that are blank, __hash__ is untouched and therefore
+# inherited from the base class. If the base is object, then
+# id-based hashing is used.
+#
+# Note that a class may already have __hash__=None if it specified an
+# __eq__ method in the class body (not one that was created by
+# @dataclass).
+#
+# See _hash_action (below) for a coded version of this table.
+
+
+# Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError):
pass
+# A sentinel object for default values to signal that a default
+# factory will be used. This is given a nice repr() which will appear
+# in the function signature of dataclasses' constructors.
class _HAS_DEFAULT_FACTORY_CLASS:
-
def __repr__(self):
- return '<factory>'
+ return "<factory>"
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
-
+# A sentinel object to detect if a parameter is supplied or not. Use
+# a class to give it a better repr.
class _MISSING_TYPE:
pass
MISSING = _MISSING_TYPE()
-_EMPTY_METADATA = types.MappingProxyType({})
+# Since most per-field metadata will be unused, create an empty
+# read-only proxy that can be shared among all fields.
+_EMPTY_METADATA = types.MappingProxyType({})
+# Markers for the various kinds of fields and pseudo-fields.
class _FIELD_BASE:
-
def __init__(self, name):
self.name = name
@@ -39,17 +187,29 @@ class _FIELD_BASE:
return self.name
-_FIELD = _FIELD_BASE('_FIELD')
-_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
-_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
-_FIELDS = '__dataclass_fields__'
-_PARAMS = '__dataclass_params__'
-_POST_INIT_NAME = '__post_init__'
-_MODULE_IDENTIFIER_RE = re.compile('^(?:\\s*(\\w+)\\s*\\.)?\\s*(\\w+)')
+_FIELD = _FIELD_BASE("_FIELD")
+_FIELD_CLASSVAR = _FIELD_BASE("_FIELD_CLASSVAR")
+_FIELD_INITVAR = _FIELD_BASE("_FIELD_INITVAR")
+# The name of an attribute on the class where we store the Field
+# objects. Also used to check if a class is a Data Class.
+_FIELDS = "__dataclass_fields__"
-class _InitVarMeta(type):
+# The name of an attribute on the class that stores the parameters to
+# @dataclass.
+_PARAMS = "__dataclass_params__"
+
+# The name of the function, that if it exists, is called at the end of
+# __init__.
+_POST_INIT_NAME = "__post_init__"
+
+# String regex that string annotations for ClassVar or InitVar must match.
+# Allows "identifier.identifier[" or "identifier[".
+# https://bugs.python.org/issue33453 for details.
+_MODULE_IDENTIFIER_RE = re.compile(r"^(?:\s*(\w+)\s*\.)?\s*(\w+)")
+
+class _InitVarMeta(type):
def __getitem__(self, params):
return self
@@ -58,12 +218,31 @@ class InitVar(metaclass=_InitVarMeta):
pass
+# Instances of Field are only ever created from within this module,
+# and only from the field() function, although Field instances are
+# exposed externally as (conceptually) read-only objects.
+#
+# name and type are filled in after the fact, not in __init__.
+# They're not known at the time this class is instantiated, but it's
+# convenient if they're available later.
+#
+# When cls._FIELDS is filled in with a list of Field objects, the name
+# and type fields will have been populated.
class Field:
- __slots__ = ('name', 'type', 'default', 'default_factory', 'repr',
- 'hash', 'init', 'compare', 'metadata', '_field_type')
-
- def __init__(self, default, default_factory, init, repr, hash, compare,
- metadata):
+ __slots__ = (
+ "name",
+ "type",
+ "default",
+ "default_factory",
+ "repr",
+ "hash",
+ "init",
+ "compare",
+ "metadata",
+ "_field_type", # Private: not to be used by user code.
+ )
+
+ def __init__(self, default, default_factory, init, repr, hash, compare, metadata):
self.name = None
self.type = None
self.default = default
@@ -72,23 +251,47 @@ class Field:
self.repr = repr
self.hash = hash
self.compare = compare
- self.metadata = _EMPTY_METADATA if metadata is None or len(metadata
- ) == 0 else types.MappingProxyType(metadata)
+ self.metadata = (
+ _EMPTY_METADATA
+ if metadata is None or len(metadata) == 0
+ else types.MappingProxyType(metadata)
+ )
self._field_type = None
def __repr__(self):
return (
- f'Field(name={self.name!r},type={self.type!r},default={self.default!r},default_factory={self.default_factory!r},init={self.init!r},repr={self.repr!r},hash={self.hash!r},compare={self.compare!r},metadata={self.metadata!r},_field_type={self._field_type})'
- )
-
+ "Field("
+ f"name={self.name!r},"
+ f"type={self.type!r},"
+ f"default={self.default!r},"
+ f"default_factory={self.default_factory!r},"
+ f"init={self.init!r},"
+ f"repr={self.repr!r},"
+ f"hash={self.hash!r},"
+ f"compare={self.compare!r},"
+ f"metadata={self.metadata!r},"
+ f"_field_type={self._field_type}"
+ ")"
+ )
+
+ # This is used to support the PEP 487 __set_name__ protocol in the
+ # case where we're using a field that contains a descriptor as a
+ # defaul value. For details on __set_name__, see
+ # https://www.python.org/dev/peps/pep-0487/#implementation-details.
+ #
+ # Note that in _process_class, this Field object is overwritten
+ # with the default value, so the end result is a descriptor that
+ # had __set_name__ called on it at the right time.
def __set_name__(self, owner, name):
- func = getattr(type(self.default), '__set_name__', None)
+ func = getattr(type(self.default), "__set_name__", None)
if func:
+ # There is a __set_name__ method on the descriptor, call
+ # it.
func(self.default, owner, name)
class _DataclassParams:
- __slots__ = 'init', 'repr', 'eq', 'order', 'unsafe_hash', 'frozen'
+ __slots__ = ("init", "repr", "eq", "order", "unsafe_hash", "frozen")
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
self.init = init
@@ -100,12 +303,30 @@ class _DataclassParams:
def __repr__(self):
return (
- f'_DataclassParams(init={self.init!r},repr={self.repr!r},eq={self.eq!r},order={self.order!r},unsafe_hash={self.unsafe_hash!r},frozen={self.frozen!r})'
- )
-
-
-def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
- hash=None, compare=True, metadata=None):
+ "_DataclassParams("
+ f"init={self.init!r},"
+ f"repr={self.repr!r},"
+ f"eq={self.eq!r},"
+ f"order={self.order!r},"
+ f"unsafe_hash={self.unsafe_hash!r},"
+ f"frozen={self.frozen!r}"
+ ")"
+ )
+
+
+# This function is used instead of exposing Field creation directly,
+# so that a type checker can be told (via overloads) that this is a
+# function whose type depends on its parameters.
+def field(
+ *,
+ default=MISSING,
+ default_factory=MISSING,
+ init=True,
+ repr=True,
+ hash=None,
+ compare=True,
+ metadata=None,
+):
"""Return an object to identify dataclass fields.
default is the default value of the field. default_factory is a
@@ -119,22 +340,631 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
It is an error to specify both default and default_factory.
"""
- pass
-
-_hash_action = {(False, False, False, False): None, (False, False, False,
- True): None, (False, False, True, False): None, (False, False, True,
- True): None, (False, True, False, False): _hash_set_none, (False, True,
- False, True): None, (False, True, True, False): _hash_add, (False, True,
- True, True): None, (True, False, False, False): _hash_add, (True, False,
- False, True): _hash_exception, (True, False, True, False): _hash_add, (
- True, False, True, True): _hash_exception, (True, True, False, False):
- _hash_add, (True, True, False, True): _hash_exception, (True, True,
- True, False): _hash_add, (True, True, True, True): _hash_exception}
-
-
-def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
- unsafe_hash=False, frozen=False):
+ if default is not MISSING and default_factory is not MISSING:
+ raise ValueError("cannot specify both default and default_factory")
+ return Field(default, default_factory, init, repr, hash, compare, metadata)
+
+
+def _tuple_str(obj_name, fields):
+ # Return a string representing each field of obj_name as a tuple
+ # member. So, if fields is ['x', 'y'] and obj_name is "self",
+ # return "(self.x,self.y)".
+
+ # Special case for the 0-tuple.
+ if not fields:
+ return "()"
+ # Note the trailing comma, needed if this turns out to be a 1-tuple.
+ return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
+
+
+def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
+ # Note that we mutate locals when exec() is called. Caller
+ # beware! The only callers are internal to this module, so no
+ # worries about external callers.
+ if locals is None:
+ locals = {}
+ return_annotation = ""
+ if return_type is not MISSING:
+ locals["_return_type"] = return_type
+ return_annotation = "->_return_type"
+ args = ",".join(args)
+ body = "\n".join(f" {b}" for b in body)
+
+ # Compute the text of the entire function.
+ txt = f"def {name}({args}){return_annotation}:\n{body}"
+
+ exec(txt, globals, locals)
+ return locals[name]
+
+
+def _field_assign(frozen, name, value, self_name):
+ # If we're a frozen class, then assign to our fields in __init__
+ # via object.__setattr__. Otherwise, just use a simple
+ # assignment.
+ #
+ # self_name is what "self" is called in this function: don't
+ # hard-code "self", since that might be a field name.
+ if frozen:
+ return f"object.__setattr__({self_name},{name!r},{value})"
+ return f"{self_name}.{name}={value}"
+
+
+def _field_init(f, frozen, globals, self_name):
+ # Return the text of the line in the body of __init__ that will
+ # initialize this field.
+
+ default_name = f"_dflt_{f.name}"
+ if f.default_factory is not MISSING:
+ if f.init:
+ # This field has a default factory. If a parameter is
+ # given, use it. If not, call the factory.
+ globals[default_name] = f.default_factory
+ value = (
+ f"{default_name}() "
+ f"if {f.name} is _HAS_DEFAULT_FACTORY "
+ f"else {f.name}"
+ )
+ else:
+ # This is a field that's not in the __init__ params, but
+ # has a default factory function. It needs to be
+ # initialized here by calling the factory function,
+ # because there's no other way to initialize it.
+
+ # For a field initialized with a default=defaultvalue, the
+ # class dict just has the default value
+ # (cls.fieldname=defaultvalue). But that won't work for a
+ # default factory, the factory must be called in __init__
+ # and we must assign that to self.fieldname. We can't
+ # fall back to the class dict's value, both because it's
+ # not set, and because it might be different per-class
+ # (which, after all, is why we have a factory function!).
+
+ globals[default_name] = f.default_factory
+ value = f"{default_name}()"
+ else:
+ # No default factory.
+ if f.init:
+ if f.default is MISSING:
+ # There's no default, just do an assignment.
+ value = f.name
+ elif f.default is not MISSING:
+ globals[default_name] = f.default
+ value = f.name
+ else:
+ # This field does not need initialization. Signify that
+ # to the caller by returning None.
+ return None
+ # Only test this now, so that we can create variables for the
+ # default. However, return None to signify that we're not going
+ # to actually do the assignment statement for InitVars.
+ if f._field_type == _FIELD_INITVAR:
+ return None
+ # Now, actually generate the field assignment.
+ return _field_assign(frozen, f.name, value, self_name)
+
+
+def _init_param(f):
+ # Return the __init__ parameter string for this field. For
+ # example, the equivalent of 'x:int=3' (except instead of 'int',
+ # reference a variable set to int, and instead of '3', reference a
+ # variable set to 3).
+ if f.default is MISSING and f.default_factory is MISSING:
+ # There's no default, and no default_factory, just output the
+ # variable name and type.
+ default = ""
+ elif f.default is not MISSING:
+ # There's a default, this will be the name that's used to look
+ # it up.
+ default = f"=_dflt_{f.name}"
+ elif f.default_factory is not MISSING:
+ # There's a factory function. Set a marker.
+ default = "=_HAS_DEFAULT_FACTORY"
+ return f"{f.name}:_type_{f.name}{default}"
+
+
+def _init_fn(fields, frozen, has_post_init, self_name):
+ # fields contains both real fields and InitVar pseudo-fields.
+
+ # Make sure we don't have fields without defaults following fields
+ # with defaults. This actually would be caught when exec-ing the
+ # function source code, but catching it here gives a better error
+ # message, and future-proofs us in case we build up the function
+ # using ast.
+ seen_default = False
+ for f in fields:
+ # Only consider fields in the __init__ call.
+ if f.init:
+ if not (f.default is MISSING and f.default_factory is MISSING):
+ seen_default = True
+ elif seen_default:
+ raise TypeError(
+ f"non-default argument {f.name!r} " "follows default argument"
+ )
+ globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY}
+
+ body_lines = []
+ for f in fields:
+ line = _field_init(f, frozen, globals, self_name)
+ # line is None means that this field doesn't require
+ # initialization (it's a pseudo-field). Just skip it.
+ if line:
+ body_lines.append(line)
+ # Does this class have a post-init function?
+ if has_post_init:
+ params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR)
+ body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})")
+ # If no body lines, use 'pass'.
+ if not body_lines:
+ body_lines = ["pass"]
+ locals = {f"_type_{f.name}": f.type for f in fields}
+ return _create_fn(
+ "__init__",
+ [self_name] + [_init_param(f) for f in fields if f.init],
+ body_lines,
+ locals=locals,
+ globals=globals,
+ return_type=None,
+ )
+
+
+def _repr_fn(fields):
+ return _create_fn(
+ "__repr__",
+ ("self",),
+ [
+ 'return self.__class__.__qualname__ + f"('
+ + ", ".join([f"{f.name}={{self.{f.name}!r}}" for f in fields])
+ + ')"'
+ ],
+ )
+
+
+def _frozen_get_del_attr(cls, fields):
+ # XXX: globals is modified on the first call to _create_fn, then
+ # the modified version is used in the second call. Is this okay?
+ globals = {"cls": cls, "FrozenInstanceError": FrozenInstanceError}
+ if fields:
+ fields_str = "(" + ",".join(repr(f.name) for f in fields) + ",)"
+ else:
+ # Special case for the zero-length tuple.
+ fields_str = "()"
+ return (
+ _create_fn(
+ "__setattr__",
+ ("self", "name", "value"),
+ (
+ f"if type(self) is cls or name in {fields_str}:",
+ ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
+ f"super(cls, self).__setattr__(name, value)",
+ ),
+ globals=globals,
+ ),
+ _create_fn(
+ "__delattr__",
+ ("self", "name"),
+ (
+ f"if type(self) is cls or name in {fields_str}:",
+ ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
+ f"super(cls, self).__delattr__(name)",
+ ),
+ globals=globals,
+ ),
+ )
+
+
+def _cmp_fn(name, op, self_tuple, other_tuple):
+ # Create a comparison function. If the fields in the object are
+ # named 'x' and 'y', then self_tuple is the string
+ # '(self.x,self.y)' and other_tuple is the string
+ # '(other.x,other.y)'.
+
+ return _create_fn(
+ name,
+ ("self", "other"),
+ [
+ "if other.__class__ is self.__class__:",
+ f" return {self_tuple}{op}{other_tuple}",
+ "return NotImplemented",
+ ],
+ )
+
+
+def _hash_fn(fields):
+ self_tuple = _tuple_str("self", fields)
+ return _create_fn("__hash__", ("self",), [f"return hash({self_tuple})"])
+
+
+def _is_classvar(a_type, typing):
+ # This test uses a typing internal class, but it's the best way to
+ # test if this is a ClassVar.
+ return type(a_type) is typing._ClassVar
+
+
+def _is_initvar(a_type, dataclasses):
+ # The module we're checking against is the module we're
+ # currently in (dataclasses.py).
+ return a_type is dataclasses.InitVar
+
+
+def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
+ # Given a type annotation string, does it refer to a_type in
+ # a_module? For example, when checking that annotation denotes a
+ # ClassVar, then a_module is typing, and a_type is
+ # typing.ClassVar.
+
+ # It's possible to look up a_module given a_type, but it involves
+ # looking in sys.modules (again!), and seems like a waste since
+ # the caller already knows a_module.
+
+ # - annotation is a string type annotation
+ # - cls is the class that this annotation was found in
+ # - a_module is the module we want to match
+ # - a_type is the type in that module we want to match
+ # - is_type_predicate is a function called with (obj, a_module)
+ # that determines if obj is of the desired type.
+
+ # Since this test does not do a local namespace lookup (and
+ # instead only a module (global) lookup), there are some things it
+ # gets wrong.
+
+ # With string annotations, cv0 will be detected as a ClassVar:
+ # CV = ClassVar
+ # @dataclass
+ # class C0:
+ # cv0: CV
+
+ # But in this example cv1 will not be detected as a ClassVar:
+ # @dataclass
+ # class C1:
+ # CV = ClassVar
+ # cv1: CV
+
+ # In C1, the code in this function (_is_type) will look up "CV" in
+ # the module and not find it, so it will not consider cv1 as a
+ # ClassVar. This is a fairly obscure corner case, and the best
+ # way to fix it would be to eval() the string "CV" with the
+ # correct global and local namespaces. However that would involve
+ # a eval() penalty for every single field of every dataclass
+ # that's defined. It was judged not worth it.
+
+ match = _MODULE_IDENTIFIER_RE.match(annotation)
+ if match:
+ ns = None
+ module_name = match.group(1)
+ if not module_name:
+ # No module name, assume the class's module did
+ # "from dataclasses import InitVar".
+ ns = sys.modules.get(cls.__module__).__dict__
+ else:
+ # Look up module_name in the class's module.
+ module = sys.modules.get(cls.__module__)
+ if module and module.__dict__.get(module_name) is a_module:
+ ns = sys.modules.get(a_type.__module__).__dict__
+ if ns and is_type_predicate(ns.get(match.group(2)), a_module):
+ return True
+ return False
+
+
+def _get_field(cls, a_name, a_type):
+ # Return a Field object for this field name and type. ClassVars
+ # and InitVars are also returned, but marked as such (see
+ # f._field_type).
+
+ # If the default value isn't derived from Field, then it's only a
+ # normal default value. Convert it to a Field().
+ default = getattr(cls, a_name, MISSING)
+ if isinstance(default, Field):
+ f = default
+ else:
+ if isinstance(default, types.MemberDescriptorType):
+ # This is a field in __slots__, so it has no default value.
+ default = MISSING
+ f = field(default=default)
+ # Only at this point do we know the name and the type. Set them.
+ f.name = a_name
+ f.type = a_type
+
+ # Assume it's a normal field until proven otherwise. We're next
+ # going to decide if it's a ClassVar or InitVar, everything else
+ # is just a normal field.
+ f._field_type = _FIELD
+
+ # In addition to checking for actual types here, also check for
+ # string annotations. get_type_hints() won't always work for us
+ # (see https://github.com/python/typing/issues/508 for example),
+ # plus it's expensive and would require an eval for every stirng
+ # annotation. So, make a best effort to see if this is a ClassVar
+ # or InitVar using regex's and checking that the thing referenced
+ # is actually of the correct type.
+
+ # For the complete discussion, see https://bugs.python.org/issue33453
+
+ # If typing has not been imported, then it's impossible for any
+ # annotation to be a ClassVar. So, only look for ClassVar if
+ # typing has been imported by any module (not necessarily cls's
+ # module).
+ typing = sys.modules.get("typing")
+ if typing:
+ if _is_classvar(a_type, typing) or (
+ isinstance(f.type, str)
+ and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar)
+ ):
+ f._field_type = _FIELD_CLASSVAR
+ # If the type is InitVar, or if it's a matching string annotation,
+ # then it's an InitVar.
+ if f._field_type is _FIELD:
+ # The module we're checking against is the module we're
+ # currently in (dataclasses.py).
+ dataclasses = sys.modules[__name__]
+ if _is_initvar(a_type, dataclasses) or (
+ isinstance(f.type, str)
+ and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar)
+ ):
+ f._field_type = _FIELD_INITVAR
+ # Validations for individual fields. This is delayed until now,
+ # instead of in the Field() constructor, since only here do we
+ # know the field name, which allows for better error reporting.
+
+ # Special restrictions for ClassVar and InitVar.
+ if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
+ if f.default_factory is not MISSING:
+ raise TypeError(f"field {f.name} cannot have a " "default factory")
+ # Should I check for other field settings? default_factory
+ # seems the most serious to check for. Maybe add others. For
+ # example, how about init=False (or really,
+ # init=<not-the-default-init-value>)? It makes no sense for
+ # ClassVar and InitVar to specify init=<anything>.
+ # For real fields, disallow mutable defaults for known types.
+ if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
+ raise ValueError(
+ f"mutable default {type(f.default)} for field "
+ f"{f.name} is not allowed: use default_factory"
+ )
+ return f
+
+
+def _set_new_attribute(cls, name, value):
+ # Never overwrites an existing attribute. Returns True if the
+ # attribute already exists.
+ if name in cls.__dict__:
+ return True
+ setattr(cls, name, value)
+ return False
+
+
+# Decide if/how we're going to create a hash function. Key is
+# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to
+# take. The common case is to do nothing, so instead of providing a
+# function that is a no-op, use None to signify that.
+
+
+def _hash_set_none(cls, fields):
+ return None
+
+
+def _hash_add(cls, fields):
+ flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
+ return _hash_fn(flds)
+
+
+def _hash_exception(cls, fields):
+ # Raise an exception.
+ raise TypeError(f"Cannot overwrite attribute __hash__ " f"in class {cls.__name__}")
+
+
+#
+# +-------------------------------------- unsafe_hash?
+# | +------------------------------- eq?
+# | | +------------------------ frozen?
+# | | | +---------------- has-explicit-hash?
+# | | | |
+# | | | | +------- action
+# | | | | |
+# v v v v v
+_hash_action = {
+ (False, False, False, False): None,
+ (False, False, False, True): None,
+ (False, False, True, False): None,
+ (False, False, True, True): None,
+ (False, True, False, False): _hash_set_none,
+ (False, True, False, True): None,
+ (False, True, True, False): _hash_add,
+ (False, True, True, True): None,
+ (True, False, False, False): _hash_add,
+ (True, False, False, True): _hash_exception,
+ (True, False, True, False): _hash_add,
+ (True, False, True, True): _hash_exception,
+ (True, True, False, False): _hash_add,
+ (True, True, False, True): _hash_exception,
+ (True, True, True, False): _hash_add,
+ (True, True, True, True): _hash_exception,
+}
+# See https://bugs.python.org/issue32929#msg312829 for an if-statement
+# version of this table.
+
+
+def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
+ # Now that dicts retain insertion order, there's no reason to use
+ # an ordered dict. I am leveraging that ordering here, because
+ # derived class fields overwrite base class fields, but the order
+ # is defined by the base class, which is found first.
+ fields = {}
+
+ setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, unsafe_hash, frozen))
+
+ # Find our base classes in reverse MRO order, and exclude
+ # ourselves. In reversed order so that more derived classes
+ # override earlier field definitions in base classes. As long as
+ # we're iterating over them, see if any are frozen.
+ any_frozen_base = False
+ has_dataclass_bases = False
+ for b in cls.__mro__[-1:0:-1]:
+ # Only process classes that have been processed by our
+ # decorator. That is, they have a _FIELDS attribute.
+ base_fields = getattr(b, _FIELDS, None)
+ if base_fields:
+ has_dataclass_bases = True
+ for f in base_fields.values():
+ fields[f.name] = f
+ if getattr(b, _PARAMS).frozen:
+ any_frozen_base = True
+ # Annotations that are defined in this class (not in base
+ # classes). If __annotations__ isn't present, then this class
+ # adds no new annotations. We use this to compute fields that are
+ # added by this class.
+ #
+ # Fields are found from cls_annotations, which is guaranteed to be
+ # ordered. Default values are from class attributes, if a field
+ # has a default. If the default value is a Field(), then it
+ # contains additional info beyond (and possibly including) the
+ # actual default value. Pseudo-fields ClassVars and InitVars are
+ # included, despite the fact that they're not real fields. That's
+ # dealt with later.
+ cls_annotations = cls.__dict__.get("__annotations__", {})
+
+ # Now find fields in our class. While doing so, validate some
+ # things, and set the default values (as class attributes) where
+ # we can.
+ cls_fields = [
+ _get_field(cls, name, type_) for name, type_ in cls_annotations.items()
+ ]
+ for f in cls_fields:
+ fields[f.name] = f
+
+ # If the class attribute (which is the default value for this
+ # field) exists and is of type 'Field', replace it with the
+ # real default. This is so that normal class introspection
+ # sees a real default value, not a Field.
+ if isinstance(getattr(cls, f.name, None), Field):
+ if f.default is MISSING:
+ # If there's no default, delete the class attribute.
+ # This happens if we specify field(repr=False), for
+ # example (that is, we specified a field object, but
+ # no default value). Also if we're using a default
+ # factory. The class attribute should not be set at
+ # all in the post-processed class.
+ delattr(cls, f.name)
+ else:
+ setattr(cls, f.name, f.default)
+ # Do we have any Field members that don't also have annotations?
+ for name, value in cls.__dict__.items():
+ if isinstance(value, Field) and not name in cls_annotations:
+ raise TypeError(f"{name!r} is a field but has no type annotation")
+ # Check rules that apply if we are derived from any dataclasses.
+ if has_dataclass_bases:
+ # Raise an exception if any of our bases are frozen, but we're not.
+ if any_frozen_base and not frozen:
+ raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one")
+ # Raise an exception if we're frozen, but none of our bases are.
+ if not any_frozen_base and frozen:
+ raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one")
+ # Remember all of the fields on our class (including bases). This
+ # also marks this class as being a dataclass.
+ setattr(cls, _FIELDS, fields)
+
+ # Was this class defined with an explicit __hash__? Note that if
+ # __eq__ is defined in this class, then python will automatically
+ # set __hash__ to None. This is a heuristic, as it's possible
+ # that such a __hash__ == None was not auto-generated, but it
+ # close enough.
+ class_hash = cls.__dict__.get("__hash__", MISSING)
+ has_explicit_hash = not (
+ class_hash is MISSING or (class_hash is None and "__eq__" in cls.__dict__)
+ )
+
+ # If we're generating ordering methods, we must be generating the
+ # eq methods.
+ if order and not eq:
+ raise ValueError("eq must be true if order is true")
+ if init:
+ # Does this class have a post-init function?
+ has_post_init = hasattr(cls, _POST_INIT_NAME)
+
+ # Include InitVars and regular fields (so, not ClassVars).
+ flds = [f for f in fields.values() if f._field_type in (_FIELD, _FIELD_INITVAR)]
+ _set_new_attribute(
+ cls,
+ "__init__",
+ _init_fn(
+ flds,
+ frozen,
+ has_post_init,
+ # The name to use for the "self"
+ # param in __init__. Use "self"
+ # if possible.
+ "__dataclass_self__" if "self" in fields else "self",
+ ),
+ )
+ # Get the fields as a list, and include only real fields. This is
+ # used in all of the following methods.
+ field_list = [f for f in fields.values() if f._field_type is _FIELD]
+
+ if repr:
+ flds = [f for f in field_list if f.repr]
+ _set_new_attribute(cls, "__repr__", _repr_fn(flds))
+ if eq:
+ # Create _eq__ method. There's no need for a __ne__ method,
+ # since python will call __eq__ and negate it.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str("self", flds)
+ other_tuple = _tuple_str("other", flds)
+ _set_new_attribute(
+ cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple)
+ )
+ if order:
+ # Create and set the ordering methods.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str("self", flds)
+ other_tuple = _tuple_str("other", flds)
+ for name, op in [
+ ("__lt__", "<"),
+ ("__le__", "<="),
+ ("__gt__", ">"),
+ ("__ge__", ">="),
+ ]:
+ if _set_new_attribute(
+ cls, name, _cmp_fn(name, op, self_tuple, other_tuple)
+ ):
+ raise TypeError(
+ f"Cannot overwrite attribute {name} "
+ f"in class {cls.__name__}. Consider using "
+ "functools.total_ordering"
+ )
+ if frozen:
+ for fn in _frozen_get_del_attr(cls, field_list):
+ if _set_new_attribute(cls, fn.__name__, fn):
+ raise TypeError(
+ f"Cannot overwrite attribute {fn.__name__} "
+ f"in class {cls.__name__}"
+ )
+ # Decide if/how we're going to create a hash function.
+ hash_action = _hash_action[
+ bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash
+ ]
+ if hash_action:
+ # No need to call _set_new_attribute here, since by the time
+ # we're here the overwriting is unconditional.
+ cls.__hash__ = hash_action(cls, field_list)
+ if not getattr(cls, "__doc__"):
+ # Create a class doc-string.
+ cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "")
+ return cls
+
+
+# _cls should never be specified by keyword, so start it with an
+# underscore. The presence of _cls is used to detect if this
+# decorator is being called with parameters or not.
+def dataclass(
+ _cls=None,
+ *,
+ init=True,
+ repr=True,
+ eq=True,
+ order=False,
+ unsafe_hash=False,
+ frozen=False,
+):
"""Returns the same class as was passed in, with dunder methods
added based on the fields defined in the class.
@@ -146,7 +976,16 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
__hash__() method function is added. If frozen is true, fields may
not be assigned to after instance creation.
"""
- pass
+
+ def wrap(cls):
+ return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
+
+ # See if we're being called as @dataclass or @dataclass().
+ if _cls is None:
+ # We're called with parens.
+ return wrap
+ # We're called as @dataclass without parens.
+ return wrap(_cls)
def fields(class_or_instance):
@@ -155,18 +994,26 @@ def fields(class_or_instance):
Accepts a dataclass or an instance of one. Tuple elements are of
type Field.
"""
- pass
+
+ # Might it be worth caching this, per class?
+ try:
+ fields = getattr(class_or_instance, _FIELDS)
+ except AttributeError:
+ raise TypeError("must be called with a dataclass type or instance")
+ # Exclude pseudo-fields. Note that fields is sorted by insertion
+ # order, so the order of the tuple is as the fields were defined.
+ return tuple(f for f in fields.values() if f._field_type is _FIELD)
def _is_dataclass_instance(obj):
"""Returns True if obj is an instance of a dataclass."""
- pass
+ return not isinstance(obj, type) and hasattr(obj, _FIELDS)
def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a
dataclass."""
- pass
+ return hasattr(obj, _FIELDS)
def asdict(obj, *, dict_factory=dict):
@@ -188,7 +1035,27 @@ def asdict(obj, *, dict_factory=dict):
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
- pass
+ if not _is_dataclass_instance(obj):
+ raise TypeError("asdict() should be called on dataclass instances")
+ return _asdict_inner(obj, dict_factory)
+
+
+def _asdict_inner(obj, dict_factory):
+ if _is_dataclass_instance(obj):
+ result = []
+ for f in fields(obj):
+ value = _asdict_inner(getattr(obj, f.name), dict_factory)
+ result.append((f.name, value))
+ return dict_factory(result)
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
+ elif isinstance(obj, dict):
+ return type(obj)(
+ (_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
+ for k, v in obj.items()
+ )
+ else:
+ return copy.deepcopy(obj)
def astuple(obj, *, tuple_factory=tuple):
@@ -209,11 +1076,43 @@ def astuple(obj, *, tuple_factory=tuple):
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
- pass
-
-def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
- repr=True, eq=True, order=False, unsafe_hash=False, frozen=False):
+ if not _is_dataclass_instance(obj):
+ raise TypeError("astuple() should be called on dataclass instances")
+ return _astuple_inner(obj, tuple_factory)
+
+
+def _astuple_inner(obj, tuple_factory):
+ if _is_dataclass_instance(obj):
+ result = []
+ for f in fields(obj):
+ value = _astuple_inner(getattr(obj, f.name), tuple_factory)
+ result.append(value)
+ return tuple_factory(result)
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
+ elif isinstance(obj, dict):
+ return type(obj)(
+ (_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
+ for k, v in obj.items()
+ )
+ else:
+ return copy.deepcopy(obj)
+
+
+def make_dataclass(
+ cls_name,
+ fields,
+ *,
+ bases=(),
+ namespace=None,
+ init=True,
+ repr=True,
+ eq=True,
+ order=False,
+ unsafe_hash=False,
+ frozen=False,
+):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an iterable
@@ -236,7 +1135,48 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to
dataclass().
"""
- pass
+
+ if namespace is None:
+ namespace = {}
+ else:
+ # Copy namespace since we're going to mutate it.
+ namespace = namespace.copy()
+ # While we're looking through the field names, validate that they
+ # are identifiers, are not keywords, and not duplicates.
+ seen = set()
+ anns = {}
+ for item in fields:
+ if isinstance(item, str):
+ name = item
+ tp = "typing.Any"
+ elif len(item) == 2:
+ (name, tp) = item
+ elif len(item) == 3:
+ name, tp, spec = item
+ namespace[name] = spec
+ else:
+ raise TypeError(f"Invalid field: {item!r}")
+ if not isinstance(name, str) or not name.isidentifier():
+ raise TypeError(f"Field names must be valid identifers: {name!r}")
+ if keyword.iskeyword(name):
+ raise TypeError(f"Field names must not be keywords: {name!r}")
+ if name in seen:
+ raise TypeError(f"Field name duplicated: {name!r}")
+ seen.add(name)
+ anns[name] = tp
+ namespace["__annotations__"] = anns
+ # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
+ # of generic dataclassses.
+ cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace))
+ return dataclass(
+ cls,
+ init=init,
+ repr=repr,
+ eq=eq,
+ order=order,
+ unsafe_hash=unsafe_hash,
+ frozen=frozen,
+ )
def replace(obj, **changes):
@@ -253,4 +1193,30 @@ def replace(obj, **changes):
c1 = replace(c, x=3)
assert c1.x == 3 and c1.y == 2
"""
- pass
+
+ # We're going to mutate 'changes', but that's okay because it's a
+ # new dict, even if called with 'replace(obj, **my_changes)'.
+
+ if not _is_dataclass_instance(obj):
+ raise TypeError("replace() should be called on dataclass instances")
+ # It's an error to have init=False fields in 'changes'.
+ # If a field is not in 'changes', read its value from the provided obj.
+
+ for f in getattr(obj, _FIELDS).values():
+ if not f.init:
+ # Error if this field is specified in changes.
+ if f.name in changes:
+ raise ValueError(
+ f"field {f.name} is declared with "
+ "init=False, it cannot be specified with "
+ "replace()"
+ )
+ continue
+ if f.name not in changes:
+ changes[f.name] = getattr(obj, f.name)
+ # Create the new object, which calls __init__() and
+ # __post_init__() (if defined), using all of the init fields we've
+ # added and/or left in 'changes'. If there are values supplied in
+ # changes that aren't fields, this will correctly raise a
+ # TypeError.
+ return obj.__class__(**changes)
diff --git a/graphene/pyutils/version.py b/graphene/pyutils/version.py
index 7f16d40..8a3be07 100644
--- a/graphene/pyutils/version.py
+++ b/graphene/pyutils/version.py
@@ -1,24 +1,58 @@
from __future__ import unicode_literals
+
import datetime
import os
import subprocess
def get_version(version=None):
- """Returns a PEP 440-compliant version number from VERSION."""
- pass
+ "Returns a PEP 440-compliant version number from VERSION."
+ version = get_complete_version(version)
+
+ # Now build the two parts of the version number:
+ # main = X.Y[.Z]
+ # sub = .devN - for pre-alpha releases
+ # | {a|b|rc}N - for alpha, beta, and rc releases
+
+ main = get_main_version(version)
+
+ sub = ""
+ if version[3] == "alpha" and version[4] == 0:
+ git_changeset = get_git_changeset()
+ sub = ".dev%s" % git_changeset if git_changeset else ".dev"
+ elif version[3] != "final":
+ mapping = {"alpha": "a", "beta": "b", "rc": "rc"}
+ sub = mapping[version[3]] + str(version[4])
+
+ return str(main + sub)
def get_main_version(version=None):
- """Returns main version (X.Y[.Z]) from VERSION."""
- pass
+ "Returns main version (X.Y[.Z]) from VERSION."
+ version = get_complete_version(version)
+ parts = 2 if version[2] == 0 else 3
+ return ".".join(str(x) for x in version[:parts])
def get_complete_version(version=None):
"""Returns a tuple of the graphene version. If version argument is non-empty,
then checks for correctness of the tuple provided.
"""
- pass
+ if version is None:
+ from graphene import VERSION as version
+ else:
+ assert len(version) == 5
+ assert version[3] in ("alpha", "beta", "rc", "final")
+
+ return version
+
+
+def get_docs_version(version=None):
+ version = get_complete_version(version)
+ if version[3] != "final":
+ return "dev"
+ else:
+ return "%d.%d" % version[:2]
def get_git_changeset():
@@ -27,4 +61,18 @@ def get_git_changeset():
This value isn't guaranteed to be unique, but collisions are very unlikely,
so it's sufficient for generating the development version numbers.
"""
- pass
+ repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ try:
+ git_log = subprocess.Popen(
+ "git log --pretty=format:%ct --quiet -1 HEAD",
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ cwd=repo_dir,
+ universal_newlines=True,
+ )
+ timestamp = git_log.communicate()[0]
+ timestamp = datetime.datetime.utcfromtimestamp(int(timestamp))
+ except:
+ return None
+ return timestamp.strftime("%Y%m%d%H%M%S")
diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py
index b1ab0bf..cc7d2da 100644
--- a/graphene/relay/connection.py
+++ b/graphene/relay/connection.py
@@ -2,7 +2,9 @@ import re
from collections.abc import Iterable
from functools import partial
from typing import Type
+
from graphql_relay import connection_from_array
+
from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeOptions
@@ -10,26 +12,72 @@ from ..utils.thenables import maybe_thenable
from .node import is_node, AbstractNode
-class PageInfo(ObjectType):
+def get_edge_class(
+ connection_class: Type["Connection"],
+ _node: Type[AbstractNode],
+ base_name: str,
+ strict_types: bool = False,
+):
+ edge_class = getattr(connection_class, "Edge", None)
+ class EdgeBase:
+ node = Field(
+ NonNull(_node) if strict_types else _node,
+ description="The item at the end of the edge",
+ )
+ cursor = String(required=True, description="A cursor for use in pagination")
- class Meta:
- description = (
- 'The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.'
- )
- has_next_page = Boolean(required=True, name='hasNextPage', description=
- 'When paginating forwards, are there more items?')
- has_previous_page = Boolean(required=True, name='hasPreviousPage',
- description='When paginating backwards, are there more items?')
- start_cursor = String(name='startCursor', description=
- 'When paginating backwards, the cursor to continue.')
- end_cursor = String(name='endCursor', description=
- 'When paginating forwards, the cursor to continue.')
+ class EdgeMeta:
+ description = f"A Relay edge containing a `{base_name}` and its cursor."
+
+ edge_name = f"{base_name}Edge"
+
+ edge_bases = [edge_class, EdgeBase] if edge_class else [EdgeBase]
+ if not isinstance(edge_class, ObjectType):
+ edge_bases = [*edge_bases, ObjectType]
+
+ return type(edge_name, tuple(edge_bases), {"Meta": EdgeMeta})
+class PageInfo(ObjectType):
+ class Meta:
+ description = (
+ "The Relay compliant `PageInfo` type, containing data necessary to"
+ " paginate this connection."
+ )
+
+ has_next_page = Boolean(
+ required=True,
+ name="hasNextPage",
+ description="When paginating forwards, are there more items?",
+ )
+
+ has_previous_page = Boolean(
+ required=True,
+ name="hasPreviousPage",
+ description="When paginating backwards, are there more items?",
+ )
+
+ start_cursor = String(
+ name="startCursor",
+ description="When paginating backwards, the cursor to continue.",
+ )
+
+ end_cursor = String(
+ name="endCursor",
+ description="When paginating forwards, the cursor to continue.",
+ )
+
+
+# noinspection PyPep8Naming
def page_info_adapter(startCursor, endCursor, hasPreviousPage, hasNextPage):
"""Adapter for creating PageInfo instances"""
- pass
+ return PageInfo(
+ start_cursor=startCursor,
+ end_cursor=endCursor,
+ has_previous_page=hasPreviousPage,
+ has_next_page=hasNextPage,
+ )
class ConnectionOptions(ObjectTypeOptions):
@@ -37,55 +85,116 @@ class ConnectionOptions(ObjectTypeOptions):
class Connection(ObjectType):
-
-
class Meta:
abstract = True
@classmethod
- def __init_subclass_with_meta__(cls, node=None, name=None, strict_types
- =False, _meta=None, **options):
+ def __init_subclass_with_meta__(
+ cls, node=None, name=None, strict_types=False, _meta=None, **options
+ ):
if not _meta:
_meta = ConnectionOptions(cls)
- assert node, f'You have to provide a node in {cls.__name__}.Meta'
- assert isinstance(node, NonNull) or issubclass(node, (Scalar, Enum,
- ObjectType, Interface, Union, NonNull)
- ), f'Received incompatible node "{node}" for Connection {cls.__name__}.'
- base_name = re.sub('Connection$', '', name or cls.__name__
- ) or node._meta.name
+ assert node, f"You have to provide a node in {cls.__name__}.Meta"
+ assert isinstance(node, NonNull) or issubclass(
+ node, (Scalar, Enum, ObjectType, Interface, Union, NonNull)
+ ), f'Received incompatible node "{node}" for Connection {cls.__name__}.'
+
+ base_name = re.sub("Connection$", "", name or cls.__name__) or node._meta.name
if not name:
- name = f'{base_name}Connection'
- options['name'] = name
+ name = f"{base_name}Connection"
+
+ options["name"] = name
+
_meta.node = node
+
if not _meta.fields:
_meta.fields = {}
- if 'page_info' not in _meta.fields:
- _meta.fields['page_info'] = Field(PageInfo, name='pageInfo',
- required=True, description=
- 'Pagination data for this connection.')
- if 'edges' not in _meta.fields:
- edge_class = get_edge_class(cls, node, base_name, strict_types)
+
+ if "page_info" not in _meta.fields:
+ _meta.fields["page_info"] = Field(
+ PageInfo,
+ name="pageInfo",
+ required=True,
+ description="Pagination data for this connection.",
+ )
+
+ if "edges" not in _meta.fields:
+ edge_class = get_edge_class(cls, node, base_name, strict_types) # type: ignore
cls.Edge = edge_class
- _meta.fields['edges'] = Field(NonNull(List(NonNull(edge_class) if
- strict_types else edge_class)), description=
- 'Contains the nodes in this connection.')
- return super(Connection, cls).__init_subclass_with_meta__(_meta=
- _meta, **options)
+ _meta.fields["edges"] = Field(
+ NonNull(List(NonNull(edge_class) if strict_types else edge_class)),
+ description="Contains the nodes in this connection.",
+ )
+ return super(Connection, cls).__init_subclass_with_meta__(
+ _meta=_meta, **options
+ )
+
+# noinspection PyPep8Naming
def connection_adapter(cls, edges, pageInfo):
"""Adapter for creating Connection instances"""
- pass
+ return cls(edges=edges, page_info=pageInfo)
class IterableConnectionField(Field):
-
def __init__(self, type_, *args, **kwargs):
- kwargs.setdefault('before', String())
- kwargs.setdefault('after', String())
- kwargs.setdefault('first', Int())
- kwargs.setdefault('last', Int())
+ kwargs.setdefault("before", String())
+ kwargs.setdefault("after", String())
+ kwargs.setdefault("first", Int())
+ kwargs.setdefault("last", Int())
super(IterableConnectionField, self).__init__(type_, *args, **kwargs)
+ @property
+ def type(self):
+ type_ = super(IterableConnectionField, self).type
+ connection_type = type_
+ if isinstance(type_, NonNull):
+ connection_type = type_.of_type
+
+ if is_node(connection_type):
+ raise Exception(
+ "ConnectionFields now need a explicit ConnectionType for Nodes.\n"
+ "Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#node-connections"
+ )
+
+ assert issubclass(
+ connection_type, Connection
+ ), f'{self.__class__.__name__} type has to be a subclass of Connection. Received "{connection_type}".'
+ return type_
+
+ @classmethod
+ def resolve_connection(cls, connection_type, args, resolved):
+ if isinstance(resolved, connection_type):
+ return resolved
+
+ assert isinstance(resolved, Iterable), (
+ f"Resolved value from the connection field has to be an iterable or instance of {connection_type}. "
+ f'Received "{resolved}"'
+ )
+ connection = connection_from_array(
+ resolved,
+ args,
+ connection_type=partial(connection_adapter, connection_type),
+ edge_type=connection_type.Edge,
+ page_info_type=page_info_adapter,
+ )
+ connection.iterable = resolved
+ return connection
+
+ @classmethod
+ def connection_resolver(cls, resolver, connection_type, root, info, **args):
+ resolved = resolver(root, info, **args)
+
+ if isinstance(connection_type, NonNull):
+ connection_type = connection_type.of_type
+
+ on_resolve = partial(cls.resolve_connection, connection_type, args)
+ return maybe_thenable(resolved, on_resolve)
+
+ def wrap_resolve(self, parent_resolver):
+ resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
+ return partial(self.connection_resolver, resolver, self.type)
+
ConnectionField = IterableConnectionField
diff --git a/graphene/relay/id_type.py b/graphene/relay/id_type.py
index 6278c7e..fb5c30e 100644
--- a/graphene/relay/id_type.py
+++ b/graphene/relay/id_type.py
@@ -1,6 +1,8 @@
from graphql_relay import from_global_id, to_global_id
+
from ..types import ID, UUID
from ..types.base import BaseType
+
from typing import Type
@@ -8,15 +10,45 @@ class BaseGlobalIDType:
"""
Base class that define the required attributes/method for a type.
"""
- graphene_type = ID
+
+ graphene_type = ID # type: Type[BaseType]
+
+ @classmethod
+ def resolve_global_id(cls, info, global_id):
+ # return _type, _id
+ raise NotImplementedError
+
+ @classmethod
+ def to_global_id(cls, _type, _id):
+ # return _id
+ raise NotImplementedError
class DefaultGlobalIDType(BaseGlobalIDType):
"""
Default global ID type: base64 encoded version of "<node type name>: <node id>".
"""
+
graphene_type = ID
+ @classmethod
+ def resolve_global_id(cls, info, global_id):
+ try:
+ _type, _id = from_global_id(global_id)
+ if not _type:
+ raise ValueError("Invalid Global ID")
+ return _type, _id
+ except Exception as e:
+ raise Exception(
+ f'Unable to parse global ID "{global_id}". '
+ 'Make sure it is a base64 encoded string in the format: "TypeName:id". '
+ f"Exception message: {e}"
+ )
+
+ @classmethod
+ def to_global_id(cls, _type, _id):
+ return to_global_id(_type, _id)
+
class SimpleGlobalIDType(BaseGlobalIDType):
"""
@@ -24,12 +56,32 @@ class SimpleGlobalIDType(BaseGlobalIDType):
To be used carefully as the user is responsible for ensuring that the IDs are indeed global
(otherwise it could cause request caching issues).
"""
+
graphene_type = ID
+ @classmethod
+ def resolve_global_id(cls, info, global_id):
+ _type = info.return_type.graphene_type._meta.name
+ return _type, global_id
+
+ @classmethod
+ def to_global_id(cls, _type, _id):
+ return _id
+
class UUIDGlobalIDType(BaseGlobalIDType):
"""
UUID global ID type.
By definition UUID are global so they are used as they are.
"""
+
graphene_type = UUID
+
+ @classmethod
+ def resolve_global_id(cls, info, global_id):
+ _type = info.return_type.graphene_type._meta.name
+ return _type, global_id
+
+ @classmethod
+ def to_global_id(cls, _type, _id):
+ return _id
diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py
index 1ea7347..2f4a4b7 100644
--- a/graphene/relay/mutation.py
+++ b/graphene/relay/mutation.py
@@ -1,37 +1,66 @@
import re
+
from ..types import Field, InputObjectType, String
from ..types.mutation import Mutation
from ..utils.thenables import maybe_thenable
class ClientIDMutation(Mutation):
-
-
class Meta:
abstract = True
@classmethod
- def __init_subclass_with_meta__(cls, output=None, input_fields=None,
- arguments=None, name=None, **options):
- input_class = getattr(cls, 'Input', None)
- base_name = re.sub('Payload$', '', name or cls.__name__)
+ def __init_subclass_with_meta__(
+ cls, output=None, input_fields=None, arguments=None, name=None, **options
+ ):
+ input_class = getattr(cls, "Input", None)
+ base_name = re.sub("Payload$", "", name or cls.__name__)
+
assert not output, "Can't specify any output"
assert not arguments, "Can't specify any arguments"
- bases = InputObjectType,
+
+ bases = (InputObjectType,)
if input_class:
- bases += input_class,
+ bases += (input_class,)
+
if not input_fields:
input_fields = {}
- cls.Input = type(f'{base_name}Input', bases, dict(input_fields,
- client_mutation_id=String(name='clientMutationId')))
- arguments = dict(input=cls.Input(required=True))
- mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
- if (cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.
- __func__):
- assert mutate_and_get_payload, f'{name or cls.__name__}.mutate_and_get_payload method is required in a ClientIDMutation.'
+
+ cls.Input = type(
+ f"{base_name}Input",
+ bases,
+ dict(input_fields, client_mutation_id=String(name="clientMutationId")),
+ )
+
+ arguments = dict(
+ input=cls.Input(required=True)
+ # 'client_mutation_id': String(name='clientMutationId')
+ )
+ mutate_and_get_payload = getattr(cls, "mutate_and_get_payload", None)
+ if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
+ assert mutate_and_get_payload, (
+ f"{name or cls.__name__}.mutate_and_get_payload method is required"
+ " in a ClientIDMutation."
+ )
+
if not name:
- name = f'{base_name}Payload'
- super(ClientIDMutation, cls).__init_subclass_with_meta__(output=
- None, arguments=arguments, name=name, **options)
- cls._meta.fields['client_mutation_id'] = Field(String, name=
- 'clientMutationId')
+ name = f"{base_name}Payload"
+
+ super(ClientIDMutation, cls).__init_subclass_with_meta__(
+ output=None, arguments=arguments, name=name, **options
+ )
+ cls._meta.fields["client_mutation_id"] = Field(String, name="clientMutationId")
+
+ @classmethod
+ def mutate(cls, root, info, input):
+ def on_resolve(payload):
+ try:
+ payload.client_mutation_id = input.get("client_mutation_id")
+ except Exception:
+ raise Exception(
+ f"Cannot set client_mutation_id in the payload object {repr(payload)}"
+ )
+ return payload
+
+ result = cls.mutate_and_get_payload(root, info, **input)
+ return maybe_thenable(result, on_resolve)
diff --git a/graphene/relay/node.py b/graphene/relay/node.py
index 138a893..5443828 100644
--- a/graphene/relay/node.py
+++ b/graphene/relay/node.py
@@ -1,5 +1,6 @@
from functools import partial
from inspect import isclass
+
from ..types import Field, Interface, ObjectType
from ..types.interface import InterfaceOptions
from ..types.utils import get_type
@@ -10,49 +11,125 @@ def is_node(objecttype):
"""
Check if the given objecttype has Node as an interface
"""
- pass
+ if not isclass(objecttype):
+ return False
+ if not issubclass(objecttype, ObjectType):
+ return False
+
+ return any(issubclass(i, Node) for i in objecttype._meta.interfaces)
-class GlobalID(Field):
- def __init__(self, node=None, parent_type=None, required=True,
- global_id_type=DefaultGlobalIDType, *args, **kwargs):
- super(GlobalID, self).__init__(global_id_type.graphene_type, *args,
- required=required, **kwargs)
+class GlobalID(Field):
+ def __init__(
+ self,
+ node=None,
+ parent_type=None,
+ required=True,
+ global_id_type=DefaultGlobalIDType,
+ *args,
+ **kwargs,
+ ):
+ super(GlobalID, self).__init__(
+ global_id_type.graphene_type, required=required, *args, **kwargs
+ )
self.node = node or Node
self.parent_type_name = parent_type._meta.name if parent_type else None
+ @staticmethod
+ def id_resolver(parent_resolver, node, root, info, parent_type_name=None, **args):
+ type_id = parent_resolver(root, info, **args)
+ parent_type_name = parent_type_name or info.parent_type.name
+ return node.to_global_id(parent_type_name, type_id) # root._meta.name
+
+ def wrap_resolve(self, parent_resolver):
+ return partial(
+ self.id_resolver,
+ parent_resolver,
+ self.node,
+ parent_type_name=self.parent_type_name,
+ )
-class NodeField(Field):
+class NodeField(Field):
def __init__(self, node, type_=False, **kwargs):
- assert issubclass(node, Node), 'NodeField can only operate in Nodes'
+ assert issubclass(node, Node), "NodeField can only operate in Nodes"
self.node_type = node
self.field_type = type_
global_id_type = node._meta.global_id_type
- super(NodeField, self).__init__(type_ or node, id=global_id_type.
- graphene_type(required=True, description='The ID of the object'
- ), **kwargs)
+ super(NodeField, self).__init__(
+ # If we don't specify a type, the field type will be the node interface
+ type_ or node,
+ id=global_id_type.graphene_type(
+ required=True, description="The ID of the object"
+ ),
+ **kwargs,
+ )
-class AbstractNode(Interface):
+ def wrap_resolve(self, parent_resolver):
+ return partial(self.node_type.node_resolver, get_type(self.field_type))
+class AbstractNode(Interface):
class Meta:
abstract = True
@classmethod
- def __init_subclass_with_meta__(cls, global_id_type=DefaultGlobalIDType,
- **options):
- assert issubclass(global_id_type, BaseGlobalIDType
- ), 'Custom ID type need to be implemented as a subclass of BaseGlobalIDType.'
+ def __init_subclass_with_meta__(cls, global_id_type=DefaultGlobalIDType, **options):
+ assert issubclass(
+ global_id_type, BaseGlobalIDType
+ ), "Custom ID type need to be implemented as a subclass of BaseGlobalIDType."
_meta = InterfaceOptions(cls)
_meta.global_id_type = global_id_type
- _meta.fields = {'id': GlobalID(cls, global_id_type=global_id_type,
- description='The ID of the object')}
- super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta,
- **options)
+ _meta.fields = {
+ "id": GlobalID(
+ cls, global_id_type=global_id_type, description="The ID of the object"
+ )
+ }
+ super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
+ @classmethod
+ def resolve_global_id(cls, info, global_id):
+ return cls._meta.global_id_type.resolve_global_id(info, global_id)
class Node(AbstractNode):
"""An object with an ID"""
+
+ @classmethod
+ def Field(cls, *args, **kwargs): # noqa: N802
+ return NodeField(cls, *args, **kwargs)
+
+ @classmethod
+ def node_resolver(cls, only_type, root, info, id):
+ return cls.get_node_from_global_id(info, id, only_type=only_type)
+
+ @classmethod
+ def get_node_from_global_id(cls, info, global_id, only_type=None):
+ _type, _id = cls.resolve_global_id(info, global_id)
+
+ graphene_type = info.schema.get_type(_type)
+ if graphene_type is None:
+ raise Exception(f'Relay Node "{_type}" not found in schema')
+
+ graphene_type = graphene_type.graphene_type
+
+ if only_type:
+ assert (
+ graphene_type == only_type
+ ), f"Must receive a {only_type._meta.name} id."
+
+ # We make sure the ObjectType implements the "Node" interface
+ if cls not in graphene_type._meta.interfaces:
+ raise Exception(
+ f'ObjectType "{_type}" does not implement the "{cls}" interface.'
+ )
+
+ get_node = getattr(graphene_type, "get_node", None)
+ if get_node:
+ return get_node(info, _id)
+
+ @classmethod
+ def to_global_id(cls, type_, id):
+ return cls._meta.global_id_type.to_global_id(type_, id)
diff --git a/graphene/types/argument.py b/graphene/types/argument.py
index 4e25b9c..d9283c4 100644
--- a/graphene/types/argument.py
+++ b/graphene/types/argument.py
@@ -1,5 +1,6 @@
from itertools import chain
from graphql import Undefined
+
from .dynamic import Dynamic
from .mountedtype import MountedType
from .structures import NonNull
@@ -41,21 +42,79 @@ class Argument(MountedType):
set if the argument is required (see spec).
"""
- def __init__(self, type_, default_value=Undefined, deprecation_reason=
- None, description=None, name=None, required=False,
- _creation_counter=None):
+ def __init__(
+ self,
+ type_,
+ default_value=Undefined,
+ deprecation_reason=None,
+ description=None,
+ name=None,
+ required=False,
+ _creation_counter=None,
+ ):
super(Argument, self).__init__(_creation_counter=_creation_counter)
+
if required:
- assert deprecation_reason is None, f'Argument {name} is required, cannot deprecate it.'
+ assert (
+ deprecation_reason is None
+ ), f"Argument {name} is required, cannot deprecate it."
type_ = NonNull(type_)
+
self.name = name
self._type = type_
self.default_value = default_value
self.description = description
self.deprecation_reason = deprecation_reason
+ @property
+ def type(self):
+ return get_type(self._type)
+
def __eq__(self, other):
- return isinstance(other, Argument) and (self.name == other.name and
- self.type == other.type and self.default_value == other.
- default_value and self.description == other.description and
- self.deprecation_reason == other.deprecation_reason)
+ return isinstance(other, Argument) and (
+ self.name == other.name
+ and self.type == other.type
+ and self.default_value == other.default_value
+ and self.description == other.description
+ and self.deprecation_reason == other.deprecation_reason
+ )
+
+
+def to_arguments(args, extra_args=None):
+ from .unmountedtype import UnmountedType
+ from .field import Field
+ from .inputfield import InputField
+
+ if extra_args:
+ extra_args = sorted(extra_args.items(), key=lambda f: f[1])
+ else:
+ extra_args = []
+ iter_arguments = chain(args.items(), extra_args)
+ arguments = {}
+ for default_name, arg in iter_arguments:
+ if isinstance(arg, Dynamic):
+ arg = arg.get_type()
+ if arg is None:
+ # If the Dynamic type returned None
+ # then we skip the Argument
+ continue
+
+ if isinstance(arg, UnmountedType):
+ arg = Argument.mounted(arg)
+
+ if isinstance(arg, (InputField, Field)):
+ raise ValueError(
+ f"Expected {default_name} to be Argument, "
+ f"but received {type(arg).__name__}. Try using Argument({arg.type})."
+ )
+
+ if not isinstance(arg, Argument):
+ raise ValueError(f'Unknown argument "{default_name}".')
+
+ arg_name = default_name or arg.name
+ assert (
+ arg_name not in arguments
+ ), f'More than one Argument have same name "{arg_name}".'
+ arguments[arg_name] = arg
+
+ return arguments
diff --git a/graphene/types/base.py b/graphene/types/base.py
index 6483a0d..84cb377 100644
--- a/graphene/types/base.py
+++ b/graphene/types/base.py
@@ -1,15 +1,20 @@
from typing import Type
+
from ..utils.subclass_with_meta import SubclassWithMeta, SubclassWithMeta_Meta
from ..utils.trim_docstring import trim_docstring
class BaseOptions:
- name = None
- description = None
- _frozen = False
+ name = None # type: str
+ description = None # type: str
+
+ _frozen = False # type: bool
def __init__(self, class_type):
- self.class_type = class_type
+ self.class_type = class_type # type: Type
+
+ def freeze(self):
+ self._frozen = True
def __setattr__(self, name, value):
if not self._frozen:
@@ -18,18 +23,22 @@ class BaseOptions:
raise Exception(f"Can't modify frozen Options {self}")
def __repr__(self):
- return f'<{self.__class__.__name__} name={repr(self.name)}>'
+ return f"<{self.__class__.__name__} name={repr(self.name)}>"
BaseTypeMeta = SubclassWithMeta_Meta
class BaseType(SubclassWithMeta):
+ @classmethod
+ def create_type(cls, class_name, **options):
+ return type(class_name, (cls,), {"Meta": options})
@classmethod
- def __init_subclass_with_meta__(cls, name=None, description=None, _meta
- =None, **_kwargs):
- assert '_meta' not in cls.__dict__, "Can't assign meta directly"
+ def __init_subclass_with_meta__(
+ cls, name=None, description=None, _meta=None, **_kwargs
+ ):
+ assert "_meta" not in cls.__dict__, "Can't assign meta directly"
if not _meta:
return
_meta.name = name or cls.__name__
diff --git a/graphene/types/base64.py b/graphene/types/base64.py
index 4f5949a..69bb338 100644
--- a/graphene/types/base64.py
+++ b/graphene/types/base64.py
@@ -1,7 +1,9 @@
from binascii import Error as _Error
from base64 import b64decode, b64encode
+
from graphql.error import GraphQLError
from graphql.language import StringValueNode, print_ast
+
from .scalars import Scalar
@@ -9,3 +11,33 @@ class Base64(Scalar):
"""
The `Base64` scalar type represents a base64-encoded String.
"""
+
+ @staticmethod
+ def serialize(value):
+ if not isinstance(value, bytes):
+ if isinstance(value, str):
+ value = value.encode("utf-8")
+ else:
+ value = str(value).encode("utf-8")
+ return b64encode(value).decode("utf-8")
+
+ @classmethod
+ def parse_literal(cls, node, _variables=None):
+ if not isinstance(node, StringValueNode):
+ raise GraphQLError(
+ f"Base64 cannot represent non-string value: {print_ast(node)}"
+ )
+ return cls.parse_value(node.value)
+
+ @staticmethod
+ def parse_value(value):
+ if not isinstance(value, bytes):
+ if not isinstance(value, str):
+ raise GraphQLError(
+ f"Base64 cannot represent non-string value: {repr(value)}"
+ )
+ value = value.encode("utf-8")
+ try:
+ return b64decode(value, validate=True).decode("utf-8")
+ except _Error:
+ raise GraphQLError(f"Base64 cannot decode value: {repr(value)}")
diff --git a/graphene/types/datetime.py b/graphene/types/datetime.py
index a473f89..d4f7447 100644
--- a/graphene/types/datetime.py
+++ b/graphene/types/datetime.py
@@ -1,8 +1,11 @@
from __future__ import absolute_import
+
import datetime
+
from aniso8601 import parse_date, parse_datetime, parse_time
from graphql.error import GraphQLError
from graphql.language import StringValueNode, print_ast
+
from .scalars import Scalar
@@ -13,6 +16,33 @@ class Date(Scalar):
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
"""
+ @staticmethod
+ def serialize(date):
+ if isinstance(date, datetime.datetime):
+ date = date.date()
+ if not isinstance(date, datetime.date):
+ raise GraphQLError(f"Date cannot represent value: {repr(date)}")
+ return date.isoformat()
+
+ @classmethod
+ def parse_literal(cls, node, _variables=None):
+ if not isinstance(node, StringValueNode):
+ raise GraphQLError(
+ f"Date cannot represent non-string value: {print_ast(node)}"
+ )
+ return cls.parse_value(node.value)
+
+ @staticmethod
+ def parse_value(value):
+ if isinstance(value, datetime.date):
+ return value
+ if not isinstance(value, str):
+ raise GraphQLError(f"Date cannot represent non-string value: {repr(value)}")
+ try:
+ return parse_date(value)
+ except ValueError:
+ raise GraphQLError(f"Date cannot represent value: {repr(value)}")
+
class DateTime(Scalar):
"""
@@ -21,6 +51,33 @@ class DateTime(Scalar):
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
"""
+ @staticmethod
+ def serialize(dt):
+ if not isinstance(dt, (datetime.datetime, datetime.date)):
+ raise GraphQLError(f"DateTime cannot represent value: {repr(dt)}")
+ return dt.isoformat()
+
+ @classmethod
+ def parse_literal(cls, node, _variables=None):
+ if not isinstance(node, StringValueNode):
+ raise GraphQLError(
+ f"DateTime cannot represent non-string value: {print_ast(node)}"
+ )
+ return cls.parse_value(node.value)
+
+ @staticmethod
+ def parse_value(value):
+ if isinstance(value, datetime.datetime):
+ return value
+ if not isinstance(value, str):
+ raise GraphQLError(
+ f"DateTime cannot represent non-string value: {repr(value)}"
+ )
+ try:
+ return parse_datetime(value)
+ except ValueError:
+ raise GraphQLError(f"DateTime cannot represent value: {repr(value)}")
+
class Time(Scalar):
"""
@@ -28,3 +85,28 @@ class Time(Scalar):
specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
"""
+
+ @staticmethod
+ def serialize(time):
+ if not isinstance(time, datetime.time):
+ raise GraphQLError(f"Time cannot represent value: {repr(time)}")
+ return time.isoformat()
+
+ @classmethod
+ def parse_literal(cls, node, _variables=None):
+ if not isinstance(node, StringValueNode):
+ raise GraphQLError(
+ f"Time cannot represent non-string value: {print_ast(node)}"
+ )
+ return cls.parse_value(node.value)
+
+ @classmethod
+ def parse_value(cls, value):
+ if isinstance(value, datetime.time):
+ return value
+ if not isinstance(value, str):
+ raise GraphQLError(f"Time cannot represent non-string value: {repr(value)}")
+ try:
+ return parse_time(value)
+ except ValueError:
+ raise GraphQLError(f"Time cannot represent value: {repr(value)}")
diff --git a/graphene/types/decimal.py b/graphene/types/decimal.py
index 5607802..0c6ccc9 100644
--- a/graphene/types/decimal.py
+++ b/graphene/types/decimal.py
@@ -1,7 +1,10 @@
from __future__ import absolute_import
+
from decimal import Decimal as _Decimal
+
from graphql import Undefined
from graphql.language.ast import StringValueNode, IntValueNode
+
from .scalars import Scalar
@@ -9,3 +12,25 @@ class Decimal(Scalar):
"""
The `Decimal` scalar type represents a python Decimal.
"""
+
+ @staticmethod
+ def serialize(dec):
+ if isinstance(dec, str):
+ dec = _Decimal(dec)
+ assert isinstance(
+ dec, _Decimal
+ ), f'Received not compatible Decimal "{repr(dec)}"'
+ return str(dec)
+
+ @classmethod
+ def parse_literal(cls, node, _variables=None):
+ if isinstance(node, (StringValueNode, IntValueNode)):
+ return cls.parse_value(node.value)
+ return Undefined
+
+ @staticmethod
+ def parse_value(value):
+ try:
+ return _Decimal(value)
+ except Exception:
+ return Undefined
diff --git a/graphene/types/definitions.py b/graphene/types/definitions.py
index 12b2a86..ac574be 100644
--- a/graphene/types/definitions.py
+++ b/graphene/types/definitions.py
@@ -1,5 +1,13 @@
from enum import Enum as PyEnum
-from graphql import GraphQLEnumType, GraphQLInputObjectType, GraphQLInterfaceType, GraphQLObjectType, GraphQLScalarType, GraphQLUnionType
+
+from graphql import (
+ GraphQLEnumType,
+ GraphQLInputObjectType,
+ GraphQLInterfaceType,
+ GraphQLObjectType,
+ GraphQLScalarType,
+ GraphQLUnionType,
+)
class GrapheneGraphQLType:
@@ -9,7 +17,7 @@ class GrapheneGraphQLType:
"""
def __init__(self, *args, **kwargs):
- self.graphene_type = kwargs.pop('graphene_type')
+ self.graphene_type = kwargs.pop("graphene_type")
super(GrapheneGraphQLType, self).__init__(*args, **kwargs)
def __copy__(self):
@@ -35,7 +43,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
- pass
+ def serialize(self, value):
+ if not isinstance(value, PyEnum):
+ enum = self.graphene_type._meta.enum
+ try:
+ # Try and get enum by value
+ value = enum(value)
+ except ValueError:
+ # Try and get enum by name
+ try:
+ value = enum[value]
+ except KeyError:
+ pass
+ return super(GrapheneEnumType, self).serialize(value)
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):
diff --git a/graphene/types/dynamic.py b/graphene/types/dynamic.py
index 4818d92..3bb2b0f 100644
--- a/graphene/types/dynamic.py
+++ b/graphene/types/dynamic.py
@@ -1,5 +1,6 @@
import inspect
from functools import partial
+
from .mountedtype import MountedType
@@ -14,3 +15,8 @@ class Dynamic(MountedType):
assert inspect.isfunction(type_) or isinstance(type_, partial)
self.type = type_
self.with_schema = with_schema
+
+ def get_type(self, schema=None):
+ if schema and self.with_schema:
+ return self.type(schema=schema)
+ return self.type()
diff --git a/graphene/types/enum.py b/graphene/types/enum.py
index cce1873..d3469a1 100644
--- a/graphene/types/enum.py
+++ b/graphene/types/enum.py
@@ -1,43 +1,79 @@
from enum import Enum as PyEnum
+
from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta
+
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
+
+
+def eq_enum(self, other):
+ if isinstance(other, self.__class__):
+ return self is other
+ return self.value is other
+
+
+def hash_enum(self):
+ return hash(self.name)
+
+
EnumType = type(PyEnum)
class EnumOptions(BaseOptions):
- enum = None
+ enum = None # type: Enum
deprecation_reason = None
class EnumMeta(SubclassWithMeta_Meta):
-
def __new__(cls, name_, bases, classdict, **options):
enum_members = dict(classdict, __eq__=eq_enum, __hash__=hash_enum)
- enum_members.pop('Meta', None)
+ # We remove the Meta attribute from the class to not collide
+ # with the enum values.
+ enum_members.pop("Meta", None)
enum = PyEnum(cls.__name__, enum_members)
- obj = SubclassWithMeta_Meta.__new__(cls, name_, bases, dict(
- classdict, __enum__=enum), **options)
+ obj = SubclassWithMeta_Meta.__new__(
+ cls, name_, bases, dict(classdict, __enum__=enum), **options
+ )
globals()[name_] = obj.__enum__
return obj
+ def get(cls, value):
+ return cls._meta.enum(value)
+
def __getitem__(cls, value):
return cls._meta.enum[value]
- def __prepare__(name, bases, **kwargs):
+ def __prepare__(name, bases, **kwargs): # noqa: N805
return {}
- def __call__(cls, *args, **kwargs):
+ def __call__(cls, *args, **kwargs): # noqa: N805
if cls is Enum:
- description = kwargs.pop('description', None)
- deprecation_reason = kwargs.pop('deprecation_reason', None)
- return cls.from_enum(PyEnum(*args, **kwargs), description=
- description, deprecation_reason=deprecation_reason)
+ description = kwargs.pop("description", None)
+ deprecation_reason = kwargs.pop("deprecation_reason", None)
+ return cls.from_enum(
+ PyEnum(*args, **kwargs),
+ description=description,
+ deprecation_reason=deprecation_reason,
+ )
return super(EnumMeta, cls).__call__(*args, **kwargs)
+ # return cls._meta.enum(*args, **kwargs)
def __iter__(cls):
return cls._meta.enum.__iter__()
+ def from_enum(
+ cls, enum, name=None, description=None, deprecation_reason=None
+ ): # noqa: N805
+ name = name or enum.__name__
+ description = description or enum.__doc__ or "An enumeration."
+ meta_dict = {
+ "enum": enum,
+ "description": description,
+ "deprecation_reason": deprecation_reason,
+ }
+ meta_class = type("Meta", (object,), meta_dict)
+ return type(name, (Enum,), {"Meta": meta_class})
+
class Enum(UnmountedType, BaseType, metaclass=EnumMeta):
"""
@@ -69,9 +105,10 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta):
if not _meta:
_meta = EnumOptions(cls)
_meta.enum = enum or cls.__enum__
- _meta.deprecation_reason = options.pop('deprecation_reason', None)
+ _meta.deprecation_reason = options.pop("deprecation_reason", None)
for key, value in _meta.enum.__members__.items():
setattr(cls, key, value)
+
super(Enum, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod
@@ -80,4 +117,4 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta):
This function is called when the unmounted type (Enum instance)
is mounted (as a Field, InputField or Argument)
"""
- pass
+ return cls
diff --git a/graphene/types/field.py b/graphene/types/field.py
index a23e927..dafb04b 100644
--- a/graphene/types/field.py
+++ b/graphene/types/field.py
@@ -1,6 +1,7 @@
import inspect
from collections.abc import Mapping
from functools import partial
+
from .argument import Argument, to_arguments
from .mountedtype import MountedType
from .resolver import default_resolver
@@ -8,9 +9,17 @@ from .structures import NonNull
from .unmountedtype import UnmountedType
from .utils import get_type
from ..utils.deprecated import warn_deprecation
+
base_type = type
+def source_resolver(source, root, info, **args):
+ resolved = default_resolver(source, None, root, info, **args)
+ if inspect.isfunction(resolved) or inspect.ismethod(resolved):
+ return resolved()
+ return resolved
+
+
class Field(MountedType):
"""
Makes a field available on an ObjectType in the GraphQL schema. Any type can be mounted as a
@@ -54,24 +63,44 @@ class Field(MountedType):
additional arguments to mount on the field.
"""
- def __init__(self, type_, args=None, resolver=None, source=None,
- deprecation_reason=None, name=None, description=None, required=
- False, _creation_counter=None, default_value=None, **extra_args):
+ def __init__(
+ self,
+ type_,
+ args=None,
+ resolver=None,
+ source=None,
+ deprecation_reason=None,
+ name=None,
+ description=None,
+ required=False,
+ _creation_counter=None,
+ default_value=None,
+ **extra_args,
+ ):
super(Field, self).__init__(_creation_counter=_creation_counter)
- assert not args or isinstance(args, Mapping
- ), f'Arguments in a field have to be a mapping, received "{args}".'
- assert not (source and resolver
- ), 'A Field cannot have a source and a resolver in at the same time.'
- assert not callable(default_value
- ), f'The default value can not be a function but received "{base_type(default_value)}".'
+ assert not args or isinstance(
+ args, Mapping
+ ), f'Arguments in a field have to be a mapping, received "{args}".'
+ assert not (
+ source and resolver
+ ), "A Field cannot have a source and a resolver in at the same time."
+ assert not callable(
+ default_value
+ ), f'The default value can not be a function but received "{base_type(default_value)}".'
+
if required:
type_ = NonNull(type_)
+
+ # Check if name is actually an argument of the field
if isinstance(name, (Argument, UnmountedType)):
- extra_args['name'] = name
+ extra_args["name"] = name
name = None
+
+ # Check if source is actually an argument of the field
if isinstance(source, (Argument, UnmountedType)):
- extra_args['source'] = source
+ extra_args["source"] = source
source = None
+
self.name = name
self._type = type_
self.args = to_arguments(args or {}, extra_args)
@@ -81,6 +110,11 @@ class Field(MountedType):
self.deprecation_reason = deprecation_reason
self.description = description
self.default_value = default_value
+
+ @property
+ def type(self):
+ return get_type(self._type)
+
get_resolver = None
def wrap_resolve(self, parent_resolver):
@@ -88,11 +122,17 @@ class Field(MountedType):
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
(parent_resolver) if the Field definition has no resolver.
"""
- pass
+ if self.get_resolver is not None:
+ warn_deprecation(
+ "The get_resolver method is being deprecated, please rename it to wrap_resolve."
+ )
+ return self.get_resolver(parent_resolver)
+
+ return self.resolver or parent_resolver
def wrap_subscribe(self, parent_subscribe):
"""
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
(parent_subscribe) if the Field definition has no subscribe.
"""
- pass
+ return parent_subscribe
diff --git a/graphene/types/generic.py b/graphene/types/generic.py
index fc0488e..2a3c8d5 100644
--- a/graphene/types/generic.py
+++ b/graphene/types/generic.py
@@ -1,6 +1,16 @@
from __future__ import unicode_literals
-from graphql.language.ast import BooleanValueNode, FloatValueNode, IntValueNode, ListValueNode, ObjectValueNode, StringValueNode
+
+from graphql.language.ast import (
+ BooleanValueNode,
+ FloatValueNode,
+ IntValueNode,
+ ListValueNode,
+ ObjectValueNode,
+ StringValueNode,
+)
+
from graphene.types.scalars import MAX_INT, MIN_INT
+
from .scalars import Scalar
@@ -10,5 +20,30 @@ class GenericScalar(Scalar):
GraphQL scalar value that could be:
String, Boolean, Int, Float, List or Object.
"""
+
+ @staticmethod
+ def identity(value):
+ return value
+
serialize = identity
parse_value = identity
+
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, (StringValueNode, BooleanValueNode)):
+ return ast.value
+ elif isinstance(ast, IntValueNode):
+ num = int(ast.value)
+ if MIN_INT <= num <= MAX_INT:
+ return num
+ elif isinstance(ast, FloatValueNode):
+ return float(ast.value)
+ elif isinstance(ast, ListValueNode):
+ return [GenericScalar.parse_literal(value) for value in ast.values]
+ elif isinstance(ast, ObjectValueNode):
+ return {
+ field.name.value: GenericScalar.parse_literal(field.value)
+ for field in ast.fields
+ }
+ else:
+ return None
diff --git a/graphene/types/inputfield.py b/graphene/types/inputfield.py
index a1bc6a6..e7ededb 100644
--- a/graphene/types/inputfield.py
+++ b/graphene/types/inputfield.py
@@ -1,4 +1,5 @@
from graphql import Undefined
+
from .mountedtype import MountedType
from .structures import NonNull
from .utils import get_type
@@ -45,15 +46,29 @@ class InputField(MountedType):
**extra_args (optional, Dict): Not used.
"""
- def __init__(self, type_, name=None, default_value=Undefined,
- deprecation_reason=None, description=None, required=False,
- _creation_counter=None, **extra_args):
+ def __init__(
+ self,
+ type_,
+ name=None,
+ default_value=Undefined,
+ deprecation_reason=None,
+ description=None,
+ required=False,
+ _creation_counter=None,
+ **extra_args,
+ ):
super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name
if required:
- assert deprecation_reason is None, f'InputField {name} is required, cannot deprecate it.'
+ assert (
+ deprecation_reason is None
+ ), f"InputField {name} is required, cannot deprecate it."
type_ = NonNull(type_)
self._type = type_
self.deprecation_reason = deprecation_reason
self.default_value = default_value
self.description = description
+
+ @property
+ def type(self):
+ return get_type(self._type)
diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py
index f99e2c1..257f48b 100644
--- a/graphene/types/inputobjecttype.py
+++ b/graphene/types/inputobjecttype.py
@@ -1,19 +1,31 @@
from typing import TYPE_CHECKING
+
from .base import BaseOptions, BaseType
from .inputfield import InputField
from .unmountedtype import UnmountedType
from .utils import yank_fields_from_attrs
+
+# For static type checking with type checker
if TYPE_CHECKING:
- from typing import Dict, Callable
+ from typing import Dict, Callable # NOQA
class InputObjectTypeOptions(BaseOptions):
- fields = None
- container = None
+ fields = None # type: Dict[str, InputField]
+ container = None # type: InputObjectTypeContainer
+# Currently in Graphene, we get a `None` whenever we access an (optional) field that was not set in an InputObjectType
+# using the InputObjectType.<attribute> dot access syntax. This is ambiguous, because in this current (Graphene
+# historical) arrangement, we cannot distinguish between a field not being set and a field being set to None.
+# At the same time, we shouldn't break existing code that expects a `None` when accessing a field that was not set.
_INPUT_OBJECT_TYPE_DEFAULT_VALUE = None
+# To mitigate this, we provide the function `set_input_object_type_default_value` to allow users to change the default
+# value returned in non-specified fields in InputObjectType to another meaningful sentinel value (e.g. Undefined)
+# if they want to. This way, we can keep code that expects a `None` working while we figure out a better solution (or
+# a well-documented breaking change) for this issue.
+
def set_input_object_type_default_value(default_value):
"""
@@ -24,12 +36,11 @@ def set_input_object_type_default_value(default_value):
This function should be called at the beginning of the app or in some other place where it is guaranteed to
be called before any InputObjectType is defined.
"""
- pass
-
-
-class InputObjectTypeContainer(dict, BaseType):
+ global _INPUT_OBJECT_TYPE_DEFAULT_VALUE
+ _INPUT_OBJECT_TYPE_DEFAULT_VALUE = default_value
+class InputObjectTypeContainer(dict, BaseType): # type: ignore
class Meta:
abstract = True
@@ -79,14 +90,14 @@ class InputObjectType(UnmountedType, BaseType):
"""
@classmethod
- def __init_subclass_with_meta__(cls, container=None, _meta=None, **options
- ):
+ def __init_subclass_with_meta__(cls, container=None, _meta=None, **options):
if not _meta:
_meta = InputObjectTypeOptions(cls)
+
fields = {}
for base in reversed(cls.__mro__):
- fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField)
- )
+ fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField))
+
if _meta.fields:
_meta.fields.update(fields)
else:
@@ -94,8 +105,7 @@ class InputObjectType(UnmountedType, BaseType):
if container is None:
container = type(cls.__name__, (InputObjectTypeContainer, cls), {})
_meta.container = container
- super(InputObjectType, cls).__init_subclass_with_meta__(_meta=_meta,
- **options)
+ super(InputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod
def get_type(cls):
@@ -103,4 +113,4 @@ class InputObjectType(UnmountedType, BaseType):
This function is called when the unmounted type (InputObjectType instance)
is mounted (as a Field, InputField or Argument)
"""
- pass
+ return cls
diff --git a/graphene/types/interface.py b/graphene/types/interface.py
index 733a6f1..31bcc7f 100644
--- a/graphene/types/interface.py
+++ b/graphene/types/interface.py
@@ -1,14 +1,17 @@
from typing import TYPE_CHECKING
+
from .base import BaseOptions, BaseType
from .field import Field
from .utils import yank_fields_from_attrs
+
+# For static type checking with type checker
if TYPE_CHECKING:
- from typing import Dict, Iterable, Type
+ from typing import Dict, Iterable, Type # NOQA
class InterfaceOptions(BaseOptions):
- fields = None
- interfaces = ()
+ fields = None # type: Dict[str, Field]
+ interfaces = () # type: Iterable[Type[Interface]]
class Interface(BaseType):
@@ -47,17 +50,27 @@ class Interface(BaseType):
def __init_subclass_with_meta__(cls, _meta=None, interfaces=(), **options):
if not _meta:
_meta = InterfaceOptions(cls)
+
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
+
if _meta.fields:
_meta.fields.update(fields)
else:
_meta.fields = fields
+
if not _meta.interfaces:
_meta.interfaces = interfaces
- super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **
- options)
+
+ super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
+ @classmethod
+ def resolve_type(cls, instance, info):
+ from .objecttype import ObjectType
+
+ if isinstance(instance, ObjectType):
+ return type(instance)
def __init__(self, *args, **kwargs):
- raise Exception('An Interface cannot be initialized')
+ raise Exception("An Interface cannot be initialized")
diff --git a/graphene/types/json.py b/graphene/types/json.py
index a4f8f8d..ca55836 100644
--- a/graphene/types/json.py
+++ b/graphene/types/json.py
@@ -1,7 +1,10 @@
from __future__ import absolute_import
+
import json
+
from graphql import Undefined
from graphql.language.ast import StringValueNode
+
from .scalars import Scalar
@@ -12,3 +15,20 @@ class JSONString(Scalar):
Use of this type is *not recommended* as you lose the benefits of having a defined, static
schema (one of the key benefits of GraphQL).
"""
+
+ @staticmethod
+ def serialize(dt):
+ return json.dumps(dt)
+
+ @staticmethod
+ def parse_literal(node, _variables=None):
+ if isinstance(node, StringValueNode):
+ try:
+ return json.loads(node.value)
+ except Exception as error:
+ raise ValueError(f"Badly formed JSONString: {str(error)}")
+ return Undefined
+
+ @staticmethod
+ def parse_value(value):
+ return json.loads(value)
diff --git a/graphene/types/mountedtype.py b/graphene/types/mountedtype.py
index ac4f6e7..c42383e 100644
--- a/graphene/types/mountedtype.py
+++ b/graphene/types/mountedtype.py
@@ -3,10 +3,18 @@ from .unmountedtype import UnmountedType
class MountedType(OrderedType):
-
@classmethod
- def mounted(cls, unmounted):
+ def mounted(cls, unmounted): # noqa: N802
"""
Mount the UnmountedType instance
"""
- pass
+ assert isinstance(
+ unmounted, UnmountedType
+ ), f"{cls.__name__} can't mount {repr(unmounted)}"
+
+ return cls(
+ unmounted.get_type(),
+ *unmounted.args,
+ _creation_counter=unmounted.creation_counter,
+ **unmounted.kwargs,
+ )
diff --git a/graphene/types/mutation.py b/graphene/types/mutation.py
index 1e5c548..2de21b3 100644
--- a/graphene/types/mutation.py
+++ b/graphene/types/mutation.py
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
+
from ..utils.deprecated import warn_deprecation
from ..utils.get_unbound_function import get_unbound_function
from ..utils.props import props
@@ -6,16 +7,18 @@ from .field import Field
from .objecttype import ObjectType, ObjectTypeOptions
from .utils import yank_fields_from_attrs
from .interface import Interface
+
+# For static type checking with type checker
if TYPE_CHECKING:
- from .argument import Argument
- from typing import Dict, Type, Callable, Iterable
+ from .argument import Argument # NOQA
+ from typing import Dict, Type, Callable, Iterable # NOQA
class MutationOptions(ObjectTypeOptions):
- arguments = None
- output = None
- resolver = None
- interfaces = ()
+ arguments = None # type: Dict[str, Argument]
+ output = None # type: Type[ObjectType]
+ resolver = None # type: Callable
+ interfaces = () # type: Iterable[Type[Interface]]
class Mutation(ObjectType):
@@ -63,34 +66,46 @@ class Mutation(ObjectType):
"""
@classmethod
- def __init_subclass_with_meta__(cls, interfaces=(), resolver=None,
- output=None, arguments=None, _meta=None, **options):
+ def __init_subclass_with_meta__(
+ cls,
+ interfaces=(),
+ resolver=None,
+ output=None,
+ arguments=None,
+ _meta=None,
+ **options,
+ ):
if not _meta:
_meta = MutationOptions(cls)
- output = output or getattr(cls, 'Output', None)
+ output = output or getattr(cls, "Output", None)
fields = {}
+
for interface in interfaces:
- assert issubclass(interface, Interface
- ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
+ assert issubclass(
+ interface, Interface
+ ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
fields.update(interface._meta.fields)
if not output:
+ # If output is defined, we don't need to get the fields
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
output = cls
if not arguments:
- input_class = getattr(cls, 'Arguments', None)
+ input_class = getattr(cls, "Arguments", None)
if not input_class:
- input_class = getattr(cls, 'Input', None)
+ input_class = getattr(cls, "Input", None)
if input_class:
warn_deprecation(
- f"""Please use {cls.__name__}.Arguments instead of {cls.__name__}.Input. Input is now only used in ClientMutationID.
-Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"""
- )
+ f"Please use {cls.__name__}.Arguments instead of {cls.__name__}.Input."
+ " Input is now only used in ClientMutationID.\n"
+ "Read more:"
+ " https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"
+ )
arguments = props(input_class) if input_class else {}
if not resolver:
- mutate = getattr(cls, 'mutate', None)
- assert mutate, 'All mutations must define a mutate method in it'
+ mutate = getattr(cls, "mutate", None)
+ assert mutate, "All mutations must define a mutate method in it"
resolver = get_unbound_function(mutate)
if _meta.fields:
_meta.fields.update(fields)
@@ -100,11 +115,20 @@ Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.m
_meta.output = output
_meta.resolver = resolver
_meta.arguments = arguments
- super(Mutation, cls).__init_subclass_with_meta__(_meta=_meta, **options
- )
+
+ super(Mutation, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod
- def Field(cls, name=None, description=None, deprecation_reason=None,
- required=False):
+ def Field(
+ cls, name=None, description=None, deprecation_reason=None, required=False
+ ):
"""Mount instance of mutation Field."""
- pass
+ return Field(
+ cls._meta.output,
+ args=cls._meta.arguments,
+ resolver=cls._meta.resolver,
+ name=name,
+ description=description or cls._meta.description,
+ deprecation_reason=deprecation_reason,
+ required=required,
+ )
diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py
index 0f31ada..b3b829f 100644
--- a/graphene/types/objecttype.py
+++ b/graphene/types/objecttype.py
@@ -1,34 +1,48 @@
from typing import TYPE_CHECKING
+
from .base import BaseOptions, BaseType, BaseTypeMeta
from .field import Field
from .interface import Interface
from .utils import yank_fields_from_attrs
+
try:
from dataclasses import make_dataclass, field
except ImportError:
- from ..pyutils.dataclasses import make_dataclass, field
+ from ..pyutils.dataclasses import make_dataclass, field # type: ignore
+# For static type checking with type checker
if TYPE_CHECKING:
- from typing import Dict, Iterable, Type
+ from typing import Dict, Iterable, Type # NOQA
class ObjectTypeOptions(BaseOptions):
- fields = None
- interfaces = ()
+ fields = None # type: Dict[str, Field]
+ interfaces = () # type: Iterable[Type[Interface]]
class ObjectTypeMeta(BaseTypeMeta):
-
def __new__(cls, name_, bases, namespace, **options):
+ # Note: it's safe to pass options as keyword arguments as they are still type-checked by ObjectTypeOptions.
-
+ # We create this type, to then overload it with the dataclass attrs
class InterObjectType:
pass
- base_cls = super().__new__(cls, name_, (InterObjectType,) + bases,
- namespace, **options)
+
+ base_cls = super().__new__(
+ cls, name_, (InterObjectType,) + bases, namespace, **options
+ )
if base_cls._meta:
- fields = [(key, 'typing.Any', field(default=field_value.
- default_value if isinstance(field_value, Field) else None)) for
- key, field_value in base_cls._meta.fields.items()]
+ fields = [
+ (
+ key,
+ "typing.Any",
+ field(
+ default=field_value.default_value
+ if isinstance(field_value, Field)
+ else None
+ ),
+ )
+ for key, field_value in base_cls._meta.fields.items()
+ ]
dataclass = make_dataclass(name_, fields, bases=())
InterObjectType.__init__ = dataclass.__init__
InterObjectType.__eq__ = dataclass.__eq__
@@ -109,19 +123,30 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
"""
@classmethod
- def __init_subclass_with_meta__(cls, interfaces=(), possible_types=(),
- default_resolver=None, _meta=None, **options):
+ def __init_subclass_with_meta__(
+ cls,
+ interfaces=(),
+ possible_types=(),
+ default_resolver=None,
+ _meta=None,
+ **options,
+ ):
if not _meta:
_meta = ObjectTypeOptions(cls)
fields = {}
+
for interface in interfaces:
- assert issubclass(interface, Interface
- ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
+ assert issubclass(
+ interface, Interface
+ ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
fields.update(interface._meta.fields)
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
- assert not (possible_types and cls.is_type_of
- ), f'{cls.__name__}.Meta.possible_types will cause type collision with {cls.__name__}.is_type_of. Please use one or other.'
+ assert not (possible_types and cls.is_type_of), (
+ f"{cls.__name__}.Meta.possible_types will cause type collision with {cls.__name__}.is_type_of. "
+ "Please use one or other."
+ )
+
if _meta.fields:
_meta.fields.update(fields)
else:
@@ -130,6 +155,7 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
_meta.interfaces = interfaces
_meta.possible_types = possible_types
_meta.default_resolver = default_resolver
- super(ObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **
- options)
+
+ super(ObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
is_type_of = None
diff --git a/graphene/types/resolver.py b/graphene/types/resolver.py
index f9a1c6e..72d2edb 100644
--- a/graphene/types/resolver.py
+++ b/graphene/types/resolver.py
@@ -1 +1,24 @@
+def attr_resolver(attname, default_value, root, info, **args):
+ return getattr(root, attname, default_value)
+
+
+def dict_resolver(attname, default_value, root, info, **args):
+ return root.get(attname, default_value)
+
+
+def dict_or_attr_resolver(attname, default_value, root, info, **args):
+ resolver = dict_resolver if isinstance(root, dict) else attr_resolver
+ return resolver(attname, default_value, root, info, **args)
+
+
default_resolver = dict_or_attr_resolver
+
+
+def set_default_resolver(resolver):
+ global default_resolver
+ assert callable(resolver), "Received non-callable resolver."
+ default_resolver = resolver
+
+
+def get_default_resolver():
+ return default_resolver
diff --git a/graphene/types/scalars.py b/graphene/types/scalars.py
index b1a7427..a468bb3 100644
--- a/graphene/types/scalars.py
+++ b/graphene/types/scalars.py
@@ -1,6 +1,13 @@
from typing import Any
+
from graphql import Undefined
-from graphql.language.ast import BooleanValueNode, FloatValueNode, IntValueNode, StringValueNode
+from graphql.language.ast import (
+ BooleanValueNode,
+ FloatValueNode,
+ IntValueNode,
+ StringValueNode,
+)
+
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
@@ -22,6 +29,7 @@ class Scalar(UnmountedType, BaseType):
def __init_subclass_with_meta__(cls, **options):
_meta = ScalarOptions(cls)
super(Scalar, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
serialize = None
parse_value = None
parse_literal = None
@@ -32,9 +40,14 @@ class Scalar(UnmountedType, BaseType):
This function is called when the unmounted type (Scalar instance)
is mounted (as a Field, InputField or Argument)
"""
- pass
+ return cls
+# As per the GraphQL Spec, Integers are only treated as valid when a valid
+# 32-bit signed integer, providing the broadest support across platforms.
+#
+# n.b. JavaScript's integers are safe between -(2^53 - 1) and 2^53 - 1 because
+# they are internally represented as IEEE 754 doubles.
MAX_INT = 2147483647
MIN_INT = -2147483648
@@ -46,9 +59,31 @@ class Int(Scalar):
represented in JSON as double-precision floating point numbers specified
by [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).
"""
+
+ @staticmethod
+ def coerce_int(value):
+ try:
+ num = int(value)
+ except ValueError:
+ try:
+ num = int(float(value))
+ except ValueError:
+ return Undefined
+ if MIN_INT <= num <= MAX_INT:
+ return num
+ return Undefined
+
serialize = coerce_int
parse_value = coerce_int
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, IntValueNode):
+ num = int(ast.value)
+ if MIN_INT <= num <= MAX_INT:
+ return num
+ return Undefined
+
class BigInt(Scalar):
"""
@@ -56,9 +91,27 @@ class BigInt(Scalar):
`BigInt` is not constrained to 32-bit like the `Int` type and thus is a less
compatible type.
"""
+
+ @staticmethod
+ def coerce_int(value):
+ try:
+ num = int(value)
+ except ValueError:
+ try:
+ num = int(float(value))
+ except ValueError:
+ return Undefined
+ return num
+
serialize = coerce_int
parse_value = coerce_int
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, IntValueNode):
+ return int(ast.value)
+ return Undefined
+
class Float(Scalar):
"""
@@ -66,9 +119,24 @@ class Float(Scalar):
values as specified by
[IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).
"""
+
+ @staticmethod
+ def coerce_float(value):
+ # type: (Any) -> float
+ try:
+ return float(value)
+ except ValueError:
+ return Undefined
+
serialize = coerce_float
parse_value = coerce_float
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, (FloatValueNode, IntValueNode)):
+ return float(ast.value)
+ return Undefined
+
class String(Scalar):
"""
@@ -76,17 +144,37 @@ class String(Scalar):
character sequences. The String type is most often used by GraphQL to
represent free-form human-readable text.
"""
+
+ @staticmethod
+ def coerce_string(value):
+ if isinstance(value, bool):
+ return "true" if value else "false"
+ return str(value)
+
serialize = coerce_string
parse_value = coerce_string
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, StringValueNode):
+ return ast.value
+ return Undefined
+
class Boolean(Scalar):
"""
The `Boolean` scalar type represents `true` or `false`.
"""
+
serialize = bool
parse_value = bool
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, BooleanValueNode):
+ return ast.value
+ return Undefined
+
class ID(Scalar):
"""
@@ -96,5 +184,12 @@ class ID(Scalar):
When expected as an input type, any string (such as `"4"`) or integer
(such as `4`) input value will be accepted as an ID.
"""
+
serialize = str
parse_value = str
+
+ @staticmethod
+ def parse_literal(ast, _variables=None):
+ if isinstance(ast, (StringValueNode, IntValueNode)):
+ return ast.value
+ return Undefined
diff --git a/graphene/types/schema.py b/graphene/types/schema.py
index 1227a4c..bceede6 100644
--- a/graphene/types/schema.py
+++ b/graphene/types/schema.py
@@ -1,10 +1,45 @@
from enum import Enum as PyEnum
import inspect
from functools import partial
-from graphql import default_type_resolver, get_introspection_query, graphql, graphql_sync, introspection_types, parse, print_schema, subscribe, validate, ExecutionResult, GraphQLArgument, GraphQLBoolean, GraphQLError, GraphQLEnumValue, GraphQLField, GraphQLFloat, GraphQLID, GraphQLInputField, GraphQLInt, GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString
+
+from graphql import (
+ default_type_resolver,
+ get_introspection_query,
+ graphql,
+ graphql_sync,
+ introspection_types,
+ parse,
+ print_schema,
+ subscribe,
+ validate,
+ ExecutionResult,
+ GraphQLArgument,
+ GraphQLBoolean,
+ GraphQLError,
+ GraphQLEnumValue,
+ GraphQLField,
+ GraphQLFloat,
+ GraphQLID,
+ GraphQLInputField,
+ GraphQLInt,
+ GraphQLList,
+ GraphQLNonNull,
+ GraphQLObjectType,
+ GraphQLSchema,
+ GraphQLString,
+)
+
from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function
-from .definitions import GrapheneEnumType, GrapheneGraphQLType, GrapheneInputObjectType, GrapheneInterfaceType, GrapheneObjectType, GrapheneScalarType, GrapheneUnionType
+from .definitions import (
+ GrapheneEnumType,
+ GrapheneGraphQLType,
+ GrapheneInputObjectType,
+ GrapheneInterfaceType,
+ GrapheneObjectType,
+ GrapheneScalarType,
+ GrapheneUnionType,
+)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
@@ -16,14 +51,48 @@ from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull
from .union import Union
from .utils import get_field_as
+
introspection_query = get_introspection_query()
-IntrospectionSchema = introspection_types['__Schema']
+IntrospectionSchema = introspection_types["__Schema"]
-class TypeMap(dict):
+def assert_valid_root_type(type_):
+ if type_ is None:
+ return
+ is_graphene_objecttype = inspect.isclass(type_) and issubclass(type_, ObjectType)
+ is_graphql_objecttype = isinstance(type_, GraphQLObjectType)
+ assert (
+ is_graphene_objecttype or is_graphql_objecttype
+ ), f"Type {type_} is not a valid ObjectType."
+
+
+def is_graphene_type(type_):
+ if isinstance(type_, (List, NonNull)):
+ return True
+ if inspect.isclass(type_) and issubclass(
+ type_, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
+ ):
+ return True
+
- def __init__(self, query=None, mutation=None, subscription=None, types=
- None, auto_camelcase=True):
+def is_type_of_from_possible_types(possible_types, root, _info):
+ return isinstance(root, possible_types)
+
+
+# We use this resolver for subscriptions
+def identity_resolve(root, info, **arguments):
+ return root
+
+
+class TypeMap(dict):
+ def __init__(
+ self,
+ query=None,
+ mutation=None,
+ subscription=None,
+ types=None,
+ auto_camelcase=True,
+ ):
assert_valid_root_type(query)
assert_valid_root_type(mutation)
assert_valid_root_type(subscription)
@@ -31,19 +100,305 @@ class TypeMap(dict):
types = []
for type_ in types:
assert is_graphene_type(type_)
+
self.auto_camelcase = auto_camelcase
+
create_graphql_type = self.add_type
+
self.query = create_graphql_type(query) if query else None
self.mutation = create_graphql_type(mutation) if mutation else None
- self.subscription = create_graphql_type(subscription
- ) if subscription else None
- self.types = [create_graphql_type(graphene_type) for graphene_type in
- types]
+ self.subscription = create_graphql_type(subscription) if subscription else None
+
+ self.types = [create_graphql_type(graphene_type) for graphene_type in types]
+
+ def add_type(self, graphene_type):
+ if inspect.isfunction(graphene_type):
+ graphene_type = graphene_type()
+ if isinstance(graphene_type, List):
+ return GraphQLList(self.add_type(graphene_type.of_type))
+ if isinstance(graphene_type, NonNull):
+ return GraphQLNonNull(self.add_type(graphene_type.of_type))
+ try:
+ name = graphene_type._meta.name
+ except AttributeError:
+ raise TypeError(f"Expected Graphene type, but received: {graphene_type}.")
+ graphql_type = self.get(name)
+ if graphql_type:
+ return graphql_type
+ if issubclass(graphene_type, ObjectType):
+ graphql_type = self.create_objecttype(graphene_type)
+ elif issubclass(graphene_type, InputObjectType):
+ graphql_type = self.create_inputobjecttype(graphene_type)
+ elif issubclass(graphene_type, Interface):
+ graphql_type = self.create_interface(graphene_type)
+ elif issubclass(graphene_type, Scalar):
+ graphql_type = self.create_scalar(graphene_type)
+ elif issubclass(graphene_type, Enum):
+ graphql_type = self.create_enum(graphene_type)
+ elif issubclass(graphene_type, Union):
+ graphql_type = self.construct_union(graphene_type)
+ else:
+ raise TypeError(f"Expected Graphene type, but received: {graphene_type}.")
+ self[name] = graphql_type
+ return graphql_type
+
+ @staticmethod
+ def create_scalar(graphene_type):
+ # We have a mapping to the original GraphQL types
+ # so there are no collisions.
+ _scalars = {
+ String: GraphQLString,
+ Int: GraphQLInt,
+ Float: GraphQLFloat,
+ Boolean: GraphQLBoolean,
+ ID: GraphQLID,
+ }
+ if graphene_type in _scalars:
+ return _scalars[graphene_type]
+
+ return GrapheneScalarType(
+ graphene_type=graphene_type,
+ name=graphene_type._meta.name,
+ description=graphene_type._meta.description,
+ serialize=getattr(graphene_type, "serialize", None),
+ parse_value=getattr(graphene_type, "parse_value", None),
+ parse_literal=getattr(graphene_type, "parse_literal", None),
+ )
+
+ @staticmethod
+ def create_enum(graphene_type):
+ values = {}
+ for name, value in graphene_type._meta.enum.__members__.items():
+ description = getattr(value, "description", None)
+ # if the "description" attribute is an Enum, it is likely an enum member
+ # called description, not a description property
+ if isinstance(description, PyEnum):
+ description = None
+ if not description and callable(graphene_type._meta.description):
+ description = graphene_type._meta.description(value)
+
+ deprecation_reason = getattr(value, "deprecation_reason", None)
+ if isinstance(deprecation_reason, PyEnum):
+ deprecation_reason = None
+ if not deprecation_reason and callable(
+ graphene_type._meta.deprecation_reason
+ ):
+ deprecation_reason = graphene_type._meta.deprecation_reason(value)
+
+ values[name] = GraphQLEnumValue(
+ value=value,
+ description=description,
+ deprecation_reason=deprecation_reason,
+ )
+
+ type_description = (
+ graphene_type._meta.description(None)
+ if callable(graphene_type._meta.description)
+ else graphene_type._meta.description
+ )
+
+ return GrapheneEnumType(
+ graphene_type=graphene_type,
+ values=values,
+ name=graphene_type._meta.name,
+ description=type_description,
+ )
+
+ def create_objecttype(self, graphene_type):
+ create_graphql_type = self.add_type
+
+ def interfaces():
+ interfaces = []
+ for graphene_interface in graphene_type._meta.interfaces:
+ interface = create_graphql_type(graphene_interface)
+ assert interface.graphene_type == graphene_interface
+ interfaces.append(interface)
+ return interfaces
- def get_function_for_type(self, graphene_type, func_name, name,
- default_value):
+ if graphene_type._meta.possible_types:
+ is_type_of = partial(
+ is_type_of_from_possible_types, graphene_type._meta.possible_types
+ )
+ else:
+ is_type_of = graphene_type.is_type_of
+
+ return GrapheneObjectType(
+ graphene_type=graphene_type,
+ name=graphene_type._meta.name,
+ description=graphene_type._meta.description,
+ fields=partial(self.create_fields_for_type, graphene_type),
+ is_type_of=is_type_of,
+ interfaces=interfaces,
+ )
+
+ def create_interface(self, graphene_type):
+ resolve_type = (
+ partial(
+ self.resolve_type, graphene_type.resolve_type, graphene_type._meta.name
+ )
+ if graphene_type.resolve_type
+ else None
+ )
+
+ def interfaces():
+ interfaces = []
+ for graphene_interface in graphene_type._meta.interfaces:
+ interface = self.add_type(graphene_interface)
+ assert interface.graphene_type == graphene_interface
+ interfaces.append(interface)
+ return interfaces
+
+ return GrapheneInterfaceType(
+ graphene_type=graphene_type,
+ name=graphene_type._meta.name,
+ description=graphene_type._meta.description,
+ fields=partial(self.create_fields_for_type, graphene_type),
+ interfaces=interfaces,
+ resolve_type=resolve_type,
+ )
+
+ def create_inputobjecttype(self, graphene_type):
+ return GrapheneInputObjectType(
+ graphene_type=graphene_type,
+ name=graphene_type._meta.name,
+ description=graphene_type._meta.description,
+ out_type=graphene_type._meta.container,
+ fields=partial(
+ self.create_fields_for_type, graphene_type, is_input_type=True
+ ),
+ )
+
+ def construct_union(self, graphene_type):
+ create_graphql_type = self.add_type
+
+ def types():
+ union_types = []
+ for graphene_objecttype in graphene_type._meta.types:
+ object_type = create_graphql_type(graphene_objecttype)
+ assert object_type.graphene_type == graphene_objecttype
+ union_types.append(object_type)
+ return union_types
+
+ resolve_type = (
+ partial(
+ self.resolve_type, graphene_type.resolve_type, graphene_type._meta.name
+ )
+ if graphene_type.resolve_type
+ else None
+ )
+
+ return GrapheneUnionType(
+ graphene_type=graphene_type,
+ name=graphene_type._meta.name,
+ description=graphene_type._meta.description,
+ types=types,
+ resolve_type=resolve_type,
+ )
+
+ def get_name(self, name):
+ if self.auto_camelcase:
+ return to_camel_case(name)
+ return name
+
+ def create_fields_for_type(self, graphene_type, is_input_type=False):
+ create_graphql_type = self.add_type
+
+ fields = {}
+ for name, field in graphene_type._meta.fields.items():
+ if isinstance(field, Dynamic):
+ field = get_field_as(field.get_type(self), _as=Field)
+ if not field:
+ continue
+ field_type = create_graphql_type(field.type)
+ if is_input_type:
+ _field = GraphQLInputField(
+ field_type,
+ default_value=field.default_value,
+ out_name=name,
+ description=field.description,
+ deprecation_reason=field.deprecation_reason,
+ )
+ else:
+ args = {}
+ for arg_name, arg in field.args.items():
+ arg_type = create_graphql_type(arg.type)
+ processed_arg_name = arg.name or self.get_name(arg_name)
+ args[processed_arg_name] = GraphQLArgument(
+ arg_type,
+ out_name=arg_name,
+ description=arg.description,
+ default_value=arg.default_value,
+ deprecation_reason=arg.deprecation_reason,
+ )
+ subscribe = field.wrap_subscribe(
+ self.get_function_for_type(
+ graphene_type, f"subscribe_{name}", name, field.default_value
+ )
+ )
+
+ # If we are in a subscription, we use (by default) an
+ # identity-based resolver for the root, rather than the
+ # default resolver for objects/dicts.
+ if subscribe:
+ field_default_resolver = identity_resolve
+ elif issubclass(graphene_type, ObjectType):
+ default_resolver = (
+ graphene_type._meta.default_resolver or get_default_resolver()
+ )
+ field_default_resolver = partial(
+ default_resolver, name, field.default_value
+ )
+ else:
+ field_default_resolver = None
+
+ resolve = field.wrap_resolve(
+ self.get_function_for_type(
+ graphene_type, f"resolve_{name}", name, field.default_value
+ )
+ or field_default_resolver
+ )
+
+ _field = GraphQLField(
+ field_type,
+ args=args,
+ resolve=resolve,
+ subscribe=subscribe,
+ deprecation_reason=field.deprecation_reason,
+ description=field.description,
+ )
+ field_name = field.name or self.get_name(name)
+ fields[field_name] = _field
+ return fields
+
+ def get_function_for_type(self, graphene_type, func_name, name, default_value):
"""Gets a resolve or subscribe function for a given ObjectType"""
- pass
+ if not issubclass(graphene_type, ObjectType):
+ return
+ resolver = getattr(graphene_type, func_name, None)
+ if not resolver:
+ # If we don't find the resolver in the ObjectType class, then try to
+ # find it in each of the interfaces
+ interface_resolver = None
+ for interface in graphene_type._meta.interfaces:
+ if name not in interface._meta.fields:
+ continue
+ interface_resolver = getattr(interface, func_name, None)
+ if interface_resolver:
+ break
+ resolver = interface_resolver
+
+ # Only if is not decorated with classmethod
+ if resolver:
+ return get_unbound_function(resolver)
+
+ def resolve_type(self, resolve_type_func, type_name, root, info, _type):
+ type_ = resolve_type_func(root, info)
+
+ if inspect.isclass(type_) and issubclass(type_, ObjectType):
+ return type_._meta.name
+
+ return_type = self[type_name]
+ return default_type_resolver(root, info, return_type)
class Schema:
@@ -67,15 +422,28 @@ class Schema:
to camelCase (preferred by GraphQL standard). Default True.
"""
- def __init__(self, query=None, mutation=None, subscription=None, types=
- None, directives=None, auto_camelcase=True):
+ def __init__(
+ self,
+ query=None,
+ mutation=None,
+ subscription=None,
+ types=None,
+ directives=None,
+ auto_camelcase=True,
+ ):
self.query = query
self.mutation = mutation
self.subscription = subscription
- type_map = TypeMap(query, mutation, subscription, types,
- auto_camelcase=auto_camelcase)
- self.graphql_schema = GraphQLSchema(type_map.query, type_map.
- mutation, type_map.subscription, type_map.types, directives)
+ type_map = TypeMap(
+ query, mutation, subscription, types, auto_camelcase=auto_camelcase
+ )
+ self.graphql_schema = GraphQLSchema(
+ type_map.query,
+ type_map.mutation,
+ type_map.subscription,
+ type_map.types,
+ directives,
+ )
def __str__(self):
return print_schema(self.graphql_schema)
@@ -93,6 +461,9 @@ class Schema:
return _type.graphene_type
return _type
+ def lazy(self, _type):
+ return lambda: self.get_type(_type)
+
def execute(self, *args, **kwargs):
"""Execute a GraphQL query on the schema.
Use the `graphql_sync` function from `graphql-core` to provide the result
@@ -117,19 +488,48 @@ class Schema:
Returns:
:obj:`ExecutionResult` containing any data and errors for the operation.
"""
- pass
+ kwargs = normalize_execute_kwargs(kwargs)
+ return graphql_sync(self.graphql_schema, *args, **kwargs)
async def execute_async(self, *args, **kwargs):
"""Execute a GraphQL query on the schema asynchronously.
Same as `execute`, but uses `graphql` instead of `graphql_sync`.
"""
- pass
+ kwargs = normalize_execute_kwargs(kwargs)
+ return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs):
"""Execute a GraphQL subscription on the schema asynchronously."""
- pass
+ # Do parsing
+ try:
+ document = parse(query)
+ except GraphQLError as error:
+ return ExecutionResult(data=None, errors=[error])
+
+ # Do validation
+ validation_errors = validate(self.graphql_schema, document)
+ if validation_errors:
+ return ExecutionResult(data=None, errors=validation_errors)
+
+ # Execute the query
+ kwargs = normalize_execute_kwargs(kwargs)
+ return await subscribe(self.graphql_schema, document, *args, **kwargs)
+
+ def introspect(self):
+ introspection = self.execute(introspection_query)
+ if introspection.errors:
+ raise introspection.errors[0]
+ return introspection.data
def normalize_execute_kwargs(kwargs):
"""Replace alias names in keyword arguments for graphql()"""
- pass
+ if "root" in kwargs and "root_value" not in kwargs:
+ kwargs["root_value"] = kwargs.pop("root")
+ if "context" in kwargs and "context_value" not in kwargs:
+ kwargs["context_value"] = kwargs.pop("context")
+ if "variables" in kwargs and "variable_values" not in kwargs:
+ kwargs["variable_values"] = kwargs.pop("variables")
+ if "operation" in kwargs and "operation_name" not in kwargs:
+ kwargs["operation_name"] = kwargs.pop("operation")
+ return kwargs
diff --git a/graphene/types/structures.py b/graphene/types/structures.py
index 155e1c0..a676397 100644
--- a/graphene/types/structures.py
+++ b/graphene/types/structures.py
@@ -10,21 +10,25 @@ class Structure(UnmountedType):
def __init__(self, of_type, *args, **kwargs):
super(Structure, self).__init__(*args, **kwargs)
- if not isinstance(of_type, Structure) and isinstance(of_type,
- UnmountedType):
+ if not isinstance(of_type, Structure) and isinstance(of_type, UnmountedType):
cls_name = type(self).__name__
of_type_name = type(of_type).__name__
raise Exception(
- f'{cls_name} could not have a mounted {of_type_name}() as inner type. Try with {cls_name}({of_type_name}).'
- )
+ f"{cls_name} could not have a mounted {of_type_name}()"
+ f" as inner type. Try with {cls_name}({of_type_name})."
+ )
self._of_type = of_type
+ @property
+ def of_type(self):
+ return get_type(self._of_type)
+
def get_type(self):
"""
This function is called when the unmounted type (List or NonNull instance)
is mounted (as a Field, InputField or Argument)
"""
- pass
+ return self
class List(Structure):
@@ -45,11 +49,14 @@ class List(Structure):
"""
def __str__(self):
- return f'[{self.of_type}]'
+ return f"[{self.of_type}]"
def __eq__(self, other):
- return isinstance(other, List) and (self.of_type == other.of_type and
- self.args == other.args and self.kwargs == other.kwargs)
+ return isinstance(other, List) and (
+ self.of_type == other.of_type
+ and self.args == other.args
+ and self.kwargs == other.kwargs
+ )
class NonNull(Structure):
@@ -77,13 +84,16 @@ class NonNull(Structure):
def __init__(self, *args, **kwargs):
super(NonNull, self).__init__(*args, **kwargs)
- assert not isinstance(self._of_type, NonNull
- ), f'Can only create NonNull of a Nullable GraphQLType but got: {self._of_type}.'
+ assert not isinstance(
+ self._of_type, NonNull
+ ), f"Can only create NonNull of a Nullable GraphQLType but got: {self._of_type}."
def __str__(self):
- return f'{self.of_type}!'
+ return f"{self.of_type}!"
def __eq__(self, other):
- return isinstance(other, NonNull) and (self.of_type == other.
- of_type and self.args == other.args and self.kwargs == other.kwargs
- )
+ return isinstance(other, NonNull) and (
+ self.of_type == other.of_type
+ and self.args == other.args
+ and self.kwargs == other.kwargs
+ )
diff --git a/graphene/types/union.py b/graphene/types/union.py
index cabc8df..b7c5dc6 100644
--- a/graphene/types/union.py
+++ b/graphene/types/union.py
@@ -1,13 +1,16 @@
from typing import TYPE_CHECKING
+
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
+
+# For static type checking with type checker
if TYPE_CHECKING:
- from .objecttype import ObjectType
- from typing import Iterable, Type
+ from .objecttype import ObjectType # NOQA
+ from typing import Iterable, Type # NOQA
class UnionOptions(BaseOptions):
- types = ()
+ types = () # type: Iterable[Type[ObjectType]]
class Union(UnmountedType, BaseType):
@@ -49,8 +52,10 @@ class Union(UnmountedType, BaseType):
@classmethod
def __init_subclass_with_meta__(cls, types=None, **options):
- assert isinstance(types, (list, tuple)) and len(types
- ) > 0, f'Must provide types for Union {cls.__name__}.'
+ assert (
+ isinstance(types, (list, tuple)) and len(types) > 0
+ ), f"Must provide types for Union {cls.__name__}."
+
_meta = UnionOptions(cls)
_meta.types = types
super(Union, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@@ -61,4 +66,11 @@ class Union(UnmountedType, BaseType):
This function is called when the unmounted type (Union instance)
is mounted (as a Field, InputField or Argument)
"""
- pass
+ return cls
+
+ @classmethod
+ def resolve_type(cls, instance, info):
+ from .objecttype import ObjectType # NOQA
+
+ if isinstance(instance, ObjectType):
+ return type(instance)
diff --git a/graphene/types/unmountedtype.py b/graphene/types/unmountedtype.py
index 5ca42ce..83a6afe 100644
--- a/graphene/types/unmountedtype.py
+++ b/graphene/types/unmountedtype.py
@@ -49,27 +49,39 @@ class UnmountedType(OrderedType):
This function is called when the UnmountedType instance
is mounted (as a Field, InputField or Argument)
"""
- pass
+ raise NotImplementedError(f"get_type not implemented in {self}")
- def Field(self):
+ def mount_as(self, _as):
+ return _as.mounted(self)
+
+ def Field(self): # noqa: N802
"""
Mount the UnmountedType as Field
"""
- pass
+ from .field import Field
+
+ return self.mount_as(Field)
- def InputField(self):
+ def InputField(self): # noqa: N802
"""
Mount the UnmountedType as InputField
"""
- pass
+ from .inputfield import InputField
- def Argument(self):
+ return self.mount_as(InputField)
+
+ def Argument(self): # noqa: N802
"""
Mount the UnmountedType as Argument
"""
- pass
+ from .argument import Argument
+
+ return self.mount_as(Argument)
def __eq__(self, other):
- return self is other or isinstance(other, UnmountedType
- ) and self.get_type() == other.get_type(
- ) and self.args == other.args and self.kwargs == other.kwargs
+ return self is other or (
+ isinstance(other, UnmountedType)
+ and self.get_type() == other.get_type()
+ and self.args == other.args
+ and self.kwargs == other.kwargs
+ )
diff --git a/graphene/types/utils.py b/graphene/types/utils.py
index 4c05498..1976448 100644
--- a/graphene/types/utils.py
+++ b/graphene/types/utils.py
@@ -1,5 +1,6 @@
import inspect
from functools import partial
+
from ..utils.module_loading import import_string
from .mountedtype import MountedType
from .unmountedtype import UnmountedType
@@ -9,7 +10,12 @@ def get_field_as(value, _as=None):
"""
Get type mounted
"""
- pass
+ if isinstance(value, MountedType):
+ return value
+ elif isinstance(value, UnmountedType):
+ if _as is None:
+ return value
+ return _as.mounted(value)
def yank_fields_from_attrs(attrs, _as=None, sort=True):
@@ -17,9 +23,28 @@ def yank_fields_from_attrs(attrs, _as=None, sort=True):
Extract all the fields in given attributes (dict)
and return them ordered
"""
- pass
+ fields_with_names = []
+ for attname, value in list(attrs.items()):
+ field = get_field_as(value, _as)
+ if not field:
+ continue
+ fields_with_names.append((attname, field))
+
+ if sort:
+ fields_with_names = sorted(fields_with_names, key=lambda f: f[1])
+ return dict(fields_with_names)
+
+
+def get_type(_type):
+ if isinstance(_type, str):
+ return import_string(_type)
+ if inspect.isfunction(_type) or isinstance(_type, partial):
+ return _type()
+ return _type
def get_underlying_type(_type):
"""Get the underlying type even if it is wrapped in structures like NonNull"""
- pass
+ while hasattr(_type, "of_type"):
+ _type = _type.of_type
+ return _type
diff --git a/graphene/types/uuid.py b/graphene/types/uuid.py
index bcfe611..f2ba1fc 100644
--- a/graphene/types/uuid.py
+++ b/graphene/types/uuid.py
@@ -1,7 +1,9 @@
from __future__ import absolute_import
from uuid import UUID as _UUID
+
from graphql.language.ast import StringValueNode
from graphql import Undefined
+
from .scalars import Scalar
@@ -10,3 +12,21 @@ class UUID(Scalar):
Leverages the internal Python implementation of UUID (uuid.UUID) to provide native UUID objects
in fields, resolvers and input.
"""
+
+ @staticmethod
+ def serialize(uuid):
+ if isinstance(uuid, str):
+ uuid = _UUID(uuid)
+
+ assert isinstance(uuid, _UUID), f"Expected UUID instance, received {uuid}"
+ return str(uuid)
+
+ @staticmethod
+ def parse_literal(node, _variables=None):
+ if isinstance(node, StringValueNode):
+ return _UUID(node.value)
+ return Undefined
+
+ @staticmethod
+ def parse_value(value):
+ return _UUID(value)
diff --git a/graphene/utils/crunch.py b/graphene/utils/crunch.py
index b20feef..b27d371 100644
--- a/graphene/utils/crunch.py
+++ b/graphene/utils/crunch.py
@@ -1,2 +1,35 @@
import json
from collections.abc import Mapping
+
+
+def to_key(value):
+ return json.dumps(value)
+
+
+def insert(value, index, values):
+ key = to_key(value)
+
+ if key not in index:
+ index[key] = len(values)
+ values.append(value)
+ return len(values) - 1
+
+ return index.get(key)
+
+
+def flatten(data, index, values):
+ if isinstance(data, (list, tuple)):
+ flattened = [flatten(child, index, values) for child in data]
+ elif isinstance(data, Mapping):
+ flattened = {key: flatten(child, index, values) for key, child in data.items()}
+ else:
+ flattened = data
+ return insert(flattened, index, values)
+
+
+def crunch(data):
+ index = {}
+ values = []
+
+ flatten(data, index, values)
+ return values
diff --git a/graphene/utils/dataloader.py b/graphene/utils/dataloader.py
index a7136e2..143558a 100644
--- a/graphene/utils/dataloader.py
+++ b/graphene/utils/dataloader.py
@@ -1,43 +1,119 @@
-from asyncio import gather, ensure_future, get_event_loop, iscoroutine, iscoroutinefunction
+from asyncio import (
+ gather,
+ ensure_future,
+ get_event_loop,
+ iscoroutine,
+ iscoroutinefunction,
+)
from collections import namedtuple
from collections.abc import Iterable
from functools import partial
-from typing import List
-Loader = namedtuple('Loader', 'key,future')
+
+from typing import List # flake8: noqa
+
+Loader = namedtuple("Loader", "key,future")
+
+
+def iscoroutinefunctionorpartial(fn):
+ return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
class DataLoader(object):
batch = True
- max_batch_size = None
+ max_batch_size = None # type: int
cache = True
- def __init__(self, batch_load_fn=None, batch=None, max_batch_size=None,
- cache=None, get_cache_key=None, cache_map=None, loop=None):
+ def __init__(
+ self,
+ batch_load_fn=None,
+ batch=None,
+ max_batch_size=None,
+ cache=None,
+ get_cache_key=None,
+ cache_map=None,
+ loop=None,
+ ):
+
self._loop = loop
+
if batch_load_fn is not None:
self.batch_load_fn = batch_load_fn
- assert iscoroutinefunctionorpartial(self.batch_load_fn
- ), 'batch_load_fn must be coroutine. Received: {}'.format(self.
- batch_load_fn)
+
+ assert iscoroutinefunctionorpartial(
+ self.batch_load_fn
+ ), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
+
if not callable(self.batch_load_fn):
- raise TypeError(
- 'DataLoader must be have a batch_load_fn which accepts Iterable<key> and returns Future<Iterable<value>>, but got: {}.'
- .format(batch_load_fn))
+ raise TypeError( # pragma: no cover
+ (
+ "DataLoader must be have a batch_load_fn which accepts "
+ "Iterable<key> and returns Future<Iterable<value>>, but got: {}."
+ ).format(batch_load_fn)
+ )
+
if batch is not None:
- self.batch = batch
+ self.batch = batch # pragma: no cover
+
if max_batch_size is not None:
self.max_batch_size = max_batch_size
+
if cache is not None:
- self.cache = cache
+ self.cache = cache # pragma: no cover
+
self.get_cache_key = get_cache_key or (lambda x: x)
+
self._cache = cache_map if cache_map is not None else {}
- self._queue = []
+ self._queue = [] # type: List[Loader]
+
+ @property
+ def loop(self):
+ if not self._loop:
+ self._loop = get_event_loop()
+
+ return self._loop
def load(self, key=None):
"""
Loads a key, returning a `Future` for the value represented by that key.
"""
- pass
+ if key is None:
+ raise TypeError( # pragma: no cover
+ (
+ "The loader.load() function must be called with a value, "
+ "but got: {}."
+ ).format(key)
+ )
+
+ cache_key = self.get_cache_key(key)
+
+ # If caching and there is a cache-hit, return cached Future.
+ if self.cache:
+ cached_result = self._cache.get(cache_key)
+ if cached_result:
+ return cached_result
+
+ # Otherwise, produce a new Future for this value.
+ future = self.loop.create_future()
+ # If caching, cache this Future.
+ if self.cache:
+ self._cache[cache_key] = future
+
+ self.do_resolve_reject(key, future)
+ return future
+
+ def do_resolve_reject(self, key, future):
+ # Enqueue this Future to be dispatched.
+ self._queue.append(Loader(key=key, future=future))
+ # Determine if a dispatch of this queue should be scheduled.
+ # A single dispatch should be scheduled per queue at the time when the
+ # queue changes from "empty" to "full".
+ if len(self._queue) == 1:
+ if self.batch:
+ # If batching, schedule a task to dispatch the queue.
+ enqueue_post_future_job(self.loop, self)
+ else:
+ # Otherwise dispatch the (queue of one) immediately.
+ dispatch_queue(self) # pragma: no cover
def load_many(self, keys):
"""
@@ -52,14 +128,24 @@ class DataLoader(object):
>>> my_loader.load('b')
>>> )
"""
- pass
+ if not isinstance(keys, Iterable):
+ raise TypeError( # pragma: no cover
+ (
+ "The loader.load_many() function must be called with Iterable<key> "
+ "but got: {}."
+ ).format(keys)
+ )
+
+ return gather(*[self.load(key) for key in keys])
def clear(self, key):
"""
Clears the value at `key` from the cache, if it exists. Returns itself for
method chaining.
"""
- pass
+ cache_key = self.get_cache_key(key)
+ self._cache.pop(cache_key, None)
+ return self
def clear_all(self):
"""
@@ -67,14 +153,44 @@ class DataLoader(object):
invalidations across this particular `DataLoader`. Returns itself for
method chaining.
"""
- pass
+ self._cache.clear()
+ return self
def prime(self, key, value):
"""
Adds the provied key and value to the cache. If the key already exists, no
change is made. Returns itself for method chaining.
"""
- pass
+ cache_key = self.get_cache_key(key)
+
+ # Only add the key if it does not already exist.
+ if cache_key not in self._cache:
+ # Cache a rejected future if the value is an Error, in order to match
+ # the behavior of load(key).
+ future = self.loop.create_future()
+ if isinstance(value, Exception):
+ future.set_exception(value)
+ else:
+ future.set_result(value)
+
+ self._cache[cache_key] = future
+
+ return self
+
+
+def enqueue_post_future_job(loop, loader):
+ async def dispatch():
+ dispatch_queue(loader)
+
+ loop.call_soon(ensure_future, dispatch())
+
+
+def get_chunks(iterable_obj, chunk_size=1):
+ chunk_size = max(1, chunk_size)
+ return (
+ iterable_obj[i : i + chunk_size]
+ for i in range(0, len(iterable_obj), chunk_size)
+ )
def dispatch_queue(loader):
@@ -82,7 +198,77 @@ def dispatch_queue(loader):
Given the current state of a Loader instance, perform a batch load
from its current queue.
"""
- pass
+ # Take the current loader queue, replacing it with an empty queue.
+ queue = loader._queue
+ loader._queue = []
+
+ # If a max_batch_size was provided and the queue is longer, then segment the
+ # queue into multiple batches, otherwise treat the queue as a single batch.
+ max_batch_size = loader.max_batch_size
+
+ if max_batch_size and max_batch_size < len(queue):
+ chunks = get_chunks(queue, max_batch_size)
+ for chunk in chunks:
+ ensure_future(dispatch_queue_batch(loader, chunk))
+ else:
+ ensure_future(dispatch_queue_batch(loader, queue))
+
+
+async def dispatch_queue_batch(loader, queue):
+ # Collect all keys to be loaded in this dispatch
+ keys = [loaded.key for loaded in queue]
+
+ # Call the provided batch_load_fn for this loader with the loader queue's keys.
+ batch_future = loader.batch_load_fn(keys)
+
+ # Assert the expected response from batch_load_fn
+ if not batch_future or not iscoroutine(batch_future):
+ return failed_dispatch( # pragma: no cover
+ loader,
+ queue,
+ TypeError(
+ (
+ "DataLoader must be constructed with a function which accepts "
+ "Iterable<key> and returns Future<Iterable<value>>, but the function did "
+ "not return a Coroutine: {}."
+ ).format(batch_future)
+ ),
+ )
+
+ try:
+ values = await batch_future
+ if not isinstance(values, Iterable):
+ raise TypeError( # pragma: no cover
+ (
+ "DataLoader must be constructed with a function which accepts "
+ "Iterable<key> and returns Future<Iterable<value>>, but the function did "
+ "not return a Future of a Iterable: {}."
+ ).format(values)
+ )
+
+ values = list(values)
+ if len(values) != len(keys):
+ raise TypeError( # pragma: no cover
+ (
+ "DataLoader must be constructed with a function which accepts "
+ "Iterable<key> and returns Future<Iterable<value>>, but the function did "
+ "not return a Future of a Iterable with the same length as the Iterable "
+ "of keys."
+ "\n\nKeys:\n{}"
+ "\n\nValues:\n{}"
+ ).format(keys, values)
+ )
+
+ # Step through the values, resolving or rejecting each Future in the
+ # loaded queue.
+ for loaded, value in zip(queue, values):
+ if isinstance(value, Exception):
+ loaded.future.set_exception(value)
+ else:
+ loaded.future.set_result(value)
+
+ except Exception as e:
+ return failed_dispatch(loader, queue, e)
def failed_dispatch(loader, queue, error):
@@ -90,4 +276,6 @@ def failed_dispatch(loader, queue, error):
Do not cache individual loads if the entire batch dispatch fails,
but still reject each request so they do not hang.
"""
- pass
+ for loaded in queue:
+ loader.clear(loaded.key)
+ loaded.future.set_exception(error)
diff --git a/graphene/utils/deduplicator.py b/graphene/utils/deduplicator.py
index 26f5d8e..3fbf139 100644
--- a/graphene/utils/deduplicator.py
+++ b/graphene/utils/deduplicator.py
@@ -1 +1,32 @@
from collections.abc import Mapping
+
+
+def deflate(node, index=None, path=None):
+ if index is None:
+ index = {}
+ if path is None:
+ path = []
+
+ if node and "id" in node and "__typename" in node:
+ route = ",".join(path)
+ cache_key = ":".join([route, str(node["__typename"]), str(node["id"])])
+
+ if index.get(cache_key) is True:
+ return {"__typename": node["__typename"], "id": node["id"]}
+ else:
+ index[cache_key] = True
+
+ result = {}
+
+ for field_name in node:
+ value = node[field_name]
+
+ new_path = path + [field_name]
+ if isinstance(value, (list, tuple)):
+ result[field_name] = [deflate(child, index, new_path) for child in value]
+ elif isinstance(value, Mapping):
+ result[field_name] = deflate(value, index, new_path)
+ else:
+ result[field_name] = value
+
+ return result
diff --git a/graphene/utils/deprecated.py b/graphene/utils/deprecated.py
index d561393..71a5bb4 100644
--- a/graphene/utils/deprecated.py
+++ b/graphene/utils/deprecated.py
@@ -1,7 +1,12 @@
import functools
import inspect
import warnings
-string_types = type(b''), type('')
+
+string_types = (type(b""), type(""))
+
+
+def warn_deprecation(text):
+ warnings.warn(text, category=DeprecationWarning, stacklevel=2)
def deprecated(reason):
@@ -10,4 +15,56 @@ def deprecated(reason):
as deprecated. It will result in a warning being emitted
when the function is used.
"""
- pass
+
+ if isinstance(reason, string_types):
+
+ # The @deprecated is used with a 'reason'.
+ #
+ # .. code-block:: python
+ #
+ # @deprecated("please, use another function")
+ # def old_function(x, y):
+ # pass
+
+ def decorator(func1):
+
+ if inspect.isclass(func1):
+ fmt1 = f"Call to deprecated class {func1.__name__} ({reason})."
+ else:
+ fmt1 = f"Call to deprecated function {func1.__name__} ({reason})."
+
+ @functools.wraps(func1)
+ def new_func1(*args, **kwargs):
+ warn_deprecation(fmt1)
+ return func1(*args, **kwargs)
+
+ return new_func1
+
+ return decorator
+
+ elif inspect.isclass(reason) or inspect.isfunction(reason):
+
+ # The @deprecated is used without any 'reason'.
+ #
+ # .. code-block:: python
+ #
+ # @deprecated
+ # def old_function(x, y):
+ # pass
+
+ func2 = reason
+
+ if inspect.isclass(func2):
+ fmt2 = f"Call to deprecated class {func2.__name__}."
+ else:
+ fmt2 = f"Call to deprecated function {func2.__name__}."
+
+ @functools.wraps(func2)
+ def new_func2(*args, **kwargs):
+ warn_deprecation(fmt2)
+ return func2(*args, **kwargs)
+
+ return new_func2
+
+ else:
+ raise TypeError(repr(type(reason)))
diff --git a/graphene/utils/get_unbound_function.py b/graphene/utils/get_unbound_function.py
index e69de29..bd311e3 100644
--- a/graphene/utils/get_unbound_function.py
+++ b/graphene/utils/get_unbound_function.py
@@ -0,0 +1,4 @@
+def get_unbound_function(func):
+ if not getattr(func, "__self__", True):
+ return func.__func__
+ return func
diff --git a/graphene/utils/is_introspection_key.py b/graphene/utils/is_introspection_key.py
index e69de29..59d72b2 100644
--- a/graphene/utils/is_introspection_key.py
+++ b/graphene/utils/is_introspection_key.py
@@ -0,0 +1,6 @@
+def is_introspection_key(key):
+ # from: https://spec.graphql.org/June2018/#sec-Schema
+ # > All types and directives defined within a schema must not have a name which
+ # > begins with "__" (two underscores), as this is used exclusively
+ # > by GraphQL’s introspection system.
+ return str(key).startswith("__")
diff --git a/graphene/utils/module_loading.py b/graphene/utils/module_loading.py
index 21e42a9..d9095d0 100644
--- a/graphene/utils/module_loading.py
+++ b/graphene/utils/module_loading.py
@@ -10,4 +10,36 @@ def import_string(dotted_path, dotted_attributes=None):
the first step, and return the corresponding value designated by the
attribute path. Raise ImportError if the import failed.
"""
- pass
+ try:
+ module_path, class_name = dotted_path.rsplit(".", 1)
+ except ValueError:
+ raise ImportError("%s doesn't look like a module path" % dotted_path)
+
+ module = import_module(module_path)
+
+ try:
+ result = getattr(module, class_name)
+ except AttributeError:
+ raise ImportError(
+ 'Module "%s" does not define a "%s" attribute/class'
+ % (module_path, class_name)
+ )
+
+ if not dotted_attributes:
+ return result
+ attributes = dotted_attributes.split(".")
+ traveled_attributes = []
+ try:
+ for attribute in attributes:
+ traveled_attributes.append(attribute)
+ result = getattr(result, attribute)
+ return result
+ except AttributeError:
+ raise ImportError(
+ 'Module "%s" does not define a "%s" attribute inside attribute/class "%s"'
+ % (module_path, ".".join(traveled_attributes), class_name)
+ )
+
+
+def lazy_import(dotted_path, dotted_attributes=None):
+ return partial(import_string, dotted_path, dotted_attributes)
diff --git a/graphene/utils/orderedtype.py b/graphene/utils/orderedtype.py
index e97396d..294ad54 100644
--- a/graphene/utils/orderedtype.py
+++ b/graphene/utils/orderedtype.py
@@ -8,17 +8,29 @@ class OrderedType:
def __init__(self, _creation_counter=None):
self.creation_counter = _creation_counter or self.gen_counter()
+ @staticmethod
+ def gen_counter():
+ counter = OrderedType.creation_counter
+ OrderedType.creation_counter += 1
+ return counter
+
+ def reset_counter(self):
+ self.creation_counter = self.gen_counter()
+
def __eq__(self, other):
+ # Needed for @total_ordering
if isinstance(self, type(other)):
return self.creation_counter == other.creation_counter
return NotImplemented
def __lt__(self, other):
+ # This is needed because bisect does not take a comparison function.
if isinstance(other, OrderedType):
return self.creation_counter < other.creation_counter
return NotImplemented
def __gt__(self, other):
+ # This is needed because bisect does not take a comparison function.
if isinstance(other, OrderedType):
return self.creation_counter > other.creation_counter
return NotImplemented
diff --git a/graphene/utils/props.py b/graphene/utils/props.py
index 114245a..26c697e 100644
--- a/graphene/utils/props.py
+++ b/graphene/utils/props.py
@@ -7,3 +7,9 @@ class _NewClass:
_all_vars = set(dir(_OldClass) + dir(_NewClass))
+
+
+def props(x):
+ return {
+ key: vars(x).get(key, getattr(x, key)) for key in dir(x) if key not in _all_vars
+ }
diff --git a/graphene/utils/resolve_only_args.py b/graphene/utils/resolve_only_args.py
index 0f0ddd2..5efff2e 100644
--- a/graphene/utils/resolve_only_args.py
+++ b/graphene/utils/resolve_only_args.py
@@ -1,2 +1,12 @@
from functools import wraps
+
from .deprecated import deprecated
+
+
+@deprecated("This function is deprecated")
+def resolve_only_args(func):
+ @wraps(func)
+ def wrapped_func(root, info, **args):
+ return func(root, **args)
+
+ return wrapped_func
diff --git a/graphene/utils/str_converters.py b/graphene/utils/str_converters.py
index b199df5..2a214f0 100644
--- a/graphene/utils/str_converters.py
+++ b/graphene/utils/str_converters.py
@@ -1 +1,17 @@
import re
+
+
+# Adapted from this response in Stackoverflow
+# http://stackoverflow.com/a/19053800/1072990
+def to_camel_case(snake_str):
+ components = snake_str.split("_")
+ # We capitalize the first letter of each component except the first one
+ # with the 'capitalize' method and join them together.
+ return components[0] + "".join(x.capitalize() if x else "_" for x in components[1:])
+
+
+# From this response in Stackoverflow
+# http://stackoverflow.com/a/1176023/1072990
+def to_snake_case(name):
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
diff --git a/graphene/utils/subclass_with_meta.py b/graphene/utils/subclass_with_meta.py
index 78666c9..c4ee11d 100644
--- a/graphene/utils/subclass_with_meta.py
+++ b/graphene/utils/subclass_with_meta.py
@@ -1,4 +1,5 @@
from inspect import isclass
+
from .props import props
@@ -11,7 +12,7 @@ class SubclassWithMeta_Meta(type):
return cls.__name__
def __repr__(cls):
- return f'<{cls.__name__} meta={repr(cls._meta)}>'
+ return f"<{cls.__name__} meta={repr(cls._meta)}>"
class SubclassWithMeta(metaclass=SubclassWithMeta_Meta):
@@ -19,7 +20,7 @@ class SubclassWithMeta(metaclass=SubclassWithMeta_Meta):
def __init_subclass__(cls, **meta_options):
"""This method just terminates the super() chain"""
- _Meta = getattr(cls, 'Meta', None)
+ _Meta = getattr(cls, "Meta", None)
_meta_props = {}
if _Meta:
if isinstance(_Meta, dict):
@@ -28,16 +29,20 @@ class SubclassWithMeta(metaclass=SubclassWithMeta_Meta):
_meta_props = props(_Meta)
else:
raise Exception(
- f'Meta have to be either a class or a dict. Received {_Meta}'
- )
- delattr(cls, 'Meta')
+ f"Meta have to be either a class or a dict. Received {_Meta}"
+ )
+ delattr(cls, "Meta")
options = dict(meta_options, **_meta_props)
- abstract = options.pop('abstract', False)
+
+ abstract = options.pop("abstract", False)
if abstract:
- assert not options, f"Abstract types can only contain the abstract attribute. Received: abstract, {', '.join(options)}"
+ assert not options, (
+ "Abstract types can only contain the abstract attribute. "
+ f"Received: abstract, {', '.join(options)}"
+ )
else:
super_class = super(cls, cls)
- if hasattr(super_class, '__init_subclass_with_meta__'):
+ if hasattr(super_class, "__init_subclass_with_meta__"):
super_class.__init_subclass_with_meta__(**options)
@classmethod
diff --git a/graphene/utils/thenables.py b/graphene/utils/thenables.py
index 0ad5e44..9628699 100644
--- a/graphene/utils/thenables.py
+++ b/graphene/utils/thenables.py
@@ -1,13 +1,25 @@
"""
This file is used mainly as a bridge for thenable abstractions.
"""
+
from inspect import isawaitable
+def await_and_execute(obj, on_resolve):
+ async def build_resolve_async():
+ return on_resolve(await obj)
+
+ return build_resolve_async()
+
+
def maybe_thenable(obj, on_resolve):
"""
Execute a on_resolve function once the thenable is resolved,
returning the same type of object inputed.
If the object is not thenable, it should return on_resolve(obj)
"""
- pass
+ if isawaitable(obj):
+ return await_and_execute(obj, on_resolve)
+
+ # If it's not awaitable, return the function executed over the object
+ return on_resolve(obj)
diff --git a/graphene/utils/trim_docstring.py b/graphene/utils/trim_docstring.py
index 1f137a2..a23c7e7 100644
--- a/graphene/utils/trim_docstring.py
+++ b/graphene/utils/trim_docstring.py
@@ -1 +1,9 @@
import inspect
+
+
+def trim_docstring(docstring):
+ # Cleans up whitespaces from an indented docstring
+ #
+ # See https://www.python.org/dev/peps/pep-0257/
+ # and https://docs.python.org/2/library/inspect.html#inspect.cleandoc
+ return inspect.cleandoc(docstring) if docstring else None
diff --git a/graphene/validation/depth_limit.py b/graphene/validation/depth_limit.py
index d55fcf7..e0f2866 100644
--- a/graphene/validation/depth_limit.py
+++ b/graphene/validation/depth_limit.py
@@ -1,10 +1,195 @@
+# This is a Python port of https://github.com/stems/graphql-depth-limit
+# which is licensed under the terms of the MIT license, reproduced below.
+#
+# -----------
+#
+# MIT License
+#
+# Copyright (c) 2017 Stem
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
try:
from re import Pattern
except ImportError:
+ # backwards compatibility for v3.6
from typing import Pattern
from typing import Callable, Dict, List, Optional, Union, Tuple
+
from graphql import GraphQLError
from graphql.validation import ValidationContext, ValidationRule
-from graphql.language import DefinitionNode, FieldNode, FragmentDefinitionNode, FragmentSpreadNode, InlineFragmentNode, Node, OperationDefinitionNode
+from graphql.language import (
+ DefinitionNode,
+ FieldNode,
+ FragmentDefinitionNode,
+ FragmentSpreadNode,
+ InlineFragmentNode,
+ Node,
+ OperationDefinitionNode,
+)
+
from ..utils.is_introspection_key import is_introspection_key
+
+
IgnoreType = Union[Callable[[str], bool], Pattern, str]
+
+
+def depth_limit_validator(
+ max_depth: int,
+ ignore: Optional[List[IgnoreType]] = None,
+ callback: Optional[Callable[[Dict[str, int]], None]] = None,
+):
+ class DepthLimitValidator(ValidationRule):
+ def __init__(self, validation_context: ValidationContext):
+ document = validation_context.document
+ definitions = document.definitions
+
+ fragments = get_fragments(definitions)
+ queries = get_queries_and_mutations(definitions)
+ query_depths = {}
+
+ for name in queries:
+ query_depths[name] = determine_depth(
+ node=queries[name],
+ fragments=fragments,
+ depth_so_far=0,
+ max_depth=max_depth,
+ context=validation_context,
+ operation_name=name,
+ ignore=ignore,
+ )
+ if callable(callback):
+ callback(query_depths)
+ super().__init__(validation_context)
+
+ return DepthLimitValidator
+
+
+def get_fragments(
+ definitions: Tuple[DefinitionNode, ...],
+) -> Dict[str, FragmentDefinitionNode]:
+ fragments = {}
+ for definition in definitions:
+ if isinstance(definition, FragmentDefinitionNode):
+ fragments[definition.name.value] = definition
+ return fragments
+
+
+# This will actually get both queries and mutations.
+# We can basically treat those the same
+def get_queries_and_mutations(
+ definitions: Tuple[DefinitionNode, ...],
+) -> Dict[str, OperationDefinitionNode]:
+ operations = {}
+
+ for definition in definitions:
+ if isinstance(definition, OperationDefinitionNode):
+ operation = definition.name.value if definition.name else "anonymous"
+ operations[operation] = definition
+ return operations
+
+
+def determine_depth(
+ node: Node,
+ fragments: Dict[str, FragmentDefinitionNode],
+ depth_so_far: int,
+ max_depth: int,
+ context: ValidationContext,
+ operation_name: str,
+ ignore: Optional[List[IgnoreType]] = None,
+) -> int:
+ if depth_so_far > max_depth:
+ context.report_error(
+ GraphQLError(
+ f"'{operation_name}' exceeds maximum operation depth of {max_depth}.",
+ [node],
+ )
+ )
+ return depth_so_far
+ if isinstance(node, FieldNode):
+ should_ignore = is_introspection_key(node.name.value) or is_ignored(
+ node, ignore
+ )
+
+ if should_ignore or not node.selection_set:
+ return 0
+ return 1 + max(
+ map(
+ lambda selection: determine_depth(
+ node=selection,
+ fragments=fragments,
+ depth_so_far=depth_so_far + 1,
+ max_depth=max_depth,
+ context=context,
+ operation_name=operation_name,
+ ignore=ignore,
+ ),
+ node.selection_set.selections,
+ )
+ )
+ elif isinstance(node, FragmentSpreadNode):
+ return determine_depth(
+ node=fragments[node.name.value],
+ fragments=fragments,
+ depth_so_far=depth_so_far,
+ max_depth=max_depth,
+ context=context,
+ operation_name=operation_name,
+ ignore=ignore,
+ )
+ elif isinstance(
+ node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode)
+ ):
+ return max(
+ map(
+ lambda selection: determine_depth(
+ node=selection,
+ fragments=fragments,
+ depth_so_far=depth_so_far,
+ max_depth=max_depth,
+ context=context,
+ operation_name=operation_name,
+ ignore=ignore,
+ ),
+ node.selection_set.selections,
+ )
+ )
+ else:
+ raise Exception(
+ f"Depth crawler cannot handle: {node.kind}."
+ ) # pragma: no cover
+
+
+def is_ignored(node: FieldNode, ignore: Optional[List[IgnoreType]] = None) -> bool:
+ if ignore is None:
+ return False
+ for rule in ignore:
+ field_name = node.name.value
+ if isinstance(rule, str):
+ if field_name == rule:
+ return True
+ elif isinstance(rule, Pattern):
+ if rule.match(field_name):
+ return True
+ elif callable(rule):
+ if rule(field_name):
+ return True
+ else:
+ raise ValueError(f"Invalid ignore option: {rule}.")
+ return False
diff --git a/graphene/validation/disable_introspection.py b/graphene/validation/disable_introspection.py
index d18720b..49a7d60 100644
--- a/graphene/validation/disable_introspection.py
+++ b/graphene/validation/disable_introspection.py
@@ -1,8 +1,16 @@
from graphql import GraphQLError
from graphql.language import FieldNode
from graphql.validation import ValidationRule
+
from ..utils.is_introspection_key import is_introspection_key
class DisableIntrospection(ValidationRule):
- pass
+ def enter_field(self, node: FieldNode, *_args):
+ field_name = node.name.value
+ if is_introspection_key(field_name):
+ self.report_error(
+ GraphQLError(
+ f"Cannot query '{field_name}': introspection is disabled.", node
+ )
+ )