diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 98cb9b2a..a8f04a05 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -19,6 +19,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: @@ -322,6 +323,11 @@ def _extract_description_from_manifest( # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) name = col_data.get("name", "") type_name = col_data.get("type_name", "") + + # Normalize SEA type to Thrift conventions before any processing + type_name = normalize_sea_type_to_thrift(type_name, col_data) + + # Now strip _TYPE suffix and convert to lowercase type_name = ( type_name[:-5] if type_name.endswith("_TYPE") else type_name ).lower() diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index a6a0a298..afa70bc8 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -92,20 +92,19 @@ def _convert_json_types(self, row: List[str]) -> List[Any]: converted_row = [] for i, value in enumerate(row): + column_name = self.description[i][0] column_type = self.description[i][1] precision = self.description[i][4] scale = self.description[i][5] - try: - converted_value = SqlTypeConverter.convert_value( - value, column_type, precision=precision, scale=scale - ) - converted_row.append(converted_value) - except Exception as e: - logger.warning( - f"Error converting value '{value}' to {column_type}: {e}" - ) - converted_row.append(value) + converted_value = SqlTypeConverter.convert_value( + value, + column_type, + column_name=column_name, + precision=precision, + scale=scale, + ) + converted_row.append(converted_value) return converted_row diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index b2de97f5..69c6dfbe 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -50,60 +50,65 @@ def _convert_decimal( class SqlType: """ - SQL type constants + SQL type constants based on Thrift TTypeId values. - The list of types can be found in the SEA REST API Reference: - https://docs.databricks.com/api/workspace/statementexecution/executestatement + These correspond to the normalized type names that come from the SEA backend + after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix). """ # Numeric types - BYTE = "byte" - SHORT = "short" - INT = "int" - LONG = "long" - FLOAT = "float" - DOUBLE = "double" - DECIMAL = "decimal" + TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE + SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE + INT = "int" # Maps to TTypeId.INT_TYPE + BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE + FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE + DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE + DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE # Boolean type - BOOLEAN = "boolean" + BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE # Date/Time types - DATE = "date" - TIMESTAMP = "timestamp" - INTERVAL = "interval" + DATE = "date" # Maps to TTypeId.DATE_TYPE + TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE + INTERVAL_YEAR_MONTH = ( + "interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE + ) + INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE # String types - CHAR = "char" - STRING = "string" + CHAR = "char" # Maps to TTypeId.CHAR_TYPE + VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE + STRING = "string" # Maps to TTypeId.STRING_TYPE # Binary type - BINARY = "binary" + BINARY = "binary" # Maps to TTypeId.BINARY_TYPE # Complex types - ARRAY = "array" - MAP = "map" - STRUCT = "struct" + ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE + MAP = "map" # Maps to TTypeId.MAP_TYPE + STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE # Other types - NULL = "null" - USER_DEFINED_TYPE = "user_defined_type" + NULL = "null" # Maps to TTypeId.NULL_TYPE + UNION = "union" # Maps to TTypeId.UNION_TYPE + USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE class SqlTypeConverter: """ Utility class for converting SQL types to Python types. - Based on the types supported by the Databricks SDK. + Based on the Thrift TTypeId types after normalization. """ # SQL type to conversion function mapping # TODO: complex types TYPE_MAPPING: Dict[str, Callable] = { # Numeric types - SqlType.BYTE: lambda v: int(v), - SqlType.SHORT: lambda v: int(v), + SqlType.TINYINT: lambda v: int(v), + SqlType.SMALLINT: lambda v: int(v), SqlType.INT: lambda v: int(v), - SqlType.LONG: lambda v: int(v), + SqlType.BIGINT: lambda v: int(v), SqlType.FLOAT: lambda v: float(v), SqlType.DOUBLE: lambda v: float(v), SqlType.DECIMAL: _convert_decimal, @@ -112,22 +117,25 @@ class SqlTypeConverter: # Date/Time types SqlType.DATE: lambda v: datetime.date.fromisoformat(v), SqlType.TIMESTAMP: lambda v: parser.parse(v), - SqlType.INTERVAL: lambda v: v, # Keep as string for now + SqlType.INTERVAL_YEAR_MONTH: lambda v: v, # Keep as string for now + SqlType.INTERVAL_DAY_TIME: lambda v: v, # Keep as string for now # String types - no conversion needed SqlType.CHAR: lambda v: v, + SqlType.VARCHAR: lambda v: v, SqlType.STRING: lambda v: v, # Binary type SqlType.BINARY: lambda v: bytes.fromhex(v), # Other types SqlType.NULL: lambda v: None, # Complex types and user-defined types return as-is - SqlType.USER_DEFINED_TYPE: lambda v: v, + SqlType.USER_DEFINED: lambda v: v, } @staticmethod def convert_value( value: str, sql_type: str, + column_name: Optional[str], **kwargs, ) -> object: """ @@ -135,7 +143,8 @@ def convert_value( Args: value: The string value to convert - sql_type: The SQL type (e.g., 'int', 'decimal') + sql_type: The SQL type (e.g., 'tinyint', 'decimal') + column_name: The name of the column being converted **kwargs: Additional keyword arguments for the conversion function Returns: @@ -155,6 +164,10 @@ def convert_value( return converter_func(value, precision, scale) else: return converter_func(value) - except (ValueError, TypeError, decimal.InvalidOperation) as e: - logger.warning(f"Error converting value '{value}' to {sql_type}: {e}") + except Exception as e: + warning_message = f"Error converting value '{value}' to {sql_type}" + if column_name: + warning_message += f" in column {column_name}" + warning_message += f": {e}" + logger.warning(warning_message) return value diff --git a/src/databricks/sql/backend/sea/utils/normalize.py b/src/databricks/sql/backend/sea/utils/normalize.py new file mode 100644 index 00000000..d725d294 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/normalize.py @@ -0,0 +1,50 @@ +""" +Type normalization utilities for SEA backend. + +This module provides functionality to normalize SEA type names to match +Thrift type naming conventions. +""" + +from typing import Dict, Any + +# SEA types that need to be translated to Thrift types +# The list of all SEA types is available in the REST reference at: +# https://docs.databricks.com/api/workspace/statementexecution/executestatement +# The list of all Thrift types can be found in the ttypes.TTypeId definition +# The SEA types that do not align with Thrift are explicitly mapped below +SEA_TO_THRIFT_TYPE_MAP = { + "BYTE": "TINYINT", + "SHORT": "SMALLINT", + "LONG": "BIGINT", + "INTERVAL": "INTERVAL", # Default mapping, will be overridden if type_interval_type is present +} + + +def normalize_sea_type_to_thrift(type_name: str, col_data: Dict[str, Any]) -> str: + """ + Normalize SEA type names to match Thrift type naming conventions. + + Args: + type_name: The type name from SEA (e.g., "BYTE", "LONG", "INTERVAL") + col_data: The full column data dictionary from manifest (for accessing type_interval_type) + + Returns: + Normalized type name matching Thrift conventions + """ + # Early return if type doesn't need mapping + if type_name not in SEA_TO_THRIFT_TYPE_MAP: + return type_name + + normalized_type = SEA_TO_THRIFT_TYPE_MAP[type_name] + + # Special handling for interval types + if type_name == "INTERVAL": + type_interval_type = col_data.get("type_interval_type") + if type_interval_type: + return ( + "INTERVAL_YEAR_MONTH" + if any(t in type_interval_type.upper() for t in ["YEAR", "MONTH"]) + else "INTERVAL_DAY_TIME" + ) + + return normalized_type diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4271f0d7..19375cde 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -262,9 +262,7 @@ def test_negative_fetch_throws_exception(self): mock_backend = Mock() mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet( - Mock(), Mock(), mock_backend - ) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index ed782a80..c514980e 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -26,12 +26,14 @@ class DownloaderTests(unittest.TestCase): def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] + def time_side_effect(): call_count[0] += 1 if call_count[0] <= 2: # First two calls (validation, start_time) return 1000 else: # All subsequent calls (logging, duration calculation) return end_time + mock_time.side_effect = time_side_effect @patch("time.time", return_value=1000) @@ -104,7 +106,7 @@ def test_run_get_response_not_ok(self, mock_time): @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) @@ -133,7 +135,7 @@ def test_run_uncompressed_successful(self, mock_time): @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 482ce655..396ad906 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -550,6 +550,66 @@ def test_extract_description_from_manifest(self, sea_client): assert description[1][1] == "int" # type_code assert description[1][6] is None # null_ok + def test_extract_description_from_manifest_with_type_normalization( + self, sea_client + ): + """Test _extract_description_from_manifest with SEA to Thrift type normalization.""" + manifest_obj = MagicMock() + manifest_obj.schema = { + "columns": [ + { + "name": "byte_col", + "type_name": "BYTE", + }, + { + "name": "short_col", + "type_name": "SHORT", + }, + { + "name": "long_col", + "type_name": "LONG", + }, + { + "name": "interval_ym_col", + "type_name": "INTERVAL", + "type_interval_type": "YEAR TO MONTH", + }, + { + "name": "interval_dt_col", + "type_name": "INTERVAL", + "type_interval_type": "DAY TO SECOND", + }, + { + "name": "interval_default_col", + "type_name": "INTERVAL", + # No type_interval_type field + }, + ] + } + + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is not None + assert len(description) == 6 + + # Check normalized types + assert description[0][0] == "byte_col" + assert description[0][1] == "tinyint" # BYTE -> tinyint + + assert description[1][0] == "short_col" + assert description[1][1] == "smallint" # SHORT -> smallint + + assert description[2][0] == "long_col" + assert description[2][1] == "bigint" # LONG -> bigint + + assert description[3][0] == "interval_ym_col" + assert description[3][1] == "interval_year_month" # INTERVAL with YEAR/MONTH + + assert description[4][0] == "interval_dt_col" + assert description[4][1] == "interval_day_time" # INTERVAL with DAY/TIME + + assert description[5][0] == "interval_default_col" + assert description[5][1] == "interval" # INTERVAL without subtype + def test_filter_session_configuration(self): """Test that _filter_session_configuration converts all values to strings.""" session_config = { diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py index 13970c5d..234cca86 100644 --- a/tests/unit/test_sea_conversion.py +++ b/tests/unit/test_sea_conversion.py @@ -18,59 +18,62 @@ class TestSqlTypeConverter: def test_convert_numeric_types(self): """Test converting numeric types.""" # Test integer types - assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123 - assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456 - assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789 - assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890 + assert SqlTypeConverter.convert_value("123", SqlType.TINYINT, None) == 123 + assert SqlTypeConverter.convert_value("456", SqlType.SMALLINT, None) == 456 + assert SqlTypeConverter.convert_value("789", SqlType.INT, None) == 789 + assert ( + SqlTypeConverter.convert_value("1234567890", SqlType.BIGINT, None) + == 1234567890 + ) # Test floating point types - assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45 - assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90 + assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT, None) == 123.45 + assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE, None) == 678.90 # Test decimal type - decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL) + decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL, None) assert isinstance(decimal_value, decimal.Decimal) assert decimal_value == decimal.Decimal("123.45") # Test decimal with precision and scale decimal_value = SqlTypeConverter.convert_value( - "123.45", SqlType.DECIMAL, precision=5, scale=2 + "123.45", SqlType.DECIMAL, None, precision=5, scale=2 ) assert isinstance(decimal_value, decimal.Decimal) assert decimal_value == decimal.Decimal("123.45") # Test invalid numeric input - result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT) + result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT, None) assert result == "not_a_number" # Returns original value on error def test_convert_boolean_type(self): """Test converting boolean types.""" # True values - assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True - assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True - assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True - assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True - assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True - assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN, None) is True + assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN, None) is True + assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN, None) is True + assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN, None) is True + assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN, None) is True + assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN, None) is True # False values - assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False - assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False - assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False - assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False - assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False - assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN, None) is False + assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN, None) is False + assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN, None) is False + assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN, None) is False + assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN, None) is False + assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN, None) is False def test_convert_datetime_types(self): """Test converting datetime types.""" # Test date type - date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE) + date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE, None) assert isinstance(date_value, datetime.date) assert date_value == datetime.date(2023, 1, 15) # Test timestamp type timestamp_value = SqlTypeConverter.convert_value( - "2023-01-15T12:30:45", SqlType.TIMESTAMP + "2023-01-15T12:30:45", SqlType.TIMESTAMP, None ) assert isinstance(timestamp_value, datetime.datetime) assert timestamp_value.year == 2023 @@ -80,51 +83,67 @@ def test_convert_datetime_types(self): assert timestamp_value.minute == 30 assert timestamp_value.second == 45 - # Test interval type (currently returns as string) - interval_value = SqlTypeConverter.convert_value( - "1 day 2 hours", SqlType.INTERVAL + # Test interval types (currently return as string) + interval_ym_value = SqlTypeConverter.convert_value( + "1-6", SqlType.INTERVAL_YEAR_MONTH, None + ) + assert interval_ym_value == "1-6" + + interval_dt_value = SqlTypeConverter.convert_value( + "1 day 2 hours", SqlType.INTERVAL_DAY_TIME, None ) - assert interval_value == "1 day 2 hours" + assert interval_dt_value == "1 day 2 hours" # Test invalid date input - result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE) + result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE, None) assert result == "not_a_date" # Returns original value on error def test_convert_string_types(self): """Test converting string types.""" # String types don't need conversion, they should be returned as-is assert ( - SqlTypeConverter.convert_value("test string", SqlType.STRING) + SqlTypeConverter.convert_value("test string", SqlType.STRING, None) == "test string" ) - assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char" + assert ( + SqlTypeConverter.convert_value("test char", SqlType.CHAR, None) + == "test char" + ) + assert ( + SqlTypeConverter.convert_value("test varchar", SqlType.VARCHAR, None) + == "test varchar" + ) def test_convert_binary_type(self): """Test converting binary type.""" # Test valid hex string - binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY) + binary_value = SqlTypeConverter.convert_value( + "48656C6C6F", SqlType.BINARY, None + ) assert isinstance(binary_value, bytes) assert binary_value == b"Hello" # Test invalid binary input - result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY) + result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY, None) assert result == "not_hex" # Returns original value on error def test_convert_unsupported_type(self): """Test converting an unsupported type.""" # Should return the original value - assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test" + assert ( + SqlTypeConverter.convert_value("test", "unsupported_type", None) == "test" + ) - # Complex types should return as-is + # Complex types should return as-is (not yet implemented in TYPE_MAPPING) assert ( - SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY) + SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY, None) == "complex_value" ) assert ( - SqlTypeConverter.convert_value("complex_value", SqlType.MAP) + SqlTypeConverter.convert_value("complex_value", SqlType.MAP, None) == "complex_value" ) assert ( - SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT) + SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT, None) == "complex_value" ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c42e6665..1c3e3b5b 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -565,50 +565,3 @@ def test_fetchall_empty_arrow_queue(self, result_set_with_arrow_queue): # Verify _convert_arrow_table was called result_set_with_arrow_queue._convert_arrow_table.assert_called_once() - - @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") - def test_convert_json_types_with_errors( - self, mock_convert_value, result_set_with_data - ): - """Test error handling in _convert_json_types.""" - # Mock the conversion to fail for the second and third values - mock_convert_value.side_effect = [ - "value1", # First value converts normally - Exception("Invalid int"), # Second value fails - Exception("Invalid boolean"), # Third value fails - ] - - # Data with invalid values - data_row = ["value1", "not_an_int", "not_a_boolean"] - - # Should not raise an exception but log warnings - result = result_set_with_data._convert_json_types(data_row) - - # The first value should be converted normally - assert result[0] == "value1" - - # The invalid values should remain as strings - assert result[1] == "not_an_int" - assert result[2] == "not_a_boolean" - - @patch("databricks.sql.backend.sea.result_set.logger") - @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") - def test_convert_json_types_with_logging( - self, mock_convert_value, mock_logger, result_set_with_data - ): - """Test that errors in _convert_json_types are logged.""" - # Mock the conversion to fail for the second and third values - mock_convert_value.side_effect = [ - "value1", # First value converts normally - Exception("Invalid int"), # Second value fails - Exception("Invalid boolean"), # Third value fails - ] - - # Data with invalid values - data_row = ["value1", "not_an_int", "not_a_boolean"] - - # Call the method - result_set_with_data._convert_json_types(data_row) - - # Verify warnings were logged - assert mock_logger.warning.call_count == 2 diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index 11055b55..94137c5b 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,7 +6,8 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' +PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" + def create_mock_conn(responses): """Creates a mock connection object whose getresponse() method yields a series of responses.""" @@ -16,15 +17,18 @@ def create_mock_conn(responses): mock_http_response = MagicMock() mock_http_response.status = resp.get("status") mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b'{}') + body = resp.get("body", b"{}") mock_http_response.fp = io.BytesIO(body) + def release(): mock_http_response.fp.close() + mock_http_response.release_conn = release mock_http_responses.append(mock_http_response) mock_conn.getresponse.side_effect = mock_http_responses return mock_conn + class TestTelemetryClientRetries: @pytest.fixture(autouse=True) def setup_and_teardown(self): @@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3): host_url="test.databricks.com", ) client = TelemetryClientFactory.get_telemetry_client(session_id) - + retry_policy = DatabricksRetryPolicy( delay_min=0.01, delay_max=0.02, stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, + stop_after_attempts_count=num_retries, delay_default=0.1, force_dangerous_codes=[], - urllib3_kwargs={'total': num_retries} + urllib3_kwargs={"total": num_retries}, ) adapter = client._http_client.session.adapters.get("https://") adapter.max_retries = retry_policy return client @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], ) def test_non_retryable_status_codes_are_not_retried(self, status_code, description): """ @@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) @@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self): Verifies that the client respects the Retry-After header and retries on 429, 502, 503. """ num_retries = 3 - expected_total_calls = num_retries + 1 + expected_total_calls = num_retries + 1 retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + mock_responses = [ + {"status": 503, "headers": {"Retry-After": str(retry_after)}}, + {"status": 429}, + {"status": 502}, + {"status": 503}, + ] + + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: start_time = time.time() client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) end_time = time.time() - - assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls - assert end_time - start_time > retry_after \ No newline at end of file + + assert ( + mock_get_conn.return_value.getresponse.call_count + == expected_total_calls + ) + assert end_time - start_time > retry_after