diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..6bcaabaec 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -19,6 +19,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift from databricks.sql.thrift_api.TCLIService import ttypes @@ -700,7 +701,10 @@ def get_catalogs( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.CATALOG_COLUMNS) return result def get_schemas( @@ -733,7 +737,10 @@ def get_schemas( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS) return result def get_tables( @@ -774,13 +781,17 @@ def get_tables( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) + result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS) + return result def get_columns( @@ -821,5 +832,8 @@ def get_columns( async_op=False, enforce_embedded_schema_correctness=False, ) - assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "Expected SeaResultSet from SEA backend" + result.prepare_metadata_columns(MetadataColumnMappings.COLUMN_COLUMNS) return result diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..09a8df1eb 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING, Dict, Union import logging from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter +from databricks.sql.backend.sea.utils.result_column import ResultColumn try: import pyarrow @@ -82,6 +83,10 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + self._metadata_columns: Optional[List[ResultColumn]] = None + # new index -> old index + self._column_index_mapping: Optional[Dict[int, Union[int, None]]] = None + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") results = self.results.next_n_rows(size) + results = self._normalise_json_metadata_cols(results) self._next_row_index += len(results) return results @@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]: """ results = self.results.remaining_rows() + results = self._normalise_json_metadata_cols(results) self._next_row_index += len(results) return results @@ -198,6 +205,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows @@ -211,6 +219,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + results = self._normalise_arrow_metadata_cols(results) self._next_row_index += results.num_rows @@ -263,3 +272,118 @@ def fetchall(self) -> List[Row]: return self._create_json_table(self.fetchall_json()) else: return self._convert_arrow_table(self.fetchall_arrow()) + + def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None: + """ + Prepare result set for metadata column normalization. + + Args: + metadata_columns: List of ResultColumn objects defining the expected columns + and their mappings from SEA column names + """ + self._metadata_columns = metadata_columns + self._prepare_column_mapping() + + def _prepare_column_mapping(self) -> None: + """ + Prepare column index mapping for metadata queries. + Updates description to use Thrift column names. + """ + # Ensure description is available + if not self.description: + raise ValueError("Cannot prepare column mapping without result description") + + # Build mapping from SEA column names to their indices + sea_column_indices = {} + for idx, col in enumerate(self.description): + sea_column_indices[col[0]] = idx + + # Create new description and index mapping + new_description = [] + self._column_index_mapping = {} # Maps new index -> old index + + for new_idx, result_column in enumerate(self._metadata_columns or []): + # Determine the old index and get column metadata + if ( + result_column.sea_col_name + and result_column.sea_col_name in sea_column_indices + ): + old_idx = sea_column_indices[result_column.sea_col_name] + old_col = self.description[old_idx] + # Use original column metadata + display_size, internal_size, precision, scale, null_ok = old_col[2:7] + else: + old_idx = None + # Use None values for missing columns + display_size, internal_size, precision, scale, null_ok = ( + None, + None, + None, + None, + True, + ) + + # Set the mapping + self._column_index_mapping[new_idx] = old_idx + + # Create the new description entry + new_description.append( + ( + result_column.thrift_col_name, # Thrift (normalised) name + result_column.thrift_col_type, # Expected type + display_size, + internal_size, + precision, + scale, + null_ok, + ) + ) + + self.description = new_description + + def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Table": + """Transform arrow table columns for metadata normalization.""" + if not self._metadata_columns or len(table.schema) == 0: + return table + + # Reorder columns and add missing ones + new_columns = [] + column_names = [] + + for new_idx, result_column in enumerate(self._metadata_columns or []): + old_idx = ( + self._column_index_mapping.get(new_idx, None) + if self._column_index_mapping + else None + ) + + column = ( + pyarrow.nulls(table.num_rows) + if old_idx is None + else table.column(old_idx) + ) + new_columns.append(column) + + column_names.append(result_column.thrift_col_name) + + return pyarrow.Table.from_arrays(new_columns, names=column_names) + + def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]]: + """Transform JSON rows for metadata normalization.""" + if not self._metadata_columns or len(rows) == 0: + return rows + + transformed_rows = [] + for row in rows: + new_row = [] + for new_idx, result_column in enumerate(self._metadata_columns or []): + old_idx = ( + self._column_index_mapping.get(new_idx, None) + if self._column_index_mapping + else None + ) + + value = None if old_idx is None else row[old_idx] + new_row.append(value) + transformed_rows.append(new_row) + return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 69c6dfbe2..6253fa823 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -11,6 +11,8 @@ from dateutil import parser from typing import Callable, Dict, Optional +from databricks.sql.thrift_api.TCLIService import ttypes + logger = logging.getLogger(__name__) @@ -48,6 +50,14 @@ def _convert_decimal( return result +def _get_type_name(thrift_type_id: int) -> str: + type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id] + type_name = type_name.lower() + if type_name.endswith("_type"): + type_name = type_name[:-5] + return type_name + + class SqlType: """ SQL type constants based on Thrift TTypeId values. @@ -57,42 +67,40 @@ class SqlType: """ # Numeric types - TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE - SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE - INT = "int" # Maps to TTypeId.INT_TYPE - BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE - FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE - DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE - DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE + TINYINT = _get_type_name(ttypes.TTypeId.TINYINT_TYPE) + SMALLINT = _get_type_name(ttypes.TTypeId.SMALLINT_TYPE) + INT = _get_type_name(ttypes.TTypeId.INT_TYPE) + BIGINT = _get_type_name(ttypes.TTypeId.BIGINT_TYPE) + FLOAT = _get_type_name(ttypes.TTypeId.FLOAT_TYPE) + DOUBLE = _get_type_name(ttypes.TTypeId.DOUBLE_TYPE) + DECIMAL = _get_type_name(ttypes.TTypeId.DECIMAL_TYPE) # Boolean type - BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE + BOOLEAN = _get_type_name(ttypes.TTypeId.BOOLEAN_TYPE) # Date/Time types - DATE = "date" # Maps to TTypeId.DATE_TYPE - TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE - INTERVAL_YEAR_MONTH = ( - "interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE - ) - INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE + DATE = _get_type_name(ttypes.TTypeId.DATE_TYPE) + TIMESTAMP = _get_type_name(ttypes.TTypeId.TIMESTAMP_TYPE) + INTERVAL_YEAR_MONTH = _get_type_name(ttypes.TTypeId.INTERVAL_YEAR_MONTH_TYPE) + INTERVAL_DAY_TIME = _get_type_name(ttypes.TTypeId.INTERVAL_DAY_TIME_TYPE) # String types - CHAR = "char" # Maps to TTypeId.CHAR_TYPE - VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE - STRING = "string" # Maps to TTypeId.STRING_TYPE + CHAR = _get_type_name(ttypes.TTypeId.CHAR_TYPE) + VARCHAR = _get_type_name(ttypes.TTypeId.VARCHAR_TYPE) + STRING = _get_type_name(ttypes.TTypeId.STRING_TYPE) # Binary type - BINARY = "binary" # Maps to TTypeId.BINARY_TYPE + BINARY = _get_type_name(ttypes.TTypeId.BINARY_TYPE) # Complex types - ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE - MAP = "map" # Maps to TTypeId.MAP_TYPE - STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE + ARRAY = _get_type_name(ttypes.TTypeId.ARRAY_TYPE) + MAP = _get_type_name(ttypes.TTypeId.MAP_TYPE) + STRUCT = _get_type_name(ttypes.TTypeId.STRUCT_TYPE) # Other types - NULL = "null" # Maps to TTypeId.NULL_TYPE - UNION = "union" # Maps to TTypeId.UNION_TYPE - USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE + NULL = _get_type_name(ttypes.TTypeId.NULL_TYPE) + UNION = _get_type_name(ttypes.TTypeId.UNION_TYPE) + USER_DEFINED = _get_type_name(ttypes.TTypeId.USER_DEFINED_TYPE) class SqlTypeConverter: diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py new file mode 100644 index 000000000..340c4c79e --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -0,0 +1,112 @@ +from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType + + +class MetadataColumnMappings: + """Column mappings for metadata queries following JDBC specification.""" + + CATALOG_COLUMN_FOR_GET_CATALOGS = ResultColumn( + "TABLE_CAT", "catalog", SqlType.STRING + ) + + CATALOG_FULL_COLUMN = ResultColumn("TABLE_CATALOG", None, SqlType.STRING) + SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn( + "TABLE_SCHEM", "databaseName", SqlType.STRING + ) + + CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalogName", SqlType.STRING) + SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING) + TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING) + TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING) + REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING) + TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING) + TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING) + TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING) + SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn( + "SELF_REFERENCING_COL_NAME", None, SqlType.STRING + ) + REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING) + + COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING) + DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) + COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING) + COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) + BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT) + + DECIMAL_DIGITS_COLUMN = ResultColumn( + "DECIMAL_DIGITS", + "decimalDigits", + SqlType.INT, + ) + NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", "radix", SqlType.INT) + ORDINAL_POSITION_COLUMN = ResultColumn( + "ORDINAL_POSITION", + "ordinalPosition", + SqlType.INT, + ) + + NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING) + SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) + SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) + CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) + IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING) + + SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING) + SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING) + SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.STRING) + SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT) + + IS_AUTO_INCREMENT_COLUMN = ResultColumn( + "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING + ) + IS_GENERATED_COLUMN = ResultColumn( + "IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING + ) + + CATALOG_COLUMNS = [CATALOG_COLUMN_FOR_GET_CATALOGS] + + SCHEMA_COLUMNS = [ + SCHEMA_COLUMN_FOR_GET_SCHEMA, + CATALOG_FULL_COLUMN, + ] + + TABLE_COLUMNS = [ + CATALOG_COLUMN, + SCHEMA_COLUMN, + TABLE_NAME_COLUMN, + TABLE_TYPE_COLUMN, + REMARKS_COLUMN, + TYPE_CATALOG_COLUMN, + TYPE_SCHEM_COLUMN, + TYPE_NAME_COLUMN, + SELF_REFERENCING_COL_NAME_COLUMN, + REF_GENERATION_COLUMN, + ] + + COLUMN_COLUMNS = [ + CATALOG_COLUMN, + SCHEMA_COLUMN, + TABLE_NAME_COLUMN, + COL_NAME_COLUMN, + DATA_TYPE_COLUMN, + COLUMN_TYPE_COLUMN, + COLUMN_SIZE_COLUMN, + BUFFER_LENGTH_COLUMN, + DECIMAL_DIGITS_COLUMN, + NUM_PREC_RADIX_COLUMN, + NULLABLE_COLUMN, + REMARKS_COLUMN, + COLUMN_DEF_COLUMN, + SQL_DATA_TYPE_COLUMN, + SQL_DATETIME_SUB_COLUMN, + CHAR_OCTET_LENGTH_COLUMN, + ORDINAL_POSITION_COLUMN, + IS_NULLABLE_COLUMN, + SCOPE_CATALOG_COLUMN, + SCOPE_SCHEMA_COLUMN, + SCOPE_TABLE_COLUMN, + SOURCE_DATA_TYPE_COLUMN, + IS_AUTO_INCREMENT_COLUMN, + # not including IS_GENERATED_COLUMN of SEA because Thrift does not return an equivalent + ] diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py new file mode 100644 index 000000000..a4c1f619b --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class ResultColumn: + """ + Represents a mapping between Thrift specification column names and SEA column names. + + Attributes: + thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT") + sea_col_name: Server result column name from SEA (e.g., "catalog") + thrift_col_type: SQL type name + """ + + thrift_col_name: str + sea_col_name: Optional[str] # None if SEA doesn't return this column + thrift_col_type: str diff --git a/tests/unit/test_metadata_mappings.py b/tests/unit/test_metadata_mappings.py new file mode 100644 index 000000000..6ab749e0f --- /dev/null +++ b/tests/unit/test_metadata_mappings.py @@ -0,0 +1,163 @@ +""" +Tests for SEA metadata column mappings and normalization. +""" + +import pytest +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings +from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType + + +class TestMetadataColumnMappings: + """Test suite for metadata column mappings.""" + + def test_result_column_creation(self): + """Test ResultColumn data class creation and attributes.""" + col = ResultColumn("TABLE_CAT", "catalog", SqlType.STRING) + assert col.thrift_col_name == "TABLE_CAT" + assert col.sea_col_name == "catalog" + assert col.thrift_col_type == SqlType.STRING + + def test_result_column_with_none_sea_name(self): + """Test ResultColumn when SEA doesn't return this column.""" + col = ResultColumn("TYPE_CAT", None, SqlType.STRING) + assert col.thrift_col_name == "TYPE_CAT" + assert col.sea_col_name is None + assert col.thrift_col_type == SqlType.STRING + + def test_catalog_columns_mapping(self): + """Test catalog columns mapping for getCatalogs.""" + catalog_cols = MetadataColumnMappings.CATALOG_COLUMNS + assert len(catalog_cols) == 1 + + catalog_col = catalog_cols[0] + assert catalog_col.thrift_col_name == "TABLE_CAT" + assert catalog_col.sea_col_name == "catalog" + assert catalog_col.thrift_col_type == SqlType.STRING + + def test_schema_columns_mapping(self): + """Test schema columns mapping for getSchemas.""" + schema_cols = MetadataColumnMappings.SCHEMA_COLUMNS + assert len(schema_cols) == 2 + + # Check TABLE_SCHEM column + schema_col = schema_cols[0] + assert schema_col.thrift_col_name == "TABLE_SCHEM" + assert schema_col.sea_col_name == "databaseName" + assert schema_col.thrift_col_type == SqlType.STRING + + # Check TABLE_CATALOG column + catalog_col = schema_cols[1] + assert catalog_col.thrift_col_name == "TABLE_CATALOG" + assert catalog_col.sea_col_name is None + assert catalog_col.thrift_col_type == SqlType.STRING + + def test_table_columns_mapping(self): + """Test table columns mapping for getTables.""" + table_cols = MetadataColumnMappings.TABLE_COLUMNS + assert len(table_cols) == 10 + + # Test key columns + expected_mappings = [ + ("TABLE_CAT", "catalogName", SqlType.STRING), + ("TABLE_SCHEM", "namespace", SqlType.STRING), + ("TABLE_NAME", "tableName", SqlType.STRING), + ("TABLE_TYPE", "tableType", SqlType.STRING), + ("REMARKS", "remarks", SqlType.STRING), + ("TYPE_CAT", None, SqlType.STRING), + ("TYPE_SCHEM", None, SqlType.STRING), + ("TYPE_NAME", None, SqlType.STRING), + ("SELF_REFERENCING_COL_NAME", None, SqlType.STRING), + ("REF_GENERATION", None, SqlType.STRING), + ] + + for i, (thrift_name, sea_name, sql_type) in enumerate(expected_mappings): + col = table_cols[i] + assert col.thrift_col_name == thrift_name + assert col.sea_col_name == sea_name + assert col.thrift_col_type == sql_type + + def test_column_columns_mapping(self): + """Test column columns mapping for getColumns.""" + column_cols = MetadataColumnMappings.COLUMN_COLUMNS + # Should have 23 columns (not including IS_GENERATED_COLUMN) + assert len(column_cols) == 23 + + # Test some key columns + key_columns = { + "TABLE_CAT": ("catalogName", SqlType.STRING), + "TABLE_SCHEM": ("namespace", SqlType.STRING), + "TABLE_NAME": ("tableName", SqlType.STRING), + "COLUMN_NAME": ("col_name", SqlType.STRING), + "DATA_TYPE": (None, SqlType.INT), + "TYPE_NAME": ("columnType", SqlType.STRING), + "COLUMN_SIZE": ("columnSize", SqlType.INT), + "DECIMAL_DIGITS": ("decimalDigits", SqlType.INT), + "NUM_PREC_RADIX": ("radix", SqlType.INT), + "ORDINAL_POSITION": ("ordinalPosition", SqlType.INT), + "IS_NULLABLE": ("isNullable", SqlType.STRING), + "IS_AUTOINCREMENT": ("isAutoIncrement", SqlType.STRING), + } + + for col in column_cols: + if col.thrift_col_name in key_columns: + expected_sea_name, expected_type = key_columns[col.thrift_col_name] + assert col.sea_col_name == expected_sea_name + assert col.thrift_col_type == expected_type + + def test_is_generated_column_not_included(self): + """Test that IS_GENERATED_COLUMN is not included in COLUMN_COLUMNS.""" + column_names = [ + col.thrift_col_name for col in MetadataColumnMappings.COLUMN_COLUMNS + ] + assert "IS_GENERATEDCOLUMN" not in column_names + + def test_column_type_consistency(self): + """Test that column types are consistent with JDBC spec.""" + # Test numeric types + assert MetadataColumnMappings.DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT + assert MetadataColumnMappings.COLUMN_SIZE_COLUMN.thrift_col_type == SqlType.INT + assert ( + MetadataColumnMappings.BUFFER_LENGTH_COLUMN.thrift_col_type + == SqlType.TINYINT + ) + assert ( + MetadataColumnMappings.DECIMAL_DIGITS_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.NUM_PREC_RADIX_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.ORDINAL_POSITION_COLUMN.thrift_col_type + == SqlType.INT + ) + assert MetadataColumnMappings.NULLABLE_COLUMN.thrift_col_type == SqlType.INT + assert ( + MetadataColumnMappings.SQL_DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT + ) + assert ( + MetadataColumnMappings.SQL_DATETIME_SUB_COLUMN.thrift_col_type + == SqlType.INT + ) + assert ( + MetadataColumnMappings.CHAR_OCTET_LENGTH_COLUMN.thrift_col_type + == SqlType.INT + ) + assert ( + MetadataColumnMappings.SOURCE_DATA_TYPE_COLUMN.thrift_col_type + == SqlType.SMALLINT + ) + + # Test string types + assert MetadataColumnMappings.CATALOG_COLUMN.thrift_col_type == SqlType.STRING + assert MetadataColumnMappings.SCHEMA_COLUMN.thrift_col_type == SqlType.STRING + assert ( + MetadataColumnMappings.TABLE_NAME_COLUMN.thrift_col_type == SqlType.STRING + ) + assert ( + MetadataColumnMappings.IS_NULLABLE_COLUMN.thrift_col_type == SqlType.STRING + ) + assert ( + MetadataColumnMappings.IS_AUTO_INCREMENT_COLUMN.thrift_col_type + == SqlType.STRING + ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..1a2621f06 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -13,6 +13,8 @@ _filter_session_configuration, ) from databricks.sql.backend.sea.models.base import ServiceError, StatementStatus +from databricks.sql.backend.sea.result_set import SeaResultSet +from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter from databricks.sql.thrift_api.TCLIService import ttypes @@ -726,7 +728,7 @@ def test_results_message_to_execute_response_is_staging_operation(self, sea_clie def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): """Test the get_catalogs method.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -755,10 +757,15 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): # Verify the result is correct assert result == mock_result_set + # Verify prepare_metadata_columns was called + mock_result_set.prepare_metadata_columns.assert_called_once_with( + MetadataColumnMappings.CATALOG_COLUMNS + ) + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): """Test the get_schemas method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -817,6 +824,12 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): ) assert "Catalog name is required for get_schemas" in str(excinfo.value) + # Verify prepare_metadata_columns was called for successful cases + assert mock_result_set.prepare_metadata_columns.call_count == 2 + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.SCHEMA_COLUMNS + ) + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method @@ -904,10 +917,15 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) + # Verify prepare_metadata_columns was called + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.TABLE_COLUMNS + ) + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -968,12 +986,18 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): ) assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify prepare_metadata_columns was called for successful cases + assert mock_result_set.prepare_metadata_columns.call_count == 2 + mock_result_set.prepare_metadata_columns.assert_called_with( + MetadataColumnMappings.COLUMN_COLUMNS + ) + def test_get_tables_with_cloud_fetch( self, sea_client_cloud_fetch, sea_session_id, mock_cursor ): """Test the get_tables method with cloud fetch enabled.""" # Mock the execute_command method and ResultSetFilter - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client_cloud_fetch, "execute_command", return_value=mock_result_set @@ -1012,7 +1036,7 @@ def test_get_schemas_with_cloud_fetch( ): """Test the get_schemas method with cloud fetch enabled.""" # Mock the execute_command method - mock_result_set = Mock() + mock_result_set = Mock(spec=SeaResultSet) with patch.object( sea_client_cloud_fetch, "execute_command", return_value=mock_result_set ) as mock_execute: