diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 89086d46e..e771912fe 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -63,21 +63,19 @@ def _qid(self): else: return self._raw_qid - def _tx_ready_run(self, query, parameters, **kwparameters): + def _tx_ready_run(self, query, parameters): # BEGIN+RUN does not carry any extra on the RUN message. # BEGIN {extra} # RUN "query" {parameters} {extra} - self._run(query, parameters, None, None, None, None, **kwparameters) + self._run(query, parameters, None, None, None, None) - def _run(self, query, parameters, db, imp_user, access_mode, bookmarks, - **kwparameters): + def _run(self, query, parameters, db, imp_user, access_mode, bookmarks): query_text = str(query) # Query or string object query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) parameters = DataDehydrator.fix_parameters( - dict(parameters or {}, **kwparameters), - patch_utc="utc" in self._connection.bolt_patches + parameters, patch_utc="utc" in self._connection.bolt_patches ) self._metadata = { diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 4a0383e72..d07aaba8d 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -185,7 +185,8 @@ def run(self, query, parameters=None, **kwparameters): :type query: str, neo4j.Query :param parameters: dictionary of parameters :type parameters: dict - :param kwparameters: additional keyword parameters + :param kwparameters: additional keyword parameters. + These take precedence over parameters passed as ``parameters``. :returns: a new :class:`neo4j.Result` object :rtype: :class:`neo4j.Result` """ @@ -212,10 +213,11 @@ def run(self, query, parameters=None, **kwparameters): cx, hydrant, self._config.fetch_size, self._result_closed, self._result_error ) + parameters = dict(parameters or {}, **kwparameters) self._autoResult._run( query, parameters, self._config.database, self._config.impersonated_user, self._config.default_access_mode, - self._bookmarks, **kwparameters + self._bookmarks ) return self._autoResult diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index c0c756d5e..88f0c1dd3 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -105,7 +105,8 @@ def run(self, query, parameters=None, **kwparameters): :type query: str :param parameters: dictionary of parameters :type parameters: dict - :param kwparameters: additional keyword parameters + :param kwparameters: additional keyword parameters. + These take precedence over parameters passed as ``parameters``. :returns: a new :class:`neo4j.Result` object :rtype: :class:`neo4j.Result` :raise TransactionError: if the transaction is already closed @@ -138,7 +139,8 @@ def run(self, query, parameters=None, **kwparameters): ) self._results.append(result) - result._tx_ready_run(query, parameters, **kwparameters) + parameters = dict(parameters or {}, **kwparameters) + result._tx_ready_run(query, parameters) return result diff --git a/tests/unit/work/test_session.py b/tests/unit/work/test_session.py index 71a3bde85..d43048699 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/work/test_session.py @@ -28,14 +28,23 @@ Transaction, unit_of_work, ) +from neo4j.io import IOPool from ._fake_connection import FakeConnection -@pytest.fixture() +@pytest.fixture def pool(mocker): - pool = mocker.MagicMock() - pool.acquire = mocker.MagicMock(side_effect=iter(FakeConnection, 0)) + pool = mocker.Mock(spec=IOPool) + assert not hasattr(pool, "acquired_connection_mocks") + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + connection = FakeConnection() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect return pool @@ -252,3 +261,49 @@ def work(tx): session.write_transaction(work) else: raise ValueError(run_type) + + +@pytest.mark.parametrize( + ("params", "kw_params", "expected_params"), + ( + ({"x": 1}, {}, {"x": 1}), + ({}, {"x": 1}, {"x": 1}), + ({"x": 1}, {"y": 2}, {"x": 1, "y": 2}), + ({"x": 1}, {"x": 2}, {"x": 2}), + ({"x": 1}, {"x": 2}, {"x": 2}), + ({"x": 1, "y": 3}, {"x": 2}, {"x": 2, "y": 3}), + ({"x": 1}, {"x": 2, "y": 3}, {"x": 2, "y": 3}), + # potentially internally used keyword arguments + ({}, {"timeout": 2}, {"timeout": 2}), + ({"timeout": 2}, {}, {"timeout": 2}), + ({}, {"imp_user": "hans"}, {"imp_user": "hans"}), + ({"imp_user": "hans"}, {}, {"imp_user": "hans"}), + ({}, {"db": "neo4j"}, {"db": "neo4j"}), + ({"db": "neo4j"}, {}, {"db": "neo4j"}), + ({}, {"database": "neo4j"}, {"database": "neo4j"}), + ({"database": "neo4j"}, {}, {"database": "neo4j"}), + ) +) +@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) +def test_session_run_parameter_precedence( + pool, params, kw_params, expected_params, run_type +): + with Session(pool, SessionConfig()) as session: + if run_type == "auto": + session.run("RETURN $x", params, **kw_params) + elif run_type == "unmanaged": + tx = session.begin_transaction() + tx.run("RETURN $x", params, **kw_params) + elif run_type == "managed": + def work(tx): + tx.run("RETURN $x", params, **kw_params) + session.write_transaction(work) + else: + raise ValueError(run_type) + + assert len(pool.acquired_connection_mocks) == 1 + connection_mock = pool.acquired_connection_mocks[0] + connection_mock.run.assert_called_once() + call_args, call_kwargs = connection_mock.run.call_args + assert call_args[0] == "RETURN $x" + assert call_kwargs["parameters"] == expected_params