Skip to content

Commit e29a3e4

Browse files
authored
Support parameterized queries for openCypher (#496)
* Support parameterized queries for openCypher * update changelog --------- Co-authored-by: Michael Chin <[email protected]>
1 parent b230002 commit e29a3e4

File tree

4 files changed

+88
-4
lines changed

4 files changed

+88
-4
lines changed

ChangeLog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd
77
- Path: 03-Sample-Applications > 05-Healthcare-and-Life-Sciences-Graphs
88
- Added openCypher and local file path support to `%seed` ([Link to PR](https://github.com/aws/graph-notebook/pull/292))
99
- Added S3 support to `%seed` ([Link to PR](https://github.com/aws/graph-notebook/pull/488))
10+
- Added support for openCypher parameterized queries ([Link to PR](https://github.com/aws/graph-notebook/pull/496))
1011
- Added `%toggle_traceback` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/486))
1112
- Added support for setting `%graph_notebook_vis_options` from a variable ([Link to PR](https://github.com/aws/graph-notebook/pull/487))
1213
- Pinned JupyterLab<4.x to fix Python 3.8/3.10 builds ([Link to PR](https://github.com/aws/graph-notebook/pull/490))

src/graph_notebook/magics/graph_magic.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,6 +2619,9 @@ def handle_opencypher_query(self, line, cell, local_ns):
26192619
parser.add_argument('--explain-type', default='dynamic',
26202620
help='explain mode to use when using the explain query mode',
26212621
choices=['dynamic', 'static', 'details', 'debug'])
2622+
parser.add_argument('-qp', '--query-parameters', type=str, default='',
2623+
help='Parameter definitions to apply to the query. This option can accept a local variable '
2624+
'name, or a string representation of the map.')
26222625
parser.add_argument('-g', '--group-by', type=str, default='~labels',
26232626
help='Property used to group nodes (e.g. code, ~id) default is ~labels')
26242627
parser.add_argument('-gd', '--group-by-depth', action='store_true', default=False,
@@ -2660,6 +2663,20 @@ def handle_opencypher_query(self, line, cell, local_ns):
26602663
res_format = None
26612664
results_df = None
26622665

2666+
query_params = None
2667+
if args.query_parameters:
2668+
if args.query_parameters in local_ns:
2669+
query_params_input = local_ns[args.query_parameters]
2670+
else:
2671+
query_params_input = args.query_parameters
2672+
if isinstance(query_params_input, dict):
2673+
query_params = query_params_input
2674+
else:
2675+
try:
2676+
query_params = json.loads(query_params_input.replace("'", '"'))
2677+
except Exception as e:
2678+
print(f"Invalid query parameter input, ignoring.")
2679+
26632680
if args.no_scroll:
26642681
oc_layout = UNRESTRICTED_LAYOUT
26652682
oc_scrollY = True
@@ -2680,7 +2697,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
26802697

26812698
if args.mode == 'explain':
26822699
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
2683-
res = self.client.opencypher_http(cell, explain=args.explain_type)
2700+
res = self.client.opencypher_http(cell, explain=args.explain_type, query_params=query_params)
26842701
query_time = time.time() * 1000 - query_start
26852702
explain = res.content.decode("utf-8")
26862703
res.raise_for_status()
@@ -2696,7 +2713,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
26962713
link=f"data:text/html;base64,{base64_str}")
26972714
elif args.mode == 'query':
26982715
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
2699-
oc_http = self.client.opencypher_http(cell)
2716+
oc_http = self.client.opencypher_http(cell, query_params=query_params)
27002717
query_time = time.time() * 1000 - query_start
27012718
oc_http.raise_for_status()
27022719

@@ -2754,7 +2771,10 @@ def handle_opencypher_query(self, line, cell, local_ns):
27542771
elif args.mode == 'bolt':
27552772
res_format = 'bolt'
27562773
query_start = time.time() * 1000
2757-
res = self.client.opencyper_bolt(cell)
2774+
if query_params:
2775+
res = self.client.opencyper_bolt(cell, **query_params)
2776+
else:
2777+
res = self.client.opencyper_bolt(cell)
27582778
query_time = time.time() * 1000 - query_start
27592779
if not args.silent:
27602780
oc_metadata = build_opencypher_metadata_from_query(query_type='bolt', results=res,

src/graph_notebook/neptune/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ def _gremlin_query_plan(self, query: str, plan_type: str, args: dict, ) -> reque
327327
res = self._http_session.send(req, verify=self.ssl_verify)
328328
return res
329329

330-
def opencypher_http(self, query: str, headers: dict = None, explain: str = None) -> requests.Response:
330+
def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
331+
query_params: dict = None) -> requests.Response:
331332
if headers is None:
332333
headers = {}
333334

@@ -343,6 +344,8 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None)
343344
if explain:
344345
data['explain'] = explain
345346
headers['Accept'] = "text/html"
347+
if query_params:
348+
data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}'
346349
else:
347350
url += 'db/neo4j/tx/commit'
348351
headers['content-type'] = 'application/json'

test/integration/iam/notebook/test_open_cypher_graph_notebook.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,66 @@ def test_opencypher_bolt(self):
4848
assert len(res) == 1
4949
assert 'b' in res[0]
5050

51+
@pytest.mark.jupyter
52+
@pytest.mark.opencypher
53+
def test_opencypher_query_parameterized_with_var_input(self):
54+
expected_league_name = "English Premier League"
55+
query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name'
56+
57+
store_to_var = 'res'
58+
self.ip.user_ns['params_var'] = {'LEAGUE_NICKNAME': 'EPL'}
59+
cell = f'''%%oc --query-parameters params_var --store-to {store_to_var}
60+
{query}'''
61+
self.ip.run_cell(cell)
62+
res = self.ip.user_ns[store_to_var]
63+
64+
assert len(res['results']) == 1
65+
assert expected_league_name == res['results'][0]['l.name']
66+
67+
@pytest.mark.jupyter
68+
@pytest.mark.opencypher
69+
def test_opencypher_query_parameterized_with_str_input(self):
70+
expected_league_name = "English Premier League"
71+
query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name'
72+
73+
store_to_var = 'res'
74+
params_str = '{"LEAGUE_NICKNAME":"EPL"}'
75+
cell = f'''%%oc --query-parameters {params_str} --store-to {store_to_var}
76+
{query}'''
77+
self.ip.run_cell(cell)
78+
res = self.ip.user_ns[store_to_var]
79+
80+
assert len(res['results']) == 1
81+
assert expected_league_name == res['results'][0]['l.name']
82+
83+
@pytest.mark.jupyter
84+
@pytest.mark.opencypher
85+
def test_opencypher_query_parameterized_invalid(self):
86+
query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name'
87+
88+
self.ip.user_ns['params_var'] = ['LEAGUE_NICKNAME']
89+
cell = f'''%%oc --query-parameters params_var
90+
{query}'''
91+
self.ip.run_cell(cell)
92+
self.assertTrue('graph_notebook_error' in self.ip.user_ns)
93+
94+
@pytest.mark.jupyter
95+
@pytest.mark.opencypher
96+
def test_opencypher_bolt_parameterized(self):
97+
expected_league_name = "English Premier League"
98+
query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name'
99+
100+
store_to_var = 'res'
101+
params_var = '{"LEAGUE_NICKNAME":"EPL"}'
102+
cell = f'''%%oc bolt --query-parameters {params_var} --store-to {store_to_var}
103+
{query}'''
104+
self.ip.run_cell(cell)
105+
self.assertFalse('graph_notebook_error' in self.ip.user_ns)
106+
res = self.ip.user_ns[store_to_var]
107+
108+
assert len(res) == 1
109+
assert expected_league_name == res[0]['l.name']
110+
51111
@pytest.mark.jupyter
52112
def test_load_opencypher_config(self):
53113
config = '''{

0 commit comments

Comments
 (0)