Skip to content

Commit 6f04dc4

Browse files
committed
fix: use jinja built-in render
instead of regex search and replace
1 parent 184ba72 commit 6f04dc4

File tree

2 files changed

+39
-61
lines changed

2 files changed

+39
-61
lines changed

graphql_server/render_graphiql.py

Lines changed: 16 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Based on (express-graphql)[https://github.com/graphql/express-graphql/blob/main/src/renderGraphiQL.ts] and
22
(subscriptions-transport-ws)[https://github.com/apollographql/subscriptions-transport-ws]"""
3-
import json
4-
import re
53
from typing import Any, Dict, Optional, Tuple
64

75
from jinja2 import Environment
@@ -216,54 +214,6 @@ class GraphiQLOptions(TypedDict):
216214
should_persist_headers: Optional[bool]
217215

218216

219-
def escape_js_value(value: Any) -> Any:
220-
quotation = False
221-
if value.startswith('"') and value.endswith('"'):
222-
quotation = True
223-
value = value[1 : len(value) - 1]
224-
225-
value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n")
226-
if quotation:
227-
value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"'
228-
229-
return value
230-
231-
232-
def process_var(template: str, name: str, value: Any, jsonify=False) -> str:
233-
pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}"
234-
if jsonify and value not in ["null", "undefined"]:
235-
value = json.dumps(value)
236-
value = escape_js_value(value)
237-
238-
return re.sub(pattern, value, template)
239-
240-
241-
def simple_renderer(template: str, **values: Dict[str, Any]) -> str:
242-
replace = [
243-
"graphiql_version",
244-
"graphiql_html_title",
245-
"subscription_url",
246-
"header_editor_enabled",
247-
"should_persist_headers",
248-
]
249-
replace_jsonify = [
250-
"query",
251-
"result",
252-
"variables",
253-
"operation_name",
254-
"default_query",
255-
"headers",
256-
]
257-
258-
for r in replace:
259-
template = process_var(template, r, values.get(r, ""))
260-
261-
for r in replace_jsonify:
262-
template = process_var(template, r, values.get(r, ""), True)
263-
264-
return template
265-
266-
267217
def _render_graphiql(
268218
data: GraphiQLData,
269219
config: GraphiQLConfig,
@@ -296,6 +246,9 @@ def _render_graphiql(
296246
or "false",
297247
}
298248

249+
if template_vars["result"] in ("null", "undefined"):
250+
template_vars["result"] = None
251+
299252
return graphiql_template, template_vars
300253

301254

@@ -305,16 +258,17 @@ async def render_graphiql_async(
305258
options: Optional[GraphiQLOptions] = None,
306259
) -> str:
307260
graphiql_template, template_vars = _render_graphiql(data, config, options)
308-
jinja_env: Optional[Environment] = config.get("jinja_env")
309-
310-
if jinja_env:
311-
template = jinja_env.from_string(graphiql_template)
312-
if jinja_env.is_async:
313-
source = await template.render_async(**template_vars)
314-
else:
315-
source = template.render(**template_vars)
316-
else:
317-
source = simple_renderer(graphiql_template, **template_vars)
261+
262+
jinja_env = config.get("jinja_env") or Environment()
263+
264+
template = jinja_env.from_string(graphiql_template)
265+
266+
source = (
267+
await template.render_async(**template_vars)
268+
if jinja_env.is_async
269+
else template.render(**template_vars)
270+
)
271+
318272
return source
319273

320274

@@ -325,5 +279,6 @@ def render_graphiql_sync(
325279
) -> str:
326280
graphiql_template, template_vars = _render_graphiql(data, config, options)
327281

328-
source = simple_renderer(graphiql_template, **template_vars)
282+
template = Environment().from_string(graphiql_template)
283+
source = template.render(**template_vars)
329284
return source

tests/test_query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
load_json_body,
1616
run_http_query,
1717
)
18+
from graphql_server.render_graphiql import (
19+
GraphiQLConfig,
20+
GraphiQLData,
21+
render_graphiql_sync,
22+
)
1823

1924
from .schema import invalid_schema, schema
2025
from .utils import as_dicts
@@ -653,3 +658,21 @@ def test_batch_allows_post_with_operation_name():
653658
results, params = run_http_query(schema, "post", data, batch_enabled=True)
654659

655660
assert results == [({"test": "Hello World", "shared": "Hello Everyone"}, None)]
661+
662+
663+
def test_graphiql_render_umlaut():
664+
results, params = run_http_query(
665+
schema,
666+
"get",
667+
data=dict(query="query helloWho($who: String){ test(who: $who) }"),
668+
query_data=dict(variables='{"who": "Björn"}'),
669+
catch=True,
670+
)
671+
result, status_code = encode_execution_results(results)
672+
673+
assert status_code == 200
674+
675+
graphiql_data = GraphiQLData(result=result, query=params[0].query)
676+
source = render_graphiql_sync(data=graphiql_data, config=GraphiQLConfig())
677+
678+
assert "Hello Bj\\\\u00f6rn" in source

0 commit comments

Comments
 (0)