From bb3e94d1a87c9fa36814b4202962337d464739d7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 06:15:36 +0000 Subject: [PATCH 1/8] introduce row_limit Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 2 ++ src/databricks/sql/backend/sea/backend.py | 3 ++- src/databricks/sql/client.py | 9 +++++++++ src/databricks/sql/session.py | 4 ++-- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 973c2932e..5a8e9da20 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -85,6 +85,7 @@ def execute_command( parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. @@ -103,6 +104,7 @@ def execute_command( parameters: List of parameters to bind to the query async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness + row_limit: Maximum number of rows to fetch overall. Only supported for SEA protocol. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 76903ccd2..3c585dd2a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -405,6 +405,7 @@ def execute_command( parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, ) -> Union["ResultSet", None]: """ Execute a SQL command using the SEA backend. @@ -462,7 +463,7 @@ def execute_command( format=format, wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, on_wait_timeout="CONTINUE", - row_limit=max_rows, + row_limit=row_limit, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, ) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e145e4e58..2a987e4a4 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -341,6 +341,7 @@ def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, + row_limit: int = None, ) -> "Cursor": """ Return a new Cursor object using the connection. @@ -355,6 +356,7 @@ def cursor( self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, + row_limit=row_limit, ) self._cursors.append(cursor) return cursor @@ -388,6 +390,7 @@ def __init__( backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, + row_limit: Optional[int] = None, ) -> None: """ These objects represent a database cursor, which is used to manage the context of a fetch @@ -397,11 +400,17 @@ def __init__( visible by other cursors or connections. """ + if not self.connection.session.use_sea and row_limit is not None: + logger.warning( + "Row limit is only supported for SEA protocol. Ignoring row_limit." + ) + self.connection = connection self.rowcount = -1 # Return -1 as this is not supported self.buffer_size_bytes = result_buffer_size_bytes self.active_result_set: Union[ResultSet, None] = None self.arraysize = arraysize + self.row_limit = row_limit # Note that Cursor closed => active result set closed, but not vice versa self.open = True self.executing_command_id = None diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 76aec4675..670f86dbe 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -97,10 +97,10 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" - use_sea = kwargs.get("use_sea", False) + self.use_sea = kwargs.get("use_sea", False) databricks_client_class: Type[DatabricksClient] - if use_sea: + if self.use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient else: From dc3a1fb4f26d819dda368ebf62088ce7966a4fa5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 06:23:25 +0000 Subject: [PATCH 2/8] move use_sea init to Session constructor Signed-off-by: varun-edachali-dbx --- src/databricks/sql/session.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 670f86dbe..f9c4ee049 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -76,7 +76,9 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) + self.use_sea = kwargs.get("use_sea", False) self.backend = self._create_backend( + self.use_sea, server_hostname, http_path, all_headers, @@ -89,6 +91,7 @@ def __init__( def _create_backend( self, + use_sea: bool, server_hostname: str, http_path: str, all_headers: List[Tuple[str, str]], @@ -97,10 +100,8 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" - self.use_sea = kwargs.get("use_sea", False) - databricks_client_class: Type[DatabricksClient] - if self.use_sea: + if use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient else: From eba17c1c128f001502ebaddf6a731854d24ff25e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 06:34:07 +0000 Subject: [PATCH 3/8] more explicit typing Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 3 ++- src/databricks/sql/client.py | 23 ++++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..ac7463692 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -4,7 +4,7 @@ import math import time import threading -from typing import List, Union, Any, TYPE_CHECKING +from typing import List, Optional, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -929,6 +929,7 @@ def execute_command( parameters=[], async_op=False, enforce_embedded_schema_correctness=False, + row_limit: Optional[int] = None, ) -> Union["ResultSet", None]: thrift_handle = session_id.to_thrift_handle() if not thrift_handle: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 2a987e4a4..cca14d762 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -341,7 +341,7 @@ def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - row_limit: int = None, + row_limit: Optional[int] = None, ) -> "Cursor": """ Return a new Cursor object using the connection. @@ -400,22 +400,23 @@ def __init__( visible by other cursors or connections. """ - if not self.connection.session.use_sea and row_limit is not None: + self.connection: Connection = connection + + if not connection.session.use_sea and row_limit is not None: logger.warning( "Row limit is only supported for SEA protocol. Ignoring row_limit." ) - self.connection = connection - self.rowcount = -1 # Return -1 as this is not supported - self.buffer_size_bytes = result_buffer_size_bytes + self.rowcount: int = -1 # Return -1 as this is not supported + self.buffer_size_bytes: int = result_buffer_size_bytes self.active_result_set: Union[ResultSet, None] = None - self.arraysize = arraysize - self.row_limit = row_limit + self.arraysize: int = arraysize + self.row_limit: Optional[int] = row_limit # Note that Cursor closed => active result set closed, but not vice versa - self.open = True - self.executing_command_id = None - self.backend = backend - self.active_command_id = None + self.open: bool = True + self.executing_command_id: Optional[CommandId] = None + self.backend: DatabricksClient = backend + self.active_command_id: Optional[CommandId] = None self.escaper = ParamEscaper() self.lastrowid = None From 00e57e7298067348cf011a0ee0d5d0b4fb7e61f1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 08:02:49 +0000 Subject: [PATCH 4/8] add row_limit to Thrift backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 5a8e9da20..49ec59a3f 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -104,7 +104,7 @@ def execute_command( parameters: List of parameters to bind to the query async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness - row_limit: Maximum number of rows to fetch overall. Only supported for SEA protocol. + row_limit: Maximum number of rows in the operation result. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index ac7463692..3b38fcd56 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -970,6 +970,7 @@ def execute_command( useArrowNativeTypes=spark_arrow_types, parameters=parameters, enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness, + resultRowLimit=row_limit, ) resp = self.make_request(self._client.ExecuteStatement, req) From 304ef0ee72d4d0ab8d02762321b62e772eb3bc70 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 19 Jun 2025 08:46:16 +0000 Subject: [PATCH 5/8] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 49ec59a3f..276954b7c 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -104,7 +104,7 @@ def execute_command( parameters: List of parameters to bind to the query async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness - row_limit: Maximum number of rows in the operation result. + row_limit: Maximum number of rows in the operation result. Returns: If async_op is False, returns a ResultSet object containing the From 414117a30d460ae4351a0c72e76cc23f705bea8a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 03:49:03 +0000 Subject: [PATCH 6/8] add e2e test for thrift resultRowLimit Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 7 ++-- src/databricks/sql/session.py | 5 ++- tests/e2e/test_driver.py | 60 +++++++++++++++++++++++++++++++++-- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index cca14d762..027d117a5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -402,11 +402,6 @@ def __init__( self.connection: Connection = connection - if not connection.session.use_sea and row_limit is not None: - logger.warning( - "Row limit is only supported for SEA protocol. Ignoring row_limit." - ) - self.rowcount: int = -1 # Return -1 as this is not supported self.buffer_size_bytes: int = result_buffer_size_bytes self.active_result_set: Union[ResultSet, None] = None @@ -802,6 +797,7 @@ def execute( parameters=prepared_params, async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, + row_limit=self.row_limit, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -858,6 +854,7 @@ def execute_async( parameters=prepared_params, async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, + row_limit=self.row_limit, ) return self diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f9c4ee049..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -76,9 +76,7 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.use_sea = kwargs.get("use_sea", False) self.backend = self._create_backend( - self.use_sea, server_hostname, http_path, all_headers, @@ -91,7 +89,6 @@ def __init__( def _create_backend( self, - use_sea: bool, server_hostname: str, http_path: str, all_headers: List[Tuple[str, str]], @@ -100,6 +97,8 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" + use_sea = kwargs.get("use_sea", False) + databricks_client_class: Type[DatabricksClient] if use_sea: logger.debug("Creating SEA backend client") diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 8cfed7c28..a68e64497 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -113,10 +113,12 @@ def connection(self, extra_params=()): conn.close() @contextmanager - def cursor(self, extra_params=()): + def cursor(self, extra_params=(), extra_cursor_params=()): with self.connection(extra_params) as conn: cursor = conn.cursor( - arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes + arraysize=self.arraysize, + buffer_size_bytes=self.buffer_size_bytes, + **extra_cursor_params, ) try: yield cursor @@ -945,6 +947,60 @@ def test_result_set_close(self): finally: cursor.close() + def test_row_limit_with_larger_result(self): + """Test that row_limit properly constrains results when query would return more rows""" + row_limit = 1000 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns more than row_limit rows + cursor.execute("SELECT * FROM range(2000)") + rows = cursor.fetchall() + + # Check if the number of rows is limited to row_limit + assert len(rows) == row_limit, f"Expected {row_limit} rows, got {len(rows)}" + + def test_row_limit_with_smaller_result(self): + """Test that row_limit doesn't affect results when query returns fewer rows than limit""" + row_limit = 100 + expected_rows = 50 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns fewer than row_limit rows + cursor.execute(f"SELECT * FROM range({expected_rows})") + rows = cursor.fetchall() + + # Check if all rows are returned (not limited by row_limit) + assert ( + len(rows) == expected_rows + ), f"Expected {expected_rows} rows, got {len(rows)}" + + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + def test_row_limit_with_arrow_larger_result(self): + """Test that row_limit properly constrains arrow results when query would return more rows""" + row_limit = 800 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns more than row_limit rows + cursor.execute("SELECT * FROM range(1500)") + arrow_table = cursor.fetchall_arrow() + + # Check if the number of rows in the arrow table is limited to row_limit + assert ( + arrow_table.num_rows == row_limit + ), f"Expected {row_limit} rows, got {arrow_table.num_rows}" + + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + def test_row_limit_with_arrow_smaller_result(self): + """Test that row_limit doesn't affect arrow results when query returns fewer rows than limit""" + row_limit = 200 + expected_rows = 100 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns fewer than row_limit rows + cursor.execute(f"SELECT * FROM range({expected_rows})") + arrow_table = cursor.fetchall_arrow() + + # Check if all rows are returned (not limited by row_limit) + assert ( + arrow_table.num_rows == expected_rows + ), f"Expected {expected_rows} rows, got {arrow_table.num_rows}" + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. From 20589077be41c02262f39de25c9d35c9f2fc8201 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 04:02:30 +0000 Subject: [PATCH 7/8] explicitly convert extra cursor params to dict Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index a68e64497..f0944328a 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -118,7 +118,7 @@ def cursor(self, extra_params=(), extra_cursor_params=()): cursor = conn.cursor( arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes, - **extra_cursor_params, + **dict(extra_cursor_params), ) try: yield cursor From 084da7a9e59f4450913cd742f5456de31073502a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 10:22:48 +0000 Subject: [PATCH 8/8] remove excess tests Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 136 --------------------------------------- 1 file changed, 136 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 5dcc36ef9..8f15bccc6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -811,142 +811,6 @@ def test_catalogs_returns_arrow_table(self): results = cursor.fetchall_arrow() assert isinstance(results, pyarrow.Table) - def test_close_connection_closes_cursors(self): - - from databricks.sql.thrift_api.TCLIService import ttypes - - with self.connection() as conn: - cursor = conn.cursor() - cursor.execute( - "SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()" - ) - ars = cursor.active_result_set - - # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True - # Cursor op state should be open before connection is closed - status_request = ttypes.TGetOperationStatusReq( - operationHandle=ars.command_id.to_thrift_handle(), - getProgressUpdate=False, - ) - op_status_at_server = ars.backend._client.GetOperationStatus(status_request) - assert op_status_at_server.operationState != CommandState.CLOSED - - conn.close() - - # When connection closes, any cursor operations should no longer exist at the server - with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.backend._client.GetOperationStatus( - status_request - ) - - def test_closing_a_closed_connection_doesnt_fail(self, caplog): - caplog.set_level(logging.DEBUG) - # Second .close() call is when this context manager exits - with self.connection() as conn: - # First .close() call is explicit here - conn.close() - assert "Session appears to have been closed already" in caplog.text - - conn = None - try: - with pytest.raises(KeyboardInterrupt): - with self.connection() as c: - conn = c - raise KeyboardInterrupt("Simulated interrupt") - finally: - if conn is not None: - assert ( - not conn.open - ), "Connection should be closed after KeyboardInterrupt" - - def test_cursor_close_properly_closes_operation(self): - """Test that Cursor.close() properly closes the active operation handle on the server.""" - with self.connection() as conn: - cursor = conn.cursor() - try: - cursor.execute("SELECT 1 AS test") - assert cursor.active_command_id is not None - cursor.close() - assert cursor.active_command_id is None - assert not cursor.open - finally: - if cursor.open: - cursor.close() - - conn = None - cursor = None - try: - with self.connection() as c: - conn = c - with pytest.raises(KeyboardInterrupt): - with conn.cursor() as cur: - cursor = cur - raise KeyboardInterrupt("Simulated interrupt") - finally: - if cursor is not None: - assert ( - not cursor.open - ), "Cursor should be closed after KeyboardInterrupt" - - def test_nested_cursor_context_managers(self): - """Test that nested cursor context managers properly close operations on the server.""" - with self.connection() as conn: - with conn.cursor() as cursor1: - cursor1.execute("SELECT 1 AS test1") - assert cursor1.active_command_id is not None - - with conn.cursor() as cursor2: - cursor2.execute("SELECT 2 AS test2") - assert cursor2.active_command_id is not None - - # After inner context manager exit, cursor2 should be not open - assert not cursor2.open - assert cursor2.active_command_id is None - - # After outer context manager exit, cursor1 should be not open - assert not cursor1.open - assert cursor1.active_command_id is None - - def test_cursor_error_handling(self): - """Test that cursor close handles errors properly to prevent orphaned operations.""" - with self.connection() as conn: - cursor = conn.cursor() - - cursor.execute("SELECT 1 AS test") - - op_handle = cursor.active_command_id - - assert op_handle is not None - - # Manually close the operation to simulate server-side closure - conn.session.backend.close_command(op_handle) - - cursor.close() - - assert not cursor.open - - def test_result_set_close(self): - """Test that ResultSet.close() properly closes operations on the server and handles state correctly.""" - with self.connection() as conn: - cursor = conn.cursor() - try: - cursor.execute("SELECT * FROM RANGE(10)") - - result_set = cursor.active_result_set - assert result_set is not None - - initial_op_state = result_set.status - - result_set.close() - - assert result_set.status == CommandState.CLOSED - assert result_set.status != initial_op_state - - # Closing the result set again should be a no-op and not raise exceptions - result_set.close() - finally: - cursor.close() - def test_row_limit_with_larger_result(self): """Test that row_limit properly constrains results when query would return more rows""" row_limit = 1000