From 9de6c8becd397592fc8bb65602e3f401c02abe1f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 31 Jul 2025 13:46:04 +0000 Subject: [PATCH 01/22] init col norm Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 21 ++- src/databricks/sql/backend/sea/result_set.py | 128 +++++++++++++++++- .../backend/sea/utils/metadata_mappings.py | 90 ++++++++++++ .../sql/backend/sea/utils/result_column.py | 18 +++ tests/e2e/test_concurrent_telemetry.py | 49 ++++--- tests/unit/test_telemetry.py | 10 +- tests/unit/test_telemetry_retry.py | 2 +- 7 files changed, 287 insertions(+), 31 deletions(-) create mode 100644 src/databricks/sql/backend/sea/utils/metadata_mappings.py create mode 100644 src/databricks/sql/backend/sea/utils/result_column.py diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a8f04a05a..50515803e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -19,6 +19,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift from databricks.sql.thrift_api.TCLIService import ttypes @@ -699,7 +700,10 @@ def get_catalogs( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.CATALOG_COLUMNS) return result def get_schemas( @@ -732,7 +736,10 @@ def get_schemas( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS) return result def get_tables( @@ -773,7 +780,10 @@ def get_tables( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS) # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter @@ -820,5 +830,8 @@ def get_columns( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.COLUMN_COLUMNS) return result diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..5ec8badfa 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING, Dict import logging from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter +from databricks.sql.backend.sea.utils.result_column import ResultColumn try: import pyarrow @@ -82,6 +83,10 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Initialize metadata columns for post-fetch transformation + self._metadata_columns = None + self._column_index_mapping = None + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") results = self.results.next_n_rows(size) + results = self._transform_json_rows(results) self._next_row_index += len(results) return results @@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]: """ results = self.results.remaining_rows() + results = self._transform_json_rows(results) self._next_row_index += len(results) return results @@ -197,7 +204,12 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) if isinstance(self.results, JsonQueue): - results = self._convert_json_to_arrow_table(results) + # Transform JSON first, then convert to Arrow + transformed_json = self._transform_json_rows(results) + results = self._convert_json_to_arrow_table(transformed_json) + else: + # Transform Arrow table directly + results = self._transform_arrow_table(results) self._next_row_index += results.num_rows @@ -210,7 +222,12 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() if isinstance(self.results, JsonQueue): - results = self._convert_json_to_arrow_table(results) + # Transform JSON first, then convert to Arrow + transformed_json = self._transform_json_rows(results) + results = self._convert_json_to_arrow_table(transformed_json) + else: + # Transform Arrow table directly + results = self._transform_arrow_table(results) self._next_row_index += results.num_rows @@ -263,3 +280,108 @@ def fetchall(self) -> List[Row]: return self._create_json_table(self.fetchall_json()) else: return self._convert_arrow_table(self.fetchall_arrow()) + + def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None: + """ + Prepare result set for metadata column normalization. + + Args: + metadata_columns: List of ResultColumn objects defining the expected columns + and their mappings from SEA column names + """ + self._metadata_columns = metadata_columns + self._prepare_column_mapping() + + def _prepare_column_mapping(self) -> None: + """ + Prepare column index mapping for metadata queries. + Updates description to use JDBC column names. + """ + # Ensure description is available + if not self.description: + raise ValueError("Cannot prepare column mapping without result description") + + # Build mapping from SEA column names to their indices + sea_column_indices = {} + for idx, col in enumerate(self.description): + sea_column_indices[col[0]] = idx + + # Create new description and index mapping + new_description = [] + self._column_index_mapping = {} # Maps new index -> old index + + for new_idx, result_column in enumerate(self._metadata_columns): + # Find the corresponding SEA column + if ( + result_column.result_set_column_name + and result_column.result_set_column_name in sea_column_indices + ): + old_idx = sea_column_indices[result_column.result_set_column_name] + self._column_index_mapping[new_idx] = old_idx + # Use the original column metadata but with JDBC name + old_col = self.description[old_idx] + new_description.append( + ( + result_column.column_name, # JDBC name + result_column.column_type, # Expected type + old_col[2], # display_size + old_col[3], # internal_size + old_col[4], # precision + old_col[5], # scale + old_col[6], # null_ok + ) + ) + else: + # Column doesn't exist in SEA - add with None values + new_description.append( + ( + result_column.column_name, + result_column.column_type, + None, + None, + None, + None, + True, + ) + ) + self._column_index_mapping[new_idx] = None + + self.description = new_description + + def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": + """Transform arrow table columns for metadata normalization.""" + if not self._metadata_columns: + return table + + # Reorder columns and add missing ones + new_columns = [] + column_names = [] + + for new_idx, result_column in enumerate(self._metadata_columns): + old_idx = self._column_index_mapping.get(new_idx) + if old_idx is not None: + new_columns.append(table.column(old_idx)) + else: + # Create null column for missing data + null_array = pyarrow.nulls(table.num_rows) + new_columns.append(null_array) + column_names.append(result_column.column_name) + + return pyarrow.Table.from_arrays(new_columns, names=column_names) + + def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]: + """Transform JSON rows for metadata normalization.""" + if not self._metadata_columns: + return rows + + transformed_rows = [] + for row in rows: + new_row = [] + for new_idx in range(len(self._metadata_columns)): + old_idx = self._column_index_mapping.get(new_idx) + if old_idx is not None: + new_row.append(row[old_idx]) + else: + new_row.append(None) + transformed_rows.append(new_row) + return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py new file mode 100644 index 000000000..218c20b43 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -0,0 +1,90 @@ +from databricks.sql.backend.sea.utils.result_column import ResultColumn + + +class MetadataColumnMappings: + """Column mappings for metadata queries following JDBC specification.""" + + # Common columns used across multiple metadata queries + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalog", "string") + CATALOG_COLUMN_FOR_TABLES = ResultColumn("TABLE_CAT", "catalogName", "string") + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string") + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string") + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string") + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string") + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string") + + # Columns specific to getColumns() + COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string") + DATA_TYPE_COLUMN = ResultColumn( + "DATA_TYPE", None, "int" + ) # SEA doesn't provide this + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string") + COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", None, "int") + DECIMAL_DIGITS_COLUMN = ResultColumn("DECIMAL_DIGITS", None, "int") + NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", None, "int") + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int") + COLUMN_DEF_COLUMN = ResultColumn( + "COLUMN_DEF", "columnType", "string" + ) # Note: duplicate mapping + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int") + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int") + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int") + ORDINAL_POSITION_COLUMN = ResultColumn("ORDINAL_POSITION", None, "int") + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string") + + # Columns for getTables() that don't exist in SEA + TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string") + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string") + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string") + SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( + "SELF_REFERENCING_COL_NAME", None, "string" + ) + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string") + + # Column lists for each metadata operation + CATALOG_COLUMNS = [CATALOG_COLUMN] + + SCHEMA_COLUMNS = [ + SCHEMA_COLUMN_FOR_GET_SCHEMA, + ResultColumn("TABLE_CATALOG", None, "string"), # SEA doesn't return this + ] + + TABLE_COLUMNS = [ + CATALOG_COLUMN_FOR_TABLES, + SCHEMA_COLUMN, + TABLE_NAME_COLUMN, + TABLE_TYPE_COLUMN, + REMARKS_COLUMN, + TYPE_CAT_COLUMN, + TYPE_SCHEM_COLUMN, + TYPE_NAME_COLUMN, + SELF_REFERENCING_COL_NAME_COLUMN, + REF_GENERATION_COLUMN, + ] + + COLUMN_COLUMNS = [ + CATALOG_COLUMN_FOR_TABLES, + SCHEMA_COLUMN, + TABLE_NAME_COLUMN, + COLUMN_NAME_COLUMN, + DATA_TYPE_COLUMN, + TYPE_NAME_COLUMN, + COLUMN_SIZE_COLUMN, + ResultColumn("BUFFER_LENGTH", None, "int"), + DECIMAL_DIGITS_COLUMN, + NUM_PREC_RADIX_COLUMN, + NULLABLE_COLUMN, + REMARKS_COLUMN, + COLUMN_DEF_COLUMN, + SQL_DATA_TYPE_COLUMN, + SQL_DATETIME_SUB_COLUMN, + CHAR_OCTET_LENGTH_COLUMN, + ORDINAL_POSITION_COLUMN, + IS_NULLABLE_COLUMN, + ResultColumn("SCOPE_CATALOG", None, "string"), + ResultColumn("SCOPE_SCHEMA", None, "string"), + ResultColumn("SCOPE_TABLE", None, "string"), + ResultColumn("SOURCE_DATA_TYPE", None, "smallint"), + ResultColumn("IS_AUTO_INCREMENT", None, "string"), + ResultColumn("IS_GENERATEDCOLUMN", None, "string"), + ] diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py new file mode 100644 index 000000000..1a1399e54 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class ResultColumn: + """ + Represents a mapping between JDBC specification column names and actual result set column names. + + Attributes: + column_name: JDBC specification column name (e.g., "TABLE_CAT") + result_set_column_name: Server result column name from SEA (e.g., "catalog") + column_type: SQL type code from databricks.sql.types + """ + + column_name: str + result_set_column_name: Optional[str] # None if SEA doesn't return this column + column_type: str diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 656bcd21f..fe753012b 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -5,9 +5,13 @@ import pytest from databricks.sql.telemetry.models.enums import StatementType -from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + TelemetryClientFactory, +) from tests.e2e.test_driver import PySQLPytestTestCase + def run_in_threads(target, num_threads, pass_index=False): """Helper to run target function in multiple threads.""" threads = [ @@ -21,7 +25,6 @@ def run_in_threads(target, num_threads, pass_index=False): class TestE2ETelemetry(PySQLPytestTestCase): - @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): """ @@ -30,7 +33,7 @@ def telemetry_setup_teardown(self): this robust and automatic. """ try: - yield + yield finally: if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) @@ -65,10 +68,10 @@ def callback_wrapper(self_client, future, sent_count): """ try: original_callback(self_client, future, sent_count) - + # Now, capture the result for our assertions response = future.result() - response.raise_for_status() # Raise an exception for 4xx/5xx errors + response.raise_for_status() # Raise an exception for 4xx/5xx errors telemetry_response = response.json() with capture_lock: captured_responses.append(telemetry_response) @@ -76,20 +79,23 @@ def callback_wrapper(self_client, future, sent_count): with capture_lock: captured_exceptions.append(e) - with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + with patch.object( + TelemetryClient, "_send_telemetry", send_telemetry_wrapper + ), patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ): def execute_query_worker(thread_id): """Each thread creates a connection and executes a query.""" time.sleep(random.uniform(0, 0.05)) - + with self.connection(extra_params={"enable_telemetry": True}) as conn: # Capture the session ID from the connection before executing the query session_id_hex = conn.get_session_id_hex() with capture_lock: captured_session_ids.append(session_id_hex) - + with conn.cursor() as cursor: cursor.execute(f"SELECT {thread_id}") # Capture the statement ID after executing the query @@ -107,7 +113,7 @@ def execute_query_worker(thread_id): # --- VERIFICATION --- assert not captured_exceptions assert len(captured_responses) > 0 - + total_successful_events = 0 for response in captured_responses: assert "errors" not in response or not response["errors"] @@ -115,22 +121,29 @@ def execute_query_worker(thread_id): total_successful_events += response["numProtoSuccess"] assert total_successful_events == num_threads * 2 - assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute)) + assert ( + len(captured_telemetry) == num_threads * 2 + ) # 2 events per thread (initial_telemetry_log, latency_log (execute)) assert len(captured_session_ids) == num_threads # One session ID per thread - assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query) + assert ( + len(captured_statement_ids) == num_threads + ) # One statement ID per thread (per query) # Separate initial logs from latency logs initial_logs = [ - e for e in captured_telemetry + e + for e in captured_telemetry if e.entry.sql_driver_log.operation_latency_ms is None and e.entry.sql_driver_log.driver_connection_params is not None and e.entry.sql_driver_log.system_configuration is not None ] latency_logs = [ - e for e in captured_telemetry - if e.entry.sql_driver_log.operation_latency_ms is not None - and e.entry.sql_driver_log.sql_statement_id is not None - and e.entry.sql_driver_log.sql_operation.statement_type == StatementType.QUERY + e + for e in captured_telemetry + if e.entry.sql_driver_log.operation_latency_ms is not None + and e.entry.sql_driver_log.sql_statement_id is not None + and e.entry.sql_driver_log.sql_operation.statement_type + == StatementType.QUERY ] # Verify counts @@ -163,4 +176,4 @@ def execute_query_worker(thread_id): for event in latency_logs: log = event.entry.sql_driver_log assert log.sql_statement_id in captured_statement_ids - assert log.session_id in captured_session_ids \ No newline at end of file + assert log.session_id in captured_session_ids diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d0e28c18d..45fc05108 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -30,7 +30,7 @@ def mock_telemetry_client(): auth_provider=auth_provider, host_url="test-host.com", executor=executor, - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) @@ -215,7 +215,7 @@ def test_client_lifecycle_flow(self): session_id_hex=session_id_hex, auth_provider=auth_provider, host_url="test-host.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -240,7 +240,7 @@ def test_disabled_telemetry_flow(self): session_id_hex=session_id_hex, auth_provider=None, host_url="test-host.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -260,7 +260,7 @@ def test_factory_error_handling(self): session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) # Should fall back to NoopTelemetryClient @@ -279,7 +279,7 @@ def test_factory_shutdown_flow(self): session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) # Factory should be initialized diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index 9f3a5c59d..d5287deb9 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -51,7 +51,7 @@ def get_client(self, session_id, num_retries=3): session_id_hex=session_id, auth_provider=None, host_url="test.databricks.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) client = TelemetryClientFactory.get_telemetry_client(session_id) From f5982f0c624b55bee5488211641ec2a42e82d736 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sun, 3 Aug 2025 13:43:16 +0000 Subject: [PATCH 02/22] partial working dump Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 89 ++++++++++++++- .../sql/backend/sea/utils/conversion.py | 4 + .../sql/backend/sea/utils/filters.py | 7 ++ .../backend/sea/utils/metadata_mappings.py | 108 ++++++++++++------ .../sea/utils/metadata_transformations.py | 98 ++++++++++++++++ .../sql/backend/sea/utils/result_column.py | 4 +- 6 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 src/databricks/sql/backend/sea/utils/metadata_transformations.py diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 5ec8badfa..e3b0c024c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -359,12 +359,57 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": for new_idx, result_column in enumerate(self._metadata_columns): old_idx = self._column_index_mapping.get(new_idx) + + # Get the source data if old_idx is not None: - new_columns.append(table.column(old_idx)) + column = table.column(old_idx) + values = column.to_pylist() else: - # Create null column for missing data - null_array = pyarrow.nulls(table.num_rows) + values = None + + # Special handling for columns that need data from other columns + if result_column.column_name == "DATA_TYPE" and result_column.result_set_column_name is None: + # Get TYPE_NAME column value for DATA_TYPE calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "TYPE_NAME": + type_idx = self._column_index_mapping.get(idx) + if type_idx is not None: + values = table.column(type_idx).to_pylist() + break + elif result_column.column_name == "NULLABLE" and result_column.result_set_column_name is None: + # Get IS_NULLABLE column value for NULLABLE calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "IS_NULLABLE": + nullable_idx = self._column_index_mapping.get(idx) + if nullable_idx is not None: + values = table.column(nullable_idx).to_pylist() + break + elif result_column.column_name == "BUFFER_LENGTH" and result_column.result_set_column_name is None: + # Get TYPE_NAME column value for BUFFER_LENGTH calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "TYPE_NAME": + type_idx = self._column_index_mapping.get(idx) + if type_idx is not None: + values = table.column(type_idx).to_pylist() + break + + # Apply transformation and create column + if values is not None: + if result_column.transform_value: + transformed_values = [result_column.transform_value(v) for v in values] + column = pyarrow.array(transformed_values) + else: + column = pyarrow.array(values) + new_columns.append(column) + else: + # Create column with default/transformed values + if result_column.transform_value: + default_value = result_column.transform_value(None) + null_array = pyarrow.array([default_value] * table.num_rows) + else: + null_array = pyarrow.nulls(table.num_rows) new_columns.append(null_array) + column_names.append(result_column.column_name) return pyarrow.Table.from_arrays(new_columns, names=column_names) @@ -377,11 +422,43 @@ def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]: transformed_rows = [] for row in rows: new_row = [] - for new_idx in range(len(self._metadata_columns)): + for new_idx, result_column in enumerate(self._metadata_columns): old_idx = self._column_index_mapping.get(new_idx) if old_idx is not None: - new_row.append(row[old_idx]) + value = row[old_idx] else: - new_row.append(None) + value = None + + # Special handling for columns that need data from other columns + if result_column.column_name == "DATA_TYPE" and result_column.result_set_column_name is None: + # Get TYPE_NAME column value for DATA_TYPE calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "TYPE_NAME": + type_idx = self._column_index_mapping.get(idx) + if type_idx is not None and type_idx < len(row): + value = row[type_idx] + break + elif result_column.column_name == "NULLABLE" and result_column.result_set_column_name is None: + # Get IS_NULLABLE column value for NULLABLE calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "IS_NULLABLE": + nullable_idx = self._column_index_mapping.get(idx) + if nullable_idx is not None and nullable_idx < len(row): + value = row[nullable_idx] + break + elif result_column.column_name == "BUFFER_LENGTH" and result_column.result_set_column_name is None: + # Get TYPE_NAME column value for BUFFER_LENGTH calculation + for idx, col in enumerate(self._metadata_columns): + if col.column_name == "TYPE_NAME": + type_idx = self._column_index_mapping.get(idx) + if type_idx is not None and type_idx < len(row): + value = row[type_idx] + break + + # Apply transformation if defined + if result_column.transform_value: + value = result_column.transform_value(value) + + new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 69c6dfbe2..139d357c7 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -150,6 +150,10 @@ def convert_value( Returns: The converted value in the appropriate Python type """ + + # Handle None values directly + if value is None: + return None sql_type = sql_type.lower().strip() diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 0bdb23b03..0f3b29339 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -86,6 +86,13 @@ def _filter_sea_result_set( arraysize=result_set.arraysize, ) + # Preserve metadata columns setup from original result set + if hasattr(result_set, '_metadata_columns') and result_set._metadata_columns: + filtered_result_set._metadata_columns = result_set._metadata_columns + filtered_result_set._column_index_mapping = result_set._column_index_mapping + # Update the description to match the original prepared description + filtered_result_set.description = result_set.description + return filtered_result_set @staticmethod diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 218c20b43..e983df005 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -1,56 +1,89 @@ from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.metadata_transformations import ( + transform_table_type, + transform_is_nullable, + transform_nullable_to_int, + transform_remarks_default, + transform_numeric_default_zero, + transform_ordinal_position_offset, + calculate_data_type, + calculate_buffer_length, + always_null, + always_null_int, + always_null_smallint, + identity +) class MetadataColumnMappings: """Column mappings for metadata queries following JDBC specification.""" # Common columns used across multiple metadata queries - CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalog", "string") - CATALOG_COLUMN_FOR_TABLES = ResultColumn("TABLE_CAT", "catalogName", "string") - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string") - SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string") - TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string") - TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string") - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string") + # FIX 1: Catalog columns - swap the mappings + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", "string", transform_value=identity) + CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn("TABLE_CAT", "catalog", "string", transform_value=identity) + # Remove CATALOG_COLUMN_FOR_TABLES - will use CATALOG_COLUMN instead + + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string", transform_value=identity) + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string", transform_value=identity) + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string", transform_value=identity) + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string", transform_value=transform_table_type) + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string", transform_value=transform_remarks_default) # Columns specific to getColumns() - COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string") + COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string", transform_value=identity) DATA_TYPE_COLUMN = ResultColumn( - "DATA_TYPE", None, "int" - ) # SEA doesn't provide this - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string") - COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", None, "int") - DECIMAL_DIGITS_COLUMN = ResultColumn("DECIMAL_DIGITS", None, "int") - NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", None, "int") - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int") + "DATA_TYPE", None, "int", transform_value=calculate_data_type + ) # Calculated from columnType + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string", transform_value=identity) + + # FIX 5: SEA actually provides these columns + COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", "int", transform_value=identity) + DECIMAL_DIGITS_COLUMN = ResultColumn("DECIMAL_DIGITS", "decimalDigits", "int", transform_value=transform_numeric_default_zero) + NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", "int", transform_value=transform_numeric_default_zero) + ORDINAL_POSITION_COLUMN = ResultColumn("ORDINAL_POSITION", "ordinalPosition", "int", transform_value=transform_ordinal_position_offset) + + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int", transform_value=transform_nullable_to_int) # Calculated from isNullable COLUMN_DEF_COLUMN = ResultColumn( - "COLUMN_DEF", "columnType", "string" + "COLUMN_DEF", "columnType", "string", transform_value=identity ) # Note: duplicate mapping - SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int") - SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int") - CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int") - ORDINAL_POSITION_COLUMN = ResultColumn("ORDINAL_POSITION", None, "int") - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string") + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int", transform_value=always_null_int) + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int", transform_value=always_null_int) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int", transform_value=always_null_int) + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string", transform_value=transform_is_nullable) # Columns for getTables() that don't exist in SEA - TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string") - TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string") - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string") + TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string", transform_value=always_null) + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string", transform_value=always_null) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string", transform_value=always_null) SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", None, "string" + "SELF_REFERENCING_COL_NAME", None, "string", transform_value=always_null ) - REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string") + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string", transform_value=always_null) + + # FIX 8: Scope columns (always null per JDBC) + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, "string", transform_value=always_null) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, "string", transform_value=always_null) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, "string", transform_value=always_null) + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, "smallint", transform_value=always_null_smallint) + + # FIX 9 & 10: Auto increment and generated columns + IS_AUTO_INCREMENT_COLUMN = ResultColumn("IS_AUTOINCREMENT", "isAutoIncrement", "string", transform_value=identity) # No underscore! + IS_GENERATED_COLUMN = ResultColumn("IS_GENERATEDCOLUMN", "isGenerated", "string", transform_value=identity) # SEA provides this + + # FIX 11: Buffer length column + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, "int", transform_value=always_null_int) # Always null per JDBC # Column lists for each metadata operation - CATALOG_COLUMNS = [CATALOG_COLUMN] + CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] # Use specific catalog column SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, - ResultColumn("TABLE_CATALOG", None, "string"), # SEA doesn't return this + ResultColumn("TABLE_CATALOG", None, "string", transform_value=always_null), # Will need special population logic ] TABLE_COLUMNS = [ - CATALOG_COLUMN_FOR_TABLES, + CATALOG_COLUMN, # Use general catalog column (catalogName) SCHEMA_COLUMN, TABLE_NAME_COLUMN, TABLE_TYPE_COLUMN, @@ -62,15 +95,16 @@ class MetadataColumnMappings: REF_GENERATION_COLUMN, ] + # FIX 13: Remove IS_GENERATEDCOLUMN from list (should be 23 columns, not 24) COLUMN_COLUMNS = [ - CATALOG_COLUMN_FOR_TABLES, + CATALOG_COLUMN, # Use general catalog column (catalogName) SCHEMA_COLUMN, TABLE_NAME_COLUMN, COLUMN_NAME_COLUMN, DATA_TYPE_COLUMN, TYPE_NAME_COLUMN, COLUMN_SIZE_COLUMN, - ResultColumn("BUFFER_LENGTH", None, "int"), + BUFFER_LENGTH_COLUMN, DECIMAL_DIGITS_COLUMN, NUM_PREC_RADIX_COLUMN, NULLABLE_COLUMN, @@ -81,10 +115,10 @@ class MetadataColumnMappings: CHAR_OCTET_LENGTH_COLUMN, ORDINAL_POSITION_COLUMN, IS_NULLABLE_COLUMN, - ResultColumn("SCOPE_CATALOG", None, "string"), - ResultColumn("SCOPE_SCHEMA", None, "string"), - ResultColumn("SCOPE_TABLE", None, "string"), - ResultColumn("SOURCE_DATA_TYPE", None, "smallint"), - ResultColumn("IS_AUTO_INCREMENT", None, "string"), - ResultColumn("IS_GENERATEDCOLUMN", None, "string"), + SCOPE_CATALOG_COLUMN, + SCOPE_SCHEMA_COLUMN, + SCOPE_TABLE_COLUMN, + SOURCE_DATA_TYPE_COLUMN, + IS_AUTO_INCREMENT_COLUMN, + # DO NOT INCLUDE IS_GENERATED_COLUMN - Thrift returns 23 columns ] diff --git a/src/databricks/sql/backend/sea/utils/metadata_transformations.py b/src/databricks/sql/backend/sea/utils/metadata_transformations.py new file mode 100644 index 000000000..5911fcbf8 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/metadata_transformations.py @@ -0,0 +1,98 @@ +from typing import Any, Optional + +# Table transformations +def transform_table_type(value: Any) -> str: + """Transform empty/null table type to 'TABLE' per JDBC spec.""" + if value is None or value == "": + return "TABLE" + return str(value) + +# Nullable transformations +def transform_is_nullable(value: Any) -> str: + """Transform boolean nullable to YES/NO per JDBC spec.""" + if value is None or value == "true" or value is True: + return "YES" + return "NO" + +def transform_nullable_to_int(value: Any) -> int: + """Transform isNullable to JDBC integer (1=nullable, 0=not nullable).""" + if value is None or value == "true" or value is True: + return 1 + return 0 + +# Default value transformations +def transform_remarks_default(value: Any) -> str: + """Transform null remarks to empty string.""" + if value is None: + return "" + return str(value) + +def transform_numeric_default_zero(value: Any) -> int: + """Transform null numeric values to 0.""" + if value is None: + return 0 + try: + return int(value) + except (ValueError, TypeError): + return 0 + +# Calculated transformations +def calculate_data_type(value: Any) -> int: + """Calculate JDBC SQL type code from Databricks type name.""" + if value is None: + return 1111 # SQL NULL type + + type_name = str(value).upper().split('(')[0] + type_map = { + 'STRING': 12, 'VARCHAR': 12, + 'INT': 4, 'INTEGER': 4, + 'DOUBLE': 8, 'FLOAT': 6, + 'BOOLEAN': 16, 'DATE': 91, + 'TIMESTAMP': 93, 'TIMESTAMP_NTZ': 93, + 'DECIMAL': 3, 'NUMERIC': 2, + 'BINARY': -2, 'ARRAY': 2003, + 'MAP': 2002, 'STRUCT': 2002, + 'TINYINT': -6, 'SMALLINT': 5, + 'BIGINT': -5, 'LONG': -5 + } + return type_map.get(type_name, 1111) + +def calculate_buffer_length(value: Any) -> Optional[int]: + """Calculate buffer length from type name.""" + if value is None: + return None + + type_name = str(value).upper() + if 'ARRAY' in type_name or 'MAP' in type_name: + return 255 + + # For other types, return None (will be null in result) + return None + +def transform_ordinal_position_offset(value: Any) -> int: + """Adjust ordinal position from 1-based to 0-based or vice versa if needed.""" + if value is None: + return 0 + try: + # SEA returns 1-based, Thrift expects 0-based + return int(value) - 1 + except (ValueError, TypeError): + return 0 + +# Null column transformations +def always_null(value: Any) -> None: + """Always return null for columns that should be null per JDBC spec.""" + return None + +def always_null_int(value: Any) -> None: + """Always return null for integer columns that should be null per JDBC spec.""" + return None + +def always_null_smallint(value: Any) -> None: + """Always return null for smallint columns that should be null per JDBC spec.""" + return None + +# Identity transformations (for columns that need no change) +def identity(value: Any) -> Any: + """Return value unchanged.""" + return value \ No newline at end of file diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index 1a1399e54..6e68537ec 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Callable, Any @dataclass(frozen=True) @@ -11,8 +11,10 @@ class ResultColumn: column_name: JDBC specification column name (e.g., "TABLE_CAT") result_set_column_name: Server result column name from SEA (e.g., "catalog") column_type: SQL type code from databricks.sql.types + transform_value: Optional function to transform values for this column """ column_name: str result_set_column_name: Optional[str] # None if SEA doesn't return this column column_type: str + transform_value: Optional[Callable[[Any], Any]] = None From d97d875146844e74c63dd65574d27f19c9b03e0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sun, 3 Aug 2025 13:57:00 +0000 Subject: [PATCH 03/22] refactor Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 115 +++++++------- .../sql/backend/sea/utils/conversion.py | 2 +- .../sql/backend/sea/utils/filters.py | 2 +- .../backend/sea/utils/metadata_mappings.py | 140 +++++++++++++----- .../sea/utils/metadata_transformations.py | 55 ++++--- 5 files changed, 201 insertions(+), 113 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index e3b0c024c..8c33568e8 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -24,6 +24,15 @@ logger = logging.getLogger(__name__) +# Column-to-column data mapping for metadata queries +# Maps target column -> source column to get data from +COLUMN_DATA_MAPPING = { + "DATA_TYPE": "TYPE_NAME", # DATA_TYPE calculated from TYPE_NAME + "NULLABLE": "IS_NULLABLE", # NULLABLE calculated from IS_NULLABLE + "BUFFER_LENGTH": "TYPE_NAME", # BUFFER_LENGTH calculated from TYPE_NAME +} + + class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" @@ -292,6 +301,40 @@ def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None self._metadata_columns = metadata_columns self._prepare_column_mapping() + def _populate_columns_from_others( + self, result_column: ResultColumn, row_data: Any + ) -> Any: + """ + Helper function to populate column data from other columns based on COLUMN_DATA_MAPPING. + + Args: + result_column: The result column that needs data + row_data: Row data (list for JSON, PyArrow table for Arrow) + + Returns: + The value to use for this column, or None if not found + """ + target_column = result_column.column_name + if target_column not in COLUMN_DATA_MAPPING: + return None + + source_column = COLUMN_DATA_MAPPING[target_column] + + # Find the source column index + for idx, col in enumerate(self._metadata_columns): + if col.column_name == source_column: + source_idx = self._column_index_mapping.get(idx) + if source_idx is not None: + # Handle Arrow table format + if hasattr(row_data, "column"): # PyArrow table + return row_data.column(source_idx).to_pylist() + # Handle JSON row format + else: + return row_data[source_idx] + break + + return None + def _prepare_column_mapping(self) -> None: """ Prepare column index mapping for metadata queries. @@ -359,44 +402,24 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": for new_idx, result_column in enumerate(self._metadata_columns): old_idx = self._column_index_mapping.get(new_idx) - - # Get the source data + + # Get the source data if old_idx is not None: column = table.column(old_idx) values = column.to_pylist() else: values = None - + # Special handling for columns that need data from other columns - if result_column.column_name == "DATA_TYPE" and result_column.result_set_column_name is None: - # Get TYPE_NAME column value for DATA_TYPE calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "TYPE_NAME": - type_idx = self._column_index_mapping.get(idx) - if type_idx is not None: - values = table.column(type_idx).to_pylist() - break - elif result_column.column_name == "NULLABLE" and result_column.result_set_column_name is None: - # Get IS_NULLABLE column value for NULLABLE calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "IS_NULLABLE": - nullable_idx = self._column_index_mapping.get(idx) - if nullable_idx is not None: - values = table.column(nullable_idx).to_pylist() - break - elif result_column.column_name == "BUFFER_LENGTH" and result_column.result_set_column_name is None: - # Get TYPE_NAME column value for BUFFER_LENGTH calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "TYPE_NAME": - type_idx = self._column_index_mapping.get(idx) - if type_idx is not None: - values = table.column(type_idx).to_pylist() - break - + if result_column.result_set_column_name is None: + values = self._populate_columns_from_others(result_column, table) + # Apply transformation and create column if values is not None: if result_column.transform_value: - transformed_values = [result_column.transform_value(v) for v in values] + transformed_values = [ + result_column.transform_value(v) for v in values + ] column = pyarrow.array(transformed_values) else: column = pyarrow.array(values) @@ -409,7 +432,7 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": else: null_array = pyarrow.nulls(table.num_rows) new_columns.append(null_array) - + column_names.append(result_column.column_name) return pyarrow.Table.from_arrays(new_columns, names=column_names) @@ -428,37 +451,15 @@ def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]: value = row[old_idx] else: value = None - + # Special handling for columns that need data from other columns - if result_column.column_name == "DATA_TYPE" and result_column.result_set_column_name is None: - # Get TYPE_NAME column value for DATA_TYPE calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "TYPE_NAME": - type_idx = self._column_index_mapping.get(idx) - if type_idx is not None and type_idx < len(row): - value = row[type_idx] - break - elif result_column.column_name == "NULLABLE" and result_column.result_set_column_name is None: - # Get IS_NULLABLE column value for NULLABLE calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "IS_NULLABLE": - nullable_idx = self._column_index_mapping.get(idx) - if nullable_idx is not None and nullable_idx < len(row): - value = row[nullable_idx] - break - elif result_column.column_name == "BUFFER_LENGTH" and result_column.result_set_column_name is None: - # Get TYPE_NAME column value for BUFFER_LENGTH calculation - for idx, col in enumerate(self._metadata_columns): - if col.column_name == "TYPE_NAME": - type_idx = self._column_index_mapping.get(idx) - if type_idx is not None and type_idx < len(row): - value = row[type_idx] - break - + if result_column.result_set_column_name is None: + value = self._populate_columns_from_others(result_column, row) + # Apply transformation if defined if result_column.transform_value: value = result_column.transform_value(value) - + new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 139d357c7..e139a11a8 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -150,7 +150,7 @@ def convert_value( Returns: The converted value in the appropriate Python type """ - + # Handle None values directly if value is None: return None diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 0f3b29339..6b4eb99f3 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -87,7 +87,7 @@ def _filter_sea_result_set( ) # Preserve metadata columns setup from original result set - if hasattr(result_set, '_metadata_columns') and result_set._metadata_columns: + if hasattr(result_set, "_metadata_columns") and result_set._metadata_columns: filtered_result_set._metadata_columns = result_set._metadata_columns filtered_result_set._column_index_mapping = result_set._column_index_mapping # Update the description to match the original prepared description diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index e983df005..848e769ac 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -11,7 +11,7 @@ always_null, always_null_int, always_null_smallint, - identity + identity, ) @@ -20,66 +20,132 @@ class MetadataColumnMappings: # Common columns used across multiple metadata queries # FIX 1: Catalog columns - swap the mappings - CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", "string", transform_value=identity) - CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn("TABLE_CAT", "catalog", "string", transform_value=identity) + CATALOG_COLUMN = ResultColumn( + "TABLE_CAT", "catalogName", "string", transform_value=identity + ) + CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( + "TABLE_CAT", "catalog", "string", transform_value=identity + ) # Remove CATALOG_COLUMN_FOR_TABLES - will use CATALOG_COLUMN instead - - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string", transform_value=identity) - SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string", transform_value=identity) - TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string", transform_value=identity) - TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string", transform_value=transform_table_type) - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string", transform_value=transform_remarks_default) + + SCHEMA_COLUMN = ResultColumn( + "TABLE_SCHEM", "namespace", "string", transform_value=identity + ) + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( + "TABLE_SCHEM", "databaseName", "string", transform_value=identity + ) + TABLE_NAME_COLUMN = ResultColumn( + "TABLE_NAME", "tableName", "string", transform_value=identity + ) + TABLE_TYPE_COLUMN = ResultColumn( + "TABLE_TYPE", "tableType", "string", transform_value=transform_table_type + ) + REMARKS_COLUMN = ResultColumn( + "REMARKS", "remarks", "string", transform_value=transform_remarks_default + ) # Columns specific to getColumns() - COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string", transform_value=identity) + COLUMN_NAME_COLUMN = ResultColumn( + "COLUMN_NAME", "col_name", "string", transform_value=identity + ) DATA_TYPE_COLUMN = ResultColumn( "DATA_TYPE", None, "int", transform_value=calculate_data_type ) # Calculated from columnType - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string", transform_value=identity) - + TYPE_NAME_COLUMN = ResultColumn( + "TYPE_NAME", "columnType", "string", transform_value=identity + ) + # FIX 5: SEA actually provides these columns - COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", "int", transform_value=identity) - DECIMAL_DIGITS_COLUMN = ResultColumn("DECIMAL_DIGITS", "decimalDigits", "int", transform_value=transform_numeric_default_zero) - NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", "int", transform_value=transform_numeric_default_zero) - ORDINAL_POSITION_COLUMN = ResultColumn("ORDINAL_POSITION", "ordinalPosition", "int", transform_value=transform_ordinal_position_offset) - - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int", transform_value=transform_nullable_to_int) # Calculated from isNullable + COLUMN_SIZE_COLUMN = ResultColumn( + "COLUMN_SIZE", "columnSize", "int", transform_value=identity + ) + DECIMAL_DIGITS_COLUMN = ResultColumn( + "DECIMAL_DIGITS", + "decimalDigits", + "int", + transform_value=transform_numeric_default_zero, + ) + NUM_PREC_RADIX_COLUMN = ResultColumn( + "NUM_PREC_RADIX", "radix", "int", transform_value=transform_numeric_default_zero + ) + ORDINAL_POSITION_COLUMN = ResultColumn( + "ORDINAL_POSITION", + "ordinalPosition", + "int", + transform_value=transform_ordinal_position_offset, + ) + + NULLABLE_COLUMN = ResultColumn( + "NULLABLE", None, "int", transform_value=transform_nullable_to_int + ) # Calculated from isNullable COLUMN_DEF_COLUMN = ResultColumn( "COLUMN_DEF", "columnType", "string", transform_value=identity ) # Note: duplicate mapping - SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int", transform_value=always_null_int) - SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int", transform_value=always_null_int) - CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int", transform_value=always_null_int) - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string", transform_value=transform_is_nullable) + SQL_DATA_TYPE_COLUMN = ResultColumn( + "SQL_DATA_TYPE", None, "int", transform_value=always_null_int + ) + SQL_DATETIME_SUB_COLUMN = ResultColumn( + "SQL_DATETIME_SUB", None, "int", transform_value=always_null_int + ) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn( + "CHAR_OCTET_LENGTH", None, "int", transform_value=always_null_int + ) + IS_NULLABLE_COLUMN = ResultColumn( + "IS_NULLABLE", "isNullable", "string", transform_value=transform_is_nullable + ) # Columns for getTables() that don't exist in SEA - TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string", transform_value=always_null) - TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string", transform_value=always_null) - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string", transform_value=always_null) + TYPE_CAT_COLUMN = ResultColumn( + "TYPE_CAT", None, "string", transform_value=always_null + ) + TYPE_SCHEM_COLUMN = ResultColumn( + "TYPE_SCHEM", None, "string", transform_value=always_null + ) + TYPE_NAME_COLUMN = ResultColumn( + "TYPE_NAME", None, "string", transform_value=always_null + ) SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( "SELF_REFERENCING_COL_NAME", None, "string", transform_value=always_null ) - REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string", transform_value=always_null) - + REF_GENERATION_COLUMN = ResultColumn( + "REF_GENERATION", None, "string", transform_value=always_null + ) + # FIX 8: Scope columns (always null per JDBC) - SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, "string", transform_value=always_null) - SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, "string", transform_value=always_null) - SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, "string", transform_value=always_null) - SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, "smallint", transform_value=always_null_smallint) - + SCOPE_CATALOG_COLUMN = ResultColumn( + "SCOPE_CATALOG", None, "string", transform_value=always_null + ) + SCOPE_SCHEMA_COLUMN = ResultColumn( + "SCOPE_SCHEMA", None, "string", transform_value=always_null + ) + SCOPE_TABLE_COLUMN = ResultColumn( + "SCOPE_TABLE", None, "string", transform_value=always_null + ) + SOURCE_DATA_TYPE_COLUMN = ResultColumn( + "SOURCE_DATA_TYPE", None, "smallint", transform_value=always_null_smallint + ) + # FIX 9 & 10: Auto increment and generated columns - IS_AUTO_INCREMENT_COLUMN = ResultColumn("IS_AUTOINCREMENT", "isAutoIncrement", "string", transform_value=identity) # No underscore! - IS_GENERATED_COLUMN = ResultColumn("IS_GENERATEDCOLUMN", "isGenerated", "string", transform_value=identity) # SEA provides this - + IS_AUTO_INCREMENT_COLUMN = ResultColumn( + "IS_AUTOINCREMENT", "isAutoIncrement", "string", transform_value=identity + ) # No underscore! + IS_GENERATED_COLUMN = ResultColumn( + "IS_GENERATEDCOLUMN", "isGenerated", "string", transform_value=identity + ) # SEA provides this + # FIX 11: Buffer length column - BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, "int", transform_value=always_null_int) # Always null per JDBC + BUFFER_LENGTH_COLUMN = ResultColumn( + "BUFFER_LENGTH", None, "int", transform_value=always_null_int + ) # Always null per JDBC # Column lists for each metadata operation CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] # Use specific catalog column SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, - ResultColumn("TABLE_CATALOG", None, "string", transform_value=always_null), # Will need special population logic + ResultColumn( + "TABLE_CATALOG", None, "string", transform_value=always_null + ), # Will need special population logic ] TABLE_COLUMNS = [ diff --git a/src/databricks/sql/backend/sea/utils/metadata_transformations.py b/src/databricks/sql/backend/sea/utils/metadata_transformations.py index 5911fcbf8..8fd27a6dd 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_transformations.py +++ b/src/databricks/sql/backend/sea/utils/metadata_transformations.py @@ -7,6 +7,7 @@ def transform_table_type(value: Any) -> str: return "TABLE" return str(value) + # Nullable transformations def transform_is_nullable(value: Any) -> str: """Transform boolean nullable to YES/NO per JDBC spec.""" @@ -14,12 +15,14 @@ def transform_is_nullable(value: Any) -> str: return "YES" return "NO" + def transform_nullable_to_int(value: Any) -> int: """Transform isNullable to JDBC integer (1=nullable, 0=not nullable).""" if value is None or value == "true" or value is True: return 1 return 0 + # Default value transformations def transform_remarks_default(value: Any) -> str: """Transform null remarks to empty string.""" @@ -27,6 +30,7 @@ def transform_remarks_default(value: Any) -> str: return "" return str(value) + def transform_numeric_default_zero(value: Any) -> int: """Transform null numeric values to 0.""" if value is None: @@ -36,39 +40,52 @@ def transform_numeric_default_zero(value: Any) -> int: except (ValueError, TypeError): return 0 + # Calculated transformations def calculate_data_type(value: Any) -> int: """Calculate JDBC SQL type code from Databricks type name.""" if value is None: return 1111 # SQL NULL type - - type_name = str(value).upper().split('(')[0] + + type_name = str(value).upper().split("(")[0] type_map = { - 'STRING': 12, 'VARCHAR': 12, - 'INT': 4, 'INTEGER': 4, - 'DOUBLE': 8, 'FLOAT': 6, - 'BOOLEAN': 16, 'DATE': 91, - 'TIMESTAMP': 93, 'TIMESTAMP_NTZ': 93, - 'DECIMAL': 3, 'NUMERIC': 2, - 'BINARY': -2, 'ARRAY': 2003, - 'MAP': 2002, 'STRUCT': 2002, - 'TINYINT': -6, 'SMALLINT': 5, - 'BIGINT': -5, 'LONG': -5 + "STRING": 12, + "VARCHAR": 12, + "INT": 4, + "INTEGER": 4, + "DOUBLE": 8, + "FLOAT": 6, + "BOOLEAN": 16, + "DATE": 91, + "TIMESTAMP": 93, + "TIMESTAMP_NTZ": 93, + "DECIMAL": 3, + "NUMERIC": 2, + "BINARY": -2, + "ARRAY": 2003, + "MAP": 2002, + "STRUCT": 2002, + "TINYINT": -6, + "SMALLINT": 5, + "BIGINT": -5, + "LONG": -5, } return type_map.get(type_name, 1111) + def calculate_buffer_length(value: Any) -> Optional[int]: """Calculate buffer length from type name.""" if value is None: return None - + type_name = str(value).upper() - if 'ARRAY' in type_name or 'MAP' in type_name: + if "ARRAY" in type_name or "MAP" in type_name: return 255 - + # For other types, return None (will be null in result) return None + def transform_ordinal_position_offset(value: Any) -> int: """Adjust ordinal position from 1-based to 0-based or vice versa if needed.""" if value is None: @@ -79,20 +96,24 @@ def transform_ordinal_position_offset(value: Any) -> int: except (ValueError, TypeError): return 0 + # Null column transformations def always_null(value: Any) -> None: """Always return null for columns that should be null per JDBC spec.""" return None + def always_null_int(value: Any) -> None: """Always return null for integer columns that should be null per JDBC spec.""" return None - + + def always_null_smallint(value: Any) -> None: """Always return null for smallint columns that should be null per JDBC spec.""" return None + # Identity transformations (for columns that need no change) def identity(value: Any) -> Any: """Return value unchanged.""" - return value \ No newline at end of file + return value From c5c9859913643bddeeca9cf68b8c67091f14c15c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 01:48:05 +0000 Subject: [PATCH 04/22] remove callback methods for transformations Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 84 +++--------- .../backend/sea/utils/metadata_mappings.py | 124 +++++------------- .../sea/utils/metadata_transformations.py | 119 ----------------- .../sql/backend/sea/utils/result_column.py | 1 - 4 files changed, 49 insertions(+), 279 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/utils/metadata_transformations.py diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 8c33568e8..58f853566 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, TYPE_CHECKING, Dict +from typing import Any, List, Optional, TYPE_CHECKING, Dict, Union import logging @@ -93,8 +93,8 @@ def __init__( ) # Initialize metadata columns for post-fetch transformation - self._metadata_columns = None - self._column_index_mapping = None + self._metadata_columns: Optional[List[ResultColumn]] = None + self._column_index_mapping: Optional[Dict[int, Union[int, None]]] = None def _convert_json_types(self, row: List[str]) -> List[Any]: """ @@ -301,40 +301,6 @@ def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None self._metadata_columns = metadata_columns self._prepare_column_mapping() - def _populate_columns_from_others( - self, result_column: ResultColumn, row_data: Any - ) -> Any: - """ - Helper function to populate column data from other columns based on COLUMN_DATA_MAPPING. - - Args: - result_column: The result column that needs data - row_data: Row data (list for JSON, PyArrow table for Arrow) - - Returns: - The value to use for this column, or None if not found - """ - target_column = result_column.column_name - if target_column not in COLUMN_DATA_MAPPING: - return None - - source_column = COLUMN_DATA_MAPPING[target_column] - - # Find the source column index - for idx, col in enumerate(self._metadata_columns): - if col.column_name == source_column: - source_idx = self._column_index_mapping.get(idx) - if source_idx is not None: - # Handle Arrow table format - if hasattr(row_data, "column"): # PyArrow table - return row_data.column(source_idx).to_pylist() - # Handle JSON row format - else: - return row_data[source_idx] - break - - return None - def _prepare_column_mapping(self) -> None: """ Prepare column index mapping for metadata queries. @@ -353,7 +319,7 @@ def _prepare_column_mapping(self) -> None: new_description = [] self._column_index_mapping = {} # Maps new index -> old index - for new_idx, result_column in enumerate(self._metadata_columns): + for new_idx, result_column in enumerate(self._metadata_columns or []): # Find the corresponding SEA column if ( result_column.result_set_column_name @@ -400,8 +366,12 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": new_columns = [] column_names = [] - for new_idx, result_column in enumerate(self._metadata_columns): - old_idx = self._column_index_mapping.get(new_idx) + for new_idx, result_column in enumerate(self._metadata_columns or []): + old_idx = ( + self._column_index_mapping.get(new_idx) + if self._column_index_mapping + else None + ) # Get the source data if old_idx is not None: @@ -410,27 +380,13 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": else: values = None - # Special handling for columns that need data from other columns - if result_column.result_set_column_name is None: - values = self._populate_columns_from_others(result_column, table) - # Apply transformation and create column if values is not None: - if result_column.transform_value: - transformed_values = [ - result_column.transform_value(v) for v in values - ] - column = pyarrow.array(transformed_values) - else: - column = pyarrow.array(values) + column = pyarrow.array(values) new_columns.append(column) else: # Create column with default/transformed values - if result_column.transform_value: - default_value = result_column.transform_value(None) - null_array = pyarrow.array([default_value] * table.num_rows) - else: - null_array = pyarrow.nulls(table.num_rows) + null_array = pyarrow.nulls(table.num_rows) new_columns.append(null_array) column_names.append(result_column.column_name) @@ -445,21 +401,17 @@ def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]: transformed_rows = [] for row in rows: new_row = [] - for new_idx, result_column in enumerate(self._metadata_columns): - old_idx = self._column_index_mapping.get(new_idx) + for new_idx, result_column in enumerate(self._metadata_columns or []): + old_idx = ( + self._column_index_mapping.get(new_idx) + if self._column_index_mapping + else None + ) if old_idx is not None: value = row[old_idx] else: value = None - # Special handling for columns that need data from other columns - if result_column.result_set_column_name is None: - value = self._populate_columns_from_others(result_column, row) - - # Apply transformation if defined - if result_column.transform_value: - value = result_column.transform_value(value) - new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 848e769ac..56b614c5c 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -1,18 +1,4 @@ from databricks.sql.backend.sea.utils.result_column import ResultColumn -from databricks.sql.backend.sea.utils.metadata_transformations import ( - transform_table_type, - transform_is_nullable, - transform_nullable_to_int, - transform_remarks_default, - transform_numeric_default_zero, - transform_ordinal_position_offset, - calculate_data_type, - calculate_buffer_length, - always_null, - always_null_int, - always_null_smallint, - identity, -) class MetadataColumnMappings: @@ -20,122 +6,74 @@ class MetadataColumnMappings: # Common columns used across multiple metadata queries # FIX 1: Catalog columns - swap the mappings - CATALOG_COLUMN = ResultColumn( - "TABLE_CAT", "catalogName", "string", transform_value=identity - ) - CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( - "TABLE_CAT", "catalog", "string", transform_value=identity - ) + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", "string") + CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn("TABLE_CAT", "catalog", "string") # Remove CATALOG_COLUMN_FOR_TABLES - will use CATALOG_COLUMN instead - SCHEMA_COLUMN = ResultColumn( - "TABLE_SCHEM", "namespace", "string", transform_value=identity - ) - SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( - "TABLE_SCHEM", "databaseName", "string", transform_value=identity - ) - TABLE_NAME_COLUMN = ResultColumn( - "TABLE_NAME", "tableName", "string", transform_value=identity - ) - TABLE_TYPE_COLUMN = ResultColumn( - "TABLE_TYPE", "tableType", "string", transform_value=transform_table_type - ) - REMARKS_COLUMN = ResultColumn( - "REMARKS", "remarks", "string", transform_value=transform_remarks_default - ) + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string") + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string") + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string") + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string") + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string") # Columns specific to getColumns() - COLUMN_NAME_COLUMN = ResultColumn( - "COLUMN_NAME", "col_name", "string", transform_value=identity - ) + COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string") DATA_TYPE_COLUMN = ResultColumn( - "DATA_TYPE", None, "int", transform_value=calculate_data_type + "DATA_TYPE", None, "int" ) # Calculated from columnType - TYPE_NAME_COLUMN = ResultColumn( - "TYPE_NAME", "columnType", "string", transform_value=identity - ) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string") # FIX 5: SEA actually provides these columns - COLUMN_SIZE_COLUMN = ResultColumn( - "COLUMN_SIZE", "columnSize", "int", transform_value=identity - ) + COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", "int") DECIMAL_DIGITS_COLUMN = ResultColumn( "DECIMAL_DIGITS", "decimalDigits", "int", - transform_value=transform_numeric_default_zero, - ) - NUM_PREC_RADIX_COLUMN = ResultColumn( - "NUM_PREC_RADIX", "radix", "int", transform_value=transform_numeric_default_zero ) + NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", "int") ORDINAL_POSITION_COLUMN = ResultColumn( "ORDINAL_POSITION", "ordinalPosition", "int", - transform_value=transform_ordinal_position_offset, ) NULLABLE_COLUMN = ResultColumn( - "NULLABLE", None, "int", transform_value=transform_nullable_to_int + "NULLABLE", None, "int" ) # Calculated from isNullable COLUMN_DEF_COLUMN = ResultColumn( - "COLUMN_DEF", "columnType", "string", transform_value=identity + "COLUMN_DEF", "columnType", "string" ) # Note: duplicate mapping - SQL_DATA_TYPE_COLUMN = ResultColumn( - "SQL_DATA_TYPE", None, "int", transform_value=always_null_int - ) - SQL_DATETIME_SUB_COLUMN = ResultColumn( - "SQL_DATETIME_SUB", None, "int", transform_value=always_null_int - ) - CHAR_OCTET_LENGTH_COLUMN = ResultColumn( - "CHAR_OCTET_LENGTH", None, "int", transform_value=always_null_int - ) - IS_NULLABLE_COLUMN = ResultColumn( - "IS_NULLABLE", "isNullable", "string", transform_value=transform_is_nullable - ) + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int") + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int") + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int") + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string") # Columns for getTables() that don't exist in SEA - TYPE_CAT_COLUMN = ResultColumn( - "TYPE_CAT", None, "string", transform_value=always_null - ) - TYPE_SCHEM_COLUMN = ResultColumn( - "TYPE_SCHEM", None, "string", transform_value=always_null - ) - TYPE_NAME_COLUMN = ResultColumn( - "TYPE_NAME", None, "string", transform_value=always_null - ) + TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string") + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string") + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string") SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", None, "string", transform_value=always_null - ) - REF_GENERATION_COLUMN = ResultColumn( - "REF_GENERATION", None, "string", transform_value=always_null + "SELF_REFERENCING_COL_NAME", None, "string" ) + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string") # FIX 8: Scope columns (always null per JDBC) - SCOPE_CATALOG_COLUMN = ResultColumn( - "SCOPE_CATALOG", None, "string", transform_value=always_null - ) - SCOPE_SCHEMA_COLUMN = ResultColumn( - "SCOPE_SCHEMA", None, "string", transform_value=always_null - ) - SCOPE_TABLE_COLUMN = ResultColumn( - "SCOPE_TABLE", None, "string", transform_value=always_null - ) - SOURCE_DATA_TYPE_COLUMN = ResultColumn( - "SOURCE_DATA_TYPE", None, "smallint", transform_value=always_null_smallint - ) + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, "string") + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, "string") + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, "string") + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, "smallint") # FIX 9 & 10: Auto increment and generated columns IS_AUTO_INCREMENT_COLUMN = ResultColumn( - "IS_AUTOINCREMENT", "isAutoIncrement", "string", transform_value=identity + "IS_AUTOINCREMENT", "isAutoIncrement", "string" ) # No underscore! IS_GENERATED_COLUMN = ResultColumn( - "IS_GENERATEDCOLUMN", "isGenerated", "string", transform_value=identity + "IS_GENERATEDCOLUMN", "isGenerated", "string" ) # SEA provides this # FIX 11: Buffer length column BUFFER_LENGTH_COLUMN = ResultColumn( - "BUFFER_LENGTH", None, "int", transform_value=always_null_int + "BUFFER_LENGTH", None, "int" ) # Always null per JDBC # Column lists for each metadata operation @@ -144,7 +82,7 @@ class MetadataColumnMappings: SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, ResultColumn( - "TABLE_CATALOG", None, "string", transform_value=always_null + "TABLE_CATALOG", None, "string" ), # Will need special population logic ] diff --git a/src/databricks/sql/backend/sea/utils/metadata_transformations.py b/src/databricks/sql/backend/sea/utils/metadata_transformations.py deleted file mode 100644 index 8fd27a6dd..000000000 --- a/src/databricks/sql/backend/sea/utils/metadata_transformations.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import Any, Optional - -# Table transformations -def transform_table_type(value: Any) -> str: - """Transform empty/null table type to 'TABLE' per JDBC spec.""" - if value is None or value == "": - return "TABLE" - return str(value) - - -# Nullable transformations -def transform_is_nullable(value: Any) -> str: - """Transform boolean nullable to YES/NO per JDBC spec.""" - if value is None or value == "true" or value is True: - return "YES" - return "NO" - - -def transform_nullable_to_int(value: Any) -> int: - """Transform isNullable to JDBC integer (1=nullable, 0=not nullable).""" - if value is None or value == "true" or value is True: - return 1 - return 0 - - -# Default value transformations -def transform_remarks_default(value: Any) -> str: - """Transform null remarks to empty string.""" - if value is None: - return "" - return str(value) - - -def transform_numeric_default_zero(value: Any) -> int: - """Transform null numeric values to 0.""" - if value is None: - return 0 - try: - return int(value) - except (ValueError, TypeError): - return 0 - - -# Calculated transformations -def calculate_data_type(value: Any) -> int: - """Calculate JDBC SQL type code from Databricks type name.""" - if value is None: - return 1111 # SQL NULL type - - type_name = str(value).upper().split("(")[0] - type_map = { - "STRING": 12, - "VARCHAR": 12, - "INT": 4, - "INTEGER": 4, - "DOUBLE": 8, - "FLOAT": 6, - "BOOLEAN": 16, - "DATE": 91, - "TIMESTAMP": 93, - "TIMESTAMP_NTZ": 93, - "DECIMAL": 3, - "NUMERIC": 2, - "BINARY": -2, - "ARRAY": 2003, - "MAP": 2002, - "STRUCT": 2002, - "TINYINT": -6, - "SMALLINT": 5, - "BIGINT": -5, - "LONG": -5, - } - return type_map.get(type_name, 1111) - - -def calculate_buffer_length(value: Any) -> Optional[int]: - """Calculate buffer length from type name.""" - if value is None: - return None - - type_name = str(value).upper() - if "ARRAY" in type_name or "MAP" in type_name: - return 255 - - # For other types, return None (will be null in result) - return None - - -def transform_ordinal_position_offset(value: Any) -> int: - """Adjust ordinal position from 1-based to 0-based or vice versa if needed.""" - if value is None: - return 0 - try: - # SEA returns 1-based, Thrift expects 0-based - return int(value) - 1 - except (ValueError, TypeError): - return 0 - - -# Null column transformations -def always_null(value: Any) -> None: - """Always return null for columns that should be null per JDBC spec.""" - return None - - -def always_null_int(value: Any) -> None: - """Always return null for integer columns that should be null per JDBC spec.""" - return None - - -def always_null_smallint(value: Any) -> None: - """Always return null for smallint columns that should be null per JDBC spec.""" - return None - - -# Identity transformations (for columns that need no change) -def identity(value: Any) -> Any: - """Return value unchanged.""" - return value diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index 6e68537ec..52406539d 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -17,4 +17,3 @@ class ResultColumn: column_name: str result_set_column_name: Optional[str] # None if SEA doesn't return this column column_type: str - transform_value: Optional[Callable[[Any], Any]] = None From 2b1442f5dcc6c271812064f6411ee87fef02268d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 01:50:14 +0000 Subject: [PATCH 05/22] remove redundant COLUMN_DATA_MAPPING Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 58f853566..3980a360d 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -24,15 +24,6 @@ logger = logging.getLogger(__name__) -# Column-to-column data mapping for metadata queries -# Maps target column -> source column to get data from -COLUMN_DATA_MAPPING = { - "DATA_TYPE": "TYPE_NAME", # DATA_TYPE calculated from TYPE_NAME - "NULLABLE": "IS_NULLABLE", # NULLABLE calculated from IS_NULLABLE - "BUFFER_LENGTH": "TYPE_NAME", # BUFFER_LENGTH calculated from TYPE_NAME -} - - class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" From 1d515e346882e057a01287461006d29c00562a49 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 01:54:15 +0000 Subject: [PATCH 06/22] rename transformation functions to normalise for metadata cols Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 3980a360d..a7e346e4f 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -165,7 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") results = self.results.next_n_rows(size) - results = self._transform_json_rows(results) + results = self._normalise_json_metadata_cols(results) self._next_row_index += len(results) return results @@ -179,7 +179,7 @@ def fetchall_json(self) -> List[List[str]]: """ results = self.results.remaining_rows() - results = self._transform_json_rows(results) + results = self._normalise_json_metadata_cols(results) self._next_row_index += len(results) return results @@ -205,11 +205,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) if isinstance(self.results, JsonQueue): # Transform JSON first, then convert to Arrow - transformed_json = self._transform_json_rows(results) + transformed_json = self._normalise_json_metadata_cols(results) results = self._convert_json_to_arrow_table(transformed_json) else: # Transform Arrow table directly - results = self._transform_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows @@ -223,11 +223,11 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() if isinstance(self.results, JsonQueue): # Transform JSON first, then convert to Arrow - transformed_json = self._transform_json_rows(results) + transformed_json = self._normalise_json_metadata_cols(results) results = self._convert_json_to_arrow_table(transformed_json) else: # Transform Arrow table directly - results = self._transform_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows @@ -348,7 +348,7 @@ def _prepare_column_mapping(self) -> None: self.description = new_description - def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": + def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Table": """Transform arrow table columns for metadata normalization.""" if not self._metadata_columns: return table @@ -384,7 +384,7 @@ def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table": return pyarrow.Table.from_arrays(new_columns, names=column_names) - def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]: + def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]]: """Transform JSON rows for metadata normalization.""" if not self._metadata_columns: return rows From 2be0c86c95b78942aec7f11e1d18c49d91ad85de Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 01:56:58 +0000 Subject: [PATCH 07/22] make mock result set be of type SeaResultSet Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 396ad906f..a4923eaee 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -13,6 +13,7 @@ _filter_session_configuration, ) from databricks.sql.backend.sea.models.base import ServiceError, StatementStatus +from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter from databricks.sql.thrift_api.TCLIService import ttypes @@ -703,7 +704,7 @@ def test_results_message_to_execute_response_is_staging_operation(self, sea_clie def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): """Test the get_catalogs method.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -735,7 +736,7 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): """Test the get_schemas method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -884,7 +885,7 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: From 92c2da47f3aa7e671b52a2a6fe369e63729654f1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 02:00:08 +0000 Subject: [PATCH 08/22] remove redundant comments Signed-off-by: varun-edachali-dbx --- .../backend/sea/utils/metadata_mappings.py | 33 ++++--------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 56b614c5c..8e0b5e5ce 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -5,10 +5,8 @@ class MetadataColumnMappings: """Column mappings for metadata queries following JDBC specification.""" # Common columns used across multiple metadata queries - # FIX 1: Catalog columns - swap the mappings CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", "string") CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn("TABLE_CAT", "catalog", "string") - # Remove CATALOG_COLUMN_FOR_TABLES - will use CATALOG_COLUMN instead SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string") SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string") @@ -18,12 +16,9 @@ class MetadataColumnMappings: # Columns specific to getColumns() COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string") - DATA_TYPE_COLUMN = ResultColumn( - "DATA_TYPE", None, "int" - ) # Calculated from columnType + DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, "int") TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string") - # FIX 5: SEA actually provides these columns COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", "int") DECIMAL_DIGITS_COLUMN = ResultColumn( "DECIMAL_DIGITS", @@ -37,12 +32,8 @@ class MetadataColumnMappings: "int", ) - NULLABLE_COLUMN = ResultColumn( - "NULLABLE", None, "int" - ) # Calculated from isNullable - COLUMN_DEF_COLUMN = ResultColumn( - "COLUMN_DEF", "columnType", "string" - ) # Note: duplicate mapping + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int") + COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", "columnType", "string") SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int") SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int") CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int") @@ -57,33 +48,24 @@ class MetadataColumnMappings: ) REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string") - # FIX 8: Scope columns (always null per JDBC) SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, "string") SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, "string") SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, "string") SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, "smallint") - # FIX 9 & 10: Auto increment and generated columns IS_AUTO_INCREMENT_COLUMN = ResultColumn( "IS_AUTOINCREMENT", "isAutoIncrement", "string" - ) # No underscore! - IS_GENERATED_COLUMN = ResultColumn( - "IS_GENERATEDCOLUMN", "isGenerated", "string" - ) # SEA provides this + ) + IS_GENERATED_COLUMN = ResultColumn("IS_GENERATEDCOLUMN", "isGenerated", "string") - # FIX 11: Buffer length column - BUFFER_LENGTH_COLUMN = ResultColumn( - "BUFFER_LENGTH", None, "int" - ) # Always null per JDBC + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, "int") # Column lists for each metadata operation CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] # Use specific catalog column SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, - ResultColumn( - "TABLE_CATALOG", None, "string" - ), # Will need special population logic + ResultColumn("TABLE_CATALOG", None, "string"), ] TABLE_COLUMNS = [ @@ -99,7 +81,6 @@ class MetadataColumnMappings: REF_GENERATION_COLUMN, ] - # FIX 13: Remove IS_GENERATEDCOLUMN from list (should be 23 columns, not 24) COLUMN_COLUMNS = [ CATALOG_COLUMN, # Use general catalog column (catalogName) SCHEMA_COLUMN, From 99481e95078033e004ebce6b9295f307d6be61fe Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 02:21:33 +0000 Subject: [PATCH 09/22] use SqlType for type conv Signed-off-by: varun-edachali-dbx --- .../backend/sea/utils/metadata_mappings.py | 73 ++++++++++--------- .../sql/backend/sea/utils/result_column.py | 2 +- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 8e0b5e5ce..7857a0c18 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -1,71 +1,78 @@ from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType class MetadataColumnMappings: """Column mappings for metadata queries following JDBC specification.""" # Common columns used across multiple metadata queries - CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", "string") - CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn("TABLE_CAT", "catalog", "string") + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.VARCHAR) + CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( + "TABLE_CAT", "catalog", SqlType.VARCHAR + ) - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string") - SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string") - TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string") - TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string") - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string") + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.VARCHAR) + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( + "TABLE_SCHEM", "databaseName", SqlType.VARCHAR + ) + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.VARCHAR) + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.VARCHAR) + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.VARCHAR) # Columns specific to getColumns() - COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string") - DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, "int") - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string") + COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) + DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.VARCHAR) - COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", "int") + COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) DECIMAL_DIGITS_COLUMN = ResultColumn( "DECIMAL_DIGITS", "decimalDigits", - "int", + SqlType.INT, ) - NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", "int") + NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", SqlType.INT) ORDINAL_POSITION_COLUMN = ResultColumn( "ORDINAL_POSITION", "ordinalPosition", - "int", + SqlType.INT, ) - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int") - COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", "columnType", "string") - SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int") - SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int") - CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int") - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string") + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", "columnType", SqlType.VARCHAR) + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.VARCHAR) # Columns for getTables() that don't exist in SEA - TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string") - TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string") - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string") + TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.VARCHAR) + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.VARCHAR) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.VARCHAR) SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", None, "string" + "SELF_REFERENCING_COL_NAME", None, SqlType.VARCHAR ) - REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string") + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.VARCHAR) - SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, "string") - SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, "string") - SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, "string") - SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, "smallint") + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.VARCHAR) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.VARCHAR) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.VARCHAR) + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.INT) IS_AUTO_INCREMENT_COLUMN = ResultColumn( - "IS_AUTOINCREMENT", "isAutoIncrement", "string" + "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.VARCHAR + ) + IS_GENERATED_COLUMN = ResultColumn( + "IS_GENERATEDCOLUMN", "isGenerated", SqlType.VARCHAR ) - IS_GENERATED_COLUMN = ResultColumn("IS_GENERATEDCOLUMN", "isGenerated", "string") - BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, "int") + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) # Column lists for each metadata operation CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] # Use specific catalog column SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, - ResultColumn("TABLE_CATALOG", None, "string"), + ResultColumn("TABLE_CATALOG", "catalogName", SqlType.VARCHAR), ] TABLE_COLUMNS = [ diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index 52406539d..30fb1e3e9 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Callable, Any +from typing import Optional @dataclass(frozen=True) From f90a75ec568c3f929054e5f5d6a0297b38867435 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 02:57:33 +0000 Subject: [PATCH 10/22] verified: get catalogs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/metadata_mappings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 7857a0c18..02a316a54 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -68,7 +68,7 @@ class MetadataColumnMappings: BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) # Column lists for each metadata operation - CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] # Use specific catalog column + CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, From 946e51327e0b547f20fdcd40e97fc16114bf372a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:02:40 +0000 Subject: [PATCH 11/22] verif: get schemas Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/metadata_mappings.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 02a316a54..6b36d1adf 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -5,16 +5,19 @@ class MetadataColumnMappings: """Column mappings for metadata queries following JDBC specification.""" - # Common columns used across multiple metadata queries - CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.VARCHAR) CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( "TABLE_CAT", "catalog", SqlType.VARCHAR ) - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.VARCHAR) + CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", "catalogName", SqlType.VARCHAR) SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( "TABLE_SCHEM", "databaseName", SqlType.VARCHAR ) + + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.VARCHAR) + + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.VARCHAR) + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.VARCHAR) TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.VARCHAR) REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.VARCHAR) @@ -72,7 +75,7 @@ class MetadataColumnMappings: SCHEMA_COLUMNS = [ SCHEMA_COLUMN_FOR_GET_SCHEMA, - ResultColumn("TABLE_CATALOG", "catalogName", SqlType.VARCHAR), + CATALOG_FULL_COLUMN, ] TABLE_COLUMNS = [ From 070b93161a36c3c91ba81867c8baac457560b719 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:10:20 +0000 Subject: [PATCH 12/22] verif: TABLE_COLUMNS from jdbc Signed-off-by: varun-edachali-dbx --- .../backend/sea/utils/metadata_mappings.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 6b36d1adf..61eda61e4 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -15,12 +15,23 @@ class MetadataColumnMappings: ) CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.VARCHAR) - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.VARCHAR) - TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.VARCHAR) TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.VARCHAR) REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.VARCHAR) + TYPE_CATALOG_COLUMN = ResultColumn( + "TYPE_CAT", "TYPE_CATALOG_COLUMN", SqlType.VARCHAR + ) + TYPE_SCHEM_COLUMN = ResultColumn( + "TYPE_SCHEM", "TYPE_SCHEMA_COLUMN", SqlType.VARCHAR + ) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "TYPE_NAME", SqlType.VARCHAR) + SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( + "SELF_REFERENCING_COL_NAME", "SELF_REFERENCING_COLUMN_NAME", SqlType.VARCHAR + ) + REF_GENERATION_COLUMN = ResultColumn( + "REF_GENERATION", "REF_GENERATION_COLUMN", SqlType.VARCHAR + ) # Columns specific to getColumns() COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) @@ -48,13 +59,6 @@ class MetadataColumnMappings: IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.VARCHAR) # Columns for getTables() that don't exist in SEA - TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.VARCHAR) - TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.VARCHAR) - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.VARCHAR) - SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", None, SqlType.VARCHAR - ) - REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.VARCHAR) SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.VARCHAR) SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.VARCHAR) @@ -79,12 +83,12 @@ class MetadataColumnMappings: ] TABLE_COLUMNS = [ - CATALOG_COLUMN, # Use general catalog column (catalogName) + CATALOG_COLUMN, SCHEMA_COLUMN, TABLE_NAME_COLUMN, TABLE_TYPE_COLUMN, REMARKS_COLUMN, - TYPE_CAT_COLUMN, + TYPE_CATALOG_COLUMN, TYPE_SCHEM_COLUMN, TYPE_NAME_COLUMN, SELF_REFERENCING_COL_NAME_COLUMN, From 939c5420f42cc648a9e39d3997e9bc8ad6b1fae0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:16:11 +0000 Subject: [PATCH 13/22] verif: COLUMN_COLUMNS from JDBC Signed-off-by: varun-edachali-dbx --- .../backend/sea/utils/metadata_mappings.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 61eda61e4..e4591c84d 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -33,12 +33,12 @@ class MetadataColumnMappings: "REF_GENERATION", "REF_GENERATION_COLUMN", SqlType.VARCHAR ) - # Columns specific to getColumns() - COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) - DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.VARCHAR) - + COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) + DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", "dataType", SqlType.INT) + COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.VARCHAR) COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) + DECIMAL_DIGITS_COLUMN = ResultColumn( "DECIMAL_DIGITS", "decimalDigits", @@ -51,19 +51,25 @@ class MetadataColumnMappings: SqlType.INT, ) - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + NULLABLE_COLUMN = ResultColumn("NULLABLE", "Nullable", SqlType.INT) COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", "columnType", SqlType.VARCHAR) - SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) - SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) - CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", "SQLDataType", SqlType.INT) + SQL_DATETIME_SUB_COLUMN = ResultColumn( + "SQL_DATETIME_SUB", "SQLDateTimeSub", SqlType.INT + ) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn( + "CHAR_OCTET_LENGTH", "CharOctetLength", SqlType.INT + ) IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.VARCHAR) - # Columns for getTables() that don't exist in SEA - - SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.VARCHAR) - SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.VARCHAR) - SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.VARCHAR) - SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.INT) + SCOPE_CATALOG_COLUMN = ResultColumn( + "SCOPE_CATALOG", "ScopeCatalog", SqlType.VARCHAR + ) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", "ScopeSchema", SqlType.VARCHAR) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", "ScopeTable", SqlType.VARCHAR) + SOURCE_DATA_TYPE_COLUMN = ResultColumn( + "SOURCE_DATA_TYPE", "SourceDataType", SqlType.INT + ) IS_AUTO_INCREMENT_COLUMN = ResultColumn( "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.VARCHAR @@ -72,9 +78,6 @@ class MetadataColumnMappings: "IS_GENERATEDCOLUMN", "isGenerated", SqlType.VARCHAR ) - BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) - - # Column lists for each metadata operation CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] SCHEMA_COLUMNS = [ @@ -96,12 +99,12 @@ class MetadataColumnMappings: ] COLUMN_COLUMNS = [ - CATALOG_COLUMN, # Use general catalog column (catalogName) + CATALOG_COLUMN, SCHEMA_COLUMN, TABLE_NAME_COLUMN, - COLUMN_NAME_COLUMN, + COL_NAME_COLUMN, DATA_TYPE_COLUMN, - TYPE_NAME_COLUMN, + COLUMN_TYPE_COLUMN, COLUMN_SIZE_COLUMN, BUFFER_LENGTH_COLUMN, DECIMAL_DIGITS_COLUMN, From 7d3174f7a6643ea2d5e66e59ccd54ad599a5d275 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:23:29 +0000 Subject: [PATCH 14/22] make stuff missing from SEA None in SEA mapping Signed-off-by: varun-edachali-dbx --- .../backend/sea/utils/metadata_mappings.py | 46 +++++++------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index e4591c84d..dd38923ca 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -9,7 +9,7 @@ class MetadataColumnMappings: "TABLE_CAT", "catalog", SqlType.VARCHAR ) - CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", "catalogName", SqlType.VARCHAR) + CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", None, SqlType.VARCHAR) SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( "TABLE_SCHEM", "databaseName", SqlType.VARCHAR ) @@ -19,22 +19,16 @@ class MetadataColumnMappings: TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.VARCHAR) TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.VARCHAR) REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.VARCHAR) - TYPE_CATALOG_COLUMN = ResultColumn( - "TYPE_CAT", "TYPE_CATALOG_COLUMN", SqlType.VARCHAR - ) - TYPE_SCHEM_COLUMN = ResultColumn( - "TYPE_SCHEM", "TYPE_SCHEMA_COLUMN", SqlType.VARCHAR - ) - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "TYPE_NAME", SqlType.VARCHAR) + TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.VARCHAR) + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.VARCHAR) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.VARCHAR) SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", "SELF_REFERENCING_COLUMN_NAME", SqlType.VARCHAR - ) - REF_GENERATION_COLUMN = ResultColumn( - "REF_GENERATION", "REF_GENERATION_COLUMN", SqlType.VARCHAR + "SELF_REFERENCING_COL_NAME", None, SqlType.VARCHAR ) + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.VARCHAR) COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) - DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", "dataType", SqlType.INT) + DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.VARCHAR) COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) @@ -51,25 +45,17 @@ class MetadataColumnMappings: SqlType.INT, ) - NULLABLE_COLUMN = ResultColumn("NULLABLE", "Nullable", SqlType.INT) - COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", "columnType", SqlType.VARCHAR) - SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", "SQLDataType", SqlType.INT) - SQL_DATETIME_SUB_COLUMN = ResultColumn( - "SQL_DATETIME_SUB", "SQLDateTimeSub", SqlType.INT - ) - CHAR_OCTET_LENGTH_COLUMN = ResultColumn( - "CHAR_OCTET_LENGTH", "CharOctetLength", SqlType.INT - ) + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.VARCHAR) + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.VARCHAR) - SCOPE_CATALOG_COLUMN = ResultColumn( - "SCOPE_CATALOG", "ScopeCatalog", SqlType.VARCHAR - ) - SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", "ScopeSchema", SqlType.VARCHAR) - SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", "ScopeTable", SqlType.VARCHAR) - SOURCE_DATA_TYPE_COLUMN = ResultColumn( - "SOURCE_DATA_TYPE", "SourceDataType", SqlType.INT - ) + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.VARCHAR) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.VARCHAR) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.VARCHAR) + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.INT) IS_AUTO_INCREMENT_COLUMN = ResultColumn( "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.VARCHAR From 82e9c4f5ad80b75769f4816ce9b613e8a63383ff Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:31:23 +0000 Subject: [PATCH 15/22] remove hardcoding in SqlType Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 16 +++--- .../sql/backend/sea/utils/conversion.py | 56 +++++++++++-------- .../sql/backend/sea/utils/result_column.py | 13 ++--- 3 files changed, 46 insertions(+), 39 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index a7e346e4f..bcc985b7d 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -313,17 +313,17 @@ def _prepare_column_mapping(self) -> None: for new_idx, result_column in enumerate(self._metadata_columns or []): # Find the corresponding SEA column if ( - result_column.result_set_column_name - and result_column.result_set_column_name in sea_column_indices + result_column.sea_col_name + and result_column.sea_col_name in sea_column_indices ): - old_idx = sea_column_indices[result_column.result_set_column_name] + old_idx = sea_column_indices[result_column.sea_col_name] self._column_index_mapping[new_idx] = old_idx # Use the original column metadata but with JDBC name old_col = self.description[old_idx] new_description.append( ( - result_column.column_name, # JDBC name - result_column.column_type, # Expected type + result_column.thrift_col_name, # JDBC name + result_column.thrift_col_type, # Expected type old_col[2], # display_size old_col[3], # internal_size old_col[4], # precision @@ -335,8 +335,8 @@ def _prepare_column_mapping(self) -> None: # Column doesn't exist in SEA - add with None values new_description.append( ( - result_column.column_name, - result_column.column_type, + result_column.thrift_col_name, + result_column.thrift_col_type, None, None, None, @@ -380,7 +380,7 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab null_array = pyarrow.nulls(table.num_rows) new_columns.append(null_array) - column_names.append(result_column.column_name) + column_names.append(result_column.thrift_col_name) return pyarrow.Table.from_arrays(new_columns, names=column_names) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index e139a11a8..d44193d46 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -11,6 +11,8 @@ from dateutil import parser from typing import Callable, Dict, Optional +from databricks.sql.thrift_api.TCLIService import ttypes + logger = logging.getLogger(__name__) @@ -56,43 +58,49 @@ class SqlType: after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix). """ + @staticmethod + def _get_type_name(thrift_type_id: int) -> str: + type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id] + type_name = type_name.lower() + if type_name.endswith("_type"): + type_name = type_name[:-5] + return type_name + # Numeric types - TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE - SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE - INT = "int" # Maps to TTypeId.INT_TYPE - BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE - FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE - DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE - DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE + TINYINT = _get_type_name(ttypes.TTypeId.TINYINT_TYPE) + SMALLINT = _get_type_name(ttypes.TTypeId.SMALLINT_TYPE) + INT = _get_type_name(ttypes.TTypeId.INT_TYPE) + BIGINT = _get_type_name(ttypes.TTypeId.BIGINT_TYPE) + FLOAT = _get_type_name(ttypes.TTypeId.FLOAT_TYPE) + DOUBLE = _get_type_name(ttypes.TTypeId.DOUBLE_TYPE) + DECIMAL = _get_type_name(ttypes.TTypeId.DECIMAL_TYPE) # Boolean type - BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE + BOOLEAN = _get_type_name(ttypes.TTypeId.BOOLEAN_TYPE) # Date/Time types - DATE = "date" # Maps to TTypeId.DATE_TYPE - TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE - INTERVAL_YEAR_MONTH = ( - "interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE - ) - INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE + DATE = _get_type_name(ttypes.TTypeId.DATE_TYPE) + TIMESTAMP = _get_type_name(ttypes.TTypeId.TIMESTAMP_TYPE) + INTERVAL_YEAR_MONTH = _get_type_name(ttypes.TTypeId.INTERVAL_YEAR_MONTH_TYPE) + INTERVAL_DAY_TIME = _get_type_name(ttypes.TTypeId.INTERVAL_DAY_TIME_TYPE) # String types - CHAR = "char" # Maps to TTypeId.CHAR_TYPE - VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE - STRING = "string" # Maps to TTypeId.STRING_TYPE + CHAR = _get_type_name(ttypes.TTypeId.CHAR_TYPE) + VARCHAR = _get_type_name(ttypes.TTypeId.VARCHAR_TYPE) + STRING = _get_type_name(ttypes.TTypeId.STRING_TYPE) # Binary type - BINARY = "binary" # Maps to TTypeId.BINARY_TYPE + BINARY = _get_type_name(ttypes.TTypeId.BINARY_TYPE) # Complex types - ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE - MAP = "map" # Maps to TTypeId.MAP_TYPE - STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE + ARRAY = _get_type_name(ttypes.TTypeId.ARRAY_TYPE) + MAP = _get_type_name(ttypes.TTypeId.MAP_TYPE) + STRUCT = _get_type_name(ttypes.TTypeId.STRUCT_TYPE) # Other types - NULL = "null" # Maps to TTypeId.NULL_TYPE - UNION = "union" # Maps to TTypeId.UNION_TYPE - USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE + NULL = _get_type_name(ttypes.TTypeId.NULL_TYPE) + UNION = _get_type_name(ttypes.TTypeId.UNION_TYPE) + USER_DEFINED = _get_type_name(ttypes.TTypeId.USER_DEFINED_TYPE) class SqlTypeConverter: diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index 30fb1e3e9..dcec18f36 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -8,12 +8,11 @@ class ResultColumn: Represents a mapping between JDBC specification column names and actual result set column names. Attributes: - column_name: JDBC specification column name (e.g., "TABLE_CAT") - result_set_column_name: Server result column name from SEA (e.g., "catalog") - column_type: SQL type code from databricks.sql.types - transform_value: Optional function to transform values for this column + thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT") + sea_col_name: Server result column name from SEA (e.g., "catalog") + thrift_col_type: SQL type name """ - column_name: str - result_set_column_name: Optional[str] # None if SEA doesn't return this column - column_type: str + thrift_col_name: str + sea_col_name: Optional[str] # None if SEA doesn't return this column + thrift_col_type: str From b2ae83c32d3731c9cba866279b6a578849dc445b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:38:27 +0000 Subject: [PATCH 16/22] move helper type name extractor out of class Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/conversion.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index d44193d46..028e437cc 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -50,6 +50,14 @@ def _convert_decimal( return result +def _get_type_name(thrift_type_id: int) -> str: + type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id] + type_name = type_name.lower() + if type_name.endswith("_type"): + type_name = type_name[:-5] + return type_name + + class SqlType: """ SQL type constants based on Thrift TTypeId values. @@ -58,14 +66,6 @@ class SqlType: after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix). """ - @staticmethod - def _get_type_name(thrift_type_id: int) -> str: - type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id] - type_name = type_name.lower() - if type_name.endswith("_type"): - type_name = type_name[:-5] - return type_name - # Numeric types TINYINT = _get_type_name(ttypes.TTypeId.TINYINT_TYPE) SMALLINT = _get_type_name(ttypes.TTypeId.SMALLINT_TYPE) From 4be68082711e4fa8eafff4e4595866b4aebcb5a3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:43:33 +0000 Subject: [PATCH 17/22] ensure SeaResultSet resp Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 53ed836ed..49f8dc31b 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -974,7 +974,7 @@ def test_get_tables_with_cloud_fetch( ): """Test the get_tables method with cloud fetch enabled.""" # Mock the execute_command method and ResultSetFilter - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client_cloud_fetch, "execute_command", return_value=mock_result_set @@ -1013,7 +1013,7 @@ def test_get_schemas_with_cloud_fetch( ): """Test the get_schemas method with cloud fetch enabled.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client_cloud_fetch, "execute_command", return_value=mock_result_set ) as mock_execute: From 0dad96616117d0c92be2eaa0ba311f72406235a5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 03:47:38 +0000 Subject: [PATCH 18/22] clean up conversion code Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 16 ++++------------ .../sql/backend/sea/utils/conversion.py | 4 ---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index bcc985b7d..f13a62d4c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -204,12 +204,8 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) if isinstance(self.results, JsonQueue): - # Transform JSON first, then convert to Arrow - transformed_json = self._normalise_json_metadata_cols(results) - results = self._convert_json_to_arrow_table(transformed_json) - else: - # Transform Arrow table directly - results = self._normalise_arrow_metadata_cols(results) + results = self._convert_json_to_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows @@ -222,12 +218,8 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() if isinstance(self.results, JsonQueue): - # Transform JSON first, then convert to Arrow - transformed_json = self._normalise_json_metadata_cols(results) - results = self._convert_json_to_arrow_table(transformed_json) - else: - # Transform Arrow table directly - results = self._normalise_arrow_metadata_cols(results) + results = self._convert_json_to_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 028e437cc..6253fa823 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -159,10 +159,6 @@ def convert_value( The converted value in the appropriate Python type """ - # Handle None values directly - if value is None: - return None - sql_type = sql_type.lower().strip() if sql_type not in SqlTypeConverter.TYPE_MAPPING: From a28596b5fca062355c3451ae3aa991f025d769ba Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 05:43:43 +0000 Subject: [PATCH 19/22] fix type codes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 3 +- src/databricks/sql/backend/sea/result_set.py | 34 ++++--------- .../backend/sea/utils/metadata_mappings.py | 50 +++++++++---------- 3 files changed, 38 insertions(+), 49 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 845f633f4..6bcaabaec 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -784,13 +784,14 @@ def get_tables( assert isinstance( result, SeaResultSet ), "Expected SeaResultSet from SEA backend" - result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS) # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) + result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS) + return result def get_columns( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index f13a62d4c..7ea91733c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -342,7 +342,7 @@ def _prepare_column_mapping(self) -> None: def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Table": """Transform arrow table columns for metadata normalization.""" - if not self._metadata_columns: + if not self._metadata_columns or len(table.schema) == 0: return table # Reorder columns and add missing ones @@ -351,26 +351,17 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab for new_idx, result_column in enumerate(self._metadata_columns or []): old_idx = ( - self._column_index_mapping.get(new_idx) + self._column_index_mapping.get(new_idx, None) if self._column_index_mapping else None ) - # Get the source data - if old_idx is not None: - column = table.column(old_idx) - values = column.to_pylist() - else: - values = None - - # Apply transformation and create column - if values is not None: - column = pyarrow.array(values) - new_columns.append(column) - else: - # Create column with default/transformed values - null_array = pyarrow.nulls(table.num_rows) - new_columns.append(null_array) + column = ( + pyarrow.nulls(table.num_rows) + if old_idx is None + else table.column(old_idx) + ) + new_columns.append(column) column_names.append(result_column.thrift_col_name) @@ -378,7 +369,7 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]]: """Transform JSON rows for metadata normalization.""" - if not self._metadata_columns: + if not self._metadata_columns or len(rows) == 0: return rows transformed_rows = [] @@ -386,15 +377,12 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any] new_row = [] for new_idx, result_column in enumerate(self._metadata_columns or []): old_idx = ( - self._column_index_mapping.get(new_idx) + self._column_index_mapping.get(new_idx, None) if self._column_index_mapping else None ) - if old_idx is not None: - value = row[old_idx] - else: - value = None + value = None if old_idx is None else row[old_idx] new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index dd38923ca..340c4c79e 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -6,32 +6,32 @@ class MetadataColumnMappings: """Column mappings for metadata queries following JDBC specification.""" CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( - "TABLE_CAT", "catalog", SqlType.VARCHAR + "TABLE_CAT", "catalog", SqlType.STRING ) - CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", None, SqlType.VARCHAR) + CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", None, SqlType.STRING) SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( - "TABLE_SCHEM", "databaseName", SqlType.VARCHAR + "TABLE_SCHEM", "databaseName", SqlType.STRING ) - CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.VARCHAR) - SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.VARCHAR) - TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.VARCHAR) - TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.VARCHAR) - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.VARCHAR) - TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.VARCHAR) - TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.VARCHAR) - TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.VARCHAR) + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.STRING) + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING) + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING) + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING) + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING) + TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING) + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING) SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( - "SELF_REFERENCING_COL_NAME", None, SqlType.VARCHAR + "SELF_REFERENCING_COL_NAME", None, SqlType.STRING ) - REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.VARCHAR) + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING) - COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.VARCHAR) + COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING) DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) - COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.VARCHAR) + COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING) COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) - BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.INT) + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT) DECIMAL_DIGITS_COLUMN = ResultColumn( "DECIMAL_DIGITS", @@ -46,22 +46,22 @@ class MetadataColumnMappings: ) NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) - COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.VARCHAR) + COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING) SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.VARCHAR) + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING) - SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.VARCHAR) - SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.VARCHAR) - SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.VARCHAR) - SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.INT) + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.STRING) + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT) IS_AUTO_INCREMENT_COLUMN = ResultColumn( - "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.VARCHAR + "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING ) IS_GENERATED_COLUMN = ResultColumn( - "IS_GENERATEDCOLUMN", "isGenerated", SqlType.VARCHAR + "IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING ) CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] @@ -108,5 +108,5 @@ class MetadataColumnMappings: SCOPE_TABLE_COLUMN, SOURCE_DATA_TYPE_COLUMN, IS_AUTO_INCREMENT_COLUMN, - # DO NOT INCLUDE IS_GENERATED_COLUMN - Thrift returns 23 columns + # not including IS_GENERATED_COLUMN of SEA because Thrift does not return an equivalent ] From 639bafae95dfa93f136b2f62293fc1ce1f085f9d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 06:08:21 +0000 Subject: [PATCH 20/22] simplify docstring Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 7ea91733c..8415c6295 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -83,8 +83,8 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) - # Initialize metadata columns for post-fetch transformation self._metadata_columns: Optional[List[ResultColumn]] = None + # new index -> old index self._column_index_mapping: Optional[Dict[int, Union[int, None]]] = None def _convert_json_types(self, row: List[str]) -> List[Any]: From b0b58fb9aae0d55cf7620a1b312d140e3b3077b1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 06:38:17 +0000 Subject: [PATCH 21/22] nit: reduce repetition Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 57 ++++++++++--------- .../sql/backend/sea/utils/result_column.py | 2 +- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 8415c6295..09a8df1eb 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -84,7 +84,7 @@ def __init__( ) self._metadata_columns: Optional[List[ResultColumn]] = None - # new index -> old index + # new index -> old index self._column_index_mapping: Optional[Dict[int, Union[int, None]]] = None def _convert_json_types(self, row: List[str]) -> List[Any]: @@ -287,7 +287,7 @@ def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None def _prepare_column_mapping(self) -> None: """ Prepare column index mapping for metadata queries. - Updates description to use JDBC column names. + Updates description to use Thrift column names. """ # Ensure description is available if not self.description: @@ -303,40 +303,41 @@ def _prepare_column_mapping(self) -> None: self._column_index_mapping = {} # Maps new index -> old index for new_idx, result_column in enumerate(self._metadata_columns or []): - # Find the corresponding SEA column + # Determine the old index and get column metadata if ( result_column.sea_col_name and result_column.sea_col_name in sea_column_indices ): old_idx = sea_column_indices[result_column.sea_col_name] - self._column_index_mapping[new_idx] = old_idx - # Use the original column metadata but with JDBC name old_col = self.description[old_idx] - new_description.append( - ( - result_column.thrift_col_name, # JDBC name - result_column.thrift_col_type, # Expected type - old_col[2], # display_size - old_col[3], # internal_size - old_col[4], # precision - old_col[5], # scale - old_col[6], # null_ok - ) - ) + # Use original column metadata + display_size, internal_size, precision, scale, null_ok = old_col[2:7] else: - # Column doesn't exist in SEA - add with None values - new_description.append( - ( - result_column.thrift_col_name, - result_column.thrift_col_type, - None, - None, - None, - None, - True, - ) + old_idx = None + # Use None values for missing columns + display_size, internal_size, precision, scale, null_ok = ( + None, + None, + None, + None, + True, + ) + + # Set the mapping + self._column_index_mapping[new_idx] = old_idx + + # Create the new description entry + new_description.append( + ( + result_column.thrift_col_name, # Thrift (normalised) name + result_column.thrift_col_type, # Expected type + display_size, + internal_size, + precision, + scale, + null_ok, ) - self._column_index_mapping[new_idx] = None + ) self.description = new_description diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index dcec18f36..a4c1f619b 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -5,7 +5,7 @@ @dataclass(frozen=True) class ResultColumn: """ - Represents a mapping between JDBC specification column names and actual result set column names. + Represents a mapping between Thrift specification column names and SEA column names. Attributes: thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT") From 55d8c759c987f27ff11299fd1b9b94b0a0651546 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 4 Aug 2025 06:58:08 +0000 Subject: [PATCH 22/22] test metadata mappings Signed-off-by: varun-edachali-dbx --- tests/unit/test_metadata_mappings.py | 163 +++++++++++++++++++++++++++ tests/unit/test_sea_backend.py | 23 ++++ 2 files changed, 186 insertions(+) create mode 100644 tests/unit/test_metadata_mappings.py diff --git a/tests/unit/test_metadata_mappings.py b/tests/unit/test_metadata_mappings.py new file mode 100644 index 000000000..6ab749e0f --- /dev/null +++ b/tests/unit/test_metadata_mappings.py @@ -0,0 +1,163 @@ +""" +Tests for SEA metadata column mappings and normalization. +""" + +import pytest +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings +from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType + + +class TestMetadataColumnMappings: + """Test suite for metadata column mappings.""" + + def test_result_column_creation(self): + """Test ResultColumn data class creation and attributes.""" + col = ResultColumn("TABLE_CAT", "catalog", SqlType.STRING) + assert col.thrift_col_name == "TABLE_CAT" + assert col.sea_col_name == "catalog" + assert col.thrift_col_type == SqlType.STRING + + def test_result_column_with_none_sea_name(self): + """Test ResultColumn when SEA doesn't return this column.""" + col = ResultColumn("TYPE_CAT", None, SqlType.STRING) + assert col.thrift_col_name == "TYPE_CAT" + assert col.sea_col_name is None + assert col.thrift_col_type == SqlType.STRING + + def test_catalog_columns_mapping(self): + """Test catalog columns mapping for getCatalogs.""" + catalog_cols = MetadataColumnMappings.CATALOG_COLUMNS + assert len(catalog_cols) == 1 + + catalog_col = catalog_cols[0] + assert catalog_col.thrift_col_name == "TABLE_CAT" + assert catalog_col.sea_col_name == "catalog" + assert catalog_col.thrift_col_type == SqlType.STRING + + def test_schema_columns_mapping(self): + """Test schema columns mapping for getSchemas.""" + schema_cols = MetadataColumnMappings.SCHEMA_COLUMNS + assert len(schema_cols) == 2 + + # Check TABLE_SCHEM column + schema_col = schema_cols[0] + assert schema_col.thrift_col_name == "TABLE_SCHEM" + assert schema_col.sea_col_name == "databaseName" + assert schema_col.thrift_col_type == SqlType.STRING + + # Check TABLE_CATALOG column + catalog_col = schema_cols[1] + assert catalog_col.thrift_col_name == "TABLE_CATALOG" + assert catalog_col.sea_col_name is None + assert catalog_col.thrift_col_type == SqlType.STRING + + def test_table_columns_mapping(self): + """Test table columns mapping for getTables.""" + table_cols = MetadataColumnMappings.TABLE_COLUMNS + assert len(table_cols) == 10 + + # Test key columns + expected_mappings = [ + ("TABLE_CAT", "catalogName", SqlType.STRING), + ("TABLE_SCHEM", "namespace", SqlType.STRING), + ("TABLE_NAME", "tableName", SqlType.STRING), + ("TABLE_TYPE", "tableType", SqlType.STRING), + ("REMARKS", "remarks", SqlType.STRING), + ("TYPE_CAT", None, SqlType.STRING), + ("TYPE_SCHEM", None, SqlType.STRING), + ("TYPE_NAME", None, SqlType.STRING), + ("SELF_REFERENCING_COL_NAME", None, SqlType.STRING), + ("REF_GENERATION", None, SqlType.STRING), + ] + + for i, (thrift_name, sea_name, sql_type) in enumerate(expected_mappings): + col = table_cols[i] + assert col.thrift_col_name == thrift_name + assert col.sea_col_name == sea_name + assert col.thrift_col_type == sql_type + + def test_column_columns_mapping(self): + """Test column columns mapping for getColumns.""" + column_cols = MetadataColumnMappings.COLUMN_COLUMNS + # Should have 23 columns (not including IS_GENERATED_COLUMN) + assert len(column_cols) == 23 + + # Test some key columns + key_columns = { + "TABLE_CAT": ("catalogName", SqlType.STRING), + "TABLE_SCHEM": ("namespace", SqlType.STRING), + "TABLE_NAME": ("tableName", SqlType.STRING), + "COLUMN_NAME": ("col_name", SqlType.STRING), + "DATA_TYPE": (None, SqlType.INT), + "TYPE_NAME": ("columnType", SqlType.STRING), + "COLUMN_SIZE": ("columnSize", SqlType.INT), + "DECIMAL_DIGITS": ("decimalDigits", SqlType.INT), + "NUM_PREC_RADIX": ("radix", SqlType.INT), + "ORDINAL_POSITION": ("ordinalPosition", SqlType.INT), + "IS_NULLABLE": ("isNullable", SqlType.STRING), + "IS_AUTOINCREMENT": ("isAutoIncrement", SqlType.STRING), + } + + for col in column_cols: + if col.thrift_col_name in key_columns: + expected_sea_name, expected_type = key_columns[col.thrift_col_name] + assert col.sea_col_name == expected_sea_name + assert col.thrift_col_type == expected_type + + def test_is_generated_column_not_included(self): + """Test that IS_GENERATED_COLUMN is not included in COLUMN_COLUMNS.""" + column_names = [ + col.thrift_col_name for col in MetadataColumnMappings.COLUMN_COLUMNS + ] + assert "IS_GENERATEDCOLUMN" not in column_names + + def test_column_type_consistency(self): + """Test that column types are consistent with JDBC spec.""" + # Test numeric types + assert MetadataColumnMappings.DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT + assert MetadataColumnMappings.COLUMN_SIZE_COLUMN.thrift_col_type == SqlType.INT + assert ( + MetadataColumnMappings.BUFFER_LENGTH_COLUMN.thrift_col_type + == SqlType.TINYINT + ) + assert ( + MetadataColumnMappings.DECIMAL_DIGITS_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.NUM_PREC_RADIX_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.ORDINAL_POSITION_COLUMN.thrift_col_type + == SqlType.INT + ) + assert MetadataColumnMappings.NULLABLE_COLUMN.thrift_col_type == SqlType.INT + assert ( + MetadataColumnMappings.SQL_DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.SQL_DATETIME_SUB_COLUMN.thrift_col_type + == SqlType.INT + ) + assert ( + MetadataColumnMappings.CHAR_OCTET_LENGTH_COLUMN.thrift_col_type + == SqlType.INT + ) + assert ( + MetadataColumnMappings.SOURCE_DATA_TYPE_COLUMN.thrift_col_type + == SqlType.SMALLINT + ) + + # Test string types + assert MetadataColumnMappings.CATALOG_COLUMN.thrift_col_type == SqlType.STRING + assert MetadataColumnMappings.SCHEMA_COLUMN.thrift_col_type == SqlType.STRING + assert ( + MetadataColumnMappings.TABLE_NAME_COLUMN.thrift_col_type == SqlType.STRING + ) + assert ( + MetadataColumnMappings.IS_NULLABLE_COLUMN.thrift_col_type == SqlType.STRING + ) + assert ( + MetadataColumnMappings.IS_AUTO_INCREMENT_COLUMN.thrift_col_type + == SqlType.STRING + ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 49f8dc31b..1a2621f06 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -14,6 +14,7 @@ ) from databricks.sql.backend.sea.models.base import ServiceError, StatementStatus from databricks.sql.backend.sea.result_set import SeaResultSet +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter from databricks.sql.thrift_api.TCLIService import ttypes @@ -756,6 +757,11 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): # Verify the result is correct assert result == mock_result_set + # Verify prepare_metadata_columns was called + mock_result_set.prepare_metadata_columns.assert_called_once_with( + MetadataColumnMappings.CATALOG_COLUMNS + ) + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): """Test the get_schemas method with various parameter combinations.""" # Mock the execute_command method @@ -818,6 +824,12 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): ) assert "Catalog name is required for get_schemas" in str(excinfo.value) + # Verify prepare_metadata_columns was called for successful cases + assert mock_result_set.prepare_metadata_columns.call_count == 2 + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.SCHEMA_COLUMNS + ) + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method @@ -905,6 +917,11 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) + # Verify prepare_metadata_columns was called + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.TABLE_COLUMNS + ) + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method @@ -969,6 +986,12 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): ) assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify prepare_metadata_columns was called for successful cases + assert mock_result_set.prepare_metadata_columns.call_count == 2 + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.COLUMN_COLUMNS + ) + def test_get_tables_with_cloud_fetch( self, sea_client_cloud_fetch, sea_session_id, mock_cursor ):