Skip to content

Add %%oc --plan-cache support for Neptune DB #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook.

## Upcoming

- Added `%reset_graph` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/610))
- Added `%get_graph` line magic and enabled `%status` for Neptune Analytics ([Link to PR](https://github.com/aws/graph-notebook/pull/611))
- Added `%%oc --plan-cache` support for Neptune DB ([Link to PR](https://github.com/aws/graph-notebook/pull/613))
- Upgraded to Gremlin-Python 3.7 ([Link to PR](https://github.com/aws/graph-notebook/pull/597))

## Release 4.3.1 (June 3, 2024)
Expand Down
57 changes: 34 additions & 23 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from graph_notebook.magics.streams import StreamViewer
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
Expand Down Expand Up @@ -180,7 +180,7 @@
MEDIA_TYPE_NTRIPLES_TEXT, MEDIA_TYPE_TURTLE, MEDIA_TYPE_N3, MEDIA_TYPE_TRIX,
MEDIA_TYPE_TRIG, MEDIA_TYPE_RDF4J_BINARY]

byte_units = {'B': 1, 'KB': 1024, 'MB': 1024**2, 'GB': 1024**3, 'TB': 1024**4}
byte_units = {'B': 1, 'KB': 1024, 'MB': 1024 ** 2, 'GB': 1024 ** 3, 'TB': 1024 ** 4}


class QueryMode(Enum):
Expand Down Expand Up @@ -521,11 +521,11 @@ def neptune_config_allowlist(self, line='', cell=''):

@line_magic
@neptune_db_only
def stream_viewer(self,line):
def stream_viewer(self, line):
parser = argparse.ArgumentParser()
parser.add_argument('language', nargs='?', default=STREAM_PG,
help=f'language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]',
choices = [STREAM_PG, STREAM_RDF])
choices=[STREAM_PG, STREAM_RDF])

parser.add_argument('--limit', type=int, default=10, help='Maximum number of rows to display at a time')

Expand All @@ -534,7 +534,7 @@ def stream_viewer(self,line):
language = args.language
limit = args.limit
uri = self.client.get_uri_with_port()
viewer = StreamViewer(self.client,uri,language,limit=limit)
viewer = StreamViewer(self.client, uri, language, limit=limit)
viewer.show()

@line_magic
Expand Down Expand Up @@ -877,7 +877,8 @@ def sparql(self, line='', cell='', local_ns: dict = None):
if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE':
lines = []
for b in results['results']['bindings']:
lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
lines.append(
f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
raw_output = widgets.Output(layout=sparql_layout)
with raw_output:
html = sparql_construct_template.render(lines=lines)
Expand Down Expand Up @@ -1168,7 +1169,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
if self.graph_notebook_config.proxy_host != '' and self.client.is_neptune_domain():
using_http = True
query_res_http = self.client.gremlin_http_query(cell, headers={'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http = self.client.gremlin_http_query(cell, headers={
'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http.raise_for_status()
query_res_http_json = query_res_http.json()
query_res = query_res_http_json['result']['data']
Expand Down Expand Up @@ -1603,7 +1605,7 @@ def on_button_delete_clicked(b):
with output:
job_status_output.clear_output()
interval_output.close()
total_status_wait = max_status_retries*poll_interval
total_status_wait = max_status_retries * poll_interval
print(result)
if interval_check_response.get("status") != 'healthy':
print(f"Could not retrieve the status of the reset operation within the allotted time of "
Expand Down Expand Up @@ -1849,7 +1851,7 @@ def load(self, line='', local_ns: dict = None):
value=str(args.concurrency),
placeholder=1,
min=1,
max=2**16,
max=2 ** 16,
disabled=False,
layout=widgets.Layout(display=concurrency_hbox_visibility,
width=widget_width)
Expand Down Expand Up @@ -2057,8 +2059,8 @@ def on_button_clicked(b):
named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
dep_hbox.children = (dep_hbox_label, dependencies,)
concurrency_hbox.children = (concurrency_hbox_label, concurrency, )
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit, )
concurrency_hbox.children = (concurrency_hbox_label, concurrency,)
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit,)

validated = True
validation_label_style = DescriptionStyle(color='red')
Expand Down Expand Up @@ -2210,8 +2212,9 @@ def on_button_clicked(b):

if poll_status.value == 'FALSE':
start_msg_label = widgets.Label(f'Load started successfully!')
polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
polling_msg_label = widgets.Label(
f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
start_msg_hbox = widgets.HBox([start_msg_label])
polling_msg_hbox = widgets.HBox([polling_msg_label])
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
Expand Down Expand Up @@ -2254,11 +2257,13 @@ def on_button_clicked(b):
with job_status_output:
# parse status & execution_time differently for Analytics and NeptuneDB
overall_status = \
interval_check_response["payload"]["status"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["status"]
interval_check_response["payload"][
"status"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["status"]
total_time_spent = \
interval_check_response["payload"]["timeElapsedSeconds"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
interval_check_response["payload"][
"timeElapsedSeconds"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
print(f'Overall Status: {overall_status}')
if overall_status in FINAL_LOAD_STATUSES:
execution_time = total_time_spent
Expand Down Expand Up @@ -3179,7 +3184,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
"""
parser = argparse.ArgumentParser()
parser.add_argument('-pc', '--plan-cache', type=str.lower, default='auto',
help=f'Neptune Analytics only. Specifies the plan cache mode to use. '
help=f'Specifies the plan cache mode to use. '
f'Accepted values: {OPENCYPHER_PLAN_CACHE_MODES}')
parser.add_argument('-qt', '--query-timeout', type=int, default=None,
help=f'Neptune Analytics only. Specifies the maximum query timeout in milliseconds.')
Expand Down Expand Up @@ -3286,17 +3291,23 @@ def handle_opencypher_query(self, line, cell, local_ns):
first_tab_html = opencypher_explain_template.render(table=explain,
link=f"data:text/html;base64,{base64_str}")
elif args.mode == 'query':
if not self.client.is_analytics_domain():
if args.plan_cache != 'auto':
print("planCache is not supported for Neptune DB, ignoring.")
if args.query_timeout is not None:
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")
if not self.client.is_analytics_domain() and args.query_timeout is not None:
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")

query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
oc_http = self.client.opencypher_http(cell, query_params=query_params,
plan_cache=args.plan_cache,
query_timeout=args.query_timeout)
query_time = time.time() * 1000 - query_start
if oc_http.status_code == 400 and not self.client.is_analytics_domain() and args.plan_cache != "auto":
try:
oc_http_ex = json.loads(oc_http.content.decode('utf-8'))
if (oc_http_ex["code"] == "MalformedQueryException"
and oc_http_ex["detailedMessage"].startswith("Invalid input")):
print("Please ensure that you are on NeptuneDB 1.3.2.0 or later when attempting to use "
"--plan-cache.")
except:
pass
oc_http.raise_for_status()

try:
Expand Down
30 changes: 22 additions & 8 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def normalize_service_name(neptune_service: str):
return NEPTUNE_DB_SERVICE_NAME


def set_plan_cache_hint(query: str, plan_cache_value: str):
plan_cache_op_re = r"(?i)USING\s+QUERY:\s*PLANCACHE"
if re.search(plan_cache_op_re, query) is not None:
print("planCache hint is already present in query. Ignoring parameter value.")
return query
plan_cache_hint = f'USING QUERY: PLANCACHE "{plan_cache_value}"\n'
query_with_hint = plan_cache_hint + query
return query_with_hint


class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT,
neptune_service: str = NEPTUNE_DB_SERVICE_NAME,
Expand Down Expand Up @@ -407,19 +417,23 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
if 'content-type' not in headers:
headers['content-type'] = 'application/x-www-form-urlencoded'
url += 'openCypher'
data = {
'query': query
}
data = {}
if plan_cache:
if plan_cache not in OPENCYPHER_PLAN_CACHE_MODES:
print('Invalid --plan-cache mode specified, defaulting to auto.')
else:
if self.is_analytics_domain():
data['planCache'] = plan_cache
elif plan_cache != 'auto':
query = set_plan_cache_hint(query, plan_cache)
data['query'] = query
if explain:
data['explain'] = explain
headers['Accept'] = "text/html"
if query_params:
data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}'
if self.is_analytics_domain():
if plan_cache:
data['planCache'] = plan_cache
if query_timeout:
data['queryTimeoutMilliseconds'] = str(query_timeout)
if query_timeout and self.is_analytics_domain():
data['queryTimeoutMilliseconds'] = str(query_timeout)
else:
url += 'db/neo4j/tx/commit'
headers['content-type'] = 'application/json'
Expand Down