diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d5f80290a9..5a0c355200 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2910,6 +2910,58 @@ def _on_analytics_master_result(self, response, master_future, query_future): self.submit(query_future.send_request) + def prepare_async(self, query, custom_payload=None, keyspace=None): + """ + Prepare the given query and return a :class:`~.ResponseFuture` + object. You may also call :meth:`~.ResponseFuture.result()` + on the :class:`.ResponseFuture` to synchronously block for + prepared statement object at any time. + + See :meth:`Session.prepare` for parameter definitions. + + Example usage:: + + >>> future = session.prepare_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... prepared_statement = future.result() + ... except Exception: + ... log.exception("Operation failed:") + """ + future = self._create_prepare_response_future(query, keyspace, custom_payload) + future._protocol_handler = self.client_protocol_handler + self._on_request(future) + future.send_request() + return future + + def _create_prepare_response_future(self, query, keyspace, custom_payload): + message = PrepareMessage(query=query, keyspace=keyspace) + future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) + + def _prepare_result_processor(future, response): + prepared_keyspace = keyspace if keyspace else None + prepared_statement = PreparedStatement.from_message( + response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, + prepared_keyspace, + self._protocol_version, response.column_metadata, response.result_metadata_id, + self.cluster.column_encryption_policy) + prepared_statement.custom_payload = custom_payload + self.cluster.add_prepared(response.query_id, prepared_statement) + if self.cluster.prepare_on_all_hosts: + # prepare statement on all hosts + host = future._current_host + try: + self.prepare_on_all_nodes(future.message.query, host, future.message.keyspace) + except Exception: + log.exception("Error preparing query on all hosts:") + + return prepared_statement + + future._set_result_processor(_prepare_result_processor) + return future + + def _create_response_future(self, query, parameters, trace, custom_payload, timeout, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None, host=None): @@ -3118,36 +3170,18 @@ def prepare(self, query, custom_payload=None, keyspace=None): **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. + When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method + attempts to prepare given query on all hosts and waits for each node to respond. + Preparing CQL query on other nodes may fail, but error is not propagated + to the caller. + `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ - message = PrepareMessage(query=query, keyspace=keyspace) - future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) - try: - future.send_request() - response = future.result().one() - except Exception: - log.exception("Error preparing query:") - raise + future = self.prepare_async(query, custom_payload, keyspace) + return future.result() - prepared_keyspace = keyspace if keyspace else None - prepared_statement = PreparedStatement.from_message( - response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, - self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy) - prepared_statement.custom_payload = future.custom_payload - - self.cluster.add_prepared(response.query_id, prepared_statement) - - if self.cluster.prepare_on_all_hosts: - host = future._current_host - try: - self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) - except Exception: - log.exception("Error preparing query on all hosts:") - - return prepared_statement - - def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): + def prepare_on_all_nodes(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. @@ -4320,6 +4354,7 @@ class ResponseFuture(object): _col_types = None _final_exception = None _query_traces = None + _result_processor = None _callbacks = None _errbacks = None _current_host = None @@ -4951,10 +4986,20 @@ def result(self): """ self._event.wait() if self._final_result is not _NOT_SET: - return ResultSet(self, self._final_result) + if self._result_processor is not None: + return self._result_processor(self, self._final_result) + else: + return ResultSet(self, self._final_result) else: raise self._final_exception + def _set_result_processor(self, result_processor): + """ + Sets internal result processor which allows to control object + returned by :meth:`ResponseFuture.result()` method. + """ + self._result_processor = result_processor + def get_query_trace_ids(self): """ Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index a643b19c07..76abc49edc 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -22,6 +22,7 @@ from cassandra import InvalidRequest, DriverException from cassandra import ConsistencyLevel, ProtocolVersion +from cassandra.cluster import ResponseFuture from cassandra.query import PreparedStatement, UNSET_VALUE from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50, requirecassandra, BasicSharedKeyspaceUnitTestCase) @@ -121,6 +122,83 @@ def test_basic(self): results = self.session.execute(bound) self.assertEqual(results, [('x', 'y', 'z')]) + def test_basic_async(self): + """ + Test basic asynchronous PreparedStatement usage + """ + self.session.execute( + """ + DROP KEYSPACE IF EXISTS preparedtests + """ + ) + self.session.execute( + """ + CREATE KEYSPACE preparedtests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + + self.session.set_keyspace("preparedtests") + self.session.execute( + """ + CREATE TABLE cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + + prepared_future = self.session.prepare_async( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + self.assertIsInstance(prepared_future, ResponseFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(('a', 'b', 'c')) + self.session.execute(bound) + + prepared_future = self.session.prepare_async( + """ + SELECT * FROM cf0 WHERE a=? + """) + self.assertIsInstance(prepared_future, ResponseFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(('a')) + results = self.session.execute(bound) + self.assertEqual(results, [('a', 'b', 'c')]) + + # test with new dict binding + prepared_future = self.session.prepare_async( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + self.assertIsInstance(prepared_future, ResponseFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({ + 'a': 'x', + 'b': 'y', + 'c': 'z' + }) + self.session.execute(bound) + + prepared_future = self.session.prepare_async( + """ + SELECT * FROM cf0 WHERE a=? + """) + self.assertIsInstance(prepared_future, ResponseFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({'a': 'x'}) + results = self.session.execute(bound) + self.assertEqual(results, [('x', 'y', 'z')]) + def test_missing_primary_key(self): """ Ensure an InvalidRequest is thrown diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 89486802b4..44c605a48d 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -528,6 +528,19 @@ def test_prepare_on_all_hosts(self): session.execute(select_statement, (1, ), host=host) self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing")) + def test_prepare_async_on_all_hosts(self): + """ + Test to validate prepare_on_all_hosts flag is honored during prepare_async execution. + """ + clus = TestCluster(prepare_on_all_hosts=True) + self.addCleanup(clus.shutdown) + + session = clus.connect(wait_for_all_pools=True) + select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ? AND v = ? ALLOW FILTERING").result() + for host in clus.metadata.all_hosts(): + session.execute(select_statement, (1, 1), host=host) + self.assertEqual(0, self.mock_handler.get_message_count('debug', "Re-preparing")) + def test_prepare_batch_statement(self): """ Test to validate a prepared statement used inside a batch statement is correctly handled @@ -647,7 +660,6 @@ def test_prepared_statement(self): prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') prepared.consistency_level = ConsistencyLevel.ONE - self.assertEqual(str(prepared), '') @@ -717,6 +729,17 @@ def test_prepared_statements(self): self.session.execute_async(batch).result() self.confirm_results() + def test_prepare_async(self): + prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result() + + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(prepared, (i, i)) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + def test_bound_statements(self): prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)") @@ -942,7 +965,7 @@ def test_no_connection_refused_on_timeout(self): exception_type = type(result).__name__ if exception_type == "NoHostAvailable": self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) - if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]: + if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub", "ErrorMessage"]: if type(result).__name__ in ["WriteTimeout", "WriteFailure"]: received_timeout = True continue