Skip to content

Commit a047733

Browse files
authored
Add %%oc --plan-cache support for Neptune DB (#613)
* Add --plan-cache support for NeptuneDB * update changelog
1 parent f07e887 commit a047733

File tree

3 files changed

+58
-31
lines changed

3 files changed

+58
-31
lines changed

ChangeLog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook.
44

55
## Upcoming
6+
67
- Added `%reset_graph` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/610))
78
- Added `%get_graph` line magic and enabled `%status` for Neptune Analytics ([Link to PR](https://github.com/aws/graph-notebook/pull/611))
9+
- Added `%%oc --plan-cache` support for Neptune DB ([Link to PR](https://github.com/aws/graph-notebook/pull/613))
810
- Upgraded to Gremlin-Python 3.7 ([Link to PR](https://github.com/aws/graph-notebook/pull/597))
911

1012
## Release 4.3.1 (June 3, 2024)

src/graph_notebook/magics/graph_magic.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from graph_notebook.magics.streams import StreamViewer
4747
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
4848
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
49-
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
49+
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
5050
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
5151
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
5252
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
@@ -180,7 +180,7 @@
180180
MEDIA_TYPE_NTRIPLES_TEXT, MEDIA_TYPE_TURTLE, MEDIA_TYPE_N3, MEDIA_TYPE_TRIX,
181181
MEDIA_TYPE_TRIG, MEDIA_TYPE_RDF4J_BINARY]
182182

183-
byte_units = {'B': 1, 'KB': 1024, 'MB': 1024**2, 'GB': 1024**3, 'TB': 1024**4}
183+
byte_units = {'B': 1, 'KB': 1024, 'MB': 1024 ** 2, 'GB': 1024 ** 3, 'TB': 1024 ** 4}
184184

185185

186186
class QueryMode(Enum):
@@ -521,11 +521,11 @@ def neptune_config_allowlist(self, line='', cell=''):
521521

522522
@line_magic
523523
@neptune_db_only
524-
def stream_viewer(self,line):
524+
def stream_viewer(self, line):
525525
parser = argparse.ArgumentParser()
526526
parser.add_argument('language', nargs='?', default=STREAM_PG,
527527
help=f'language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]',
528-
choices = [STREAM_PG, STREAM_RDF])
528+
choices=[STREAM_PG, STREAM_RDF])
529529

530530
parser.add_argument('--limit', type=int, default=10, help='Maximum number of rows to display at a time')
531531

@@ -534,7 +534,7 @@ def stream_viewer(self,line):
534534
language = args.language
535535
limit = args.limit
536536
uri = self.client.get_uri_with_port()
537-
viewer = StreamViewer(self.client,uri,language,limit=limit)
537+
viewer = StreamViewer(self.client, uri, language, limit=limit)
538538
viewer.show()
539539

540540
@line_magic
@@ -877,7 +877,8 @@ def sparql(self, line='', cell='', local_ns: dict = None):
877877
if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE':
878878
lines = []
879879
for b in results['results']['bindings']:
880-
lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
880+
lines.append(
881+
f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
881882
raw_output = widgets.Output(layout=sparql_layout)
882883
with raw_output:
883884
html = sparql_construct_template.render(lines=lines)
@@ -1168,7 +1169,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
11681169
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
11691170
if self.graph_notebook_config.proxy_host != '' and self.client.is_neptune_domain():
11701171
using_http = True
1171-
query_res_http = self.client.gremlin_http_query(cell, headers={'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
1172+
query_res_http = self.client.gremlin_http_query(cell, headers={
1173+
'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
11721174
query_res_http.raise_for_status()
11731175
query_res_http_json = query_res_http.json()
11741176
query_res = query_res_http_json['result']['data']
@@ -1603,7 +1605,7 @@ def on_button_delete_clicked(b):
16031605
with output:
16041606
job_status_output.clear_output()
16051607
interval_output.close()
1606-
total_status_wait = max_status_retries*poll_interval
1608+
total_status_wait = max_status_retries * poll_interval
16071609
print(result)
16081610
if interval_check_response.get("status") != 'healthy':
16091611
print(f"Could not retrieve the status of the reset operation within the allotted time of "
@@ -1849,7 +1851,7 @@ def load(self, line='', local_ns: dict = None):
18491851
value=str(args.concurrency),
18501852
placeholder=1,
18511853
min=1,
1852-
max=2**16,
1854+
max=2 ** 16,
18531855
disabled=False,
18541856
layout=widgets.Layout(display=concurrency_hbox_visibility,
18551857
width=widget_width)
@@ -2057,8 +2059,8 @@ def on_button_clicked(b):
20572059
named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
20582060
base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
20592061
dep_hbox.children = (dep_hbox_label, dependencies,)
2060-
concurrency_hbox.children = (concurrency_hbox_label, concurrency, )
2061-
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit, )
2062+
concurrency_hbox.children = (concurrency_hbox_label, concurrency,)
2063+
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit,)
20622064

20632065
validated = True
20642066
validation_label_style = DescriptionStyle(color='red')
@@ -2210,8 +2212,9 @@ def on_button_clicked(b):
22102212

22112213
if poll_status.value == 'FALSE':
22122214
start_msg_label = widgets.Label(f'Load started successfully!')
2213-
polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
2214-
f'in another cell to check the current status of your bulk load.')
2215+
polling_msg_label = widgets.Label(
2216+
f'You can run "%load_status {load_result["payload"]["loadId"]}" '
2217+
f'in another cell to check the current status of your bulk load.')
22152218
start_msg_hbox = widgets.HBox([start_msg_label])
22162219
polling_msg_hbox = widgets.HBox([polling_msg_label])
22172220
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
@@ -2254,11 +2257,13 @@ def on_button_clicked(b):
22542257
with job_status_output:
22552258
# parse status & execution_time differently for Analytics and NeptuneDB
22562259
overall_status = \
2257-
interval_check_response["payload"]["status"] if self.client.is_analytics_domain() \
2258-
else interval_check_response["payload"]["overallStatus"]["status"]
2260+
interval_check_response["payload"][
2261+
"status"] if self.client.is_analytics_domain() \
2262+
else interval_check_response["payload"]["overallStatus"]["status"]
22592263
total_time_spent = \
2260-
interval_check_response["payload"]["timeElapsedSeconds"] if self.client.is_analytics_domain() \
2261-
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
2264+
interval_check_response["payload"][
2265+
"timeElapsedSeconds"] if self.client.is_analytics_domain() \
2266+
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
22622267
print(f'Overall Status: {overall_status}')
22632268
if overall_status in FINAL_LOAD_STATUSES:
22642269
execution_time = total_time_spent
@@ -3179,7 +3184,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
31793184
"""
31803185
parser = argparse.ArgumentParser()
31813186
parser.add_argument('-pc', '--plan-cache', type=str.lower, default='auto',
3182-
help=f'Neptune Analytics only. Specifies the plan cache mode to use. '
3187+
help=f'Specifies the plan cache mode to use. '
31833188
f'Accepted values: {OPENCYPHER_PLAN_CACHE_MODES}')
31843189
parser.add_argument('-qt', '--query-timeout', type=int, default=None,
31853190
help=f'Neptune Analytics only. Specifies the maximum query timeout in milliseconds.')
@@ -3286,17 +3291,23 @@ def handle_opencypher_query(self, line, cell, local_ns):
32863291
first_tab_html = opencypher_explain_template.render(table=explain,
32873292
link=f"data:text/html;base64,{base64_str}")
32883293
elif args.mode == 'query':
3289-
if not self.client.is_analytics_domain():
3290-
if args.plan_cache != 'auto':
3291-
print("planCache is not supported for Neptune DB, ignoring.")
3292-
if args.query_timeout is not None:
3293-
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")
3294+
if not self.client.is_analytics_domain() and args.query_timeout is not None:
3295+
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")
32943296

32953297
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
32963298
oc_http = self.client.opencypher_http(cell, query_params=query_params,
32973299
plan_cache=args.plan_cache,
32983300
query_timeout=args.query_timeout)
32993301
query_time = time.time() * 1000 - query_start
3302+
if oc_http.status_code == 400 and not self.client.is_analytics_domain() and args.plan_cache != "auto":
3303+
try:
3304+
oc_http_ex = json.loads(oc_http.content.decode('utf-8'))
3305+
if (oc_http_ex["code"] == "MalformedQueryException"
3306+
and oc_http_ex["detailedMessage"].startswith("Invalid input")):
3307+
print("Please ensure that you are on NeptuneDB 1.3.2.0 or later when attempting to use "
3308+
"--plan-cache.")
3309+
except:
3310+
pass
33003311
oc_http.raise_for_status()
33013312

33023313
try:

src/graph_notebook/neptune/client.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ def normalize_service_name(neptune_service: str):
163163
return NEPTUNE_DB_SERVICE_NAME
164164

165165

166+
def set_plan_cache_hint(query: str, plan_cache_value: str):
167+
plan_cache_op_re = r"(?i)USING\s+QUERY:\s*PLANCACHE"
168+
if re.search(plan_cache_op_re, query) is not None:
169+
print("planCache hint is already present in query. Ignoring parameter value.")
170+
return query
171+
plan_cache_hint = f'USING QUERY: PLANCACHE "{plan_cache_value}"\n'
172+
query_with_hint = plan_cache_hint + query
173+
return query_with_hint
174+
175+
166176
class Client(object):
167177
def __init__(self, host: str, port: int = DEFAULT_PORT,
168178
neptune_service: str = NEPTUNE_DB_SERVICE_NAME,
@@ -407,19 +417,23 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
407417
if 'content-type' not in headers:
408418
headers['content-type'] = 'application/x-www-form-urlencoded'
409419
url += 'openCypher'
410-
data = {
411-
'query': query
412-
}
420+
data = {}
421+
if plan_cache:
422+
if plan_cache not in OPENCYPHER_PLAN_CACHE_MODES:
423+
print('Invalid --plan-cache mode specified, defaulting to auto.')
424+
else:
425+
if self.is_analytics_domain():
426+
data['planCache'] = plan_cache
427+
elif plan_cache != 'auto':
428+
query = set_plan_cache_hint(query, plan_cache)
429+
data['query'] = query
413430
if explain:
414431
data['explain'] = explain
415432
headers['Accept'] = "text/html"
416433
if query_params:
417434
data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}'
418-
if self.is_analytics_domain():
419-
if plan_cache:
420-
data['planCache'] = plan_cache
421-
if query_timeout:
422-
data['queryTimeoutMilliseconds'] = str(query_timeout)
435+
if query_timeout and self.is_analytics_domain():
436+
data['queryTimeoutMilliseconds'] = str(query_timeout)
423437
else:
424438
url += 'db/neo4j/tx/commit'
425439
headers['content-type'] = 'application/json'

0 commit comments

Comments
 (0)