From 5e5147b0a77cdd75e23d88be8261e99628648bd7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 28 May 2025 17:51:24 +0530 Subject: [PATCH 01/23] Separate Session related functionality from Connection class (#571) * decouple session class from existing Connection ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx * add open property to Connection to ensure maintenance of existing API Signed-off-by: varun-edachali-dbx * update unit tests to address ThriftBackend through session instead of through Connection Signed-off-by: varun-edachali-dbx * chore: move session specific tests from test_client to test_session Signed-off-by: varun-edachali-dbx * formatting (black) as in CONTRIBUTING.md Signed-off-by: varun-edachali-dbx * use connection open property instead of long chain through session Signed-off-by: varun-edachali-dbx * trigger integration workflow Signed-off-by: varun-edachali-dbx * fix: ensure open attribute of Connection never fails in case the openSession takes long, the initialisation of the session will not complete immediately. This could make the session attribute inaccessible. If the Connection is deleted in this time, the open() check will throw because the session attribute does not exist. Thus, we default to the Connection being closed in this case. This was not an issue before because open was a direct attribute of the Connection class. Caught in the integration tests. Signed-off-by: varun-edachali-dbx * fix: de-complicate earlier connection open logic earlier, one of the integration tests was failing because 'session was not an attribute of Connection'. This is likely tied to a local configuration issue related to unittest that was causing an error in the test suite itself. The tests are now passing without checking for the session attribute. https://github.com/databricks/databricks-sql-python/pull/567/commits/c676f9b0281cc3e4fe9c6d8216cc62fc75eade3b Signed-off-by: varun-edachali-dbx * Revert "fix: de-complicate earlier connection open logic" This reverts commit d6b1b196c98a6e9d8e593a88c34bbde010519ef4. Signed-off-by: varun-edachali-dbx * [empty commit] attempt to trigger ci e2e workflow Signed-off-by: varun-edachali-dbx * Update CODEOWNERS (#562) new codeowners Signed-off-by: varun-edachali-dbx * Enhance Cursor close handling and context manager exception management to prevent server side resource leaks (#554) * Enhance Cursor close handling and context manager exception management * tests * fmt * Fix Cursor.close() to properly handle CursorAlreadyClosedError * Remove specific test message from Cursor.close() error handling * Improve error handling in connection and cursor context managers to ensure proper closure during exceptions, including KeyboardInterrupt. Add tests for nested cursor management and verify operation closure on server-side errors. * add * add Signed-off-by: varun-edachali-dbx * PECOBLR-86 improve logging on python driver (#556) * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan * fixed format Signed-off-by: Sai Shree Pradhan * used lazy logging Signed-off-by: Sai Shree Pradhan * changed debug to error logs Signed-off-by: Sai Shree Pradhan * used lazy logging Signed-off-by: Sai Shree Pradhan --------- Signed-off-by: Sai Shree Pradhan Signed-off-by: varun-edachali-dbx * Revert "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit dbb2ec52306b91072a2ee842270c7113aece9aff, reversing changes made to 7192f117279d4f0adcbafcdf2238c18663324515. Signed-off-by: varun-edachali-dbx * Reapply "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit bdb83817f49e1d88a01679b11da8e55e8e80b42f. Signed-off-by: varun-edachali-dbx * fix: separate session opening logic from instantiation ensures correctness of self.session.open call in Connection Signed-off-by: varun-edachali-dbx * fix: use is_open attribute to denote session availability Signed-off-by: varun-edachali-dbx * fix: access thrift backend through session Signed-off-by: varun-edachali-dbx * chore: use get_handle() instead of private session attribute in client Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * fix: remove accidentally removed assertions Signed-off-by: varun-edachali-dbx --------- Signed-off-by: varun-edachali-dbx Signed-off-by: Sai Shree Pradhan Co-authored-by: Jothi Prakash Co-authored-by: Madhav Sainanee Co-authored-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 147 +++++++---------------- src/databricks/sql/session.py | 160 +++++++++++++++++++++++++ tests/e2e/test_driver.py | 2 +- tests/unit/test_client.py | 216 ++++------------------------------ tests/unit/test_session.py | 187 +++++++++++++++++++++++++++++ 5 files changed, 416 insertions(+), 296 deletions(-) create mode 100644 src/databricks/sql/session.py create mode 100644 tests/unit/test_session.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a85..d6a9e6b08 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -45,6 +45,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -224,66 +225,28 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + # Create the session + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, **kwargs, ) + self.session.open() - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + logger.info( + "Successfully opened connection with session " + + str(self.get_session_id_hex()) ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) @@ -342,34 +305,32 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the session ID from the Session object""" + return self.session.get_id() - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format from the Session object""" + return self.session.get_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp): + """Delegate to Session class static method""" + return Session.get_protocol_version(openSessionResp) + + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.is_open def cursor( self, @@ -386,7 +347,7 @@ def cursor( cursor = Cursor( self, - self.thrift_backend, + self.session.thrift_backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -402,28 +363,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -833,7 +776,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -896,7 +839,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -992,7 +935,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1018,7 +961,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1051,7 +994,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1086,7 +1029,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 000000000..f2f38d572 --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,160 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.thrift_backend import ThriftBackend + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + self.is_open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + self.session_configuration = session_configuration + self.catalog = catalog + self.schema = schema + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._handle = None + self.protocol_version = None + + def open(self) -> None: + self._open_session_resp = self.thrift_backend.open_session( + self.session_configuration, self.catalog, self.schema + ) + self._handle = self._open_session_resp.sessionHandle + self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.is_open = True + logger.info("Successfully opened session " + str(self.get_id_hex())) + + @staticmethod + def get_protocol_version(openSessionResp): + """ + Since the sessionHandle will sometimes have a serverProtocolVersion, it takes + precedence over the serverProtocolVersion defined in the OpenSessionResponse. + """ + if ( + openSessionResp.sessionHandle + and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") + and openSessionResp.sessionHandle.serverProtocolVersion + ): + return openSessionResp.sessionHandle.serverProtocolVersion + return openSessionResp.serverProtocolVersion + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + def get_handle(self): + return self._handle + + def get_id(self): + handle = self.get_handle() + if handle is None: + return None + return self.thrift_backend.handle_to_id(handle) + + def get_id_hex(self): + handle = self.get_handle() + if handle is None: + return None + return self.thrift_backend.handle_to_hex_id(handle) + + def close(self) -> None: + """Close the underlying session.""" + logger.info(f"Closing session {self.get_id_hex()}") + if not self.is_open: + logger.debug("Session appears to have been closed already") + return + + try: + self.thrift_backend.close_session(self.get_handle()) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + f"Attempted to close session that was already closed: {e}" + ) + else: + logger.warning( + f"Attempt to close session raised an exception at the server: {e}" + ) + except Exception as e: + logger.error(f"Attempt to close session raised a local exception: {e}") + + self.is_open = False diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index d0c721109..abe0e22d2 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -920,7 +920,7 @@ def test_cursor_error_handling(self): assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.thrift_backend.close_command(op_handle) + conn.session.thrift_backend.close_command(op_handle) cursor.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..51439b2b4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -83,105 +83,10 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - - @patch("databricks.sql.client.ThriftBackend") - def test_closing_connection_closes_commands(self, mock_thrift_client_class): - """Test that closing a connection properly closes commands. - - This test verifies that when a connection is closed: - 1. the active result set is marked as closed server-side - 2. The operation state is set to CLOSED - 3. backend.close_command is called only for commands that weren't already closed - - Args: - mock_thrift_client_class: Mock for ThriftBackend class - """ + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.client.ResultSet" % PACKAGE_NAME) + def test_closing_connection_closes_commands(self, mock_result_set_class): + # Test once with has_been_closed_server side, once without for closed in (True, False): with self.subTest(closed=closed): # Set initial state based on whether the command is already closed @@ -243,7 +148,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Should NOT have called backend.close_command (already closed) mock_backend.close_command.assert_not_called() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -253,7 +158,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -273,7 +178,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): thrift_backend=mock_backend, execute_response=Mock(), ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -285,7 +193,11 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -343,37 +255,14 @@ def test_context_manager_closes_cursor(self): mock_close.assert_called_once_with() cursor = client.Cursor(Mock(), Mock()) - cursor.close = Mock() - try: - with self.assertRaises(KeyboardInterrupt): - with cursor: - raise KeyboardInterrupt("Simulated interrupt") - finally: - cursor.close.assert_called() - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + cursor.close = Mock() - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close = Mock() try: with self.assertRaises(KeyboardInterrupt): - with connection: + with cursor: raise KeyboardInterrupt("Simulated interrupt") finally: - connection.close.assert_called() + cursor.close.assert_called() def dict_product(self, dicts): """ @@ -473,21 +362,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -496,33 +370,6 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) @@ -582,7 +429,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -595,7 +442,7 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): @@ -680,24 +527,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -716,7 +546,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): @@ -735,7 +565,7 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..eb392a229 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, +) + +import databricks.sql + + +class SessionTestSuite(unittest.TestCase): + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + self.assertEqual(args["server_hostname"], host) + self.assertEqual(args["http_path"], http_path) + connection.close() + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + self.assertIn(("foo", "bar"), call_args) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") + self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") + self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") + self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + self.assertIn(user_agent_header, http_headers) + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + self.assertIn(user_agent_header_with_entry, http_headers) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + +if __name__ == "__main__": + unittest.main() From 57370b350216b08b9e1254e95064674f2ca8b615 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 30 May 2025 22:24:43 +0530 Subject: [PATCH 02/23] Introduce Backend Interface (DatabricksClient) (#573) NOTE: the `test_complex_types` e2e test was not working at the time of this merge. The test must be triggered when the test is back up and running as intended. * remove excess logs, assertions, instantiations large merge artifacts Signed-off-by: varun-edachali-dbx * formatting (black) + remove excess log (merge artifact) Signed-off-by: varun-edachali-dbx * fix typing Signed-off-by: varun-edachali-dbx * remove un-necessary check Signed-off-by: varun-edachali-dbx * remove un-necessary replace call Signed-off-by: varun-edachali-dbx * introduce __str__ methods for CommandId and SessionId Signed-off-by: varun-edachali-dbx * docstrings for DatabricksClient interface Signed-off-by: varun-edachali-dbx * stronger typing of Cursor and ExecuteResponse Signed-off-by: varun-edachali-dbx * remove utility functions from backend interface, fix circular import Signed-off-by: varun-edachali-dbx * rename info to properties Signed-off-by: varun-edachali-dbx * newline for cleanliness Signed-off-by: varun-edachali-dbx * fix circular import Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * to_hex_id -> get_hex_id Signed-off-by: varun-edachali-dbx * better comment on protocol version getter Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * move guid to hex id to new utils module Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * move staging allowed local path to connection props Signed-off-by: varun-edachali-dbx * add strong return type for execute_command Signed-off-by: varun-edachali-dbx * skip auth, error handling in databricksclient interface Signed-off-by: varun-edachali-dbx * chore: docstring + line width Signed-off-by: varun-edachali-dbx * get_id -> get_guid Signed-off-by: varun-edachali-dbx * chore: docstring Signed-off-by: varun-edachali-dbx * fix: to_hex_id -> to_hex_guid Signed-off-by: varun-edachali-dbx --------- Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 344 ++++++++++++++++++ .../sql/{ => backend}/thrift_backend.py | 263 +++++++------ src/databricks/sql/backend/types.py | 306 ++++++++++++++++ src/databricks/sql/backend/utils/__init__.py | 3 + .../sql/backend/utils/guid_utils.py | 22 ++ src/databricks/sql/client.py | 124 ++++--- src/databricks/sql/session.py | 53 ++- src/databricks/sql/utils.py | 3 +- tests/e2e/test_driver.py | 27 +- tests/unit/test_client.py | 91 +++-- tests/unit/test_fetches.py | 13 +- tests/unit/test_fetches_bench.py | 4 +- tests/unit/test_parameters.py | 17 +- tests/unit/test_session.py | 91 +++-- tests/unit/test_thrift_backend.py | 230 +++++++----- 15 files changed, 1185 insertions(+), 406 deletions(-) create mode 100644 src/databricks/sql/backend/databricks_client.py rename src/databricks/sql/{ => backend}/thrift_backend.py (87%) create mode 100644 src/databricks/sql/backend/types.py create mode 100644 src/databricks/sql/backend/utils/__init__.py create mode 100644 src/databricks/sql/backend/utils/guid_utils.py diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py new file mode 100644 index 000000000..edff10159 --- /dev/null +++ b/src/databricks/sql/backend/databricks_client.py @@ -0,0 +1,344 @@ +""" +Abstract client interface for interacting with Databricks SQL services. + +Implementations of this class are responsible for: +- Managing connections to Databricks SQL services +- Executing SQL queries and commands +- Retrieving query results +- Fetching metadata about catalogs, schemas, tables, and columns +""" + +from abc import ABC, abstractmethod +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.backend.types import SessionId, CommandId +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions + + +class DatabricksClient(ABC): + # == Connection and Session Management == + @abstractmethod + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service. + + This method establishes a new session with the server and returns a session + identifier that can be used for subsequent operations. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + InvalidServerResponseError: If the server response is invalid or unexpected + """ + pass + + @abstractmethod + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + This method terminates the session identified by the given session ID and + releases any resources associated with it. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + pass + + # == Query Execution, Command Management == + @abstractmethod + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ) -> Optional[ExecuteResponse]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns an ExecuteResponse object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ + pass + + @abstractmethod + def cancel_command(self, command_id: CommandId) -> None: + """ + Cancels a running command or query. + + This method attempts to cancel a command that is currently being executed. + It can be called from a different thread than the one executing the command. + + Args: + command_id: The command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error canceling the command + """ + pass + + @abstractmethod + def close_command(self, command_id: CommandId) -> ttypes.TStatus: + """ + Closes a command and releases associated resources. + + This method informs the server that the client is done with the command + and any resources associated with it can be released. + + Args: + command_id: The command identifier to close + + Returns: + ttypes.TStatus: The status of the close operation + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error closing the command + """ + pass + + @abstractmethod + def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: + """ + Gets the current state of a query or command. + + This method retrieves the current execution state of a command from the server. + + Args: + command_id: The command identifier to check + + Returns: + ttypes.TOperationState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the state + ServerOperationError: If the command is in an error state + DatabaseError: If the command has been closed unexpectedly + """ + pass + + @abstractmethod + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> ExecuteResponse: + """ + Retrieves the results of a previously executed command. + + This method fetches the results of a command that was executed asynchronously + or retrieves additional results from a command that has more rows available. + + Args: + command_id: The command identifier for which to retrieve results + cursor: The cursor object that will handle the results + + Returns: + ExecuteResponse: An object containing the query results and metadata + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the results + """ + pass + + # == Metadata Operations == + @abstractmethod + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> ExecuteResponse: + """ + Retrieves a list of available catalogs. + + This method fetches metadata about all catalogs available in the current + session's context. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + + Returns: + ExecuteResponse: An object containing the catalog metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the catalogs + """ + pass + + @abstractmethod + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. + + This method fetches metadata about schemas available in the specified catalog + or all catalogs if no catalog is specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + + Returns: + ExecuteResponse: An object containing the schema metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the schemas + """ + pass + + @abstractmethod + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. + + This method fetches metadata about tables available in the specified catalog + and schema, or all catalogs and schemas if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) + + Returns: + ExecuteResponse: An object containing the table metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the tables + """ + pass + + @abstractmethod + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. + + This method fetches metadata about columns available in the specified table, + or all tables if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + column_name: Optional column name pattern to filter by + + Returns: + ExecuteResponse: An object containing the column metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the columns + """ + pass + + @property + @abstractmethod + def max_download_threads(self) -> int: + """ + Gets the maximum number of download threads for cloud fetch operations. + + Returns: + int: The maximum number of download threads + """ + pass diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py similarity index 87% rename from src/databricks/sql/thrift_backend.py rename to src/databricks/sql/backend/thrift_backend.py index e3dc38ad5..c09397c2f 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,9 +5,18 @@ import time import uuid import threading -from typing import List, Union +from typing import List, Optional, Union, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState +from databricks.sql.backend.types import ( + SessionId, + CommandId, + BackendType, +) +from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -41,6 +50,7 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.databricks_client import DatabricksClient logger = logging.getLogger(__name__) @@ -73,7 +83,7 @@ } -class ThriftBackend: +class ThriftDatabricksClient(DatabricksClient): CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE @@ -91,7 +101,6 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): # Internal arguments in **kwargs: @@ -150,7 +159,6 @@ def __init__( else: raise ValueError("No valid connection settings.") - self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -161,7 +169,7 @@ def __init__( ) # Cloud fetch - self.max_download_threads = kwargs.get("max_download_threads", 10) + self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options @@ -224,6 +232,10 @@ def __init__( self._request_lock = threading.RLock() + @property + def max_download_threads(self) -> int: + return self._max_download_threads + # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): # Configure retries & timing: use user-settings or defaults, and bound @@ -446,8 +458,10 @@ def attempt_request(attempt): logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) - error_message = ThriftBackend._extract_error_message_from_headers( - getattr(self._transport, "headers", {}) + error_message = ( + ThriftDatabricksClient._extract_error_message_from_headers( + getattr(self._transport, "headers", {}) + ) ) finally: # Calling `close()` here releases the active HTTP connection back to the pool @@ -483,7 +497,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftDatabricksClient._check_response_for_error(response) return response error_info = response_or_error_info @@ -534,7 +548,7 @@ def _check_session_configuration(self, session_configuration): ) ) - def open_session(self, session_configuration, catalog, schema): + def open_session(self, session_configuration, catalog, schema) -> SessionId: try: self._transport.open() session_configuration = { @@ -562,13 +576,22 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - return response + properties = ( + {"serverProtocolVersion": response.serverProtocolVersion} + if response.serverProtocolVersion + else {} + ) + return SessionId.from_thrift_handle(response.sessionHandle, properties) except: self._transport.close() raise - def close_session(self, session_handle) -> None: - req = ttypes.TCloseSessionReq(sessionHandle=session_handle) + def close_session(self, session_id: SessionId) -> None: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") + + req = ttypes.TCloseSessionReq(sessionHandle=thrift_handle) try: self.make_request(self._client.CloseSession, req) finally: @@ -583,7 +606,7 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.displayMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, ) @@ -592,18 +615,18 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.errorMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( "Command {} unexpectedly closed server side".format( - op_handle and self.guid_to_hex_id(op_handle.operationId.guid) + op_handle and guid_to_hex_id(op_handle.operationId.guid) ), { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid) + and guid_to_hex_id(op_handle.operationId.guid) }, ) @@ -707,7 +730,8 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftDatabricksClient._col_to_description(col) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -767,6 +791,9 @@ def _results_message_to_execute_response(self, resp, operation_state): ) else: arrow_queue_opt = None + + command_id = CommandId.from_thrift_handle(resp.operationHandle) + return ExecuteResponse( arrow_queue=arrow_queue_opt, status=operation_state, @@ -774,21 +801,24 @@ def _results_message_to_execute_response(self, resp, operation_state): has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=resp.operationHandle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) - def get_execution_result(self, op_handle, cursor): - - assert op_handle is not None + def get_execution_result( + self, command_id: CommandId, cursor: "Cursor" + ) -> ExecuteResponse: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=cursor.arraysize, maxBytes=cursor.buffer_size_bytes, @@ -834,7 +864,7 @@ def get_execution_result(self, op_handle, cursor): has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=op_handle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) @@ -857,51 +887,57 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, op_handle) -> "TOperationState": - poll_resp = self._poll_for_status(op_handle) + def get_query_state(self, command_id: CommandId) -> "TOperationState": + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") + + poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState - self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) + self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) return operation_state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation ) def execute_command( self, - operation, - session_handle, - max_rows, - max_bytes, - lz4_compression, - cursor, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", use_cloud_fetch=True, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ): - assert session_handle is not None + ) -> Optional[ExecuteResponse]: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") logger.debug( "ThriftBackend.execute_command(operation=%s, session_handle=%s)", operation, - session_handle, + thrift_handle, ) spark_arrow_types = ttypes.TSparkArrowTypes( @@ -913,7 +949,7 @@ def execute_command( intervalTypesAsArrow=False, ) req = ttypes.TExecuteStatementReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, statement=operation, runAsync=True, # For async operation we don't want the direct results @@ -938,14 +974,23 @@ def execute_command( if async_op: self._handle_execute_response_async(resp, cursor) + return None else: return self._handle_execute_response(resp, cursor) - def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): - assert session_handle is not None + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetCatalogsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -955,17 +1000,19 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): def get_schemas( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetSchemasReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -977,19 +1024,21 @@ def get_schemas( def get_tables( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, table_types=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetTablesReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1003,19 +1052,21 @@ def get_tables( def get_columns( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, column_name=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetColumnsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1028,7 +1079,9 @@ def get_columns( return self._handle_execute_response(resp, cursor) def _handle_execute_response(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( @@ -1039,28 +1092,31 @@ def _handle_execute_response(self, resp, cursor): return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) def fetch_results( self, - op_handle, - max_rows, - max_bytes, - expected_row_start_offset, - lz4_compressed, + command_id: CommandId, + max_rows: int, + max_bytes: int, + expected_row_start_offset: int, + lz4_compressed: bool, arrow_schema_bytes, description, use_cloud_fetch=True, ): - assert op_handle is not None + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=max_rows, maxBytes=max_bytes, @@ -1089,46 +1145,21 @@ def fetch_results( return queue, resp.hasMoreRows - def close_command(self, op_handle): - logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) - req = ttypes.TCloseOperationReq(operationHandle=op_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + def cancel_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - def cancel_command(self, active_op_handle): - logger.debug( - "Cancelling command {}".format( - self.guid_to_hex_id(active_op_handle.operationId.guid) - ) - ) - req = ttypes.TCancelOperationReq(active_op_handle) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - @staticmethod - def handle_to_id(session_handle): - return session_handle.sessionId.guid - - @staticmethod - def handle_to_hex_id(session_handle: TCLIService.TSessionHandle): - this_uuid = uuid.UUID(bytes=session_handle.sessionId.guid) - return str(this_uuid) + def close_command(self, command_id: CommandId): + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - @staticmethod - def guid_to_hex_id(guid: bytes) -> str: - """Return a hexadecimal string instead of bytes - - Example: - IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' - OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' - - If conversion to hexadecimal fails, the original bytes are returned - """ - - this_uuid: Union[bytes, uuid.UUID] - - try: - this_uuid = uuid.UUID(bytes=guid) - except Exception as e: - logger.debug(f"Unable to convert bytes to UUID: {bytes} -- {str(e)}") - this_uuid = guid - return str(this_uuid) + logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) + req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) + resp = self.make_request(self._client.CloseOperation, req) + return resp.status diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py new file mode 100644 index 000000000..740be0199 --- /dev/null +++ b/src/databricks/sql/backend/types.py @@ -0,0 +1,306 @@ +from enum import Enum +from typing import Dict, Optional, Any, Union +import logging + +from databricks.sql.backend.utils import guid_to_hex_id + +logger = logging.getLogger(__name__) + + +class BackendType(Enum): + """ + Enum representing the type of backend + """ + + THRIFT = "thrift" + SEA = "sea" + + +class SessionId: + """ + A normalized session identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TSessionHandle and + SEA's session ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + properties: Optional[Dict[str, Any]] = None, + ): + """ + Initialize a SessionId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the session + secret: The secret part of the identifier (only used for Thrift) + properties: Additional information about the session + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.properties = properties or {} + + def __str__(self) -> str: + """ + Return a string representation of the SessionId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the session ID + """ + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.get_hex_guid()}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle( + cls, session_handle, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a Thrift session handle. + + Args: + session_handle: A TSessionHandle object from the Thrift API + + Returns: + A SessionId instance + """ + if session_handle is None: + return None + + guid_bytes = session_handle.sessionId.guid + secret_bytes = session_handle.sessionId.secret + + if session_handle.serverProtocolVersion is not None: + if properties is None: + properties = {} + properties["serverProtocolVersion"] = session_handle.serverProtocolVersion + + return cls(BackendType.THRIFT, guid_bytes, secret_bytes, properties) + + @classmethod + def from_sea_session_id( + cls, session_id: str, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a SEA session ID. + + Args: + session_id: The SEA session ID string + + Returns: + A SessionId instance + """ + return cls(BackendType.SEA, session_id, properties=properties) + + def to_thrift_handle(self): + """ + Convert this SessionId to a Thrift TSessionHandle. + + Returns: + A TSessionHandle object or None if this is not a Thrift session ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + server_protocol_version = self.properties.get("serverProtocolVersion") + return ttypes.TSessionHandle( + sessionId=handle_identifier, serverProtocolVersion=server_protocol_version + ) + + def to_sea_session_id(self): + """ + Get the SEA session ID string. + + Returns: + The session ID string or None if this is not a SEA session ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def get_guid(self) -> Any: + """ + Get the ID of the session. + """ + return self.guid + + def get_hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the session ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) + + def get_protocol_version(self): + """ + Get the server protocol version for this session. + + Returns: + The server protocol version or None if it does not exist + It is not expected to exist for SEA sessions. + """ + return self.properties.get("serverProtocolVersion") + + +class CommandId: + """ + A normalized command identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TOperationHandle and + SEA's statement ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, + ): + """ + Initialize a CommandId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the command + secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle(cls, operation_handle): + """ + Create a CommandId from a Thrift operation handle. + + Args: + operation_handle: A TOperationHandle object from the Thrift API + + Returns: + A CommandId instance + """ + if operation_handle is None: + return None + + guid_bytes = operation_handle.operationId.guid + secret_bytes = operation_handle.operationId.secret + + return cls( + BackendType.THRIFT, + guid_bytes, + secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, + ) + + @classmethod + def from_sea_statement_id(cls, statement_id: str): + """ + Create a CommandId from a SEA statement ID. + + Args: + statement_id: The SEA statement ID string + + Returns: + A CommandId instance + """ + return cls(BackendType.SEA, statement_id) + + def to_thrift_handle(self): + """ + Convert this CommandId to a Thrift TOperationHandle. + + Returns: + A TOperationHandle object or None if this is not a Thrift command ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + return ttypes.TOperationHandle( + operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, + ) + + def to_sea_statement_id(self): + """ + Get the SEA statement ID string. + + Returns: + The statement ID string or None if this is not a SEA statement ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def to_hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the command ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) diff --git a/src/databricks/sql/backend/utils/__init__.py b/src/databricks/sql/backend/utils/__init__.py new file mode 100644 index 000000000..3d601e5e6 --- /dev/null +++ b/src/databricks/sql/backend/utils/__init__.py @@ -0,0 +1,3 @@ +from .guid_utils import guid_to_hex_id + +__all__ = ["guid_to_hex_id"] diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py new file mode 100644 index 000000000..28975171f --- /dev/null +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -0,0 +1,22 @@ +import uuid +import logging + +logger = logging.getLogger(__name__) + + +def guid_to_hex_id(guid: bytes) -> str: + """Return a hexadecimal string instead of bytes + + Example: + IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' + OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' + + If conversion to hexadecimal fails, a string representation of the original + bytes is returned + """ + try: + this_uuid = uuid.UUID(bytes=guid) + except Exception as e: + logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}") + return str(guid) + return str(this_uuid) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d6a9e6b08..1c384c735 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,7 +21,8 @@ CursorAlreadyClosedError, ) from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( ExecuteResponse, ParamEscaper, @@ -46,6 +47,7 @@ from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session +from databricks.sql.backend.types import CommandId, BackendType from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -230,7 +232,6 @@ def read(self) -> Optional[OAuthToken]: self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) self._cursors = [] # type: List[Cursor] - # Create the session self.session = Session( server_hostname, http_path, @@ -243,14 +244,10 @@ def read(self) -> Optional[OAuthToken]: ) self.session.open() - logger.info( - "Successfully opened connection with session " - + str(self.get_session_id_hex()) - ) - self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) ) + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -305,11 +302,11 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - """Get the session ID from the Session object""" + """Get the raw session ID (backend-specific)""" return self.session.get_id() def get_session_id_hex(self): - """Get the session ID in hex format from the Session object""" + """Get the session ID in hex format""" return self.session.get_id_hex() @staticmethod @@ -347,7 +344,7 @@ def cursor( cursor = Cursor( self, - self.session.thrift_backend, + self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -380,7 +377,7 @@ class Cursor: def __init__( self, connection: Connection, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, ) -> None: @@ -399,8 +396,8 @@ def __init__( # Note that Cursor closed => active result set closed, but not vice versa self.open = True self.executing_command_id = None - self.thrift_backend = thrift_backend - self.active_op_handle = None + self.backend = backend + self.active_command_id = None self.escaper = ParamEscaper() self.lastrowid = None @@ -774,9 +771,9 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.execute_command( + execute_response = self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -786,10 +783,12 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) + assert execute_response is not None # async_op = False above + self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, self.connection.use_cloud_fetch, @@ -797,7 +796,7 @@ def execute( if execute_response.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -837,9 +836,9 @@ def execute_async( self._check_not_closed() self._close_and_clear_active_result_set() - self.thrift_backend.execute_command( + self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -859,7 +858,9 @@ def get_query_state(self) -> "TOperationState": :return: """ self._check_not_closed() - return self.thrift_backend.get_query_state(self.active_op_handle) + if self.active_command_id is None: + raise Error("No active command to get state for") + return self.backend.get_query_state(self.active_command_id) def is_query_pending(self): """ @@ -889,20 +890,20 @@ def get_async_execution_result(self): operation_state = self.get_query_state() if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.thrift_backend.get_execution_result( - self.active_op_handle, self + execute_response = self.backend.get_execution_result( + self.active_command_id, self ) self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, ) if execute_response.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -934,8 +935,8 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_catalogs( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -943,9 +944,10 @@ def catalogs(self) -> "Cursor": self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -960,8 +962,8 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_schemas( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -971,9 +973,10 @@ def schemas( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -993,8 +996,8 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_tables( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_tables( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1006,9 +1009,10 @@ def tables( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -1028,8 +1032,8 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_columns( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_columns( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1041,9 +1045,10 @@ def columns( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -1117,8 +1122,8 @@ def cancel(self) -> None: The command should be closed to free resources from the server. This method can be called from another thread. """ - if self.active_op_handle is not None: - self.thrift_backend.cancel_command(self.active_op_handle) + if self.active_command_id is not None: + self.backend.cancel_command(self.active_command_id) else: logger.warning( "Attempting to cancel a command, but there is no " @@ -1130,9 +1135,9 @@ def close(self) -> None: self.open = False # Close active operation handle if it exists - if self.active_op_handle: + if self.active_command_id: try: - self.thrift_backend.close_command(self.active_op_handle) + self.backend.close_command(self.active_command_id) except RequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") @@ -1141,7 +1146,7 @@ def close(self) -> None: except Exception as e: logging.warning(f"Error closing operation handle: {e}") finally: - self.active_op_handle = None + self.active_command_id = None if self.active_result_set: self._close_and_clear_active_result_set() @@ -1154,8 +1159,8 @@ def query_id(self) -> Optional[str]: This attribute will be ``None`` if the cursor has not had an operation invoked via the execute method yet, or if cursor was closed. """ - if self.active_op_handle is not None: - return str(UUID(bytes=self.active_op_handle.operationId.guid)) + if self.active_command_id is not None: + return self.active_command_id.to_hex_guid() return None @property @@ -1207,7 +1212,7 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -1217,18 +1222,20 @@ def __init__( :param connection: The parent connection that was used to execute this command :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param backend: The DatabricksClient instance to use for fetching results + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results """ self.connection = connection - self.command_id = execute_response.command_handle + self.command_id = execute_response.command_id self.op_state = execute_response.status self.has_been_closed_server_side = execute_response.has_been_closed_server_side self.has_more_rows = execute_response.has_more_rows self.buffer_size_bytes = result_buffer_size_bytes self.lz4_compressed = execute_response.lz4_compressed self.arraysize = arraysize - self.thrift_backend = thrift_backend + self.backend = backend self.description = execute_response.description self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 @@ -1251,9 +1258,16 @@ def __iter__(self): break def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available - results, has_more_rows = self.thrift_backend.fetch_results( - op_handle=self.command_id, + if not isinstance(self.backend, ThriftDatabricksClient): + # currently, we are assuming only the Thrift backend exists + raise NotImplementedError( + "Fetching further result batches is currently only implemented for the Thrift backend." + ) + + # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results + thrift_backend_instance = self.backend # type: ThriftDatabricksClient + results, has_more_rows = thrift_backend_instance.fetch_results( + command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, expected_row_start_offset=self._next_row_index, @@ -1468,19 +1482,21 @@ def close(self) -> None: If the connection has not been closed, and the cursor has not already been closed on the server for some other reason, issue a request to the server to close it. """ + # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to + # when we generalise the ResultSet try: if ( - self.op_state != self.thrift_backend.CLOSED_OP_STATE + self.op_state != ttypes.TOperationState.CLOSED_STATE and not self.has_been_closed_server_side and self.connection.open ): - self.thrift_backend.close_command(self.command_id) + self.backend.close_command(self.command_id) except RequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = self.thrift_backend.CLOSED_OP_STATE + self.op_state = ttypes.TOperationState.CLOSED_STATE @staticmethod def _get_schema_description(table_schema_message): diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f2f38d572..2ee5e53f1 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -7,7 +7,9 @@ from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) @@ -71,7 +73,7 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.thrift_backend = ThriftBackend( + self.backend: DatabricksClient = ThriftDatabricksClient( self.host, self.port, http_path, @@ -82,31 +84,21 @@ def __init__( **kwargs, ) - self._handle = None self.protocol_version = None - def open(self) -> None: - self._open_session_resp = self.thrift_backend.open_session( - self.session_configuration, self.catalog, self.schema + def open(self): + self._session_id = self.backend.open_session( + session_configuration=self.session_configuration, + catalog=self.catalog, + schema=self.schema, ) - self._handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.protocol_version = self.get_protocol_version(self._session_id) self.is_open = True logger.info("Successfully opened session " + str(self.get_id_hex())) @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_protocol_version(session_id: SessionId): + return session_id.get_protocol_version() @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -118,20 +110,17 @@ def server_parameterized_queries_enabled(protocolVersion): else: return False - def get_handle(self): - return self._handle + def get_session_id(self) -> SessionId: + """Get the normalized session ID""" + return self._session_id def get_id(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_id(handle) + """Get the raw session ID (backend-specific)""" + return self._session_id.get_guid() - def get_id_hex(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_hex_id(handle) + def get_id_hex(self) -> str: + """Get the session ID in hex format""" + return self._session_id.get_hex_guid() def close(self) -> None: """Close the underlying session.""" @@ -141,7 +130,7 @@ def close(self) -> None: return try: - self.thrift_backend.close_session(self.get_handle()) + self.backend.close_session(self._session_id) except RequestError as e: if isinstance(e.args[1], SessionAlreadyClosedError): logger.info("Session was closed by a prior request") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0ce2fa169..733d425d6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -26,6 +26,7 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.types import CommandId from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter @@ -345,7 +346,7 @@ def _create_empty_table(self) -> "pyarrow.Table": ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_handle arrow_queue arrow_schema_bytes", + "command_id arrow_queue arrow_schema_bytes", ) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index abe0e22d2..c446b6715 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -822,11 +822,10 @@ def test_close_connection_closes_cursors(self): # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( - operationHandle=ars.command_id, getProgressUpdate=False - ) - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( - status_request + operationHandle=ars.command_id.to_thrift_handle(), + getProgressUpdate=False, ) + op_status_at_server = ars.backend._client.GetOperationStatus(status_request) assert ( op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE @@ -836,7 +835,7 @@ def test_close_connection_closes_cursors(self): # When connection closes, any cursor operations should no longer exist at the server with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + op_status_at_server = ars.backend._client.GetOperationStatus( status_request ) @@ -866,9 +865,9 @@ def test_cursor_close_properly_closes_operation(self): cursor = conn.cursor() try: cursor.execute("SELECT 1 AS test") - assert cursor.active_op_handle is not None + assert cursor.active_command_id is not None cursor.close() - assert cursor.active_op_handle is None + assert cursor.active_command_id is None assert not cursor.open finally: if cursor.open: @@ -894,19 +893,19 @@ def test_nested_cursor_context_managers(self): with self.connection() as conn: with conn.cursor() as cursor1: cursor1.execute("SELECT 1 AS test1") - assert cursor1.active_op_handle is not None + assert cursor1.active_command_id is not None with conn.cursor() as cursor2: cursor2.execute("SELECT 2 AS test2") - assert cursor2.active_op_handle is not None + assert cursor2.active_command_id is not None # After inner context manager exit, cursor2 should be not open assert not cursor2.open - assert cursor2.active_op_handle is None + assert cursor2.active_command_id is None # After outer context manager exit, cursor1 should be not open assert not cursor1.open - assert cursor1.active_op_handle is None + assert cursor1.active_command_id is None def test_cursor_error_handling(self): """Test that cursor close handles errors properly to prevent orphaned operations.""" @@ -915,12 +914,12 @@ def test_cursor_error_handling(self): cursor.execute("SELECT 1 AS test") - op_handle = cursor.active_op_handle + op_handle = cursor.active_command_id assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.session.thrift_backend.close_command(op_handle) + conn.session.backend.close_command(op_handle) cursor.close() @@ -940,7 +939,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 51439b2b4..f77cab782 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -16,13 +16,14 @@ TOperationState, TOperationType, ) -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row +from databricks.sql.client import CommandId from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests @@ -30,10 +31,10 @@ from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftDatabricksClientMockFactory: @classmethod def new(cls): - ThriftBackendMock = Mock(spec=ThriftBackend) + ThriftBackendMock = Mock(spec=ThriftDatabricksClient) ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) @@ -44,7 +45,7 @@ def new(cls): description=None, arrow_queue=None, is_staging_operation=False, - command_handle=b"\x22", + command_id=None, has_been_closed_server_side=True, has_more_rows=True, lz4_compressed=True, @@ -83,7 +84,10 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -148,7 +152,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): # Should NOT have called backend.close_command (already closed) mock_backend.close_command.assert_not_called() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -158,7 +162,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -175,7 +179,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_backend = Mock() result_set = client.ResultSet( connection=mock_connection, - thrift_backend=mock_backend, + backend=mock_backend, execute_response=Mock(), ) # Setup session mock on the mock_connection @@ -205,7 +209,7 @@ def test_closing_result_set_hard_closes_commands(self): result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle + mock_results_response.command_id ) @patch("%s.client.ResultSet" % PACKAGE_NAME) @@ -217,7 +221,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_set_class.side_effect = mock_result_sets cursor = client.Cursor( - connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() + connection=Mock(), backend=ThriftDatabricksClientMockFactory.new() ) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -255,11 +259,11 @@ def test_context_manager_closes_cursor(self): mock_close.assert_called_once_with() cursor = client.Cursor(Mock(), Mock()) - cursor.close = Mock() + cursor.close = Mock() try: with self.assertRaises(KeyboardInterrupt): - with cursor: + with cursor: raise KeyboardInterrupt("Simulated interrupt") finally: cursor.close.assert_called() @@ -276,7 +280,7 @@ def dict_product(self, dicts): """ return (dict(zip(dicts.keys(), x)) for x in itertools.product(*dicts.values())) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -297,7 +301,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -320,7 +324,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -346,10 +350,10 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe def test_cancel_command_calls_the_backend(self): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) - mock_op_handle = Mock() - cursor.active_op_handle = mock_op_handle + mock_command_id = Mock() + cursor.active_command_id = mock_command_id cursor.cancel() - mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) + mock_thrift_backend.cancel_command.assert_called_with(mock_command_id) @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @@ -371,7 +375,7 @@ def test_version_is_canonical(self): self.assertIsNotNone(re.match(canonical_version_re, version)) def test_execute_parameter_passthrough(self): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = ThriftDatabricksClientMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) tests = [ @@ -395,16 +399,16 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend + self, mock_result_set_class ): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] mock_result_set_class.side_effect = mock_result_set_instances - mock_thrift_backend = ThriftBackendMockFactory.new() - cursor = client.Cursor(Mock(), mock_thrift_backend()) + mock_backend = ThriftDatabricksClientMockFactory.new() + + cursor = client.Cursor(Mock(), mock_backend) params = [{"x": None}, {"x": "foo1"}, {"x": "bar2"}] expected_queries = ["SELECT NULL", "SELECT 'foo1'", "SELECT 'bar2'"] @@ -412,13 +416,13 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), + len(mock_backend.execute_command.call_args_list), len(expected_queries), "Expected execute_command to be called the same number of times as params were passed", ) for expected_query, call_args in zip( - expected_queries, mock_thrift_backend.execute_command.call_args_list + expected_queries, mock_backend.execute_command.call_args_list ): self.assertEqual(call_args[1]["operation"], expected_query) @@ -429,7 +433,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -442,14 +446,14 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): c.rollback() @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): def make_fake_row_slice(n_rows): mock_slice = Mock() @@ -474,7 +478,7 @@ def make_fake_row_slice(n_rows): self.assertEqual(cursor.rownumber, 29) @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_thrift_backend = mock_thrift_backend_class.return_value mock_table = Mock() @@ -527,7 +531,7 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -546,13 +550,13 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - ThriftBackendMockFactory.apply_property_to_mock( + ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) mock_client_class.execute_command.return_value = mock_execute_response @@ -565,7 +569,10 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" @@ -574,9 +581,13 @@ def test_access_current_query_id(self): self.assertIsNone(cursor.query_id) - cursor.active_op_handle = TOperationHandle( - operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT, + cursor.active_command_id = CommandId.from_thrift_handle( + TOperationHandle( + operationId=THandleIdentifier( + guid=UUID(operation_id).bytes, secret=0x00 + ), + operationType=TOperationType.EXECUTE_STATEMENT, + ) ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) @@ -587,18 +598,18 @@ def test_cursor_close_handles_exception(self): """Test that Cursor.close() handles exceptions from close_command properly.""" mock_backend = Mock() mock_connection = Mock() - mock_op_handle = Mock() + mock_command_id = Mock() mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) - cursor.active_op_handle = mock_op_handle + cursor.active_command_id = mock_command_id cursor.close() - mock_backend.close_command.assert_called_once_with(mock_op_handle) + mock_backend.close_command.assert_called_once_with(mock_command_id) - self.assertIsNone(cursor.active_op_handle) + self.assertIsNone(cursor.active_command_id) self.assertFalse(cursor.open) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2cb..1c6a1b18d 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -9,6 +9,7 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -39,14 +40,14 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) rs = client.ResultSet( connection=Mock(), - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, @@ -64,7 +65,7 @@ def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 def fetch_results( - op_handle, + command_id, max_rows, max_bytes, expected_row_start_offset, @@ -79,13 +80,13 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 rs = client.ResultSet( connection=Mock(), - thrift_backend=mock_thrift_backend, + backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -95,7 +96,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=None, arrow_schema_bytes=None, is_staging_operation=False, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 552872221..b302c00da 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -31,13 +31,13 @@ def make_dummy_result_set_from_initial_results(arrow_table): arrow_queue = ArrowQueue(arrow_table, arrow_table.num_rows, 0) rs = client.ResultSet( connection=None, - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema=arrow_table.schema, ), diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 249730789..65e65faff 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -24,6 +24,7 @@ MapParameter, ArrayParameter, ) +from databricks.sql.backend.types import SessionId from databricks.sql.parameters.native import ( TDbsqlParameter, TSparkParameter, @@ -46,7 +47,10 @@ class TestSessionHandleChecks(object): ( TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle(1, None), + sessionHandle=TSessionHandle( + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=None, + ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ), @@ -55,7 +59,8 @@ class TestSessionHandleChecks(object): TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, sessionHandle=TSessionHandle( - 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, @@ -63,7 +68,13 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - assert Connection.get_protocol_version(test_input) == expected + properties = ( + {"serverProtocolVersion": test_input.serverProtocolVersion} + if test_input.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) + assert Connection.get_protocol_version(session_id) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index eb392a229..858119f92 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4,7 +4,10 @@ from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, + TSessionHandle, + THandleIdentifier, ) +from databricks.sql.backend.types import SessionId, BackendType import databricks.sql @@ -21,22 +24,23 @@ class SessionTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.close() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -63,7 +67,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -71,7 +75,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -87,7 +91,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -108,22 +112,23 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: pass - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -133,54 +138,62 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() + + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) + # Check that open_session was called with the correct session_configuration as keyword argument + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["session_configuration"], mock_session_config) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + # Check that open_session was called with the correct catalog and schema as keyword arguments + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["catalog"], mock_cat) + self.assertEqual(call_kwargs["schema"], mock_schem) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) # not strictly necessary as the refcount is 0, but just to be sure gc.collect() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") if __name__ == "__main__": diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..41a2a5800 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -17,7 +17,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.types import CommandId, SessionId, BackendType def retry_policy_factory(): @@ -51,6 +52,7 @@ class ThriftBackendTestSuite(unittest.TestCase): open_session_resp = ttypes.TOpenSessionResp( status=okay_status, serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, + sessionHandle=session_handle, ) metadata_resp = ttypes.TGetResultSetMetadataResp( @@ -73,7 +75,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -92,7 +94,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -126,14 +128,16 @@ def test_hive_schema_to_arrow_schema_preserves_column_names(self): ] t_table_schema = ttypes.TTableSchema(columns) - arrow_schema = ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + arrow_schema = ThriftDatabricksClient._hive_schema_to_arrow_schema( + t_table_schema + ) self.assertEqual(arrow_schema.field(0).name, "column 1") self.assertEqual(arrow_schema.field(1).name, "column 2") self.assertEqual(arrow_schema.field(2).name, "column 2") self.assertEqual(arrow_schema.field(3).name, "") - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value bad_protocol_versions = [ @@ -163,7 +167,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): "expected server to use a protocol version", str(cm.exception) ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value good_protocol_versions = [ @@ -174,7 +178,9 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): for protocol_version in good_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version + status=self.okay_status, + serverProtocolVersion=protocol_version, + sessionHandle=self.session_handle, ) thrift_backend = self._make_fake_thrift_backend() @@ -182,7 +188,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -229,7 +235,7 @@ def test_tls_cert_args_are_propagated( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -315,7 +321,7 @@ def test_tls_no_verify_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -339,7 +345,7 @@ def test_tls_verify_hostname_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -356,7 +362,7 @@ def test_tls_verify_hostname_is_respected( @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -371,7 +377,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname", 123, "path_value", @@ -386,7 +392,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname/", 123, "path_value", @@ -401,7 +407,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -413,7 +419,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -423,7 +429,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -434,7 +440,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -467,9 +473,9 @@ def test_non_primitive_types_raise_error(self): t_table_schema = ttypes.TTableSchema(columns) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + ThriftDatabricksClient._hive_schema_to_arrow_schema(t_table_schema) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_description(t_table_schema) + ThriftDatabricksClient._hive_schema_to_description(t_table_schema) def test_hive_schema_to_description_preserves_column_names_and_types(self): # Full coverage of all types is done in integration tests, this is just a @@ -493,7 +499,7 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, @@ -532,7 +538,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, [ @@ -545,7 +551,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -589,7 +595,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -628,7 +634,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -642,7 +648,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_checks_operation_state_in_polls( self, tcli_service_class ): @@ -672,7 +678,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( ) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -686,7 +692,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( if op_state_resp.errorMessage: self.assertIn(op_state_resp.errorMessage, str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -710,7 +716,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -724,7 +730,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_direct_results_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -750,7 +756,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -812,7 +818,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -825,7 +831,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( self, tcli_service_class ): @@ -863,7 +869,7 @@ def test_handle_execute_response_can_handle_without_direct_results( op_state_2, op_state_3, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -900,7 +906,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -917,7 +923,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ttypes.TOperationState.FINISHED_STATE, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value arrow_schema_mock = MagicMock(name="Arrow schema mock") @@ -946,7 +952,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value hive_schema_mock = MagicMock(name="Hive schema mock") @@ -976,7 +982,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): @@ -1020,7 +1026,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): @@ -1064,7 +1070,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1075,7 +1081,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( self.assertEqual(has_more_rows, has_more_rows_resp) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue tcli_service_instance = tcli_service_class.return_value @@ -1108,7 +1114,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1117,7 +1123,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): ssl_options=SSLOptions(), ) arrow_queue, has_more_results = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1128,14 +1134,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1157,14 +1163,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1185,14 +1191,14 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1222,14 +1228,14 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1263,14 +1269,14 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1304,12 +1310,12 @@ def test_get_columns_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1320,10 +1326,10 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1331,16 +1337,17 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_command(self.operation_handle) + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.close_command(command_id) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1348,13 +1355,14 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_session(self.session_handle) + session_id = SessionId.from_thrift_handle(self.session_handle) + thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( self, tcli_service_class ): @@ -1392,7 +1400,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1403,12 +1411,16 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") - @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") + @patch( + "databricks.sql.backend.thrift_backend.convert_arrow_based_set_to_arrow_table" + ) + @patch( + "databricks.sql.backend.thrift_backend.convert_column_based_set_to_arrow_table" + ) def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1443,7 +1455,7 @@ def test_create_arrow_table_calls_correct_conversion_method( def test_convert_arrow_based_set_to_arrow_table( self, open_stream_mock, lz4_decompress_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1597,17 +1609,18 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = self._make_fake_thrift_backend() - active_op_handle_mock = Mock() - thrift_backend.cancel_command(active_op_handle_mock) + # Create a proper CommandId from the existing operation_handle + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.cancel_command(command_id) self.assertEqual( tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, - active_op_handle_mock, + self.operation_handle, ) def test_handle_execute_response_sets_active_op_handle(self): @@ -1615,19 +1628,27 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() thrift_backend._results_message_to_execute_response = Mock() + + # Create a mock response with a real operation handle mock_resp = Mock() + mock_resp.operationHandle = ( + self.operation_handle + ) # Use the real operation handle from the test class mock_cursor = Mock() thrift_backend._handle_execute_response(mock_resp, mock_cursor) - self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + self.assertEqual( + mock_resp.operationHandle, mock_cursor.active_command_id.to_thrift_handle() + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class @@ -1654,7 +1675,7 @@ def test_make_request_will_retry_GetOperationStatus( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1681,7 +1702,7 @@ def test_make_request_will_retry_GetOperationStatus( ) with self.assertLogs( - "databricks.sql.thrift_backend", level=logging.WARNING + "databricks.sql.backend.thrift_backend", level=logging.WARNING ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1702,7 +1723,8 @@ def test_make_request_will_retry_GetOperationStatus( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos @@ -1731,7 +1753,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1763,7 +1785,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1779,7 +1801,8 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class @@ -1791,7 +1814,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1820,7 +1843,7 @@ def test_make_request_will_read_error_message_headers_if_set( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1944,7 +1967,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100, } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1959,7 +1982,12 @@ def test_retry_args_passthrough(self, mock_http_client): @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): + for k, ( + _, + _, + min, + max, + ) in databricks.sql.backend.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ( (min - 1, min), (max + 1, max), @@ -1970,7 +1998,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1986,7 +2014,7 @@ def test_retry_args_bounding(self, mock_http_client): for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_configuration_passthrough(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1998,7 +2026,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2011,12 +2039,12 @@ def test_configuration_passthrough(self, tcli_client_class): open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.configuration, expected_config) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2036,13 +2064,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, canUseMultipleCatalogs=can_use_multiple_cats, initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), + sessionHandle=self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2066,14 +2095,14 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_set_in_open_session_req( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2086,13 +2115,13 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2126,7 +2155,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( ) backend.open_session({}, cat, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -2135,9 +2164,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, canUseMultipleCatalogs=True, initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), + sessionHandle=self.session_handle, ) - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2154,8 +2184,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class ): @@ -2172,7 +2204,7 @@ def test_execute_command_sets_complex_type_fields_correctly( if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", From 75752bf66a1999a0cabfbccf66b06da15f3ca36f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 12:10:35 +0530 Subject: [PATCH 03/23] Implement ResultSet Abstraction (backend interfaces for fetch phase) (#574) * ensure backend client returns a ResultSet type in backend tests Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * newline for cleanliness Signed-off-by: varun-edachali-dbx * fix circular import Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * to_hex_id -> get_hex_id Signed-off-by: varun-edachali-dbx * better comment on protocol version getter Signed-off-by: varun-edachali-dbx * formatting (black) Signed-off-by: varun-edachali-dbx * stricter typing for cursor Signed-off-by: varun-edachali-dbx * correct typing Signed-off-by: varun-edachali-dbx * correct tests and merge artifacts Signed-off-by: varun-edachali-dbx * remove accidentally modified workflow files remnants of old merge Signed-off-by: varun-edachali-dbx * chore: remove accidentally modified workflow files Signed-off-by: varun-edachali-dbx * add back accidentally removed docstrings Signed-off-by: varun-edachali-dbx * clean up docstrings Signed-off-by: varun-edachali-dbx * log hex Signed-off-by: varun-edachali-dbx * remove unnecessary _replace call Signed-off-by: varun-edachali-dbx * add __str__ for CommandId Signed-off-by: varun-edachali-dbx * take TOpenSessionResp in get_protocol_version to maintain existing interface Signed-off-by: varun-edachali-dbx * active_op_handle -> active_mmand_id Signed-off-by: varun-edachali-dbx * ensure None returned for close_command Signed-off-by: varun-edachali-dbx * account for ResultSet return in new pydocs Signed-off-by: varun-edachali-dbx * pydoc for types Signed-off-by: varun-edachali-dbx * move common state to ResultSet aprent Signed-off-by: varun-edachali-dbx * stronger typing in resultSet behaviour Signed-off-by: varun-edachali-dbx * remove redundant patch in test Signed-off-by: varun-edachali-dbx * add has_been_closed_server_side assertion Signed-off-by: varun-edachali-dbx * remove redundancies in tests Signed-off-by: varun-edachali-dbx * more robust close check Signed-off-by: varun-edachali-dbx * use normalised state in e2e test Signed-off-by: varun-edachali-dbx * simplify corrected test Signed-off-by: varun-edachali-dbx * add line gaps after multi-line pydocs for consistency Signed-off-by: varun-edachali-dbx * use normalised CommandState type in ExecuteResponse Signed-off-by: varun-edachali-dbx --------- Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 41 +- src/databricks/sql/backend/thrift_backend.py | 117 ++++- src/databricks/sql/backend/types.py | 92 +++- .../sql/backend/utils/guid_utils.py | 1 + src/databricks/sql/client.py | 404 ++--------------- src/databricks/sql/result_set.py | 412 ++++++++++++++++++ src/databricks/sql/session.py | 1 + src/databricks/sql/types.py | 4 + src/databricks/sql/utils.py | 7 + tests/e2e/test_driver.py | 8 +- tests/unit/test_client.py | 149 ++++--- tests/unit/test_fetches.py | 9 +- tests/unit/test_parameters.py | 8 +- tests/unit/test_thrift_backend.py | 32 +- 14 files changed, 775 insertions(+), 510 deletions(-) create mode 100644 src/databricks/sql/result_set.py diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index edff10159..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -15,10 +15,16 @@ from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.backend.types import SessionId, CommandId +from databricks.sql.backend.types import SessionId, CommandId, CommandState from databricks.sql.utils import ExecuteResponse from databricks.sql.types import SSLOptions +# Forward reference for type hints +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + class DatabricksClient(ABC): # == Connection and Session Management == @@ -81,7 +87,7 @@ def execute_command( parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. @@ -101,7 +107,7 @@ def execute_command( enforce_embedded_schema_correctness: Whether to enforce schema correctness Returns: - If async_op is False, returns an ExecuteResponse object containing the + If async_op is False, returns a ResultSet object containing the query results and metadata. If async_op is True, returns None and the results must be fetched later using get_execution_result(). @@ -130,7 +136,7 @@ def cancel_command(self, command_id: CommandId) -> None: pass @abstractmethod - def close_command(self, command_id: CommandId) -> ttypes.TStatus: + def close_command(self, command_id: CommandId) -> None: """ Closes a command and releases associated resources. @@ -140,9 +146,6 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: Args: command_id: The command identifier to close - Returns: - ttypes.TStatus: The status of the close operation - Raises: ValueError: If the command ID is invalid OperationalError: If there's an error closing the command @@ -150,7 +153,7 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: pass @abstractmethod - def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: + def get_query_state(self, command_id: CommandId) -> CommandState: """ Gets the current state of a query or command. @@ -160,7 +163,7 @@ def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: command_id: The command identifier to check Returns: - ttypes.TOperationState: The current state of the command + CommandState: The current state of the command Raises: ValueError: If the command ID is invalid @@ -175,7 +178,7 @@ def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves the results of a previously executed command. @@ -187,7 +190,7 @@ def get_execution_result( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the query results and metadata + ResultSet: An object containing the query results and metadata Raises: ValueError: If the command ID is invalid @@ -203,7 +206,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of available catalogs. @@ -217,7 +220,7 @@ def get_catalogs( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the catalog metadata + ResultSet: An object containing the catalog metadata Raises: ValueError: If the session ID is invalid @@ -234,7 +237,7 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. @@ -250,7 +253,7 @@ def get_schemas( schema_name: Optional schema name pattern to filter by Returns: - ExecuteResponse: An object containing the schema metadata + ResultSet: An object containing the schema metadata Raises: ValueError: If the session ID is invalid @@ -269,7 +272,7 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. @@ -287,7 +290,7 @@ def get_tables( table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) Returns: - ExecuteResponse: An object containing the table metadata + ResultSet: An object containing the table metadata Raises: ValueError: If the session ID is invalid @@ -306,7 +309,7 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. @@ -324,7 +327,7 @@ def get_columns( column_name: Optional column name pattern to filter by Returns: - ExecuteResponse: An object containing the column metadata + ResultSet: An object containing the column metadata Raises: ValueError: If the session ID is invalid diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index c09397c2f..de388f1d4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -9,9 +9,11 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( + CommandState, SessionId, CommandId, BackendType, @@ -84,8 +86,8 @@ class ThriftDatabricksClient(DatabricksClient): - CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE - ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE + CLOSED_OP_STATE = CommandState.CLOSED + ERROR_OP_STATE = CommandState.FAILED _retry_delay_min: float _retry_delay_max: float @@ -349,6 +351,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -796,7 +799,7 @@ def _results_message_to_execute_response(self, resp, operation_state): return ExecuteResponse( arrow_queue=arrow_queue_opt, - status=operation_state, + status=CommandState.from_thrift_state(operation_state), has_been_closed_server_side=has_been_closed_server_side, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -808,7 +811,9 @@ def _results_message_to_execute_response(self, resp, operation_state): def get_execution_result( self, command_id: CommandId, cursor: "Cursor" - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -857,9 +862,9 @@ def get_execution_result( ssl_options=self._ssl_options, ) - return ExecuteResponse( + execute_response = ExecuteResponse( arrow_queue=queue, - status=resp.status, + status=CommandState.from_thrift_state(resp.status), has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -869,6 +874,15 @@ def get_execution_result( arrow_schema_bytes=schema_bytes, ) + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -887,7 +901,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, command_id: CommandId) -> "TOperationState": + def get_query_state(self, command_id: CommandId) -> CommandState: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -895,7 +909,10 @@ def get_query_state(self, command_id: CommandId) -> "TOperationState": poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return operation_state + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Unknown command state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -929,7 +946,9 @@ def execute_command( parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -976,7 +995,16 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - return self._handle_execute_response(resp, cursor) + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=use_cloud_fetch, + ) def get_catalogs( self, @@ -984,7 +1012,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -996,7 +1026,17 @@ def get_catalogs( ), ) resp = self.make_request(self._client.GetCatalogs, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_schemas( self, @@ -1006,7 +1046,9 @@ def get_schemas( cursor: "Cursor", catalog_name=None, schema_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1020,7 +1062,17 @@ def get_schemas( schemaName=schema_name, ) resp = self.make_request(self._client.GetSchemas, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_tables( self, @@ -1032,7 +1084,9 @@ def get_tables( schema_name=None, table_name=None, table_types=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1048,7 +1102,17 @@ def get_tables( tableTypes=table_types, ) resp = self.make_request(self._client.GetTables, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_columns( self, @@ -1060,7 +1124,9 @@ def get_columns( schema_name=None, table_name=None, column_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1076,7 +1142,17 @@ def get_columns( columnName=column_name, ) resp = self.make_request(self._client.GetColumns, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def _handle_execute_response(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1154,12 +1230,11 @@ def cancel_command(self, command_id: CommandId) -> None: req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - def close_command(self, command_id: CommandId): + def close_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + self.make_request(self._client.CloseOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 740be0199..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,12 +1,86 @@ from enum import Enum -from typing import Dict, Optional, Any, Union +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id +from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) +class CommandState(Enum): + """ + Enum representing the execution state of a command in Databricks SQL. + + This enum maps Thrift operation states to normalized command states, + providing a consistent interface for tracking command execution status + across different backend implementations. + + Attributes: + PENDING: Command is queued or initialized but not yet running + RUNNING: Command is currently executing + SUCCEEDED: Command completed successfully + FAILED: Command failed due to error, timeout, or unknown state + CLOSED: Command has been closed + CANCELLED: Command was cancelled before completion + """ + + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + CLOSED = "CLOSED" + CANCELLED = "CANCELLED" + + @classmethod + def from_thrift_state( + cls, state: ttypes.TOperationState + ) -> Optional["CommandState"]: + """ + Convert a Thrift TOperationState to a normalized CommandState. + + Args: + state: A TOperationState from the Thrift API representing the current + state of an operation + + Returns: + CommandState: The corresponding normalized command state + + Raises: + ValueError: If the provided state is not a recognized TOperationState + + State Mappings: + - INITIALIZED_STATE, PENDING_STATE -> PENDING + - RUNNING_STATE -> RUNNING + - FINISHED_STATE -> SUCCEEDED + - ERROR_STATE, TIMEDOUT_STATE, UKNOWN_STATE -> FAILED + - CLOSED_STATE -> CLOSED + - CANCELED_STATE -> CANCELLED + """ + + if state in ( + ttypes.TOperationState.INITIALIZED_STATE, + ttypes.TOperationState.PENDING_STATE, + ): + return cls.PENDING + elif state == ttypes.TOperationState.RUNNING_STATE: + return cls.RUNNING + elif state == ttypes.TOperationState.FINISHED_STATE: + return cls.SUCCEEDED + elif state in ( + ttypes.TOperationState.ERROR_STATE, + ttypes.TOperationState.TIMEDOUT_STATE, + ttypes.TOperationState.UKNOWN_STATE, + ): + return cls.FAILED + elif state == ttypes.TOperationState.CLOSED_STATE: + return cls.CLOSED + elif state == ttypes.TOperationState.CANCELED_STATE: + return cls.CANCELLED + else: + return None + + class BackendType(Enum): """ Enum representing the type of backend @@ -40,6 +114,7 @@ def __init__( secret: The secret part of the identifier (only used for Thrift) properties: Additional information about the session """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -55,6 +130,7 @@ def __str__(self) -> str: Returns: A string representation of the session ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -79,6 +155,7 @@ def from_thrift_handle( Returns: A SessionId instance """ + if session_handle is None: return None @@ -105,6 +182,7 @@ def from_sea_session_id( Returns: A SessionId instance """ + return cls(BackendType.SEA, session_id, properties=properties) def to_thrift_handle(self): @@ -114,6 +192,7 @@ def to_thrift_handle(self): Returns: A TSessionHandle object or None if this is not a Thrift session ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -132,6 +211,7 @@ def to_sea_session_id(self): Returns: The session ID string or None if this is not a SEA session ID """ + if self.backend_type != BackendType.SEA: return None @@ -141,6 +221,7 @@ def get_guid(self) -> Any: """ Get the ID of the session. """ + return self.guid def get_hex_guid(self) -> str: @@ -150,6 +231,7 @@ def get_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: @@ -163,6 +245,7 @@ def get_protocol_version(self): The server protocol version or None if it does not exist It is not expected to exist for SEA sessions. """ + return self.properties.get("serverProtocolVersion") @@ -194,6 +277,7 @@ def __init__( has_result_set: Whether the command has a result set modified_row_count: The number of rows modified by the command """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -211,6 +295,7 @@ def __str__(self) -> str: Returns: A string representation of the command ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -233,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -259,6 +345,7 @@ def from_sea_statement_id(cls, statement_id: str): Returns: A CommandId instance """ + return cls(BackendType.SEA, statement_id) def to_thrift_handle(self): @@ -268,6 +355,7 @@ def to_thrift_handle(self): Returns: A TOperationHandle object or None if this is not a Thrift command ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -288,6 +376,7 @@ def to_sea_statement_id(self): Returns: The statement ID string or None if this is not a SEA statement ID """ + if self.backend_type != BackendType.SEA: return None @@ -300,6 +389,7 @@ def to_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py index 28975171f..2c440afd2 100644 --- a/src/databricks/sql/backend/utils/guid_utils.py +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -14,6 +14,7 @@ def guid_to_hex_id(guid: bytes) -> str: If conversion to hexadecimal fails, a string representation of the original bytes is returned """ + try: this_uuid = uuid.UUID(bytes=guid) except Exception as e: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1c384c735..9f7c060a7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -42,14 +42,15 @@ ParameterApproach, ) - +from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session -from databricks.sql.backend.types import CommandId, BackendType +from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, TSparkParameter, TOperationState, ) @@ -320,9 +321,17 @@ def protocol_version(self): return self.session.protocol_version @staticmethod - def get_protocol_version(openSessionResp): + def get_protocol_version(openSessionResp: TOpenSessionResp): """Delegate to Session class static method""" - return Session.get_protocol_version(openSessionResp) + properties = ( + {"serverProtocolVersion": openSessionResp.serverProtocolVersion} + if openSessionResp.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle( + openSessionResp.sessionHandle, properties + ) + return Session.get_protocol_version(session_id) @property def open(self) -> bool: @@ -388,6 +397,7 @@ def __init__( Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately visible by other cursors or connections. """ + self.connection = connection self.rowcount = -1 # Return -1 as this is not supported self.buffer_size_bytes = result_buffer_size_bytes @@ -746,6 +756,7 @@ def execute( :returns self """ + logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -771,7 +782,7 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.execute_command( + self.active_result_set = self.backend.execute_command( operation=prepared_operation, session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, @@ -783,18 +794,8 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) - assert execute_response is not None # async_op = False above - - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -815,6 +816,7 @@ def execute_async( :param parameters: :return: """ + param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -851,7 +853,7 @@ def execute_async( return self - def get_query_state(self) -> "TOperationState": + def get_query_state(self) -> CommandState: """ Get the state of the async executing query or basically poll the status of the query @@ -869,11 +871,7 @@ def is_query_pending(self): :return: """ operation_state = self.get_query_state() - - return not operation_state or operation_state in [ - ttypes.TOperationState.RUNNING_STATE, - ttypes.TOperationState.PENDING_STATE, - ] + return operation_state in [CommandState.PENDING, CommandState.RUNNING] def get_async_execution_result(self): """ @@ -889,19 +887,12 @@ def get_async_execution_result(self): time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) operation_state = self.get_query_state() - if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.backend.get_execution_result( + if operation_state == CommandState.SUCCEEDED: + self.active_result_set = self.backend.get_execution_result( self.active_command_id, self ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -935,20 +926,12 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_catalogs( + self.active_result_set = self.backend.get_catalogs( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def schemas( @@ -962,7 +945,7 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_schemas( + self.active_result_set = self.backend.get_schemas( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -970,14 +953,6 @@ def schemas( catalog_name=catalog_name, schema_name=schema_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def tables( @@ -996,7 +971,7 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_tables( + self.active_result_set = self.backend.get_tables( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1006,14 +981,6 @@ def tables( table_name=table_name, table_types=table_types, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def columns( @@ -1032,7 +999,7 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_columns( + self.active_result_set = self.backend.get_columns( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1042,14 +1009,6 @@ def columns( table_name=table_name, column_name=column_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def fetchall(self) -> List[Row]: @@ -1205,312 +1164,3 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass - - -class ResultSet: - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - backend: DatabricksClient, - result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - arraysize: int = 10000, - use_cloud_fetch: bool = True, - ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param backend: The DatabricksClient instance to use for fetching results - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - """ - self.connection = connection - self.command_id = execute_response.command_id - self.op_state = execute_response.status - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.buffer_size_bytes = result_buffer_size_bytes - self.lz4_compressed = execute_response.lz4_compressed - self.arraysize = arraysize - self.backend = backend - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes - self._next_row_index = 0 - self._use_cloud_fetch = use_cloud_fetch - - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity - self._fill_results_buffer() - - def __iter__(self): - while True: - row = self.fetchone() - if row: - yield row - else: - break - - def _fill_results_buffer(self): - if not isinstance(self.backend, ThriftDatabricksClient): - # currently, we are assuming only the Thrift backend exists - raise NotImplementedError( - "Fetching further result batches is currently only implemented for the Thrift backend." - ) - - # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results - thrift_backend_instance = self.backend # type: ThriftDatabricksClient - results, has_more_rows = thrift_backend_instance.fetch_results( - command_id=self.command_id, - max_rows=self.arraysize, - max_bytes=self.buffer_size_bytes, - expected_row_start_offset=self._next_row_index, - lz4_compressed=self.lz4_compressed, - arrow_schema_bytes=self._arrow_schema_bytes, - description=self.description, - use_cloud_fetch=self._use_cloud_fetch, - ) - self.results = results - self.has_more_rows = has_more_rows - - def _convert_columnar_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - result = [] - for row_index in range(table.num_rows): - curr_row = [] - for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) - result.append(ResultRow(*curr_row)) - - return result - - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - - @property - def rownumber(self): - return self._next_row_index - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows of a query result, returning a PyArrow table. - - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - - def fetchmany_columnar(self, size: int): - """ - Fetch the next set of rows of a query result, returning a Columnar Table. - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) - self._next_row_index += partial_results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results - - def fetchall_columnar(self): - """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) - self._next_row_index += partial_results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - if isinstance(self.results, ColumnQueue): - res = self._convert_columnar_table(self.fetchmany_columnar(1)) - else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) - - if len(res) > 0: - return res[0] - else: - return None - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchall_columnar()) - else: - return self._convert_arrow_table(self.fetchall_arrow()) - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchmany_columnar(size)) - else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) - - def close(self) -> None: - """ - Close the cursor. - - If the connection has not been closed, and the cursor has not already - been closed on the server for some other reason, issue a request to the server to close it. - """ - # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to - # when we generalise the ResultSet - try: - if ( - self.op_state != ttypes.TOperationState.CLOSED_STATE - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.op_state = ttypes.TOperationState.CLOSED_STATE - - @staticmethod - def _get_schema_description(table_schema_message): - """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 - """ - - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ - - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py new file mode 100644 index 000000000..a0d8d3579 --- /dev/null +++ b/src/databricks/sql/result_set.py @@ -0,0 +1,412 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Any, Union, TYPE_CHECKING + +import logging +import time +import pandas + +from databricks.sql.backend.types import CommandId, CommandState + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.backend.databricks_client import DatabricksClient + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + from databricks.sql.client import Connection + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import Row +from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue + +logger = logging.getLogger(__name__) + + +class ResultSet(ABC): + """ + Abstract base class for result sets returned by different backend implementations. + + This class defines the interface that all concrete result set implementations must follow. + """ + + def __init__( + self, + connection: "Connection", + backend: "DatabricksClient", + command_id: CommandId, + op_state: Optional[CommandState], + has_been_closed_server_side: bool, + arraysize: int, + buffer_size_bytes: int, + ): + """ + A ResultSet manages the results of a single command. + + :param connection: The parent connection that was used to execute this command + :param backend: The specialised backend client to be invoked in the fetch phase + :param execute_response: A `ExecuteResponse` class returned by a command execution + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + """ + self.command_id = command_id + self.op_state = op_state + self.has_been_closed_server_side = has_been_closed_server_side + self.connection = connection + self.backend = backend + self.arraysize = arraysize + self.buffer_size_bytes = buffer_size_bytes + self._next_row_index = 0 + self.description = None + + def __iter__(self): + while True: + row = self.fetchone() + if row: + yield row + else: + break + + @property + def rownumber(self): + return self._next_row_index + + @property + @abstractmethod + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + pass + + # Define abstract methods that concrete implementations must implement + @abstractmethod + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + pass + + @abstractmethod + def fetchone(self) -> Optional[Row]: + """Fetch the next row of a query result set.""" + pass + + @abstractmethod + def fetchmany(self, size: int) -> List[Row]: + """Fetch the next set of rows of a query result.""" + pass + + @abstractmethod + def fetchall(self) -> List[Row]: + """Fetch all remaining rows of a query result.""" + pass + + @abstractmethod + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all remaining rows as an Arrow table.""" + pass + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if ( + self.op_state != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.op_state = CommandState.CLOSED + + +class ThriftResultSet(ResultSet): + """ResultSet implementation for the Thrift backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: ExecuteResponse, + thrift_client: "ThriftDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + use_cloud_fetch: bool = True, + ): + """ + Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. + + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + """ + super().__init__( + connection, + thrift_client, + execute_response.command_id, + execute_response.status, + execute_response.has_been_closed_server_side, + arraysize, + buffer_size_bytes, + ) + + # Initialize ThriftResultSet-specific attributes + self.has_been_closed_server_side = execute_response.has_been_closed_server_side + self.has_more_rows = execute_response.has_more_rows + self.lz4_compressed = execute_response.lz4_compressed + self.description = execute_response.description + self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._use_cloud_fetch = use_cloud_fetch + self._is_staging_operation = execute_response.is_staging_operation + + # Initialize results queue + if execute_response.arrow_queue: + # In this case the server has taken the fast path and returned an initial batch of + # results + self.results = execute_response.arrow_queue + else: + # In this case, there are results waiting on the server so we fetch now for simplicity + self._fill_results_buffer() + + def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available + results, has_more_rows = self.backend.fetch_results( + command_id=self.command_id, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + expected_row_start_offset=self._next_row_index, + lz4_compressed=self.lz4_compressed, + arrow_schema_bytes=self._arrow_schema_bytes, + description=self.description, + use_cloud_fetch=self._use_cloud_fetch, + ) + self.results = results + self.has_more_rows = has_more_rows + + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + + def merge_columnar(self, result1, result2) -> "ColumnTable": + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows of a query result, returning a PyArrow table. + + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + if isinstance(results, ColumnTable) and isinstance( + partial_results, ColumnTable + ): + results = self.merge_columnar(results, partial_results) + else: + results = pyarrow.concat_tables([results, partial_results]) + self._next_row_index += partial_results.num_rows + + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) + return results + + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if len(res) > 0: + return res[0] + else: + return None + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + @property + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + return self._is_staging_operation + + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 2ee5e53f1..6d69b5487 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -31,6 +31,7 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ + self.is_open = False self.host = server_hostname self.port = kwargs.get("_port", 443) diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index fef22cd9f..4d9f8be5f 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -158,6 +158,7 @@ def asDict(self, recursive: bool = False) -> Dict[str, Any]: >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True """ + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") @@ -186,6 +187,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": """create new Row object""" + if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values " @@ -228,6 +230,7 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -235,6 +238,7 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 733d425d6..515ec763a 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -74,6 +74,7 @@ def build_queue( Returns: ResultSetQueue """ + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes @@ -173,12 +174,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ + self.cur_row_index = start_row_index self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index @@ -216,6 +219,7 @@ def __init__( lz4_compressed (bool): Whether the files are lz4 compressed. description (List[List[Any]]): Hive table schema description. """ + self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads self.start_row_index = start_row_offset @@ -256,6 +260,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch @@ -285,6 +290,7 @@ def remaining_rows(self) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -577,6 +583,7 @@ def transform_paramstyle( Returns: str """ + output = operation if ( param_structure == ParameterStructure.POSITIONAL diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c446b6715..22897644f 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -30,6 +30,7 @@ OperationalError, RequestError, ) +from databricks.sql.backend.types import CommandState from tests.e2e.common.predicates import ( pysql_has_version, pysql_supports_arrow, @@ -826,10 +827,7 @@ def test_close_connection_closes_cursors(self): getProgressUpdate=False, ) op_status_at_server = ars.backend._client.GetOperationStatus(status_request) - assert ( - op_status_at_server.operationState - != ttypes.TOperationState.CLOSED_STATE - ) + assert op_status_at_server.operationState != CommandState.CLOSED conn.close() @@ -939,7 +937,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.op_state == CommandState.CLOSED assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f77cab782..8ec4cc499 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,7 +15,9 @@ THandleIdentifier, TOperationState, TOperationType, + TOperationState, ) +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql @@ -23,7 +25,9 @@ from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row -from databricks.sql.client import CommandId +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.utils import ExecuteResponse from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests @@ -38,12 +42,11 @@ def new(cls): ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( - MockTExecuteStatementResp, + mock_result_set, description=None, - arrow_queue=None, is_staging_operation=False, command_id=None, has_been_closed_server_side=True, @@ -52,7 +55,7 @@ def new(cls): arrow_schema_bytes=b"schema", ) - ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + ThriftBackendMock.execute_command.return_value = mock_result_set return ThriftBackendMock @@ -84,69 +87,75 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch( - "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, - ThriftDatabricksClientMockFactory.new(), - ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_closing_connection_closes_commands(self, mock_result_set_class): + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_closing_connection_closes_commands(self, mock_thrift_client_class): + """Test that connection.close() properly closes result sets through the real close chain.""" # Test once with has_been_closed_server side, once without for closed in (True, False): with self.subTest(closed=closed): - # Set initial state based on whether the command is already closed - initial_state = ( - TOperationState.FINISHED_STATE - if not closed - else TOperationState.CLOSED_STATE - ) - # Mock the execute response with controlled state mock_execute_response = Mock(spec=ExecuteResponse) - mock_execute_response.status = initial_state + + mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.status = ( + CommandState.SUCCEEDED if not closed else CommandState.CLOSED + ) mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False - # Mock the backend that will be used - mock_backend = Mock(spec=ThriftBackend) + # Mock the backend that will be used by the real ThriftResultSet + mock_backend = Mock(spec=ThriftDatabricksClient) + mock_backend.staging_allowed_local_path = None + + # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend # Create connection and cursor - connection = databricks.sql.connect( - server_hostname="foo", - http_path="dummy_path", - access_token="tok", - ) + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - # Mock execute_command to return our execute response - cursor.thrift_backend.execute_command = Mock( - return_value=mock_execute_response + # Create a REAL ThriftResultSet that will be returned by execute_command + real_result_set = ThriftResultSet( + connection=connection, + execute_response=mock_execute_response, + thrift_client=mock_backend, ) - # Execute a command + # Verify initial state + self.assertEqual(real_result_set.has_been_closed_server_side, closed) + expected_op_state = ( + CommandState.CLOSED if closed else CommandState.SUCCEEDED + ) + self.assertEqual(real_result_set.op_state, expected_op_state) + + # Mock execute_command to return our real result set + cursor.backend.execute_command = Mock(return_value=real_result_set) + + # Execute a command - this should set cursor.active_result_set to our real result set cursor.execute("SELECT 1") - # Get the active result set for later assertions - active_result_set = cursor.active_result_set + # Verify that cursor.execute() set up the result set correctly + self.assertIsInstance(cursor.active_result_set, ThriftResultSet) + self.assertEqual( + cursor.active_result_set.has_been_closed_server_side, closed + ) - # Close the connection + # Close the connection - this should trigger the real close chain: + # connection.close() -> cursor.close() -> result_set.close() connection.close() - # Verify the close logic worked: + # Verify the REAL close logic worked through the chain: # 1. has_been_closed_server_side should always be True after close() - assert active_result_set.has_been_closed_server_side is True + self.assertTrue(real_result_set.has_been_closed_server_side) # 2. op_state should always be CLOSED after close() - assert ( - active_result_set.op_state - == connection.thrift_backend.CLOSED_OP_STATE - ) + self.assertEqual(real_result_set.op_state, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: # Should have called backend.close_command during the close chain mock_backend.close_command.assert_called_once_with( - mock_execute_response.command_handle + mock_execute_response.command_id ) else: # Should NOT have called backend.close_command (already closed) @@ -177,10 +186,11 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - result_set = client.ResultSet( + + result_set = ThriftResultSet( connection=mock_connection, - backend=mock_backend, execute_response=Mock(), + thrift_client=mock_backend, ) # Setup session mock on the mock_connection mock_session = Mock() @@ -202,7 +212,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = client.ResultSet( + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -212,17 +222,16 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.command_id ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command( - self, mock_result_set_class - ): - + def test_executing_multiple_commands_uses_the_most_recent_command(self): mock_result_sets = [Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_sets + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False - cursor = client.Cursor( - connection=Mock(), backend=ThriftDatabricksClientMockFactory.new() - ) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + + cursor = client.Cursor(connection=Mock(), backend=mock_backend) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -247,7 +256,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -399,14 +408,15 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class - ): + def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_set_instances + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_set_instances: + mock_rs.is_staging_operation = False + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_set_instances cursor = client.Cursor(Mock(), mock_backend) @@ -559,8 +569,9 @@ def test_staging_operation_response_is_handled( ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) - mock_client_class.execute_command.return_value = mock_execute_response - mock_client_class.return_value = mock_client_class + mock_client = mock_client_class.return_value + mock_client.execute_command.return_value = Mock(is_staging_operation=True) + mock_client_class.return_value = mock_client connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -667,9 +678,9 @@ def mock_close_normal(): def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" - result_set = client.ResultSet.__new__(client.ResultSet) - result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" + result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) + result_set.backend = Mock() + result_set.backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True result_set.op_state = "RUNNING" @@ -680,31 +691,31 @@ class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - result_set.thrift_backend.close_command.side_effect = MockRequestError() + result_set.backend.close_command.side_effect = MockRequestError() original_close = client.ResultSet.close try: try: if ( - result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): - result_set.thrift_backend.close_command(result_set.command_id) + result_set.backend.close_command(result_set.command_id) except MockRequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state = result_set.backend.CLOSED_OP_STATE - result_set.thrift_backend.close_command.assert_called_once_with( + result_set.backend.close_command.assert_called_once_with( result_set.command_id ) assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 1c6a1b18d..030510a64 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -10,6 +10,7 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ThriftResultSet @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -38,9 +39,8 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, @@ -52,6 +52,7 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, ), + thrift_client=None, ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ @@ -84,9 +85,8 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -101,6 +101,7 @@ def fetch_results( arrow_schema_bytes=None, is_staging_operation=False, ), + thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 65e65faff..cf2e24951 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -68,13 +68,7 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - properties = ( - {"serverProtocolVersion": test_input.serverProtocolVersion} - if test_input.serverProtocolVersion - else {} - ) - session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) - assert Connection.get_protocol_version(session_id) == expected + assert Connection.get_protocol_version(test_input) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 41a2a5800..57a2a61e3 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -18,7 +18,8 @@ from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.types import CommandId, SessionId, BackendType +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -882,7 +883,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) self.assertEqual( results_message_response.status, - ttypes.TOperationState.FINISHED_STATE, + CommandState.SUCCEEDED, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -1152,7 +1153,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock) + result = thrift_backend.execute_command( + "foo", Mock(), 100, 200, Mock(), cursor_mock + ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1181,7 +1187,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1209,7 +1218,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_schemas( + result = thrift_backend.get_schemas( Mock(), 100, 200, @@ -1217,6 +1226,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1246,7 +1258,7 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_tables( + result = thrift_backend.get_tables( Mock(), 100, 200, @@ -1256,6 +1268,9 @@ def test_get_tables_calls_client_and_handle_execute_response( table_name="table_pattern", table_types=["type1", "type2"], ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1287,7 +1302,7 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_columns( + result = thrift_backend.get_columns( Mock(), 100, 200, @@ -1297,6 +1312,9 @@ def test_get_columns_calls_client_and_handle_execute_response( table_name="table_pattern", column_name="column_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) From 450b80dff677721e66051c90d6afff607dbaedf2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 03:49:38 +0000 Subject: [PATCH 04/23] remove un-necessary initialisation assertions Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 8ec4cc499..6155bc815 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -121,25 +121,12 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): thrift_client=mock_backend, ) - # Verify initial state - self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( - CommandState.CLOSED if closed else CommandState.SUCCEEDED - ) - self.assertEqual(real_result_set.op_state, expected_op_state) - # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) # Execute a command - this should set cursor.active_result_set to our real result set cursor.execute("SELECT 1") - # Verify that cursor.execute() set up the result set correctly - self.assertIsInstance(cursor.active_result_set, ThriftResultSet) - self.assertEqual( - cursor.active_result_set.has_been_closed_server_side, closed - ) - # Close the connection - this should trigger the real close chain: # connection.close() -> cursor.close() -> result_set.close() connection.close() From a926f02d2466cba6808d297994d403840271650c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 03:56:23 +0000 Subject: [PATCH 05/23] remove un-necessary line break s Signed-off-by: varun-edachali-dbx --- src/databricks/sql/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index 4d9f8be5f..e188ef577 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -187,7 +187,6 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": """create new Row object""" - if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values " @@ -230,7 +229,6 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" - if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -238,7 +236,6 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" - if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) From 55ad0012d2e82892f901aaf900186c4a30fb29a0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 03:57:48 +0000 Subject: [PATCH 06/23] more un-necessary line breaks Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 515ec763a..8b25eccc6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -181,7 +181,6 @@ def __init__( def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" - length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index From fa15730a8e972867a7dac2db51c59c51988a17f7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:13:10 +0000 Subject: [PATCH 07/23] constrain diff of test_closing_connection_closes_commands Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6155bc815..2b4e66a99 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -89,32 +89,40 @@ class ClientTestSuite(unittest.TestCase): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_thrift_client_class): - """Test that connection.close() properly closes result sets through the real close chain.""" - # Test once with has_been_closed_server side, once without + """Test that closing a connection properly closes commands. + + This test verifies that when a connection is closed: + 1. the active result set is marked as closed server-side + 2. The operation state is set to CLOSED + 3. backend.close_command is called only for commands that weren't already closed + + Args: + mock_thrift_client_class: Mock for ThriftBackend class + """ + for closed in (True, False): with self.subTest(closed=closed): + # set initial state based on whether the command is already closed + initial_state = ( + CommandState.CLOSED if closed else CommandState.SUCCEEDED + ) + # Mock the execute response with controlled state mock_execute_response = Mock(spec=ExecuteResponse) - - mock_execute_response.command_id = Mock(spec=CommandId) - mock_execute_response.status = ( - CommandState.SUCCEEDED if not closed else CommandState.CLOSED - ) + mock_execute_response.status = initial_state mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False + mock_execute_response.command_id = Mock(spec=CommandId) # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - - # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend # Create connection and cursor connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - # Create a REAL ThriftResultSet that will be returned by execute_command real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, @@ -127,8 +135,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Execute a command - this should set cursor.active_result_set to our real result set cursor.execute("SELECT 1") - # Close the connection - this should trigger the real close chain: - # connection.close() -> cursor.close() -> result_set.close() + # Close the connection connection.close() # Verify the REAL close logic worked through the chain: From 019c7fbde63276a1ca134e635de00b3a1519b84f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:15:39 +0000 Subject: [PATCH 08/23] reduce diff of test_closing_connection_closes_commands Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2b4e66a99..e0a7ba1ff 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -102,7 +102,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): for closed in (True, False): with self.subTest(closed=closed): - # set initial state based on whether the command is already closed + # Set initial state based on whether the command is already closed initial_state = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) @@ -114,7 +114,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.is_staging_operation = False mock_execute_response.command_id = Mock(spec=CommandId) - # Mock the backend that will be used by the real ThriftResultSet + # Mock the backend that will be used mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None mock_thrift_client_class.return_value = mock_backend @@ -132,7 +132,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) - # Execute a command - this should set cursor.active_result_set to our real result set + # Execute a command cursor.execute("SELECT 1") # Close the connection From 726abe777b9aa17d145bd6790b2c7d99f1af6bdb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:18:29 +0000 Subject: [PATCH 09/23] use pytest-like assertions for test_closing_connection_closes_commands Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e0a7ba1ff..66533f606 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -138,12 +138,12 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Close the connection connection.close() - # Verify the REAL close logic worked through the chain: + # Verify the close logic worked: # 1. has_been_closed_server_side should always be True after close() - self.assertTrue(real_result_set.has_been_closed_server_side) + assert real_result_set.has_been_closed_server_side is True # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + assert real_result_set.op_state == CommandState.CLOSED # 3. Backend close_command should be called appropriately if not closed: From bf6d41c15fcdd373f264604d08f95c66f4bbd316 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 05:46:34 +0000 Subject: [PATCH 10/23] ensure command_id is not None Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..4517ebcec 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -796,6 +796,8 @@ def _results_message_to_execute_response(self, resp, operation_state): arrow_queue_opt = None command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") return ExecuteResponse( arrow_queue=arrow_queue_opt, @@ -1156,6 +1158,8 @@ def get_columns( def _handle_execute_response(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) @@ -1169,6 +1173,9 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) From 5afa7337c328bc1ec486111f6b168e5c1fbf2cb4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 05:57:18 +0000 Subject: [PATCH 11/23] line breaks after multi-line pyfocs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..99faa7b75 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -51,6 +51,7 @@ def __init__( :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount :param arraysize: The max number of rows to fetch at a time (PEP-249) """ + self.command_id = command_id self.op_state = op_state self.has_been_closed_server_side = has_been_closed_server_side @@ -117,6 +118,7 @@ def close(self) -> None: If the connection has not been closed, and the result set has not already been closed on the server for some other reason, issue a request to the server to close it. """ + try: if ( self.op_state != CommandState.CLOSED @@ -155,6 +157,7 @@ def __init__( arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results """ + super().__init__( connection, thrift_client, From e3dfd36ce61632ecfc5666bd7d90b5dc46704941 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 06:05:35 +0000 Subject: [PATCH 12/23] ensure non null operationHandle for commandId creation Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..2cfad7bf4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -595,6 +595,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) thrift_backend = ThriftDatabricksClient( "foobar", @@ -753,6 +754,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp @@ -783,6 +785,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_2 = resp_type( @@ -795,6 +798,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_3 = resp_type( @@ -805,6 +809,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=ttypes.TFetchResultsResp(status=self.bad_status), closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_4 = resp_type( @@ -815,6 +820,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=ttypes.TCloseOperationResp(status=self.bad_status), ), + operationHandle=self.operation_handle, ) for error_resp in [resp_1, resp_2, resp_3, resp_4]: From 63360b305de9741d4d030fb859f4059656e0ff69 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 06:11:09 +0000 Subject: [PATCH 13/23] use command_id methods instead of explicit guid_to_hex_id conversion Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 4517ebcec..c85b7e1c0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,20 +3,17 @@ import logging import math import time -import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import Union, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet + from databricks.sql.result_set import ResultSet -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, ) from databricks.sql.backend.utils import guid_to_hex_id @@ -1233,7 +1230,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.to_hex_guid())) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) From 13ffb8d1c1ef7d5f071d5c0a48acc8d9c247facc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 06:22:41 +0000 Subject: [PATCH 14/23] remove un-necessary artifacts in test_session, add back assertion Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 858119f92..161af37c8 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -128,6 +128,15 @@ def test_context_manager_closes_connection(self, mock_client_class): self.assertEqual(close_session_call_args.guid, b"\x22") self.assertEqual(close_session_call_args.secret, b"\x33") + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close = Mock() + try: + with self.assertRaises(KeyboardInterrupt): + with connection: + raise KeyboardInterrupt("Simulated interrupt") + finally: + connection.close.assert_called() + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( @@ -146,16 +155,10 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() - - # Create a mock SessionId that will be returned by open_session - mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") - mock_client_class.return_value.open_session.return_value = mock_session_id - databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) - # Check that open_session was called with the correct session_configuration as keyword argument call_kwargs = mock_client_class.return_value.open_session.call_args[1] self.assertEqual(call_kwargs["session_configuration"], mock_session_config) @@ -163,16 +166,10 @@ def test_configuration_passthrough(self, mock_client_class): def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() - - # Create a mock SessionId that will be returned by open_session - mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") - mock_client_class.return_value.open_session.return_value = mock_session_id - databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem ) - # Check that open_session was called with the correct catalog and schema as keyword arguments call_kwargs = mock_client_class.return_value.open_session.call_args[1] self.assertEqual(call_kwargs["catalog"], mock_cat) self.assertEqual(call_kwargs["schema"], mock_schem) @@ -181,7 +178,6 @@ def test_initial_namespace_passthrough(self, mock_client_class): def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value - # Create a mock SessionId that will be returned by open_session mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") instance.open_session.return_value = mock_session_id From d75905084128e13a02853589d7119a1cb2723a62 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 11:06:45 +0000 Subject: [PATCH 15/23] add from __future__ import annotations to remove string literals around forward refs, remove some unused imports Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 37 ++++++++----------- src/databricks/sql/backend/thrift_backend.py | 1 - src/databricks/sql/result_set.py | 18 ++++----- src/databricks/sql/session.py | 2 +- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..43138f560 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -8,22 +8,17 @@ - Fetching metadata about catalogs, schemas, tables, and columns """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING +from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions - -# Forward reference for type hints -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet class DatabricksClient(ABC): @@ -82,12 +77,12 @@ def execute_command( max_rows: int, max_bytes: int, lz4_compression: bool, - cursor: "Cursor", + cursor: Cursor, use_cloud_fetch: bool, parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: + ) -> Union[ResultSet, None]: """ Executes a SQL command or query within the specified session. @@ -177,8 +172,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: def get_execution_result( self, command_id: CommandId, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> ResultSet: """ Retrieves the results of a previously executed command. @@ -205,8 +200,8 @@ def get_catalogs( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> ResultSet: """ Retrieves a list of available catalogs. @@ -234,10 +229,10 @@ def get_schemas( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": + ) -> ResultSet: """ Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. @@ -267,12 +262,12 @@ def get_tables( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": + ) -> ResultSet: """ Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. @@ -304,12 +299,12 @@ def get_columns( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": + ) -> ResultSet: """ Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index c85b7e1c0..f930897ae 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1,4 +1,3 @@ -from decimal import Decimal import errno import logging import math diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 99faa7b75..2ffc3f257 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING import logging -import time import pandas +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import CommandId, CommandState try: @@ -13,13 +15,11 @@ pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row -from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue logger = logging.getLogger(__name__) @@ -34,8 +34,8 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", + connection: Connection, + backend: DatabricksClient, command_id: CommandId, op_state: Optional[CommandState], has_been_closed_server_side: bool, @@ -139,9 +139,9 @@ class ThriftResultSet(ResultSet): def __init__( self, - connection: "Connection", + connection: Connection, execute_response: ExecuteResponse, - thrift_client: "ThriftDatabricksClient", + thrift_client: ThriftDatabricksClient, buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 6d69b5487..9ddcdf172 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -9,7 +9,7 @@ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.backend.types import SessionId logger = logging.getLogger(__name__) From 1e2143490a2f580069625ca6f60b171a756984f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 11:21:07 +0000 Subject: [PATCH 16/23] move docstring of DatabricksClient within class Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 43138f560..0337d8d06 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -1,13 +1,3 @@ -""" -Abstract client interface for interacting with Databricks SQL services. - -Implementations of this class are responsible for: -- Managing connections to Databricks SQL services -- Executing SQL queries and commands -- Retrieving query results -- Fetching metadata about catalogs, schemas, tables, and columns -""" - from __future__ import annotations from abc import ABC, abstractmethod @@ -22,6 +12,16 @@ class DatabricksClient(ABC): + """ + Abstract client interface for interacting with Databricks SQL services. + + Implementations of this class are responsible for: + - Managing connections to Databricks SQL services + - Executing SQL queries and commands + - Retrieving query results + - Fetching metadata about catalogs, schemas, tables, and columns + """ + # == Connection and Session Management == @abstractmethod def open_session( From cd4015b1a6049ad96467db3aa91df3a468fc13f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 11:23:14 +0000 Subject: [PATCH 17/23] move ThriftResultSet import to top of file Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f930897ae..b752d3678 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,6 +5,8 @@ import threading from typing import Union, TYPE_CHECKING +from databricks.sql.result_set import ThriftResultSet + if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet @@ -810,8 +812,6 @@ def _results_message_to_execute_response(self, resp, operation_state): def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -945,8 +945,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") From ed8b610ebfb28c638602e753976fcc17aacf7c36 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 11:25:32 +0000 Subject: [PATCH 18/23] make backend/utils __init__ file empty Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/backend/types.py | 2 +- src/databricks/sql/backend/utils/__init__.py | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b752d3678..4a4a02738 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -16,7 +16,7 @@ SessionId, CommandId, ) -from databricks.sql.backend.utils import guid_to_hex_id +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id try: import pyarrow diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..834944b31 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, Any import logging -from databricks.sql.backend.utils import guid_to_hex_id +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/backend/utils/__init__.py b/src/databricks/sql/backend/utils/__init__.py index 3d601e5e6..e69de29bb 100644 --- a/src/databricks/sql/backend/utils/__init__.py +++ b/src/databricks/sql/backend/utils/__init__.py @@ -1,3 +0,0 @@ -from .guid_utils import guid_to_hex_id - -__all__ = ["guid_to_hex_id"] From 94d951ea6dfd2fff6b45cc1019cf8ddde8b1c73d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 11:32:00 +0000 Subject: [PATCH 19/23] use from __future__ import annotations to remove string literals around Cursor Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 4a4a02738..08f76dd05 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import errno import logging import math @@ -810,7 +812,7 @@ def _results_message_to_execute_response(self, resp, operation_state): ) def get_execution_result( - self, command_id: CommandId, cursor: "Cursor" + self, command_id: CommandId, cursor: Cursor ) -> "ResultSet": thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -939,7 +941,7 @@ def execute_command( max_rows: int, max_bytes: int, lz4_compression: bool, - cursor: "Cursor", + cursor: Cursor, use_cloud_fetch=True, parameters=[], async_op=False, @@ -1007,7 +1009,7 @@ def get_catalogs( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, ) -> "ResultSet": from databricks.sql.result_set import ThriftResultSet @@ -1039,7 +1041,7 @@ def get_schemas( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name=None, schema_name=None, ) -> "ResultSet": @@ -1075,7 +1077,7 @@ def get_tables( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name=None, schema_name=None, table_name=None, @@ -1115,7 +1117,7 @@ def get_columns( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name=None, schema_name=None, table_name=None, From c20058e3fee7d3d4ce7bfc676591a137309dadd7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 12:46:55 +0000 Subject: [PATCH 20/23] use lazy logging Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/backend/utils/guid_utils.py | 2 +- src/databricks/sql/session.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 08f76dd05..514d937d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1229,7 +1229,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.to_hex_guid())) + logger.debug("Cancelling command %s", command_id.to_hex_guid()) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py index 2c440afd2..a6cb0e0db 100644 --- a/src/databricks/sql/backend/utils/guid_utils.py +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -18,6 +18,6 @@ def guid_to_hex_id(guid: bytes) -> str: try: this_uuid = uuid.UUID(bytes=guid) except Exception as e: - logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}") + logger.debug("Unable to convert bytes to UUID: %r -- %s", guid, str(e)) return str(guid) return str(this_uuid) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 9ddcdf172..93108b02a 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -95,7 +95,7 @@ def open(self): ) self.protocol_version = self.get_protocol_version(self._session_id) self.is_open = True - logger.info("Successfully opened session " + str(self.get_id_hex())) + logger.info("Successfully opened session %s", str(self.get_id_hex())) @staticmethod def get_protocol_version(session_id: SessionId): @@ -125,7 +125,7 @@ def get_id_hex(self) -> str: def close(self) -> None: """Close the underlying session.""" - logger.info(f"Closing session {self.get_id_hex()}") + logger.info("Closing session %s", self.get_id_hex()) if not self.is_open: logger.debug("Session appears to have been closed already") return @@ -138,13 +138,13 @@ def close(self) -> None: except DatabaseError as e: if "Invalid SessionHandle" in str(e): logger.warning( - f"Attempted to close session that was already closed: {e}" + "Attempted to close session that was already closed: %s", e ) else: logger.warning( - f"Attempt to close session raised an exception at the server: {e}" + "Attempt to close session raised an exception at the server: %s", e ) except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") + logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False From fe3acb168b5b2b91e80fbace068d39c38cfbb26f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 15:21:36 +0000 Subject: [PATCH 21/23] replace getters with property tag Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 15 +++++---------- src/databricks/sql/client.py | 16 ++++++++-------- src/databricks/sql/session.py | 19 +++++++++++-------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 834944b31..ddeac474a 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -139,7 +139,7 @@ def __str__(self) -> str: if isinstance(self.secret, bytes) else str(self.secret) ) - return f"{self.get_hex_guid()}|{secret_hex}" + return f"{self.hex_guid}|{secret_hex}" return str(self.guid) @classmethod @@ -217,14 +217,8 @@ def to_sea_session_id(self): return self.guid - def get_guid(self) -> Any: - """ - Get the ID of the session. - """ - - return self.guid - - def get_hex_guid(self) -> str: + @property + def hex_guid(self) -> str: """ Get a hexadecimal string representation of the session ID. @@ -237,7 +231,8 @@ def get_hex_guid(self) -> str: else: return str(self.guid) - def get_protocol_version(self): + @property + def protocol_version(self): """ Get the server protocol version for this session. diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..93937ce43 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -304,11 +304,11 @@ def __del__(self): def get_session_id(self): """Get the raw session ID (backend-specific)""" - return self.session.get_id() + return self.session.guid def get_session_id_hex(self): """Get the session ID in hex format""" - return self.session.get_id_hex() + return self.session.guid_hex @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -784,7 +784,7 @@ def execute( self._close_and_clear_active_result_set() self.active_result_set = self.backend.execute_command( operation=prepared_operation, - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -840,7 +840,7 @@ def execute_async( self._close_and_clear_active_result_set() self.backend.execute_command( operation=prepared_operation, - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -927,7 +927,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_catalogs( - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -946,7 +946,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_schemas( - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -972,7 +972,7 @@ def tables( self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_tables( - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1000,7 +1000,7 @@ def columns( self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_columns( - session_id=self.connection.session.get_session_id(), + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 93108b02a..3bf0532dc 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -95,11 +95,11 @@ def open(self): ) self.protocol_version = self.get_protocol_version(self._session_id) self.is_open = True - logger.info("Successfully opened session %s", str(self.get_id_hex())) + logger.info("Successfully opened session %s", str(self.guid_hex)) @staticmethod def get_protocol_version(session_id: SessionId): - return session_id.get_protocol_version() + return session_id.protocol_version @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -111,21 +111,24 @@ def server_parameterized_queries_enabled(protocolVersion): else: return False - def get_session_id(self) -> SessionId: + @property + def session_id(self) -> SessionId: """Get the normalized session ID""" return self._session_id - def get_id(self): + @property + def guid(self) -> Any: """Get the raw session ID (backend-specific)""" - return self._session_id.get_guid() + return self._session_id.guid - def get_id_hex(self) -> str: + @property + def guid_hex(self) -> str: """Get the session ID in hex format""" - return self._session_id.get_hex_guid() + return self._session_id.hex_guid def close(self) -> None: """Close the underlying session.""" - logger.info("Closing session %s", self.get_id_hex()) + logger.info("Closing session %s", self.guid_hex) if not self.is_open: logger.debug("Session appears to have been closed already") return From 61dfc4dc99788a9b474b9a46effb729da858d15e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 15:33:12 +0000 Subject: [PATCH 22/23] set active_command_id to None, not active_op_handle Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 2 +- tests/unit/test_client.py | 15 --------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index fd9c82d1e..7886c2f6f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1079,7 +1079,7 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - self.active_op_handle = None + self.active_command_id = None if self.active_result_set: self._close_and_clear_active_result_set() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4491674df..a5db003e7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -260,21 +260,6 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary From 64fb9b277aa70db90d90d08c19591c98c8cc111f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 15:39:55 +0000 Subject: [PATCH 23/23] align test_session with pytest instead of unittest Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 52 +++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 161af37c8..a5c751782 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1,4 +1,4 @@ -import unittest +import pytest from unittest.mock import patch, MagicMock, Mock, PropertyMock import gc @@ -12,7 +12,7 @@ import databricks.sql -class SessionTestSuite(unittest.TestCase): +class TestSession: """ Unit tests for Session functionality """ @@ -37,8 +37,8 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): # Check that close_session was called with the correct SessionId close_session_call_args = instance.close_session.call_args[0][0] - self.assertEqual(close_session_call_args.guid, b"\x22") - self.assertEqual(close_session_call_args.secret, b"\x33") + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33" @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): @@ -63,8 +63,8 @@ def test_auth_args(self, mock_client_class): for args in connection_args: connection = databricks.sql.connect(**args) host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) + assert args["server_hostname"] == host + assert args["http_path"] == http_path connection.close() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -73,7 +73,7 @@ def test_http_header_passthrough(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) + assert ("foo", "bar") in call_args @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): @@ -86,10 +86,10 @@ def test_tls_arg_passthrough(self, mock_client_class): ) kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + assert kwargs["_tls_verify_hostname"] == "hostname" + assert kwargs["_tls_trusted_ca_file"] == "trusted ca file" + assert kwargs["_tls_client_cert_key_file"] == "trusted client cert" + assert kwargs["_tls_client_cert_key_password"] == "key password" @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): @@ -100,7 +100,7 @@ def test_useragent_header(self, mock_client_class): "User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), ) - self.assertIn(user_agent_header, http_headers) + assert user_agent_header in http_headers databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") user_agent_header_with_entry = ( @@ -110,7 +110,7 @@ def test_useragent_header(self, mock_client_class): ), ) http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) + assert user_agent_header_with_entry in http_headers @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): @@ -125,13 +125,13 @@ def test_context_manager_closes_connection(self, mock_client_class): # Check that close_session was called with the correct SessionId close_session_call_args = instance.close_session.call_args[0][0] - self.assertEqual(close_session_call_args.guid, b"\x22") - self.assertEqual(close_session_call_args.secret, b"\x33") + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33" connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.close = Mock() try: - with self.assertRaises(KeyboardInterrupt): + with pytest.raises(KeyboardInterrupt): with connection: raise KeyboardInterrupt("Simulated interrupt") finally: @@ -143,14 +143,12 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS ) - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) + assert mock_client_class.call_args[1]["_retry_stop_after_attempts_count"] == 54 @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + assert mock_client_class.call_args[1]["_socket_timeout"] == 234 @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): @@ -160,7 +158,7 @@ def test_configuration_passthrough(self, mock_client_class): ) call_kwargs = mock_client_class.return_value.open_session.call_args[1] - self.assertEqual(call_kwargs["session_configuration"], mock_session_config) + assert call_kwargs["session_configuration"] == mock_session_config @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): @@ -171,8 +169,8 @@ def test_initial_namespace_passthrough(self, mock_client_class): ) call_kwargs = mock_client_class.return_value.open_session.call_args[1] - self.assertEqual(call_kwargs["catalog"], mock_cat) - self.assertEqual(call_kwargs["schema"], mock_schem) + assert call_kwargs["catalog"] == mock_cat + assert call_kwargs["schema"] == mock_schem @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): @@ -188,9 +186,5 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): # Check that close_session was called with the correct SessionId close_session_call_args = instance.close_session.call_args[0][0] - self.assertEqual(close_session_call_args.guid, b"\x22") - self.assertEqual(close_session_call_args.secret, b"\x33") - - -if __name__ == "__main__": - unittest.main() + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33"