diff --git a/redisgraph/graph.py b/redisgraph/graph.py index a12fe7f..2322187 100644 --- a/redisgraph/graph.py +++ b/redisgraph/graph.py @@ -1,6 +1,6 @@ import redis -from redisgraph.util import random_string, quote_string +from redisgraph.util import random_string, quote_string, stringify_param_value from redisgraph.query_result import QueryResult from redisgraph.exceptions import VersionMismatchException @@ -162,13 +162,7 @@ def _build_params_header(self, params): # Header starts with "CYPHER" params_header = "CYPHER " for key, value in params.items(): - # If value is string add quotation marks. - if isinstance(value, str): - value = quote_string(value) - # Value is None, replace with "null" string. - elif value is None: - value = "null" - params_header += str(key) + "=" + str(value) + " " + params_header += str(key) + "=" + stringify_param_value(value) + " " return params_header def query(self, q, params=None, timeout=None, read_only=False): diff --git a/redisgraph/util.py b/redisgraph/util.py index fc5480a..36f0f11 100644 --- a/redisgraph/util.py +++ b/redisgraph/util.py @@ -1,7 +1,7 @@ import random import string -__all__ = ['random_string', 'quote_string'] +__all__ = ['random_string', 'quote_string', 'stringify_param_value'] def random_string(length=10): @@ -28,3 +28,30 @@ def quote_string(v): v = v.replace('"', '\\"') return '"{}"'.format(v) + + +def stringify_param_value(value): + """ + Turn a parameter value into a string suitable for the params header of + a Cypher command. + + You may pass any value that would be accepted by `json.dumps()`. + + Ways in which output differs from that of `str()`: + * Strings are quoted. + * None --> "null". + * In dictionaries, keys are _not_ quoted. + + :param value: The parameter value to be turned into a string. + :return: string + """ + if isinstance(value, str): + return quote_string(value) + elif value is None: + return "null" + elif isinstance(value, (list, tuple)): + return f'[{",".join(map(stringify_param_value, value))}]' + elif isinstance(value, dict): + return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' + else: + return str(value) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 5476933..4d524d4 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -17,3 +17,28 @@ def test_quote_string(self): self.assertEqual(util.quote_string('\"'), '"\\\""') self.assertEqual(util.quote_string('"'), '"\\""') self.assertEqual(util.quote_string('a"a'), '"a\\"a"') + + def test_stringify_param_value(self): + cases = [ + [ + "abc", '"abc"' + ], + [ + None, "null" + ], + [ + ["abc", 123, None], + '["abc",123,null]' + ], + [ + {'age': 2, 'color': 'orange'}, + '{age:2,color:"orange"}' + ], + [ + [{'age': 2, 'color': 'orange'}, {'age': 7, 'color': 'gray'}, ], + '[{age:2,color:"orange"},{age:7,color:"gray"}]' + ], + ] + for param, expected in cases: + observed = util.stringify_param_value(param) + self.assertEqual(observed, expected)