From 8be1ed753d97fe1c3b6a10ab6bdb637e0b240a78 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 7 Oct 2021 10:59:52 +0200 Subject: [PATCH 1/8] Add support for impersonation --- neo4j/__init__.py | 9 +- neo4j/conf.py | 4 + neo4j/io/__init__.py | 120 ++++++++++++++++++-------- neo4j/io/_bolt3.py | 56 +++++++++--- neo4j/io/_bolt4.py | 147 ++++++++++++++++++++++++++++++-- neo4j/work/result.py | 6 +- neo4j/work/simple.py | 50 ++++++++--- neo4j/work/transaction.py | 8 +- testkitbackend/requests.py | 9 +- testkitbackend/test_config.json | 10 +++ tests/integration/conftest.py | 3 +- tests/unit/io/test_direct.py | 3 +- tests/unit/test_conf.py | 1 + tests/unit/test_driver.py | 2 + tests/unit/work/test_result.py | 24 +++--- 15 files changed, 360 insertions(+), 92 deletions(-) diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 950bf6fa8..707d0c437 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -329,11 +329,9 @@ def supports_multi_db(self): :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. :rtype: bool """ - cx = self._pool.acquire(access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database) - support = cx.supports_multiple_databases - self._pool.release(cx) - - return support + with self.session() as session: + session._connect(READ_ACCESS) + return session._connection.supports_multiple_databases class BoltDriver(Direct, Driver): @@ -447,6 +445,7 @@ def _verify_routing_connectivity(self): routing_info[ix] = self._pool.fetch_routing_info( address=table.routers[0], database=self._default_workspace_config.database, + imp_user=self._default_workspace_config.impersonated_user, bookmarks=None, timeout=self._default_workspace_config .connection_acquisition_timeout diff --git a/neo4j/conf.py b/neo4j/conf.py index 80ad44c8b..f74dd2e51 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -283,6 +283,10 @@ class WorkspaceConfig(Config): #: Fetch Size fetch_size = 1000 + #: User to impersonate + impersonated_user = None + # Note that you need appropriate permissions to do so. + class SessionConfig(WorkspaceConfig): """ Session configuration. diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 38410dba6..a688e05ff 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -393,7 +393,7 @@ def __del__(self): pass @abc.abstractmethod - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -401,6 +401,7 @@ def route(self, database=None, bookmarks=None): sent to the network, and a response is fetched. :param database: database for which to fetch a routing table + :param imp_user: the user to impersonate :param bookmarks: iterable of bookmark values after which this transaction should begin :return: dictionary of raw routing data @@ -408,8 +409,8 @@ def route(self, database=None, bookmarks=None): pass @abc.abstractmethod - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, - timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -419,6 +420,7 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + :param imp_user: the user to impersonate :param handlers: handler functions passed into the returned Response object :return: Response object """ @@ -447,7 +449,8 @@ def pull(self, n=-1, qid=-1, **handlers): pass @abc.abstractmethod - def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -455,6 +458,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + :param imp_user: the user to impersonate :param handlers: handler functions passed into the returned Response object :return: Response object """ @@ -708,12 +712,13 @@ def time_remaining(): "within {!r}s".format(timeout)) def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + imp_user=None, bookmarks=None): """ Acquire a connection to a server that can satisfy a set of parameters. :param access_mode: :param timeout: :param database: + :param imp_user: :param bookmarks: """ @@ -827,7 +832,7 @@ def __repr__(self): return "<{} address={!r}>".format(self.__class__.__name__, self.address) def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + imp_user=None, bookmarks=None): # The access_mode and database is not needed for a direct connection, its just there for consistency. return self._acquire(self.address, timeout) @@ -908,15 +913,23 @@ def get_default_database_router_addresses(self): def get_routing_table_for_default_database(self): return self.routing_tables[self.workspace_config.database] - def create_routing_table(self, database): + def get_or_create_routing_table(self, database): if database not in self.routing_tables: - self.routing_tables[database] = RoutingTable(database=database, routers=self.get_default_database_initial_router_addresses()) + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] - def fetch_routing_info(self, address, database, bookmarks, timeout): + def fetch_routing_info(self, address, database, imp_user, bookmarks, + timeout): """ Fetch raw routing info from a given router address. :param address: router address :param database: the database name to get routing table for + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched :param timeout: connection acquisition timeout in seconds @@ -930,7 +943,9 @@ def fetch_routing_info(self, address, database, bookmarks, timeout): cx = self._acquire(address, timeout) try: routing_table = cx.route( - database or self.workspace_config.database, bookmarks + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks ) finally: self.release(cx) @@ -955,21 +970,26 @@ def fetch_routing_info(self, address, database, bookmarks, timeout): self.deactivate(address) return routing_table - def fetch_routing_table(self, *, address, timeout, database, bookmarks): + def fetch_routing_table(self, *, address, timeout, database, imp_user, + bookmarks): """ Fetch a routing table from a given router address. :param address: router address :param timeout: seconds :param database: the database name :type: str + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table :return: a new RoutingTable instance or None if the given router is currently unable to provide routing information """ try: - new_routing_info = self.fetch_routing_info(address, database, - bookmarks, timeout) + new_routing_info = self.fetch_routing_info( + address, database, imp_user, bookmarks, timeout + ) except (ServiceUnavailable, SessionExpired): new_routing_info = None if not new_routing_info: @@ -978,7 +998,10 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): else: servers = new_routing_info[0]["servers"] ttl = new_routing_info[0]["ttl"] - new_routing_table = RoutingTable.parse_routing_info(database=database, servers=servers, ttl=ttl) + database = new_routing_info[0].get("db", database) + new_routing_table = RoutingTable.parse_routing_info( + database=database, servers=servers, ttl=ttl + ) # Parse routing info and count the number of each type of server num_routers = len(new_routing_table.routers) @@ -1001,8 +1024,8 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): # At least one of each is fine, so return this table return new_routing_table - def update_routing_table_from(self, *routers, database=None, - bookmarks=None): + def update_routing_table_from(self, *routers, database=None, imp_user=None, + bookmarks=None, database_callback=None): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, @@ -1014,29 +1037,44 @@ def update_routing_table_from(self, *routers, database=None, new_routing_table = self.fetch_routing_table( address=address, timeout=self.pool_config.connection_timeout, - database=database, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks ) if new_routing_table is not None: - self.routing_tables[database].update(new_routing_table) + new_databse = new_routing_table.database + self.get_or_create_routing_table(new_databse)\ + .update(new_routing_table) log.debug( "[#0000] C: address=%r (%r)", - address, self.routing_tables[database] + address, self.routing_tables[new_databse] ) + if callable(database_callback): + database_callback(new_databse) return True self.deactivate(router) return False - def update_routing_table(self, *, database, bookmarks): + def update_routing_table(self, *, database, imp_user, bookmarks, + database_callback=None): """ Update the routing table from the first router able to provide valid routing information. :param database: The database name + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param database_callback: A callback function that will be called with + the database name as only argument when a new routing table has been + acquired. This database name might different from `database` if that + was None and the underlying protocol supports reporting back the + actual database. :raise neo4j.exceptions.ServiceUnavailable: """ # copied because it can be modified - existing_routers = set(self.routing_tables[database].routers) + existing_routers = set( + self.get_or_create_routing_table(database).routers + ) prefer_initial_routing_address = \ self.routing_tables[database].missing_fresh_writer() @@ -1045,20 +1083,23 @@ def update_routing_table(self, *, database, bookmarks): # TODO: Test this state if self.update_routing_table_from( self.first_initial_routing_address, database=database, - bookmarks=bookmarks + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback ): # Why is only the first initial routing address used? return if self.update_routing_table_from( *(existing_routers - {self.first_initial_routing_address}), - database=database, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback ): return if not prefer_initial_routing_address: if self.update_routing_table_from( self.first_initial_routing_address, database=database, - bookmarks=bookmarks + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback ): # Why is only the first initial routing address used? return @@ -1068,13 +1109,13 @@ def update_routing_table(self, *, database, bookmarks): raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - servers = self.routing_tables[database].servers() + servers = self.get_or_create_routing_table(database).servers() for address in list(self.connections): if address not in servers: super(Neo4jPool, self).deactivate(address) - def ensure_routing_table_is_fresh(self, *, access_mode, database, - bookmarks): + def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, + bookmarks, database_callback=None): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring @@ -1088,36 +1129,41 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, :return: `True` if an update was required, `False` otherwise. """ from neo4j.api import READ_ACCESS - if self.routing_tables[database].is_fresh(readonly=(access_mode == READ_ACCESS)): + if self.get_or_create_routing_table(database)\ + .is_fresh(readonly=(access_mode == READ_ACCESS)): # Readers are fresh. return False with self.refresh_lock: - self.update_routing_table(database=database, bookmarks=bookmarks) + self.update_routing_table( + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ) self.update_connection_pool(database=database) for database in list(self.routing_tables.keys()): # Remove unused databases in the routing table # Remove the routing table after a timeout = TTL + 30s log.debug("[#0000] C: database=%s", database) - if self.routing_tables[database].should_be_purged_from_memory() and database != self.workspace_config.database: + if (self.routing_tables[database].should_be_purged_from_memory() + and database != self.workspace_config.database): del self.routing_tables[database] return True - def _select_address(self, *, access_mode, database, bookmarks): + def _select_address(self, *, access_mode, database, imp_user, bookmarks): from neo4j.api import READ_ACCESS """ Selects the address with the fewest in-use connections. """ - self.create_routing_table(database) self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, bookmarks=bookmarks + access_mode=access_mode, database=database, imp_user=imp_user, + bookmarks=bookmarks ) log.debug("[#0000] C: %r", self.routing_tables) if access_mode == READ_ACCESS: - addresses = self.routing_tables[database].readers + addresses = self.get_or_create_routing_table(database).readers else: - addresses = self.routing_tables[database].writers + addresses = self.get_or_create_routing_table(database).writers addresses_by_usage = {} for address in addresses: addresses_by_usage.setdefault(self.in_use_connection_count(address), []).append(address) @@ -1129,7 +1175,7 @@ def _select_address(self, *, access_mode, database, bookmarks): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + imp_user=None, bookmarks=None): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) if not timeout: @@ -1142,7 +1188,7 @@ def acquire(self, access_mode=None, timeout=None, database=None, # Get an address for a connection that have the fewest in-use connections. address = self._select_address( access_mode=access_mode, database=database, - bookmarks=bookmarks + imp_user=imp_user, bookmarks=bookmarks ) log.debug("[#0000] C: database=%r address=%r", database, address) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index fe7608b0b..4ed722d86 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -165,12 +165,22 @@ def hello(self): self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, bookmarks=None): - if database is not None: # default database - raise ConfigurationError("Database name parameter for selecting database is not " - "supported in Bolt Protocol {!r}. Database name {!r}. " - "Server Agent {!r}.".format(Bolt3.PROTOCOL_VERSION, database, - self.server_info.agent)) + def route(self, database=None, imp_user=None, bookmarks=None): + if database is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}. " + "Server Agent {!r}".format( + self.PROTOCOL_VERSION, database, self.server_info.agent + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) metadata = {} records = [] @@ -197,9 +207,22 @@ def fail(md): routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): if db is not None: - raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) if not parameters: parameters = {} extra = {} @@ -238,9 +261,22 @@ def pull(self, n=-1, qid=-1, **handlers): log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, "pull", **handlers)) - def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): if db is not None: - raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 4b7c2045c..086f7baf2 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -32,6 +32,7 @@ Version, ) from neo4j.exceptions import ( + ConfigurationError, DatabaseUnavailable, DriverError, ForbiddenOnReadOnlyDatabase, @@ -122,7 +123,14 @@ def hello(self): self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) metadata = {} records = [] @@ -160,7 +168,15 @@ def fail(md): routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) if not parameters: parameters = {} extra = {} @@ -206,7 +222,14 @@ def pull(self, n=-1, qid=-1, **handlers): self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, **handlers): + db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified @@ -376,7 +399,14 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) def fail(md): from neo4j._exceptions import BoltRoutingError @@ -384,12 +414,15 @@ def fail(md): if code == "Neo.ClientError.Database.DatabaseNotFound": return # surface this error to the user elif code == "Neo.ClientError.Procedure.ProcedureNotFound": - raise BoltRoutingError("Server does not support routing", self.unresolved_address) + raise BoltRoutingError("Server does not support routing", + self.unresolved_address) else: - raise BoltRoutingError("Routing support broken on server", self.unresolved_address) + raise BoltRoutingError("Routing support broken on server", + self.unresolved_address) routing_context = self.routing_context or {} - log.debug("[#%04X] C: ROUTE %r %r", self.local_port, routing_context, database) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, database) metadata = {} if bookmarks is None: bookmarks = [] @@ -440,3 +473,103 @@ class Bolt4x4(Bolt4x3): """ PROTOCOL_VERSION = Version(4, 4) + + def route(self, database=None, imp_user=None, bookmarks=None): + def fail(md): + from neo4j._exceptions import BoltRoutingError + code = md.get("code") + if code == "Neo.ClientError.Database.DatabaseNotFound": + return # surface this error to the user + elif code == "Neo.ClientError.Procedure.ProcedureNotFound": + raise BoltRoutingError("Server does not support routing", + self.unresolved_address) + else: + raise BoltRoutingError("Routing support broken on server", + self.unresolved_address) + + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, db_context) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, db_context), + response=Response(self, "route", + on_success=metadata.update, + on_failure=fail)) + self.send_all() + self.fetch_all() + return [metadata.get("rt")] + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, + " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 75921b0ba..647153bb5 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -67,9 +67,10 @@ def _tx_ready_run(self, query, parameters, **kwparameters): # BEGIN+RUN does not carry any extra on the RUN message. # BEGIN {extra} # RUN "query" {parameters} {extra} - self._run(query, parameters, None, None, None, **kwparameters) + self._run(query, parameters, None, None, None, None, **kwparameters) - def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters): + def _run(self, query, parameters, db, imp_user, access_mode, bookmarks, + **kwparameters): query_text = str(query) # Query or string object query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) @@ -104,6 +105,7 @@ def on_failed_attach(metadata): metadata=query_metadata, timeout=query_timeout, db=db, + imp_user=imp_user, on_success=on_attached, on_failure=on_failed_attach, ) diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 15f08b2ed..bf2106ef7 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -41,6 +41,7 @@ TransientError, TransactionError, ) +from neo4j.io import Neo4jPool from neo4j.work import Workspace from neo4j.work.result import Result from neo4j.work.transaction import Transaction @@ -81,6 +82,9 @@ class Session(Workspace): # :class:`.Transaction` should be carried out. _bookmarks = None + # Sessions are supposed to cache the database on which to operate. + _cached_database = False + # The state this session is in. _state_failed = False @@ -106,7 +110,11 @@ def __exit__(self, exception_type, exception_value, traceback): self._state_failed = True self.close() - def _connect(self, access_mode, database): + def _set_cached_database(self, database): + self._cached_database = True + self._config.database = database + + def _connect(self, access_mode): if access_mode is None: access_mode = self._config.default_access_mode if self._connection: @@ -115,10 +123,28 @@ def _connect(self, access_mode, database): self._connection.send_all() self._connection.fetch_all() self._disconnect() + if not self._cached_database: + if (self._config.database is not None + or not isinstance(self._pool, Neo4jPool)): + self._set_cached_database(self._config.database) + else: + # This is the first time we open a connection to a server in a + # cluster environment for this session without explicitly + # configured database. Hence, we request a routing table update + # to try to fetch the home database. If provided by the server, + # we shall use this database explicitly for all subsequent + # actions within this session. + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._bookmarks, + database_callback=self._set_cached_database + ) self._connection = self._pool.acquire( access_mode=access_mode, timeout=self._config.connection_acquisition_timeout, - database=database, + database=self._config.database, + imp_user=self._config.impersonated_user, bookmarks=self._bookmarks ) @@ -218,7 +244,7 @@ def run(self, query, parameters=None, **kwparameters): self._autoResult._buffer_all() # This will buffer upp all records for the previous auto-transaction if not self._connection: - self._connect(self._config.default_access_mode, database=self._config.database) + self._connect(self._config.default_access_mode) cx = self._connection protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info @@ -231,7 +257,8 @@ def run(self, query, parameters=None, **kwparameters): ) self._autoResult._run( query, parameters, self._config.database, - self._config.default_access_mode, self._bookmarks, **kwparameters + self._config.impersonated_user, self._config.default_access_mode, + self._bookmarks, **kwparameters ) return self._autoResult @@ -266,16 +293,18 @@ def _transaction_error_handler(self, _): self._transaction = None self._disconnect() - def _open_transaction(self, *, access_mode, database, metadata=None, + def _open_transaction(self, *, access_mode, metadata=None, timeout=None): - self._connect(access_mode=access_mode, database=database) + self._connect(access_mode=access_mode) self._transaction = Transaction( self._connection, self._config.fetch_size, self._transaction_closed_handler, self._transaction_error_handler ) - self._transaction._begin(database, self._bookmarks, access_mode, - metadata, timeout) + self._transaction._begin( + self._config.database, self._config.impersonated_user, + self._bookmarks, access_mode, metadata, timeout + ) def begin_transaction(self, metadata=None, timeout=None): """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. @@ -312,7 +341,8 @@ def begin_transaction(self, metadata=None, timeout=None): if self._transaction: raise TransactionError("Explicit transaction already open") - self._open_transaction(access_mode=self._config.default_access_mode, database=self._config.database, metadata=metadata, timeout=timeout) + self._open_transaction(access_mode=self._config.default_access_mode, + metadata=metadata, timeout=timeout) return self._transaction @@ -332,7 +362,7 @@ def _run_transaction(self, access_mode, transaction_function, *args, **kwargs): while True: try: - self._open_transaction(access_mode=access_mode, database=self._config.database, metadata=metadata, timeout=timeout) + self._open_transaction(access_mode=access_mode, metadata=metadata, timeout=timeout) tx = self._transaction try: result = transaction_function(tx, *args, **kwargs) diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 0d77f3e02..746767480 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -62,9 +62,11 @@ def __exit__(self, exception_type, exception_value, traceback): self.commit() self.close() - def _begin(self, database, bookmarks, access_mode, metadata, timeout): - self._connection.begin(bookmarks=bookmarks, metadata=metadata, - timeout=timeout, mode=access_mode, db=database) + def _begin(self, database, imp_user, bookmarks, access_mode, metadata, timeout): + self._connection.begin( + bookmarks=bookmarks, metadata=metadata, timeout=timeout, + mode=access_mode, db=database, imp_user=imp_user + ) self._error_handling_connection.send_all() self._error_handling_connection.fetch_all() diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index 70036e822..a5cac5910 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -191,12 +191,14 @@ def NewSession(backend, data): elif access_mode == "w": access_mode = neo4j.WRITE_ACCESS else: - raise Exception("Unknown access mode:" + access_mode) + raise ValueError("Unknown access mode:" + access_mode) config = { "default_access_mode": access_mode, "bookmarks": data["bookmarks"], "database": data["database"], - "fetch_size": data.get("fetchSize", None) + "fetch_size": data.get("fetchSize", None), + "impersonated_user": data.get("impersonatedUser", None), + } session = driver.session(**config) key = backend.next_key() @@ -390,8 +392,7 @@ def ForcedRoutingTableUpdate(backend, data): database = data["database"] bookmarks = data["bookmarks"] with driver._pool.refresh_lock: - driver._pool.create_routing_table(database) - driver._pool.update_routing_table(database=database, + driver._pool.update_routing_table(database=database, imp_user=None, bookmarks=bookmarks) backend.send_response("Driver", {"id": driver_id}) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 47b97e7a4..469d179ff 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -6,24 +6,32 @@ "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_retry_write_until_success_with_leader_change_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", + "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_retry_write_until_success_with_leader_change_using_tx_function": + "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_retry_write_until_success_with_leader_change_on_run_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v3.RoutingV3.test_should_retry_write_until_success_with_leader_change_on_run_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_retry_write_until_success_with_leader_change_on_run_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", + "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_retry_write_until_success_with_leader_change_on_run_using_tx_function": + "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_retry_write_until_success_with_leader_shutdown_during_tx_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v3.RoutingV3.test_should_retry_write_until_success_with_leader_shutdown_during_tx_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_retry_write_until_success_with_leader_shutdown_during_tx_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", + "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_retry_write_until_success_with_leader_shutdown_during_tx_using_tx_function": + "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_successfully_acquire_rt_when_router_ip_changes": "Test makes assumptions about how verify_connectivity is implemented", "stub.routing.test_routing_v3.RoutingV3.test_should_successfully_acquire_rt_when_router_ip_changes": "Test makes assumptions about how verify_connectivity is implemented", "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_successfully_acquire_rt_when_router_ip_changes": "Test makes assumptions about how verify_connectivity is implemented", + "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_successfully_acquire_rt_when_router_ip_changes": + "Test makes assumptions about how verify_connectivity is implemented", "stub.retry.test_retry_clustering.TestRetryClustering.test_retry_ForbiddenOnReadOnlyDatabase_ChangingWriter": "Test makes assumptions about how verify_connectivity is implemented", "stub.authorization.test_authorization.TestAuthorizationV4x3.test_should_retry_on_auth_expired_on_begin_using_tx_function": @@ -49,6 +57,8 @@ "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, + "Feature:Bolt:4.4": true, + "Feature:Impersonation": true, "AuthorizationExpiredTreatment": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalResets": true, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 64aa1f3d4..015ba64d8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -299,7 +299,8 @@ def bolt_driver(target, auth): def neo4j_driver(target, auth): try: driver = GraphDatabase.neo4j_driver(target, auth=auth) - driver._pool.update_routing_table(database=None, bookmarks=None) + driver._pool.update_routing_table(database=None, imp_user=None, + bookmarks=None) except ServiceUnavailable as error: if isinstance(error.__cause__, BoltHandshakeError): pytest.skip(error.args[0]) diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py index dd1ef330f..ccfff00dd 100644 --- a/tests/unit/io/test_direct.py +++ b/tests/unit/io/test_direct.py @@ -104,7 +104,8 @@ def opener(addr, timeout): super().__init__(opener, self.pool_config, self.workspace_config) self.address = address - def acquire(self, access_mode=None, timeout=None, database=None): + def acquire(self, access_mode=None, timeout=None, database=None, + imp_user=None): return self._acquire(self.address, timeout) diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index ccd795011..6e685edae 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -63,6 +63,7 @@ "bookmarks": (), "default_access_mode": WRITE_ACCESS, "database": None, + "impersonated_user": None, "fetch_size": 100, } diff --git a/tests/unit/test_driver.py b/tests/unit/test_driver.py index 1e35cbf4b..fcb978413 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/test_driver.py @@ -134,12 +134,14 @@ def test_driver_opens_write_session_by_default(uri, mocker): access_mode=WRITE_ACCESS, timeout=mocker.ANY, database=mocker.ANY, + imp_user=mocker.ANY, bookmarks=mocker.ANY ) tx_begin_mock.assert_called_once_with( tx, mocker.ANY, mocker.ANY, + mocker.ANY, WRITE_ACCESS, mocker.ANY, mocker.ANY diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index df33293a2..21627a30c 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -219,7 +219,7 @@ def test_result_iteration(method): records = [[1], [2], [3], [4], [5]] connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), 2, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) _fetch_and_compare_all_records(result, "x", records, method) @@ -232,9 +232,9 @@ def test_parallel_result_iteration(method, invert_fetch): records=(Records(["x"], records1), Records(["x"], records2)) ) result1 = Result(connection, HydratorStub(), 2, noop, noop) - result1._run("CYPHER1", {}, None, "r", None) + result1._run("CYPHER1", {}, None, None, "r", None) result2 = Result(connection, HydratorStub(), 2, noop, noop) - result2._run("CYPHER2", {}, None, "r", None) + result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: _fetch_and_compare_all_records(result2, "x", records2, method) _fetch_and_compare_all_records(result1, "x", records1, method) @@ -252,9 +252,9 @@ def test_interwoven_result_iteration(method, invert_fetch): records=(Records(["x"], records1), Records(["y"], records2)) ) result1 = Result(connection, HydratorStub(), 2, noop, noop) - result1._run("CYPHER1", {}, None, "r", None) + result1._run("CYPHER1", {}, None, None, "r", None) result2 = Result(connection, HydratorStub(), 2, noop, noop) - result2._run("CYPHER2", {}, None, "r", None) + result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): end = n if n is None else start + n @@ -276,7 +276,7 @@ def test_interwoven_result_iteration(method, invert_fetch): def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = result.peek() if i == len(records): @@ -292,7 +292,7 @@ def test_result_peek(records, fetch_size): def test_result_single(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) with pytest.warns(None) as warning_record: record = result.single() if not records: @@ -310,7 +310,7 @@ def test_result_single(records, fetch_size): def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] list(result) assert list(result.keys()) == ["x"] @@ -323,7 +323,7 @@ def test_consume(records, consume_one, summary_meta): connection = ConnectionStub(records=Records(["x"], records), summary_meta=summary_meta) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: next(iter(result)) @@ -356,7 +356,7 @@ def test_time_in_summary(t_first, t_last): run_meta=run_meta, summary_meta=summary_meta) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() if t_first is not None: @@ -377,7 +377,7 @@ def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() assert isinstance(summary.counters, SummaryCounters) @@ -389,7 +389,7 @@ def test_query_type(query_type): summary_meta={"type": query_type}) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() assert isinstance(summary.query_type, str) From aaa3bfad00d29547f18df97d1a58056ae6b23c00 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 7 Oct 2021 14:19:13 +0200 Subject: [PATCH 2/8] Add docs for impersonation --- docs/source/api.rst | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 9a4cc0147..aafc71d00 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -499,6 +499,41 @@ Name of the database to query. :Default: ``neo4j.DEFAULT_DATABASE`` +.. _impersonated-user-ref: + +``impersonated_user`` +--------------------- +Name of the user to impersonate. +This means that all actions in the session will be executed in the security +context of the impersonated user. For this, the user for which the +:class:``Driver`` has been created needs to have the appropriate permissions. + +:Type: ``str``, None + + +.. py:data:: None + :noindex: + + Will not perform impersonation. + + +.. Note:: + + The server or all servers of the cluster need to support impersonation when. + Otherwise, the driver will raise :py:exc:`.ConfigurationError` + as soon as it encounters a server that does not. + + +.. code-block:: python + + from neo4j import GraphDatabase + driver = GraphDatabase.driver(uri, auth=(user, password)) + session = driver.session(impersonated_user="alice") + + +:Default: ``None`` + + .. _default-access-mode-ref: ``default_access_mode`` From f743def2d3f63fbaf47dfe698cc3abe26beab1a9 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 7 Oct 2021 17:54:38 +0200 Subject: [PATCH 3/8] Fix unit test --- tests/unit/test_driver.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_driver.py b/tests/unit/test_driver.py index fcb978413..4741c06e8 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/test_driver.py @@ -124,7 +124,10 @@ def test_driver_trust_config_error( def test_driver_opens_write_session_by_default(uri, mocker): driver = GraphDatabase.driver(uri) from neo4j.work.transaction import Transaction - with driver.session() as session: + # we set a specific db, because else the driver would try to fetch a RT + # to get hold of the actual home database (which won't work in this + # unittest) + with driver.session(database="foobar") as session: acquire_mock = mocker.patch.object(session._pool, "acquire", autospec=True) tx_begin_mock = mocker.patch.object(Transaction, "_begin", From c88f2d333d1dd8932cb49ea8646820caecf76561 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 8 Oct 2021 13:02:27 +0200 Subject: [PATCH 4/8] Fix race condition --- neo4j/io/__init__.py | 89 ++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index a688e05ff..9c28d1a77 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -58,7 +58,6 @@ ) from threading import ( Condition, - Lock, RLock, ) from time import perf_counter @@ -880,7 +879,7 @@ def __init__(self, opener, pool_config, workspace_config, address): log.debug("[#0000] C: routing address %r", address) self.address = address self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} - self.refresh_lock = Lock() + self.refresh_lock = RLock() def __repr__(self): """ The representation shows the initial routing addresses. @@ -914,12 +913,13 @@ def get_routing_table_for_default_database(self): return self.routing_tables[self.workspace_config.database] def get_or_create_routing_table(self, database): - if database not in self.routing_tables: - self.routing_tables[database] = RoutingTable( - database=database, - routers=self.get_default_database_initial_router_addresses() - ) - return self.routing_tables[database] + with self.refresh_lock: + if database not in self.routing_tables: + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] def fetch_routing_info(self, address, database, imp_user, bookmarks, timeout): @@ -1024,8 +1024,8 @@ def fetch_routing_table(self, *, address, timeout, database, imp_user, # At least one of each is fine, so return this table return new_routing_table - def update_routing_table_from(self, *routers, database=None, imp_user=None, - bookmarks=None, database_callback=None): + def _update_routing_table_from(self, *routers, database=None, imp_user=None, + bookmarks=None, database_callback=None): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, @@ -1071,42 +1071,43 @@ def update_routing_table(self, *, database, imp_user, bookmarks, :raise neo4j.exceptions.ServiceUnavailable: """ - # copied because it can be modified - existing_routers = set( - self.get_or_create_routing_table(database).routers - ) - - prefer_initial_routing_address = \ - self.routing_tables[database].missing_fresh_writer() + with self.refresh_lock: + # copied because it can be modified + existing_routers = set( + self.get_or_create_routing_table(database).routers + ) - if prefer_initial_routing_address: - # TODO: Test this state - if self.update_routing_table_from( - self.first_initial_routing_address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + prefer_initial_routing_address = \ + self.routing_tables[database].missing_fresh_writer() + + if prefer_initial_routing_address: + # TODO: Test this state + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + if self._update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, imp_user=imp_user, bookmarks=bookmarks, database_callback=database_callback ): - # Why is only the first initial routing address used? return - if self.update_routing_table_from( - *(existing_routers - {self.first_initial_routing_address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ): - return - if not prefer_initial_routing_address: - if self.update_routing_table_from( - self.first_initial_routing_address, database=database, - imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ): - # Why is only the first initial routing address used? - return + if not prefer_initial_routing_address: + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return - # None of the routers have been successful, so just fail - log.error("Unable to retrieve routing information") - raise ServiceUnavailable("Unable to retrieve routing information") + # None of the routers have been successful, so just fail + log.error("Unable to retrieve routing information") + raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): servers = self.get_or_create_routing_table(database).servers() @@ -1129,11 +1130,11 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, :return: `True` if an update was required, `False` otherwise. """ from neo4j.api import READ_ACCESS - if self.get_or_create_routing_table(database)\ - .is_fresh(readonly=(access_mode == READ_ACCESS)): - # Readers are fresh. - return False with self.refresh_lock: + if self.get_or_create_routing_table(database)\ + .is_fresh(readonly=(access_mode == READ_ACCESS)): + # Readers are fresh. + return False self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, From 06de21688abef1c0346c4dd3bb0d5fa83ccfce34 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 8 Oct 2021 19:36:33 +0200 Subject: [PATCH 5/8] Acquire RT without impersonated user if possible This is an optimization saving a few bytes over the wire. When there is an explicit database configured for a session, there is no need to send the impersonated user along with the `ROUTE` request. The only difference this would make is that insufficient permissions would be noticed when fetching the RT instead of when performing any actual action on the session. Since routing is part of the global driver/pool operations and not coupled to sessions, doing it this way seems like the logical conclusion. --- neo4j/io/__init__.py | 21 ++++++++++----------- neo4j/work/simple.py | 1 - tests/unit/io/test_direct.py | 3 +-- tests/unit/test_driver.py | 1 - 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index b5f0e12ef..b3ddf3c89 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -711,13 +711,12 @@ def time_remaining(): "within {!r}s".format(timeout)) def acquire(self, access_mode=None, timeout=None, database=None, - imp_user=None, bookmarks=None): + bookmarks=None): """ Acquire a connection to a server that can satisfy a set of parameters. :param access_mode: :param timeout: :param database: - :param imp_user: :param bookmarks: """ @@ -1152,7 +1151,7 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, return True - def _select_address(self, *, access_mode, database, imp_user, bookmarks): + def _select_address(self, *, access_mode, database): from neo4j.api import READ_ACCESS """ Selects the address with the fewest in-use connections. """ @@ -1178,11 +1177,12 @@ def _select_address(self, *, access_mode, database, imp_user, bookmarks): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire(self, access_mode=None, timeout=None, database=None, - imp_user=None, bookmarks=None): + bookmarks=None): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) if not timeout: - raise ClientError("'timeout' must be a float larger than 0; {}".format(timeout)) + raise ClientError("'timeout' must be a float larger than 0; {}" + .format(timeout)) from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) @@ -1190,17 +1190,16 @@ def acquire(self, access_mode=None, timeout=None, database=None, log.debug("[#0000] C: %r", self.routing_tables) self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, imp_user=imp_user, + access_mode=access_mode, database=database, imp_user=None, bookmarks=bookmarks ) while True: try: - # Get an address for a connection that have the fewest in-use connections. - address = self._select_address( - access_mode=access_mode, database=database, - imp_user=imp_user, bookmarks=bookmarks - ) + # Get an address for a connection that have the fewest in-use + # connections. + address = self._select_address(access_mode=access_mode, + database=database) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err try: diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index bf2106ef7..84a057772 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -144,7 +144,6 @@ def _connect(self, access_mode): access_mode=access_mode, timeout=self._config.connection_acquisition_timeout, database=self._config.database, - imp_user=self._config.impersonated_user, bookmarks=self._bookmarks ) diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py index ccfff00dd..dd1ef330f 100644 --- a/tests/unit/io/test_direct.py +++ b/tests/unit/io/test_direct.py @@ -104,8 +104,7 @@ def opener(addr, timeout): super().__init__(opener, self.pool_config, self.workspace_config) self.address = address - def acquire(self, access_mode=None, timeout=None, database=None, - imp_user=None): + def acquire(self, access_mode=None, timeout=None, database=None): return self._acquire(self.address, timeout) diff --git a/tests/unit/test_driver.py b/tests/unit/test_driver.py index 4741c06e8..0c1192e47 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/test_driver.py @@ -137,7 +137,6 @@ def test_driver_opens_write_session_by_default(uri, mocker): access_mode=WRITE_ACCESS, timeout=mocker.ANY, database=mocker.ANY, - imp_user=mocker.ANY, bookmarks=mocker.ANY ) tx_begin_mock.assert_called_once_with( From 68a9c1d691d0d4df76c1ca4fd7a9c87778530373 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 12 Oct 2021 12:45:17 +0200 Subject: [PATCH 6/8] Only prefer initial router if RT got initialized w/o writers Previously, the driver would prefer the initial router for fetching a new RT whenever the current RT had no writers (left). However, if should check what the RT looked like, when it was received. --- neo4j/io/__init__.py | 2 +- neo4j/routing.py | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index b3ddf3c89..40644e72a 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1077,7 +1077,7 @@ def update_routing_table(self, *, database, imp_user, bookmarks, ) prefer_initial_routing_address = \ - self.routing_tables[database].missing_fresh_writer() + self.routing_tables[database].initialized_without_writers if prefer_initial_routing_address: # TODO: Test this state diff --git a/neo4j/routing.py b/neo4j/routing.py index a0ef48c25..8303f4c25 100644 --- a/neo4j/routing.py +++ b/neo4j/routing.py @@ -110,6 +110,7 @@ def __init__(self, *, database, routers=(), readers=(), writers=(), ttl=0): self.routers = OrderedSet(routers) self.readers = OrderedSet(readers) self.writers = OrderedSet(writers) + self.initialized_without_writers = not self.writers self.last_updated_time = perf_counter() self.ttl = ttl self.database = database @@ -142,14 +143,6 @@ def is_fresh(self, readonly=False): log.debug("[#0000] C: Table has_server_for_mode=%r", has_server_for_mode) return not expired and self.routers and has_server_for_mode - def missing_fresh_writer(self): - """ Check if the routing table have a fresh write address. - - :return: Return true if it does not have a fresh write address. - :rtype: bool - """ - return not self.is_fresh(readonly=False) - def should_be_purged_from_memory(self): """ Check if the routing table is stale and not used for a long time and should be removed from memory. @@ -168,6 +161,7 @@ def update(self, new_routing_table): self.routers.replace(new_routing_table.routers) self.readers.replace(new_routing_table.readers) self.writers.replace(new_routing_table.writers) + self.initialized_without_writers = not self.writers self.last_updated_time = perf_counter() self.ttl = new_routing_table.ttl log.debug("[#0000] S: table=%r", self) From bbafbaec37fa04c5083be74bd40844d0a2a03680 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 12 Oct 2021 16:29:16 +0200 Subject: [PATCH 7/8] Docs: prefer explicit db config if possible --- docs/source/api.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index aafc71d00..3481480fe 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -486,7 +486,12 @@ Name of the database to query. .. Note:: - The default database can be set on the Neo4j instance settings. + The default database can be set on the Neo4j instance settings. + +.. Note:: + It is recommended to always specify the database explicitly when possible. + This allows the driver to work more efficiently, as it will not have to + resolve the home database first. .. code-block:: python From 6f02d12b3e22e43b781cd8457c9bd13717eeccda Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 13 Oct 2021 12:07:16 +0200 Subject: [PATCH 8/8] Add unit tests for impersonation in Bolt4x4 class --- tests/unit/io/test_class_bolt4x4.py | 42 ++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/io/test_class_bolt4x4.py index 562a720c6..19378a1cd 100644 --- a/tests/unit/io/test_class_bolt4x4.py +++ b/tests/unit/io/test_class_bolt4x4.py @@ -56,31 +56,43 @@ def test_conn_is_not_stale(fake_socket, set_stale): connection.set_stale() assert connection.stale() is set_stale - -def test_db_extra_in_begin(fake_socket): +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) - connection.begin(db="something") + connection.begin(*args, **kwargs) connection.send_all() - tag, fields = socket.pop_message() + tag, is_fields = socket.pop_message() assert tag == b"\x11" - assert len(fields) == 1 - assert fields[0] == {"db": "something"} - - -def test_db_extra_in_run(fake_socket): + assert tuple(is_fields) == expected_fields + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) - connection.run("", {}, db="something") + connection.run(*args, **kwargs) connection.send_all() - tag, fields = socket.pop_message() + tag, is_fields = socket.pop_message() assert tag == b"\x10" - assert len(fields) == 3 - assert fields[0] == "" - assert fields[1] == {} - assert fields[2] == {"db": "something"} + assert tuple(is_fields) == expected_fields def test_n_extra_in_discard(fake_socket):