diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index edff10159..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -15,10 +15,16 @@ from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.backend.types import SessionId, CommandId +from databricks.sql.backend.types import SessionId, CommandId, CommandState from databricks.sql.utils import ExecuteResponse from databricks.sql.types import SSLOptions +# Forward reference for type hints +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + class DatabricksClient(ABC): # == Connection and Session Management == @@ -81,7 +87,7 @@ def execute_command( parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. @@ -101,7 +107,7 @@ def execute_command( enforce_embedded_schema_correctness: Whether to enforce schema correctness Returns: - If async_op is False, returns an ExecuteResponse object containing the + If async_op is False, returns a ResultSet object containing the query results and metadata. If async_op is True, returns None and the results must be fetched later using get_execution_result(). @@ -130,7 +136,7 @@ def cancel_command(self, command_id: CommandId) -> None: pass @abstractmethod - def close_command(self, command_id: CommandId) -> ttypes.TStatus: + def close_command(self, command_id: CommandId) -> None: """ Closes a command and releases associated resources. @@ -140,9 +146,6 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: Args: command_id: The command identifier to close - Returns: - ttypes.TStatus: The status of the close operation - Raises: ValueError: If the command ID is invalid OperationalError: If there's an error closing the command @@ -150,7 +153,7 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: pass @abstractmethod - def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: + def get_query_state(self, command_id: CommandId) -> CommandState: """ Gets the current state of a query or command. @@ -160,7 +163,7 @@ def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: command_id: The command identifier to check Returns: - ttypes.TOperationState: The current state of the command + CommandState: The current state of the command Raises: ValueError: If the command ID is invalid @@ -175,7 +178,7 @@ def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves the results of a previously executed command. @@ -187,7 +190,7 @@ def get_execution_result( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the query results and metadata + ResultSet: An object containing the query results and metadata Raises: ValueError: If the command ID is invalid @@ -203,7 +206,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of available catalogs. @@ -217,7 +220,7 @@ def get_catalogs( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the catalog metadata + ResultSet: An object containing the catalog metadata Raises: ValueError: If the session ID is invalid @@ -234,7 +237,7 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. @@ -250,7 +253,7 @@ def get_schemas( schema_name: Optional schema name pattern to filter by Returns: - ExecuteResponse: An object containing the schema metadata + ResultSet: An object containing the schema metadata Raises: ValueError: If the session ID is invalid @@ -269,7 +272,7 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. @@ -287,7 +290,7 @@ def get_tables( table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) Returns: - ExecuteResponse: An object containing the table metadata + ResultSet: An object containing the table metadata Raises: ValueError: If the session ID is invalid @@ -306,7 +309,7 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. @@ -324,7 +327,7 @@ def get_columns( column_name: Optional column name pattern to filter by Returns: - ExecuteResponse: An object containing the column metadata + ResultSet: An object containing the column metadata Raises: ValueError: If the session ID is invalid diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index c09397c2f..de388f1d4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -9,9 +9,11 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( + CommandState, SessionId, CommandId, BackendType, @@ -84,8 +86,8 @@ class ThriftDatabricksClient(DatabricksClient): - CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE - ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE + CLOSED_OP_STATE = CommandState.CLOSED + ERROR_OP_STATE = CommandState.FAILED _retry_delay_min: float _retry_delay_max: float @@ -349,6 +351,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -796,7 +799,7 @@ def _results_message_to_execute_response(self, resp, operation_state): return ExecuteResponse( arrow_queue=arrow_queue_opt, - status=operation_state, + status=CommandState.from_thrift_state(operation_state), has_been_closed_server_side=has_been_closed_server_side, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -808,7 +811,9 @@ def _results_message_to_execute_response(self, resp, operation_state): def get_execution_result( self, command_id: CommandId, cursor: "Cursor" - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -857,9 +862,9 @@ def get_execution_result( ssl_options=self._ssl_options, ) - return ExecuteResponse( + execute_response = ExecuteResponse( arrow_queue=queue, - status=resp.status, + status=CommandState.from_thrift_state(resp.status), has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -869,6 +874,15 @@ def get_execution_result( arrow_schema_bytes=schema_bytes, ) + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -887,7 +901,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, command_id: CommandId) -> "TOperationState": + def get_query_state(self, command_id: CommandId) -> CommandState: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -895,7 +909,10 @@ def get_query_state(self, command_id: CommandId) -> "TOperationState": poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return operation_state + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Unknown command state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -929,7 +946,9 @@ def execute_command( parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -976,7 +995,16 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - return self._handle_execute_response(resp, cursor) + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=use_cloud_fetch, + ) def get_catalogs( self, @@ -984,7 +1012,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -996,7 +1026,17 @@ def get_catalogs( ), ) resp = self.make_request(self._client.GetCatalogs, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_schemas( self, @@ -1006,7 +1046,9 @@ def get_schemas( cursor: "Cursor", catalog_name=None, schema_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1020,7 +1062,17 @@ def get_schemas( schemaName=schema_name, ) resp = self.make_request(self._client.GetSchemas, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_tables( self, @@ -1032,7 +1084,9 @@ def get_tables( schema_name=None, table_name=None, table_types=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1048,7 +1102,17 @@ def get_tables( tableTypes=table_types, ) resp = self.make_request(self._client.GetTables, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_columns( self, @@ -1060,7 +1124,9 @@ def get_columns( schema_name=None, table_name=None, column_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1076,7 +1142,17 @@ def get_columns( columnName=column_name, ) resp = self.make_request(self._client.GetColumns, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def _handle_execute_response(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1154,12 +1230,11 @@ def cancel_command(self, command_id: CommandId) -> None: req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - def close_command(self, command_id: CommandId): + def close_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + self.make_request(self._client.CloseOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 740be0199..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,12 +1,86 @@ from enum import Enum -from typing import Dict, Optional, Any, Union +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id +from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) +class CommandState(Enum): + """ + Enum representing the execution state of a command in Databricks SQL. + + This enum maps Thrift operation states to normalized command states, + providing a consistent interface for tracking command execution status + across different backend implementations. + + Attributes: + PENDING: Command is queued or initialized but not yet running + RUNNING: Command is currently executing + SUCCEEDED: Command completed successfully + FAILED: Command failed due to error, timeout, or unknown state + CLOSED: Command has been closed + CANCELLED: Command was cancelled before completion + """ + + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + CLOSED = "CLOSED" + CANCELLED = "CANCELLED" + + @classmethod + def from_thrift_state( + cls, state: ttypes.TOperationState + ) -> Optional["CommandState"]: + """ + Convert a Thrift TOperationState to a normalized CommandState. + + Args: + state: A TOperationState from the Thrift API representing the current + state of an operation + + Returns: + CommandState: The corresponding normalized command state + + Raises: + ValueError: If the provided state is not a recognized TOperationState + + State Mappings: + - INITIALIZED_STATE, PENDING_STATE -> PENDING + - RUNNING_STATE -> RUNNING + - FINISHED_STATE -> SUCCEEDED + - ERROR_STATE, TIMEDOUT_STATE, UKNOWN_STATE -> FAILED + - CLOSED_STATE -> CLOSED + - CANCELED_STATE -> CANCELLED + """ + + if state in ( + ttypes.TOperationState.INITIALIZED_STATE, + ttypes.TOperationState.PENDING_STATE, + ): + return cls.PENDING + elif state == ttypes.TOperationState.RUNNING_STATE: + return cls.RUNNING + elif state == ttypes.TOperationState.FINISHED_STATE: + return cls.SUCCEEDED + elif state in ( + ttypes.TOperationState.ERROR_STATE, + ttypes.TOperationState.TIMEDOUT_STATE, + ttypes.TOperationState.UKNOWN_STATE, + ): + return cls.FAILED + elif state == ttypes.TOperationState.CLOSED_STATE: + return cls.CLOSED + elif state == ttypes.TOperationState.CANCELED_STATE: + return cls.CANCELLED + else: + return None + + class BackendType(Enum): """ Enum representing the type of backend @@ -40,6 +114,7 @@ def __init__( secret: The secret part of the identifier (only used for Thrift) properties: Additional information about the session """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -55,6 +130,7 @@ def __str__(self) -> str: Returns: A string representation of the session ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -79,6 +155,7 @@ def from_thrift_handle( Returns: A SessionId instance """ + if session_handle is None: return None @@ -105,6 +182,7 @@ def from_sea_session_id( Returns: A SessionId instance """ + return cls(BackendType.SEA, session_id, properties=properties) def to_thrift_handle(self): @@ -114,6 +192,7 @@ def to_thrift_handle(self): Returns: A TSessionHandle object or None if this is not a Thrift session ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -132,6 +211,7 @@ def to_sea_session_id(self): Returns: The session ID string or None if this is not a SEA session ID """ + if self.backend_type != BackendType.SEA: return None @@ -141,6 +221,7 @@ def get_guid(self) -> Any: """ Get the ID of the session. """ + return self.guid def get_hex_guid(self) -> str: @@ -150,6 +231,7 @@ def get_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: @@ -163,6 +245,7 @@ def get_protocol_version(self): The server protocol version or None if it does not exist It is not expected to exist for SEA sessions. """ + return self.properties.get("serverProtocolVersion") @@ -194,6 +277,7 @@ def __init__( has_result_set: Whether the command has a result set modified_row_count: The number of rows modified by the command """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -211,6 +295,7 @@ def __str__(self) -> str: Returns: A string representation of the command ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -233,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -259,6 +345,7 @@ def from_sea_statement_id(cls, statement_id: str): Returns: A CommandId instance """ + return cls(BackendType.SEA, statement_id) def to_thrift_handle(self): @@ -268,6 +355,7 @@ def to_thrift_handle(self): Returns: A TOperationHandle object or None if this is not a Thrift command ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -288,6 +376,7 @@ def to_sea_statement_id(self): Returns: The statement ID string or None if this is not a SEA statement ID """ + if self.backend_type != BackendType.SEA: return None @@ -300,6 +389,7 @@ def to_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py index 28975171f..2c440afd2 100644 --- a/src/databricks/sql/backend/utils/guid_utils.py +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -14,6 +14,7 @@ def guid_to_hex_id(guid: bytes) -> str: If conversion to hexadecimal fails, a string representation of the original bytes is returned """ + try: this_uuid = uuid.UUID(bytes=guid) except Exception as e: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1c384c735..9f7c060a7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -42,14 +42,15 @@ ParameterApproach, ) - +from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session -from databricks.sql.backend.types import CommandId, BackendType +from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, TSparkParameter, TOperationState, ) @@ -320,9 +321,17 @@ def protocol_version(self): return self.session.protocol_version @staticmethod - def get_protocol_version(openSessionResp): + def get_protocol_version(openSessionResp: TOpenSessionResp): """Delegate to Session class static method""" - return Session.get_protocol_version(openSessionResp) + properties = ( + {"serverProtocolVersion": openSessionResp.serverProtocolVersion} + if openSessionResp.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle( + openSessionResp.sessionHandle, properties + ) + return Session.get_protocol_version(session_id) @property def open(self) -> bool: @@ -388,6 +397,7 @@ def __init__( Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately visible by other cursors or connections. """ + self.connection = connection self.rowcount = -1 # Return -1 as this is not supported self.buffer_size_bytes = result_buffer_size_bytes @@ -746,6 +756,7 @@ def execute( :returns self """ + logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -771,7 +782,7 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.execute_command( + self.active_result_set = self.backend.execute_command( operation=prepared_operation, session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, @@ -783,18 +794,8 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) - assert execute_response is not None # async_op = False above - - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -815,6 +816,7 @@ def execute_async( :param parameters: :return: """ + param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -851,7 +853,7 @@ def execute_async( return self - def get_query_state(self) -> "TOperationState": + def get_query_state(self) -> CommandState: """ Get the state of the async executing query or basically poll the status of the query @@ -869,11 +871,7 @@ def is_query_pending(self): :return: """ operation_state = self.get_query_state() - - return not operation_state or operation_state in [ - ttypes.TOperationState.RUNNING_STATE, - ttypes.TOperationState.PENDING_STATE, - ] + return operation_state in [CommandState.PENDING, CommandState.RUNNING] def get_async_execution_result(self): """ @@ -889,19 +887,12 @@ def get_async_execution_result(self): time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) operation_state = self.get_query_state() - if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.backend.get_execution_result( + if operation_state == CommandState.SUCCEEDED: + self.active_result_set = self.backend.get_execution_result( self.active_command_id, self ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -935,20 +926,12 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_catalogs( + self.active_result_set = self.backend.get_catalogs( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def schemas( @@ -962,7 +945,7 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_schemas( + self.active_result_set = self.backend.get_schemas( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -970,14 +953,6 @@ def schemas( catalog_name=catalog_name, schema_name=schema_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def tables( @@ -996,7 +971,7 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_tables( + self.active_result_set = self.backend.get_tables( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1006,14 +981,6 @@ def tables( table_name=table_name, table_types=table_types, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def columns( @@ -1032,7 +999,7 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_columns( + self.active_result_set = self.backend.get_columns( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1042,14 +1009,6 @@ def columns( table_name=table_name, column_name=column_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def fetchall(self) -> List[Row]: @@ -1205,312 +1164,3 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass - - -class ResultSet: - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - backend: DatabricksClient, - result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - arraysize: int = 10000, - use_cloud_fetch: bool = True, - ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param backend: The DatabricksClient instance to use for fetching results - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - """ - self.connection = connection - self.command_id = execute_response.command_id - self.op_state = execute_response.status - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.buffer_size_bytes = result_buffer_size_bytes - self.lz4_compressed = execute_response.lz4_compressed - self.arraysize = arraysize - self.backend = backend - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes - self._next_row_index = 0 - self._use_cloud_fetch = use_cloud_fetch - - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity - self._fill_results_buffer() - - def __iter__(self): - while True: - row = self.fetchone() - if row: - yield row - else: - break - - def _fill_results_buffer(self): - if not isinstance(self.backend, ThriftDatabricksClient): - # currently, we are assuming only the Thrift backend exists - raise NotImplementedError( - "Fetching further result batches is currently only implemented for the Thrift backend." - ) - - # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results - thrift_backend_instance = self.backend # type: ThriftDatabricksClient - results, has_more_rows = thrift_backend_instance.fetch_results( - command_id=self.command_id, - max_rows=self.arraysize, - max_bytes=self.buffer_size_bytes, - expected_row_start_offset=self._next_row_index, - lz4_compressed=self.lz4_compressed, - arrow_schema_bytes=self._arrow_schema_bytes, - description=self.description, - use_cloud_fetch=self._use_cloud_fetch, - ) - self.results = results - self.has_more_rows = has_more_rows - - def _convert_columnar_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - result = [] - for row_index in range(table.num_rows): - curr_row = [] - for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) - result.append(ResultRow(*curr_row)) - - return result - - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - - @property - def rownumber(self): - return self._next_row_index - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows of a query result, returning a PyArrow table. - - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - - def fetchmany_columnar(self, size: int): - """ - Fetch the next set of rows of a query result, returning a Columnar Table. - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) - self._next_row_index += partial_results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results - - def fetchall_columnar(self): - """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) - self._next_row_index += partial_results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - if isinstance(self.results, ColumnQueue): - res = self._convert_columnar_table(self.fetchmany_columnar(1)) - else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) - - if len(res) > 0: - return res[0] - else: - return None - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchall_columnar()) - else: - return self._convert_arrow_table(self.fetchall_arrow()) - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchmany_columnar(size)) - else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) - - def close(self) -> None: - """ - Close the cursor. - - If the connection has not been closed, and the cursor has not already - been closed on the server for some other reason, issue a request to the server to close it. - """ - # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to - # when we generalise the ResultSet - try: - if ( - self.op_state != ttypes.TOperationState.CLOSED_STATE - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.op_state = ttypes.TOperationState.CLOSED_STATE - - @staticmethod - def _get_schema_description(table_schema_message): - """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 - """ - - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ - - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py new file mode 100644 index 000000000..a0d8d3579 --- /dev/null +++ b/src/databricks/sql/result_set.py @@ -0,0 +1,412 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Any, Union, TYPE_CHECKING + +import logging +import time +import pandas + +from databricks.sql.backend.types import CommandId, CommandState + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.backend.databricks_client import DatabricksClient + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + from databricks.sql.client import Connection + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import Row +from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue + +logger = logging.getLogger(__name__) + + +class ResultSet(ABC): + """ + Abstract base class for result sets returned by different backend implementations. + + This class defines the interface that all concrete result set implementations must follow. + """ + + def __init__( + self, + connection: "Connection", + backend: "DatabricksClient", + command_id: CommandId, + op_state: Optional[CommandState], + has_been_closed_server_side: bool, + arraysize: int, + buffer_size_bytes: int, + ): + """ + A ResultSet manages the results of a single command. + + :param connection: The parent connection that was used to execute this command + :param backend: The specialised backend client to be invoked in the fetch phase + :param execute_response: A `ExecuteResponse` class returned by a command execution + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + """ + self.command_id = command_id + self.op_state = op_state + self.has_been_closed_server_side = has_been_closed_server_side + self.connection = connection + self.backend = backend + self.arraysize = arraysize + self.buffer_size_bytes = buffer_size_bytes + self._next_row_index = 0 + self.description = None + + def __iter__(self): + while True: + row = self.fetchone() + if row: + yield row + else: + break + + @property + def rownumber(self): + return self._next_row_index + + @property + @abstractmethod + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + pass + + # Define abstract methods that concrete implementations must implement + @abstractmethod + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + pass + + @abstractmethod + def fetchone(self) -> Optional[Row]: + """Fetch the next row of a query result set.""" + pass + + @abstractmethod + def fetchmany(self, size: int) -> List[Row]: + """Fetch the next set of rows of a query result.""" + pass + + @abstractmethod + def fetchall(self) -> List[Row]: + """Fetch all remaining rows of a query result.""" + pass + + @abstractmethod + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all remaining rows as an Arrow table.""" + pass + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if ( + self.op_state != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.op_state = CommandState.CLOSED + + +class ThriftResultSet(ResultSet): + """ResultSet implementation for the Thrift backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: ExecuteResponse, + thrift_client: "ThriftDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + use_cloud_fetch: bool = True, + ): + """ + Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. + + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + """ + super().__init__( + connection, + thrift_client, + execute_response.command_id, + execute_response.status, + execute_response.has_been_closed_server_side, + arraysize, + buffer_size_bytes, + ) + + # Initialize ThriftResultSet-specific attributes + self.has_been_closed_server_side = execute_response.has_been_closed_server_side + self.has_more_rows = execute_response.has_more_rows + self.lz4_compressed = execute_response.lz4_compressed + self.description = execute_response.description + self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._use_cloud_fetch = use_cloud_fetch + self._is_staging_operation = execute_response.is_staging_operation + + # Initialize results queue + if execute_response.arrow_queue: + # In this case the server has taken the fast path and returned an initial batch of + # results + self.results = execute_response.arrow_queue + else: + # In this case, there are results waiting on the server so we fetch now for simplicity + self._fill_results_buffer() + + def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available + results, has_more_rows = self.backend.fetch_results( + command_id=self.command_id, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + expected_row_start_offset=self._next_row_index, + lz4_compressed=self.lz4_compressed, + arrow_schema_bytes=self._arrow_schema_bytes, + description=self.description, + use_cloud_fetch=self._use_cloud_fetch, + ) + self.results = results + self.has_more_rows = has_more_rows + + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + + def merge_columnar(self, result1, result2) -> "ColumnTable": + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows of a query result, returning a PyArrow table. + + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + if isinstance(results, ColumnTable) and isinstance( + partial_results, ColumnTable + ): + results = self.merge_columnar(results, partial_results) + else: + results = pyarrow.concat_tables([results, partial_results]) + self._next_row_index += partial_results.num_rows + + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) + return results + + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if len(res) > 0: + return res[0] + else: + return None + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + @property + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + return self._is_staging_operation + + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 2ee5e53f1..6d69b5487 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -31,6 +31,7 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ + self.is_open = False self.host = server_hostname self.port = kwargs.get("_port", 443) diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index fef22cd9f..4d9f8be5f 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -158,6 +158,7 @@ def asDict(self, recursive: bool = False) -> Dict[str, Any]: >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True """ + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") @@ -186,6 +187,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": """create new Row object""" + if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values " @@ -228,6 +230,7 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -235,6 +238,7 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c541ad3fd..2622b1172 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -74,6 +74,7 @@ def build_queue( Returns: ResultSetQueue """ + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes @@ -173,12 +174,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ + self.cur_row_index = start_row_index self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index @@ -216,6 +219,7 @@ def __init__( lz4_compressed (bool): Whether the files are lz4 compressed. description (List[List[Any]]): Hive table schema description. """ + self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads self.start_row_index = start_row_offset @@ -256,6 +260,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch @@ -285,6 +290,7 @@ def remaining_rows(self) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -566,6 +572,7 @@ def transform_paramstyle( Returns: str """ + output = operation if ( param_structure == ParameterStructure.POSITIONAL diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c446b6715..22897644f 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -30,6 +30,7 @@ OperationalError, RequestError, ) +from databricks.sql.backend.types import CommandState from tests.e2e.common.predicates import ( pysql_has_version, pysql_supports_arrow, @@ -826,10 +827,7 @@ def test_close_connection_closes_cursors(self): getProgressUpdate=False, ) op_status_at_server = ars.backend._client.GetOperationStatus(status_request) - assert ( - op_status_at_server.operationState - != ttypes.TOperationState.CLOSED_STATE - ) + assert op_status_at_server.operationState != CommandState.CLOSED conn.close() @@ -939,7 +937,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.op_state == CommandState.CLOSED assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fa6fae1d9..1a7950870 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -14,7 +14,9 @@ TOperationHandle, THandleIdentifier, TOperationType, + TOperationState, ) +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql @@ -22,7 +24,9 @@ from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row -from databricks.sql.client import CommandId +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -36,12 +40,11 @@ def new(cls): ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( - MockTExecuteStatementResp, + mock_result_set, description=None, - arrow_queue=None, is_staging_operation=False, command_id=None, has_been_closed_server_side=True, @@ -50,7 +53,7 @@ def new(cls): arrow_schema_bytes=b"schema", ) - ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + ThriftBackendMock.execute_command.return_value = mock_result_set return ThriftBackendMock @@ -82,25 +85,79 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch( - "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, - ThriftDatabricksClientMockFactory.new(), - ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_closing_connection_closes_commands(self, mock_result_set_class): + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_closing_connection_closes_commands(self, mock_thrift_client_class): + """Test that connection.close() properly closes result sets through the real close chain.""" # Test once with has_been_closed_server side, once without for closed in (True, False): with self.subTest(closed=closed): - mock_result_set_class.return_value = Mock() + # Mock the execute response with controlled state + mock_execute_response = Mock(spec=ExecuteResponse) + + mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.status = ( + CommandState.SUCCEEDED if not closed else CommandState.CLOSED + ) + mock_execute_response.has_been_closed_server_side = closed + mock_execute_response.is_staging_operation = False + + # Mock the backend that will be used by the real ThriftResultSet + mock_backend = Mock(spec=ThriftDatabricksClient) + mock_backend.staging_allowed_local_path = None + + # Configure the decorator's mock to return our specific mock_backend + mock_thrift_client_class.return_value = mock_backend + + # Create connection and cursor connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - cursor.execute("SELECT 1;") - connection.close() - self.assertTrue( - mock_result_set_class.return_value.has_been_closed_server_side + # Create a REAL ThriftResultSet that will be returned by execute_command + real_result_set = ThriftResultSet( + connection=connection, + execute_response=mock_execute_response, + thrift_client=mock_backend, + ) + + # Verify initial state + self.assertEqual(real_result_set.has_been_closed_server_side, closed) + expected_op_state = ( + CommandState.CLOSED if closed else CommandState.SUCCEEDED + ) + self.assertEqual(real_result_set.op_state, expected_op_state) + + # Mock execute_command to return our real result set + cursor.backend.execute_command = Mock(return_value=real_result_set) + + # Execute a command - this should set cursor.active_result_set to our real result set + cursor.execute("SELECT 1") + + # Verify that cursor.execute() set up the result set correctly + self.assertIsInstance(cursor.active_result_set, ThriftResultSet) + self.assertEqual( + cursor.active_result_set.has_been_closed_server_side, closed ) - mock_result_set_class.return_value.close.assert_called_once_with() + + # Close the connection - this should trigger the real close chain: + # connection.close() -> cursor.close() -> result_set.close() + connection.close() + + # Verify the REAL close logic worked through the chain: + # 1. has_been_closed_server_side should always be True after close() + self.assertTrue(real_result_set.has_been_closed_server_side) + + # 2. op_state should always be CLOSED after close() + self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + + # 3. Backend close_command should be called appropriately + if not closed: + # Should have called backend.close_command during the close chain + mock_backend.close_command.assert_called_once_with( + mock_execute_response.command_id + ) + else: + # Should NOT have called backend.close_command (already closed) + mock_backend.close_command.assert_not_called() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @@ -127,10 +184,11 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - result_set = client.ResultSet( + + result_set = ThriftResultSet( connection=mock_connection, - backend=mock_backend, execute_response=Mock(), + thrift_client=mock_backend, ) # Setup session mock on the mock_connection mock_session = Mock() @@ -152,7 +210,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = client.ResultSet( + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -162,17 +220,16 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.command_id ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command( - self, mock_result_set_class - ): - + def test_executing_multiple_commands_uses_the_most_recent_command(self): mock_result_sets = [Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_sets + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False - cursor = client.Cursor( - connection=Mock(), backend=ThriftDatabricksClientMockFactory.new() - ) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + + cursor = client.Cursor(connection=Mock(), backend=mock_backend) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -197,7 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -349,14 +406,15 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class - ): + def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_set_instances + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_set_instances: + mock_rs.is_staging_operation = False + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_set_instances cursor = client.Cursor(Mock(), mock_backend) @@ -509,8 +567,9 @@ def test_staging_operation_response_is_handled( ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) - mock_client_class.execute_command.return_value = mock_execute_response - mock_client_class.return_value = mock_client_class + mock_client = mock_client_class.return_value + mock_client.execute_command.return_value = Mock(is_staging_operation=True) + mock_client_class.return_value = mock_client connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -617,9 +676,9 @@ def mock_close_normal(): def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" - result_set = client.ResultSet.__new__(client.ResultSet) - result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" + result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) + result_set.backend = Mock() + result_set.backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True result_set.op_state = "RUNNING" @@ -630,31 +689,31 @@ class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - result_set.thrift_backend.close_command.side_effect = MockRequestError() + result_set.backend.close_command.side_effect = MockRequestError() original_close = client.ResultSet.close try: try: if ( - result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): - result_set.thrift_backend.close_command(result_set.command_id) + result_set.backend.close_command(result_set.command_id) except MockRequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state = result_set.backend.CLOSED_OP_STATE - result_set.thrift_backend.close_command.assert_called_once_with( + result_set.backend.close_command.assert_called_once_with( result_set.command_id ) assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 1c6a1b18d..030510a64 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -10,6 +10,7 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ThriftResultSet @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -38,9 +39,8 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, @@ -52,6 +52,7 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, ), + thrift_client=None, ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ @@ -84,9 +85,8 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -101,6 +101,7 @@ def fetch_results( arrow_schema_bytes=None, is_staging_operation=False, ), + thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 949230d1e..37e6cf1c9 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -64,13 +64,7 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - properties = ( - {"serverProtocolVersion": test_input.serverProtocolVersion} - if test_input.serverProtocolVersion - else {} - ) - session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) - assert Connection.get_protocol_version(session_id) == expected + assert Connection.get_protocol_version(test_input) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 41a2a5800..57a2a61e3 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -18,7 +18,8 @@ from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.types import CommandId, SessionId, BackendType +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -882,7 +883,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) self.assertEqual( results_message_response.status, - ttypes.TOperationState.FINISHED_STATE, + CommandState.SUCCEEDED, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -1152,7 +1153,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock) + result = thrift_backend.execute_command( + "foo", Mock(), 100, 200, Mock(), cursor_mock + ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1181,7 +1187,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1209,7 +1218,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_schemas( + result = thrift_backend.get_schemas( Mock(), 100, 200, @@ -1217,6 +1226,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1246,7 +1258,7 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_tables( + result = thrift_backend.get_tables( Mock(), 100, 200, @@ -1256,6 +1268,9 @@ def test_get_tables_calls_client_and_handle_execute_response( table_name="table_pattern", table_types=["type1", "type2"], ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1287,7 +1302,7 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_columns( + result = thrift_backend.get_columns( Mock(), 100, 200, @@ -1297,6 +1312,9 @@ def test_get_columns_calls_client_and_handle_execute_response( table_name="table_pattern", column_name="column_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)