diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 6bcaabaec..02b8d4604 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -20,7 +20,12 @@ MetadataCommands, ) from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings +from databricks.sql.backend.sea.utils.metadata_transforms import ( + create_table_catalog_transform, +) from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift +from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: @@ -740,7 +745,23 @@ def get_schemas( assert isinstance( result, SeaResultSet ), "Expected SeaResultSet from SEA backend" - result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS) + + # Create dynamic schema columns with catalog name bound to TABLE_CATALOG + schema_columns = [] + for col in MetadataColumnMappings.SCHEMA_COLUMNS: + if col.thrift_col_name == "TABLE_CATALOG": + # Create a new column with the catalog transform bound + dynamic_col = ResultColumn( + col.thrift_col_name, + col.sea_col_name, + col.thrift_col_type, + create_table_catalog_transform(catalog_name), + ) + schema_columns.append(dynamic_col) + else: + schema_columns.append(col) + + result.prepare_metadata_columns(schema_columns) return result def get_tables( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 09a8df1eb..af68721a1 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -320,7 +320,7 @@ def _prepare_column_mapping(self) -> None: None, None, None, - True, + None, ) # Set the mapping @@ -356,14 +356,20 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab if self._column_index_mapping else None ) - column = ( pyarrow.nulls(table.num_rows) if old_idx is None else table.column(old_idx) ) - new_columns.append(column) + # Apply transform if available + if result_column.transform_value: + # Convert to list, apply transform, and convert back + values = column.to_pylist() + transformed_values = [result_column.transform_value(v) for v in values] + column = pyarrow.array(transformed_values) + + new_columns.append(column) column_names.append(result_column.thrift_col_name) return pyarrow.Table.from_arrays(new_columns, names=column_names) @@ -382,8 +388,11 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any] if self._column_index_mapping else None ) - value = None if old_idx is None else row[old_idx] + + # Apply transform if available + if result_column.transform_value: + value = result_column.transform_value(value) new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 340c4c79e..ff5f2ab8b 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -1,5 +1,13 @@ from databricks.sql.backend.sea.utils.result_column import ResultColumn from databricks.sql.backend.sea.utils.conversion import SqlType +from databricks.sql.backend.sea.utils.metadata_transforms import ( + transform_remarks, + transform_is_autoincrement, + transform_is_nullable, + transform_nullable, + transform_data_type, + transform_ordinal_position, +) class MetadataColumnMappings: @@ -18,7 +26,9 @@ class MetadataColumnMappings: SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING) TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING) TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING) - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING) + REMARKS_COLUMN = ResultColumn( + "REMARKS", "remarks", SqlType.STRING, transform_remarks + ) TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING) TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING) TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING) @@ -28,7 +38,9 @@ class MetadataColumnMappings: REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING) COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING) - DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) + DATA_TYPE_COLUMN = ResultColumn( + "DATA_TYPE", "columnType", SqlType.INT, transform_data_type + ) COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING) COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT) @@ -43,14 +55,19 @@ class MetadataColumnMappings: "ORDINAL_POSITION", "ordinalPosition", SqlType.INT, + transform_ordinal_position, ) - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + NULLABLE_COLUMN = ResultColumn( + "NULLABLE", "isNullable", SqlType.INT, transform_nullable + ) COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING) SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING) + IS_NULLABLE_COLUMN = ResultColumn( + "IS_NULLABLE", "isNullable", SqlType.STRING, transform_is_nullable + ) SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING) SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING) @@ -58,7 +75,10 @@ class MetadataColumnMappings: SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT) IS_AUTO_INCREMENT_COLUMN = ResultColumn( - "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING + "IS_AUTO_INCREMENT", + "isAutoIncrement", + SqlType.STRING, + transform_is_autoincrement, ) IS_GENERATED_COLUMN = ResultColumn( "IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING diff --git a/src/databricks/sql/backend/sea/utils/metadata_transforms.py b/src/databricks/sql/backend/sea/utils/metadata_transforms.py new file mode 100644 index 000000000..efff2236a --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/metadata_transforms.py @@ -0,0 +1,83 @@ +"""Simple transformation functions for metadata value normalization.""" + + +def transform_is_autoincrement(value): + """Transform IS_AUTOINCREMENT: boolean to YES/NO string.""" + if isinstance(value, bool) or value is None: + return "YES" if value else "NO" + return value + + +def transform_is_nullable(value): + """Transform IS_NULLABLE: true/false to YES/NO string.""" + if value is True or value == "true": + return "YES" + elif value is False or value == "false": + return "NO" + return value + + +def transform_remarks(value): + if value is None: + return "" + return value + + +def transform_nullable(value): + """Transform NULLABLE column: boolean/string to integer.""" + if value is True or value == "true" or value == "YES": + return 1 + elif value is False or value == "false" or value == "NO": + return 0 + return value + + +# Type code mapping based on JDBC specification +TYPE_CODE_MAP = { + "STRING": 12, # VARCHAR + "VARCHAR": 12, # VARCHAR + "CHAR": 1, # CHAR + "INT": 4, # INTEGER + "INTEGER": 4, # INTEGER + "BIGINT": -5, # BIGINT + "SMALLINT": 5, # SMALLINT + "TINYINT": -6, # TINYINT + "DOUBLE": 8, # DOUBLE + "FLOAT": 6, # FLOAT + "REAL": 7, # REAL + "DECIMAL": 3, # DECIMAL + "NUMERIC": 2, # NUMERIC + "BOOLEAN": 16, # BOOLEAN + "DATE": 91, # DATE + "TIMESTAMP": 93, # TIMESTAMP + "BINARY": -2, # BINARY + "ARRAY": 2003, # ARRAY + "MAP": 2002, # JAVA_OBJECT + "STRUCT": 2002, # JAVA_OBJECT +} + + +def transform_data_type(value): + """Transform DATA_TYPE: type name to JDBC type code.""" + if isinstance(value, str): + # Handle parameterized types like DECIMAL(10,2) + base_type = value.split("(")[0].upper() + return TYPE_CODE_MAP.get(base_type, value) + return value + + +def transform_ordinal_position(value): + """Transform ORDINAL_POSITION: decrement by 1 (1-based to 0-based).""" + if isinstance(value, int): + return value - 1 + return value + + +def create_table_catalog_transform(catalog_name): + """Factory function to create TABLE_CATALOG transform with bound catalog name.""" + + def transform_table_catalog(value): + """Transform TABLE_CATALOG: return the catalog name for all rows.""" + return catalog_name + + return transform_table_catalog diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index a4c1f619b..2980bd8d9 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Callable, Any @dataclass(frozen=True) @@ -11,8 +11,10 @@ class ResultColumn: thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT") sea_col_name: Server result column name from SEA (e.g., "catalog") thrift_col_type: SQL type name + transform_value: Optional callback to transform values for this column """ thrift_col_name: str sea_col_name: Optional[str] # None if SEA doesn't return this column thrift_col_type: str + transform_value: Optional[Callable[[Any], Any]] = None diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..7b86cfbe8 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -562,8 +562,17 @@ def test_get_schemas(self): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - def test_get_catalogs(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_get_catalogs(self, backend_params): + with self.cursor(backend_params) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description diff --git a/tests/unit/test_metadata_mappings.py b/tests/unit/test_metadata_mappings.py index 6ab749e0f..f0ffef067 100644 --- a/tests/unit/test_metadata_mappings.py +++ b/tests/unit/test_metadata_mappings.py @@ -89,7 +89,7 @@ def test_column_columns_mapping(self): "TABLE_SCHEM": ("namespace", SqlType.STRING), "TABLE_NAME": ("tableName", SqlType.STRING), "COLUMN_NAME": ("col_name", SqlType.STRING), - "DATA_TYPE": (None, SqlType.INT), + "DATA_TYPE": ("columnType", SqlType.INT), "TYPE_NAME": ("columnType", SqlType.STRING), "COLUMN_SIZE": ("columnSize", SqlType.INT), "DECIMAL_DIGITS": ("decimalDigits", SqlType.INT), diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1a2621f06..5f2df8887 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -826,9 +826,6 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): # Verify prepare_metadata_columns was called for successful cases assert mock_result_set.prepare_metadata_columns.call_count == 2 - mock_result_set.prepare_metadata_columns.assert_called_with( - MetadataColumnMappings.SCHEMA_COLUMNS - ) def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations."""