Skip to content

Commit 0c67f8f

Browse files
committed
Allow providing directives to GraphQLSchema
Closes #33
1 parent 3ee3ee5 commit 0c67f8f

File tree

7 files changed

+118
-39
lines changed

7 files changed

+118
-39
lines changed

graphql/core/type/directives.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33

44

55
class GraphQLDirective(object):
6-
pass
6+
__slots__ = 'name', 'args', 'description', 'on_operation', 'on_fragment', 'on_field'
7+
8+
def __init__(self, name, description=None, args=None, on_operation=False, on_fragment=False, on_field=False):
9+
self.name = name
10+
self.description = description
11+
self.args = args or []
12+
self.on_operation = on_operation
13+
self.on_fragment = on_fragment
14+
self.on_field = on_field
715

816

917
def arg(name, *args, **kwargs):
@@ -12,25 +20,26 @@ def arg(name, *args, **kwargs):
1220
return a
1321

1422

15-
class GraphQLIncludeDirective(GraphQLDirective):
16-
name = 'include'
17-
args = [arg(
23+
GraphQLIncludeDirective = GraphQLDirective(
24+
name='include',
25+
args=[arg(
1826
'if',
1927
type=GraphQLNonNull(GraphQLBoolean),
2028
description='Directs the executor to include this field or fragment only when the `if` argument is true.',
21-
)]
22-
on_operation = False
23-
on_fragment = True
24-
on_field = True
25-
26-
27-
class GraphQLSkipDirective(GraphQLDirective):
28-
name = 'skip'
29-
args = [arg(
29+
)],
30+
on_operation=False,
31+
on_fragment=True,
32+
on_field=True
33+
)
34+
35+
GraphQLSkipDirective = GraphQLDirective(
36+
name='skip',
37+
args=[arg(
3038
'if',
3139
type=GraphQLNonNull(GraphQLBoolean),
3240
description='Directs the executor to skip this field or fragment only when the `if` argument is true.',
33-
)]
34-
on_operation = False
35-
on_fragment = True
36-
on_field = True
41+
)],
42+
on_operation=False,
43+
on_fragment=True,
44+
on_field=True
45+
)

graphql/core/type/schema.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
GraphQLObjectType,
88
GraphQLUnionType,
99
)
10-
from .directives import GraphQLIncludeDirective, GraphQLSkipDirective
10+
from .directives import GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective
1111
from .introspection import IntrospectionSchema
1212

1313

@@ -26,7 +26,7 @@ class GraphQLSchema(object):
2626
"""
2727
__slots__ = '_query', '_mutation', '_subscription', '_type_map', '_directives',
2828

29-
def __init__(self, query, mutation=None, subscription=None):
29+
def __init__(self, query, mutation=None, subscription=None, directives=None):
3030
assert isinstance(query, GraphQLObjectType), 'Schema query must be Object Type but got: {}.'.format(query)
3131
if mutation:
3232
assert isinstance(mutation, GraphQLObjectType), \
@@ -40,7 +40,19 @@ def __init__(self, query, mutation=None, subscription=None):
4040
self._mutation = mutation
4141
self._subscription = subscription
4242
self._type_map = self._build_type_map()
43-
self._directives = None
43+
44+
if directives is None:
45+
directives = [
46+
GraphQLIncludeDirective,
47+
GraphQLSkipDirective
48+
]
49+
50+
assert all(isinstance(d, GraphQLDirective) for d in directives), \
51+
'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format(
52+
directives
53+
)
54+
55+
self._directives = directives
4456

4557
for type in self._type_map.values():
4658
if isinstance(type, GraphQLObjectType):
@@ -63,12 +75,6 @@ def get_type(self, name):
6375
return self._type_map.get(name)
6476

6577
def get_directives(self):
66-
if self._directives is None:
67-
self._directives = [
68-
GraphQLIncludeDirective,
69-
GraphQLSkipDirective
70-
]
71-
7278
return self._directives
7379

7480
def get_directive(self, name):

graphql/core/utils/build_client_schema.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_input_type,
2323
is_output_type
2424
)
25+
from ..type.directives import GraphQLDirective
2526
from ..type.introspection import TypeKind
2627
from .value_from_ast import value_from_ast
2728

@@ -194,19 +195,42 @@ def build_default_value(f):
194195

195196
def build_input_value_def_map(input_value_introspection, argument_type):
196197
return OrderedDict([
197-
(f['name'], argument_type(
198-
description=f['description'],
199-
type=get_input_type(f['type']),
200-
default_value=build_default_value(f)
201-
)) for f in input_value_introspection
198+
(f['name'], build_input_value(f, argument_type)) for f in input_value_introspection
202199
])
203200

201+
def build_input_value(input_value_introspection, argument_type):
202+
input_value = argument_type(
203+
description=input_value_introspection['description'],
204+
type=get_input_type(input_value_introspection['type']),
205+
default_value=build_default_value(input_value_introspection)
206+
)
207+
input_value.name = input_value_introspection['name']
208+
return input_value
209+
210+
def build_directive(directive_introspection):
211+
return GraphQLDirective(
212+
name=directive_introspection['name'],
213+
description=directive_introspection['description'],
214+
args=[build_input_value(a, GraphQLArgument) for a in directive_introspection['args']],
215+
on_operation=directive_introspection['onOperation'],
216+
on_fragment=directive_introspection['onFragment'],
217+
on_field=directive_introspection['onField']
218+
)
219+
204220
for type_introspection_name in type_introspection_map:
205221
get_named_type(type_introspection_name)
206222

207-
query_type = get_type(schema_introspection['queryType'])
208-
mutation_type = get_type(schema_introspection['mutationType']) if schema_introspection.get('mutationType') else None
209-
subscription_type = get_type(schema_introspection['subscriptionType']) if \
223+
query_type = get_object_type(schema_introspection['queryType'])
224+
mutation_type = get_object_type(schema_introspection['mutationType']) if schema_introspection.get('mutationType') else None
225+
subscription_type = get_object_type(schema_introspection['subscriptionType']) if \
210226
schema_introspection.get('subscriptionType') else None
211227

212-
return GraphQLSchema(query=query_type, mutation=mutation_type, subscription=subscription_type)
228+
directives = [build_directive(d) for d in schema_introspection['directives']] \
229+
if schema_introspection['directives'] else []
230+
231+
return GraphQLSchema(
232+
query=query_type,
233+
mutation=mutation_type,
234+
subscription=subscription_type,
235+
directives=directives
236+
)

tests/core_type/test_validation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ def test_rejects_a_schema_whose_subscription_type_is_an_input_type(self):
164164

165165
assert str(excinfo.value) == 'Schema subscription must be Object Type but got: SomeInputObject.'
166166

167+
def test_rejects_a_schema_whose_directives_are_incorrectly_typed(self):
168+
with raises(AssertionError) as excinfo:
169+
GraphQLSchema(query=SomeObjectType, directives=['somedirective'])
170+
171+
assert str(excinfo.value) == 'Schema directives must be List[GraphQLDirective] if provided but got: ' \
172+
'[\'somedirective\'].'
173+
167174

168175
# noinspection PyMethodMayBeStatic,PyPep8Naming
169176
class TestTypeSystem_ASchemaMustContainUniquelyNamedTypes:

tests/core_utils/test_build_client_schema.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
GraphQLBoolean,
2323
GraphQLID,
2424
)
25+
from graphql.core.type.directives import GraphQLDirective
2526
from graphql.core.utils.introspection_query import introspection_query
2627
from graphql.core.utils.build_client_schema import build_client_schema
2728

@@ -382,6 +383,30 @@ def test_builds_a_schema_with_field_arguments_with_default_values():
382383
_test_schema(schema)
383384

384385

386+
def test_builds_a_schema_with_custom_directives():
387+
schema = GraphQLSchema(
388+
query=GraphQLObjectType(
389+
name='Simple',
390+
description='This is a simple type',
391+
fields={
392+
'string': GraphQLField(
393+
type=GraphQLString,
394+
description='This is a string field'
395+
)
396+
},
397+
),
398+
directives=[
399+
GraphQLDirective(
400+
name='customDirective',
401+
description='This is a custom directive',
402+
on_field=True
403+
)
404+
]
405+
)
406+
407+
_test_schema(schema)
408+
409+
385410
def test_builds_a_schema_aware_of_deprecation():
386411
schema = GraphQLSchema(
387412
query=GraphQLObjectType(

tests/core_validation/test_known_directives.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,12 @@ def test_with_well_placed_directives():
8989
def test_with_misplaced_directives():
9090
expect_fails_rule(KnownDirectives, '''
9191
query Foo @include(if: true) {
92-
name
93-
...Frag
92+
name @operationOnly
93+
...Frag @operationOnly
9494
}
9595
''', [
96-
misplaced_directive('include', 'operation', 2, 17)
96+
misplaced_directive('include', 'operation', 2, 17),
97+
misplaced_directive('operationOnly', 'field', 3, 14),
98+
misplaced_directive('operationOnly', 'fragment', 4, 17),
99+
97100
])

tests/core_validation/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from graphql.core.type.directives import GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective
12
from graphql.core.validation import validate
23
from graphql.core.language.parser import parse
34
from graphql.core.type import (
@@ -162,7 +163,11 @@
162163
'complicatedArgs': GraphQLField(ComplicatedArgs),
163164
})
164165

165-
default_schema = GraphQLSchema(query=QueryRoot)
166+
default_schema = GraphQLSchema(query=QueryRoot, directives=[
167+
GraphQLDirective(name='operationOnly', on_operation=True),
168+
GraphQLIncludeDirective,
169+
GraphQLSkipDirective
170+
])
166171

167172

168173
def expect_valid(schema, rules, query):

0 commit comments

Comments
 (0)