diff --git a/redisgraph/graph.py b/redisgraph/graph.py index 909cbb4..f4d802f 100644 --- a/redisgraph/graph.py +++ b/redisgraph/graph.py @@ -218,7 +218,7 @@ def merge(self, pattern): return self.query(query) # Procedures. - def call_procedure(self, procedure, read_only=False, *args, **kwagrs): + def call_procedure(self, procedure, *args, read_only=False, **kwagrs): args = [quote_string(arg) for arg in args] q = 'CALL %s(%s)' % (procedure, ','.join(args)) diff --git a/redisgraph/query_result.py b/redisgraph/query_result.py index 6003802..4d92f1f 100644 --- a/redisgraph/query_result.py +++ b/redisgraph/query_result.py @@ -4,6 +4,7 @@ from .exceptions import VersionMismatchException from prettytable import PrettyTable from redis import ResponseError +from collections import OrderedDict LABELS_ADDED = 'Labels added' NODES_CREATED = 'Nodes created' @@ -39,6 +40,7 @@ class ResultSetScalarTypes: VALUE_EDGE = 7 VALUE_NODE = 8 VALUE_PATH = 9 + VALUE_MAP = 10 class QueryResult: @@ -125,6 +127,14 @@ def parse_entity_properties(self, props): return properties + def parse_string(self, cell): + if isinstance(cell, bytes): + return cell.decode() + elif not isinstance(cell, str): + return str(cell) + else: + return cell + def parse_node(self, cell): # Node ID (integer), # [label string offset (integer)], @@ -156,6 +166,19 @@ def parse_path(self, cell): edges = self.parse_scalar(cell[1]) return Path(nodes, edges) + def parse_map(self, cell): + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = self.parse_scalar(cell[i+1]) + + return m + def parse_scalar(self, cell): scalar_type = int(cell[0]) value = cell[1] @@ -165,12 +188,7 @@ def parse_scalar(self, cell): scalar = None elif scalar_type == ResultSetScalarTypes.VALUE_STRING: - if isinstance(value, bytes): - scalar = value.decode() - elif not isinstance(value, str): - scalar = str(value) - else: - scalar = value + scalar = self.parse_string(value) elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER: scalar = int(value) @@ -202,6 +220,9 @@ def parse_scalar(self, cell): elif scalar_type == ResultSetScalarTypes.VALUE_PATH: scalar = self.parse_path(value) + elif scalar_type == ResultSetScalarTypes.VALUE_MAP: + scalar = self.parse_map(value) + elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN: print("Unknown scalar type\n") diff --git a/tests/functional/test_all.py b/tests/functional/test_all.py index 7f7bf9c..422448a 100644 --- a/tests/functional/test_all.py +++ b/tests/functional/test_all.py @@ -102,6 +102,19 @@ def test_param(self): # All done, remove graph. redis_graph.delete() + def test_map(self): + redis_graph = Graph('map', self.r) + + query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" + + actual = redis_graph.query(query).result_set[0][0] + expected = {'a': 1, 'b': 'str', 'c': None, 'd': [1, 2, 3], 'e': True, 'f': {'x': 1, 'y': 2}} + + self.assertEqual(actual, expected) + + # All done, remove graph. + redis_graph.delete() + def test_index_response(self): redis_graph = Graph('social', self.r)