diff --git a/arango/database.py b/arango/database.py index 1c1713e0..ca0895b9 100644 --- a/arango/database.py +++ b/arango/database.py @@ -2948,6 +2948,14 @@ def begin_batch_execution( """ return BatchDatabase(self._conn, return_result, max_workers) + def fetch_transaction(self, transaction_id: str) -> "TransactionDatabase": + """Fetch an existing transaction. + + :param transaction_id: The ID of the existing transaction. + :type transaction_id: str + """ + return TransactionDatabase(connection=self._conn, transaction_id=transaction_id) + def begin_transaction( self, read: Union[str, Sequence[str], None] = None, @@ -3125,6 +3133,9 @@ class TransactionDatabase(Database): :type lock_timeout: int | None :param max_size: Max transaction size in bytes. :type max_size: int | None + :param transaction_id: Initialize using an existing transaction instead of creating + a new transaction. + :type transaction_id: str | None """ def __init__( @@ -3137,6 +3148,7 @@ def __init__( allow_implicit: Optional[bool] = None, lock_timeout: Optional[int] = None, max_size: Optional[int] = None, + transaction_id: Optional[str] = None, ) -> None: self._executor: TransactionApiExecutor super().__init__( @@ -3150,6 +3162,7 @@ def __init__( allow_implicit=allow_implicit, lock_timeout=lock_timeout, max_size=max_size, + transaction_id=transaction_id, ), ) diff --git a/arango/exceptions.py b/arango/exceptions.py index 000a0f8f..52ad8ffd 100644 --- a/arango/exceptions.py +++ b/arango/exceptions.py @@ -772,6 +772,10 @@ class TransactionAbortError(ArangoServerError): """Failed to abort transaction.""" +class TransactionFetchError(ArangoServerError): + """Failed to fetch existing transaction.""" + + class TransactionListError(ArangoServerError): """Failed to retrieve transactions.""" diff --git a/arango/executor.py b/arango/executor.py index 47ac4a19..b854c671 100644 --- a/arango/executor.py +++ b/arango/executor.py @@ -19,6 +19,7 @@ OverloadControlExecutorError, TransactionAbortError, TransactionCommitError, + TransactionFetchError, TransactionInitError, TransactionStatusError, ) @@ -241,6 +242,9 @@ class TransactionApiExecutor: :type max_size: int :param allow_dirty_read: Allow reads from followers in a cluster. :type allow_dirty_read: bool | None + :param transaction_id: Initialize using an existing transaction instead of starting + a new transaction. + :type transaction_id: str | None """ def __init__( @@ -254,6 +258,7 @@ def __init__( lock_timeout: Optional[int] = None, max_size: Optional[int] = None, allow_dirty_read: bool = False, + transaction_id: Optional[str] = None, ) -> None: self._conn = connection @@ -275,19 +280,29 @@ def __init__( if max_size is not None: data["maxTransactionSize"] = max_size - request = Request( - method="post", - endpoint="/_api/transaction/begin", - data=data, - headers={"x-arango-allow-dirty-read": "true"} if allow_dirty_read else None, - ) - resp = self._conn.send_request(request) + if transaction_id is None: + request = Request( + method="post", + endpoint="/_api/transaction/begin", + data=data, + headers=( + {"x-arango-allow-dirty-read": "true"} if allow_dirty_read else None + ), + ) + resp = self._conn.send_request(request) - if not resp.is_success: - raise TransactionInitError(resp, request) + if not resp.is_success: + raise TransactionInitError(resp, request) + + result = resp.body["result"] + self._id: str = result["id"] + else: + self._id = transaction_id - result: Json = resp.body["result"] - self._id: str = result["id"] + try: + self.status() + except TransactionStatusError as err: + raise TransactionFetchError(err.response, err.request) @property def context(self) -> str: diff --git a/docs/transaction.rst b/docs/transaction.rst index 18d60a68..66fb50c8 100644 --- a/docs/transaction.rst +++ b/docs/transaction.rst @@ -68,6 +68,15 @@ logical unit of work (ACID compliant). assert '_rev' in txn_col.insert({'_key': 'Lily'}) assert len(txn_col) == 6 + # Fetch an existing transaction. Useful if you have received a Transaction ID + # from some other part of your system or an external system. + original_txn = db.begin_transaction(write='students') + txn_col = original_txn.collection('students') + assert '_rev' in txn_col.insert({'_key': 'Chip'}) + txn_db = db.fetch_transaction(original_txn.transaction_id) + txn_col = txn_db.collection('students') + assert '_rev' in txn_col.insert({'_key': 'Alya'}) + # Abort the transaction txn_db.abort_transaction() assert 'Kate' not in col diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 59e86b7c..7edc2a9c 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -5,6 +5,7 @@ TransactionAbortError, TransactionCommitError, TransactionExecuteError, + TransactionFetchError, TransactionInitError, TransactionStatusError, ) @@ -117,6 +118,38 @@ def test_transaction_commit(db, col, docs): assert err.value.error_code in {10, 1655} +def test_transaction_fetch_existing(db, col, docs): + original_txn = db.begin_transaction( + read=col.name, + write=col.name, + exclusive=[], + sync=True, + allow_implicit=False, + lock_timeout=1000, + max_size=10000, + ) + txn_col = original_txn.collection(col.name) + + assert "_rev" in txn_col.insert(docs[0]) + assert "_rev" in txn_col.delete(docs[0]) + + txn_db = db.fetch_transaction(transaction_id=original_txn.transaction_id) + + txn_col = txn_db.collection(col.name) + assert "_rev" in txn_col.insert(docs[1]) + assert "_rev" in txn_col.delete(docs[1]) + + txn_db.commit_transaction() + assert txn_db.transaction_status() == "committed" + assert original_txn.transaction_status() == "committed" + assert txn_db.transaction_id == original_txn.transaction_id + + # Test fetch transaction that does not exist + with pytest.raises(TransactionFetchError) as err: + db.fetch_transaction(transaction_id="illegal") + assert err.value.error_code in {10, 1655} + + def test_transaction_abort(db, col, docs): txn_db = db.begin_transaction(write=col.name) txn_col = txn_db.collection(col.name)