Skip to content

Enforce Gremlin protocol and serializer based on database type #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming

- Updated Gremlin config `message_serializer` to accept all TinkerPop serializers ([Link to PR](https://github.com/aws/graph-notebook/pull/685))
- Implemented service-based dynamic allowlists and defaults for Gremlin serializer and protocol combinations ([Link to PR](https://github.com/aws/graph-notebook/pull/697))
- Added `%get_import_task` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/668))
- Added `--export-to` JSON file option to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/684))
- Deprecated Python 3.8 support ([Link to PR](https://github.com/aws/graph-notebook/pull/683))
Expand Down
131 changes: 81 additions & 50 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS,
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE,
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants,
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP,
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP, GREMLIN_SERIALIZERS_WS,
GREMLIN_SERIALIZERS_ALL, NEPTUNE_GREMLIN_SERIALIZERS_HTTP,
DEFAULT_GREMLIN_WS_SERIALIZER, DEFAULT_GREMLIN_HTTP_SERIALIZER,
NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME,
normalize_service_name)
normalize_service_name, normalize_protocol_name,
normalize_serializer_class_name)

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')

Expand Down Expand Up @@ -57,7 +60,8 @@ class GremlinSection(object):
"""

def __init__(self, traversal_source: str = '', username: str = '', password: str = '',
message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False):
message_serializer: str = '', connection_protocol: str = '',
include_protocol: bool = False, neptune_service: str = ''):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
Expand All @@ -71,57 +75,78 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str
if traversal_source == '':
traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE

serializer_lower = message_serializer.lower()
# TODO: Update with untyped serializers once supported in GremlinPython
# Accept TinkerPop serializer class name
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphson
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphbinary
if serializer_lower == '':
message_serializer = DEFAULT_GREMLIN_SERIALIZER
elif 'graphson' in serializer_lower:
message_serializer = 'GraphSON'
if 'untyped' in serializer_lower:
message_serializer += 'Untyped'
if 'v1' in serializer_lower:
if 'untyped' in serializer_lower:
message_serializer += 'MessageSerializerV1'
else:
message_serializer += 'MessageSerializerGremlinV1'
elif 'v2' in serializer_lower:
message_serializer += 'MessageSerializerV2'
invalid_serializer_input = False
if message_serializer != '':
message_serializer, invalid_serializer_input = normalize_serializer_class_name(message_serializer)

if include_protocol:
# Neptune endpoint
invalid_protocol_input = False
if connection_protocol != '':
connection_protocol, invalid_protocol_input = normalize_protocol_name(connection_protocol)

if neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME:
if connection_protocol != DEFAULT_HTTP_PROTOCOL:
if invalid_protocol_input:
print(f"Invalid connection protocol specified, you must use {DEFAULT_HTTP_PROTOCOL}. ")
elif connection_protocol == DEFAULT_WS_PROTOCOL:
print(f"Enforcing HTTP protocol.")
connection_protocol = DEFAULT_HTTP_PROTOCOL
# temporary restriction until GraphSON-typed and GraphBinary results are supported
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
else:
print(f"{message_serializer} is not currently supported for HTTP connections, "
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
else:
message_serializer += 'MessageSerializerV3'
elif 'graphbinary' in serializer_lower:
message_serializer = GRAPHBINARYV1
if connection_protocol not in [DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL]:
if invalid_protocol_input:
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_WS_PROTOCOL}. "
f"Valid protocols: [websockets, http].")
connection_protocol = DEFAULT_WS_PROTOCOL

if connection_protocol == DEFAULT_HTTP_PROTOCOL:
# temporary restriction until GraphSON-typed and GraphBinary results are supported
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
else:
print(f"{message_serializer} is not currently supported for HTTP connections, "
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
else:
if message_serializer not in GREMLIN_SERIALIZERS_WS:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
elif message_serializer != '':
print(f"{message_serializer} is not currently supported by Gremlin Python driver, "
f"defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER

self.connection_protocol = connection_protocol
else:
print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. '
f'Valid serializers: {GREMLIN_SERIALIZERS_HTTP}.')
message_serializer = DEFAULT_GREMLIN_SERIALIZER
# Non-Neptune database - check and set valid WebSockets serializer if invalid/empty
if message_serializer not in GREMLIN_SERIALIZERS_WS:
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER
if invalid_serializer_input:
print(f'Invalid Gremlin serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. '
f'Valid serializers: {GREMLIN_SERIALIZERS_WS}.')

self.traversal_source = traversal_source
self.username = username
self.password = password
self.message_serializer = message_serializer

if include_protocol:
protocol_lower = connection_protocol.lower()
if message_serializer in GREMLIN_SERIALIZERS_HTTP:
connection_protocol = DEFAULT_HTTP_PROTOCOL
if protocol_lower != '' and protocol_lower not in HTTP_PROTOCOL_FORMATS:
print(f"Enforcing HTTP protocol usage for serializer: {message_serializer}.")
else:
if protocol_lower == '':
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
elif protocol_lower in HTTP_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_HTTP_PROTOCOL
elif protocol_lower in WS_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_WS_PROTOCOL
else:
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. "
f"Valid protocols: [websockets, http].")
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
self.connection_protocol = connection_protocol

def to_dict(self):
return self.__dict__

Expand Down Expand Up @@ -178,8 +203,8 @@ def __init__(self, host: str, port: int,
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL
if gremlin_section is not None:
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else ''
if hasattr(gremlin_section, "connection_protocol"):
if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL:
print("Enforcing HTTP connection protocol for proxy connections.")
Expand All @@ -189,9 +214,12 @@ def __init__(self, host: str, port: int,
else:
final_protocol = default_protocol
self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer,
connection_protocol=final_protocol, include_protocol=True)
connection_protocol=final_protocol,
include_protocol=True,
neptune_service=self.neptune_service)
else:
self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True)
self.gremlin = GremlinSection(include_protocol=True,
neptune_service=self.neptune_service)
self.neo4j = Neo4JSection()
else:
self.is_neptune_config = False
Expand Down Expand Up @@ -331,11 +359,14 @@ def generate_default_config():
auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
protocol_arg = args.gremlin_connection_protocol
include_protocol = False
gremlin_service = ''
if is_allowed_neptune_host(args.host, args.neptune_hosts):
include_protocol = True
gremlin_service = args.neptune_service
if not protocol_arg:
protocol_arg = DEFAULT_HTTP_PROTOCOL \
if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL

config = generate_config(args.host, int(args.port),
AuthModeEnum(auth_mode_arg),
args.ssl, args.ssl_verify,
Expand All @@ -344,7 +375,7 @@ def generate_default_config():
SparqlSection(args.sparql_path, ''),
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
args.gremlin_password, args.gremlin_serializer,
protocol_arg, include_protocol),
protocol_arg, include_protocol, gremlin_service),
Neo4JSection(args.neo4j_username, args.neo4j_password,
args.neo4j_auth, args.neo4j_database),
args.neptune_hosts)
Expand Down
18 changes: 13 additions & 5 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
SparqlSection, GremlinSection, Neo4JSection
from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL, \
DEFAULT_GREMLIN_HTTP_SERIALIZER, DEFAULT_GREMLIN_WS_SERIALIZER, \
normalize_service_name

neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region']
neptune_gremlin_params = ['connection_protocol']
Expand All @@ -30,18 +32,24 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I
is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)

if is_neptune_host:
neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME
if 'neptune_service' in data:
neptune_service = normalize_service_name(data['neptune_service'])
else:
neptune_service = NEPTUNE_DB_SERVICE_NAME
if 'gremlin' in data:
data['gremlin']['include_protocol'] = True
if 'connection_protocol' not in data['gremlin']:
data['gremlin']['connection_protocol'] = DEFAULT_WS_PROTOCOL \
if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
gremlin_section = GremlinSection(**data['gremlin'])
gremlin_section = GremlinSection(**data['gremlin'],
include_protocol=True,
neptune_service=neptune_service)
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
else:
protocol = DEFAULT_WS_PROTOCOL if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
gremlin_section = GremlinSection(include_protocol=True, connection_protocol=protocol)
gremlin_section = GremlinSection(include_protocol=True,
connection_protocol=protocol,
neptune_service=neptune_service)
if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \
or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD:
print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n')
Expand Down
15 changes: 11 additions & 4 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, GREMLIN_EXPLAIN_MODES, \
OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, OPENCYPHER_STATUS_STATE_MODES, \
normalize_service_name, NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, GRAPH_PG_INFO_METRICS, \
DEFAULT_GREMLIN_PROTOCOL, GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
GREMLIN_SERIALIZERS_WS, GREMLIN_SERIALIZERS_CLASS_TO_MIME_MAP, normalize_protocol_name, generate_snapshot_name)
from graph_notebook.network import SPARQLNetwork
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
Expand Down Expand Up @@ -1250,11 +1250,18 @@ def gremlin(self, line, cell, local_ns: dict = None):
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
if self.client.is_neptune_domain():
if args.connection_protocol != '':
connection_protocol = normalize_protocol_name(args.connection_protocol)
connection_protocol, bad_protocol_input = normalize_protocol_name(args.connection_protocol)
if bad_protocol_input:
if self.client.is_analytics_domain():
connection_protocol = DEFAULT_HTTP_PROTOCOL
else:
connection_protocol = DEFAULT_WS_PROTOCOL
print(f"Connection protocol input is invalid for Neptune, "
f"defaulting to {connection_protocol}.")
if connection_protocol == DEFAULT_WS_PROTOCOL and \
self.graph_notebook_config.gremlin.message_serializer not in GREMLIN_SERIALIZERS_WS:
print("Unsupported serializer for GremlinPython client, "
"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
print(f"Serializer is unsupported for GremlinPython client, "
f"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
print("Defaulting to HTTP protocol.")
connection_protocol = DEFAULT_HTTP_PROTOCOL
else:
Expand Down
Loading
Loading