Skip to content

Commit a7caad0

Browse files
authored
Merge pull request #38 from TangoAgency/batch
Add support for batching several requests into one
2 parents 2665d73 + 6bd89f2 commit a7caad0

File tree

3 files changed

+130
-31
lines changed

3 files changed

+130
-31
lines changed

graphene_django/tests/test_views.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,23 @@
88
from urllib.parse import urlencode
99

1010

11-
def url_string(**url_params):
12-
string = '/graphql'
13-
11+
def url_string(string='/graphql', **url_params):
1412
if url_params:
1513
string += '?' + urlencode(url_params)
1614

1715
return string
1816

1917

18+
def batch_url_string(**url_params):
19+
return url_string('/graphql/batch', **url_params)
20+
21+
2022
def response_json(response):
2123
return json.loads(response.content.decode())
2224

2325

2426
j = lambda **kwargs: json.dumps(kwargs)
27+
jl = lambda **kwargs: json.dumps([kwargs])
2528

2629

2730
def test_graphiql_is_enabled(client):
@@ -169,6 +172,17 @@ def test_allows_post_with_json_encoding(client):
169172
}
170173

171174

175+
def test_batch_allows_post_with_json_encoding(client):
176+
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')
177+
178+
assert response.status_code == 200
179+
assert response_json(response) == [{
180+
'id': 1,
181+
'payload': { 'data': {'test': "Hello World"} },
182+
'status': 200,
183+
}]
184+
185+
172186
def test_allows_sending_a_mutation_via_post(client):
173187
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
174188

@@ -199,6 +213,22 @@ def test_supports_post_json_query_with_string_variables(client):
199213
}
200214

201215

216+
217+
def test_batch_supports_post_json_query_with_string_variables(client):
218+
response = client.post(batch_url_string(), jl(
219+
id=1,
220+
query='query helloWho($who: String){ test(who: $who) }',
221+
variables=json.dumps({'who': "Dolly"})
222+
), 'application/json')
223+
224+
assert response.status_code == 200
225+
assert response_json(response) == [{
226+
'id': 1,
227+
'payload': { 'data': {'test': "Hello Dolly"} },
228+
'status': 200,
229+
}]
230+
231+
202232
def test_supports_post_json_query_with_json_variables(client):
203233
response = client.post(url_string(), j(
204234
query='query helloWho($who: String){ test(who: $who) }',
@@ -211,6 +241,21 @@ def test_supports_post_json_query_with_json_variables(client):
211241
}
212242

213243

244+
def test_batch_supports_post_json_query_with_json_variables(client):
245+
response = client.post(batch_url_string(), jl(
246+
id=1,
247+
query='query helloWho($who: String){ test(who: $who) }',
248+
variables={'who': "Dolly"}
249+
), 'application/json')
250+
251+
assert response.status_code == 200
252+
assert response_json(response) == [{
253+
'id': 1,
254+
'payload': { 'data': {'test': "Hello Dolly"} },
255+
'status': 200,
256+
}]
257+
258+
214259
def test_supports_post_url_encoded_query_with_string_variables(client):
215260
response = client.post(url_string(), urlencode(dict(
216261
query='query helloWho($who: String){ test(who: $who) }',
@@ -285,6 +330,33 @@ def test_allows_post_with_operation_name(client):
285330
}
286331

287332

333+
def test_batch_allows_post_with_operation_name(client):
334+
response = client.post(batch_url_string(), jl(
335+
id=1,
336+
query='''
337+
query helloYou { test(who: "You"), ...shared }
338+
query helloWorld { test(who: "World"), ...shared }
339+
query helloDolly { test(who: "Dolly"), ...shared }
340+
fragment shared on QueryRoot {
341+
shared: test(who: "Everyone")
342+
}
343+
''',
344+
operationName='helloWorld'
345+
), 'application/json')
346+
347+
assert response.status_code == 200
348+
assert response_json(response) == [{
349+
'id': 1,
350+
'payload': {
351+
'data': {
352+
'test': 'Hello World',
353+
'shared': 'Hello Everyone'
354+
}
355+
},
356+
'status': 200,
357+
}]
358+
359+
288360
def test_allows_post_with_get_operation_name(client):
289361
response = client.post(url_string(
290362
operationName='helloWorld'

graphene_django/tests/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from ..views import GraphQLView
44

55
urlpatterns = [
6+
url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
67
url(r'^graphql', GraphQLView.as_view(graphiql=True)),
78
]

graphene_django/views.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ class GraphQLView(View):
6262
middleware = None
6363
root_value = None
6464
pretty = False
65+
batch = False
6566

66-
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False):
67+
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
68+
batch=False):
6769
if not schema:
6870
schema = graphene_settings.SCHEMA
6971

@@ -77,8 +79,10 @@ def __init__(self, schema=None, executor=None, middleware=None, root_value=None,
7779
self.root_value = root_value
7880
self.pretty = pretty
7981
self.graphiql = graphiql
82+
self.batch = batch
8083

8184
assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
85+
assert not all((graphiql, batch)), 'Use either graphiql or batch processing'
8286

8387
# noinspection PyUnusedLocal
8488
def get_root_value(self, request):
@@ -99,34 +103,15 @@ def dispatch(self, request, *args, **kwargs):
99103
data = self.parse_body(request)
100104
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
101105

102-
query, variables, operation_name = self.get_graphql_params(request, data)
103-
104-
execution_result = self.execute_graphql_request(
105-
request,
106-
data,
107-
query,
108-
variables,
109-
operation_name,
110-
show_graphiql
111-
)
112-
113-
if execution_result:
114-
response = {}
115-
116-
if execution_result.errors:
117-
response['errors'] = [self.format_error(e) for e in execution_result.errors]
118-
119-
if execution_result.invalid:
120-
status_code = 400
121-
else:
122-
status_code = 200
123-
response['data'] = execution_result.data
124-
125-
result = self.json_encode(request, response, pretty=show_graphiql)
106+
if self.batch:
107+
responses = [self.get_response(request, entry) for entry in data]
108+
result = '[{}]'.format(','.join([response[0] for response in responses]))
109+
status_code = max(responses, key=lambda response: response[1])[1]
126110
else:
127-
result = None
111+
result, status_code = self.get_response(request, data, show_graphiql)
128112

129113
if show_graphiql:
114+
query, variables, operation_name, id = self.get_graphql_params(request, data)
130115
return self.render_graphiql(
131116
request,
132117
graphiql_version=self.graphiql_version,
@@ -150,6 +135,43 @@ def dispatch(self, request, *args, **kwargs):
150135
})
151136
return response
152137

138+
def get_response(self, request, data, show_graphiql=False):
139+
query, variables, operation_name, id = self.get_graphql_params(request, data)
140+
141+
execution_result = self.execute_graphql_request(
142+
request,
143+
data,
144+
query,
145+
variables,
146+
operation_name,
147+
show_graphiql
148+
)
149+
150+
status_code = 200
151+
if execution_result:
152+
response = {}
153+
154+
if execution_result.errors:
155+
response['errors'] = [self.format_error(e) for e in execution_result.errors]
156+
157+
if execution_result.invalid:
158+
status_code = 400
159+
else:
160+
response['data'] = execution_result.data
161+
162+
if self.batch:
163+
response = {
164+
'id': id,
165+
'payload': response,
166+
'status': status_code,
167+
}
168+
169+
result = self.json_encode(request, response, pretty=show_graphiql)
170+
else:
171+
result = None
172+
173+
return result, status_code
174+
153175
def render_graphiql(self, request, **data):
154176
return render(request, self.graphiql_template, data)
155177

@@ -170,7 +192,10 @@ def parse_body(self, request):
170192
elif content_type == 'application/json':
171193
try:
172194
request_json = json.loads(request.body.decode('utf-8'))
173-
assert isinstance(request_json, dict)
195+
if self.batch:
196+
assert isinstance(request_json, list)
197+
else:
198+
assert isinstance(request_json, dict)
174199
return request_json
175200
except:
176201
raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.'))
@@ -242,6 +267,7 @@ def request_wants_html(cls, request):
242267
def get_graphql_params(request, data):
243268
query = request.GET.get('query') or data.get('query')
244269
variables = request.GET.get('variables') or data.get('variables')
270+
id = request.GET.get('id') or data.get('id')
245271

246272
if variables and isinstance(variables, six.text_type):
247273
try:
@@ -251,7 +277,7 @@ def get_graphql_params(request, data):
251277

252278
operation_name = request.GET.get('operationName') or data.get('operationName')
253279

254-
return query, variables, operation_name
280+
return query, variables, operation_name, id
255281

256282
@staticmethod
257283
def format_error(error):

0 commit comments

Comments
 (0)