diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 30fd6c26..fbdbe6a5 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,8 +1,6 @@ import json import logging -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - logger = logging.getLogger(__name__) ### PEP-249 Mandated ### @@ -21,11 +19,11 @@ def __init__( self.context = context or {} error_name = self.__class__.__name__ - if session_id_hex: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_failure_log(error_name, self.message) + + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + + telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..83435db6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -22,6 +22,8 @@ DatabricksOAuthProvider, ExternalAuthProvider, ) +from requests.adapters import HTTPAdapter +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType import sys import platform import uuid @@ -31,6 +33,24 @@ logger = logging.getLogger(__name__) +class TelemetryHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. + This ensures the retry timer is started and the command type is set correctly, + allowing the policy to manage its state for the duration of the request retries. + """ + + def send(self, request, **kwargs): + # The DatabricksRetryPolicy needs state set before the first attempt. + if isinstance(self.max_retries, DatabricksRetryPolicy): + # Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER + # to signal this to the retry policy, bypassing stricter rules for commands + # like ExecuteStatement. + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() + return super().send(request, **kwargs) + + class TelemetryHelper: """Helper class for getting telemetry related information.""" @@ -146,6 +166,11 @@ class TelemetryClient(BaseTelemetryClient): It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. """ + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 + TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds + TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 + # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" @@ -170,6 +195,18 @@ def __init__( self._host_url = host_url self._executor = executor + self._telemetry_retry_policy = DatabricksRetryPolicy( + delay_min=self.TELEMETRY_RETRY_DELAY_MIN, + delay_max=self.TELEMETRY_RETRY_DELAY_MAX, + stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, + stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, + delay_default=1.0, # Not directly used by telemetry, but required by constructor + force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes + ) + self._session = requests.Session() + adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) + self._session.mount("https://", adapter) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -215,7 +252,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") future = self._executor.submit( - requests.post, + self._session.post, url, data=json.dumps(request), headers=headers, @@ -303,6 +340,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._session.close() class TelemetryClientFactory: @@ -402,7 +440,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py new file mode 100644 index 00000000..57274ed2 --- /dev/null +++ b/tests/e2e/test_telemetry_retry.py @@ -0,0 +1,213 @@ +# tests/e2e/test_telemetry_retry.py + +import pytest +import logging +from unittest.mock import patch, MagicMock +from functools import wraps +import time +from concurrent.futures import Future + +# Imports for the code being tested +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType +from databricks.sql.telemetry.models.enums import AuthMech +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType + +# Imports for mocking the network layer correctly +from urllib3.connectionpool import HTTPSConnectionPool +from urllib3.exceptions import MaxRetryError +from requests.exceptions import ConnectionError as RequestsConnectionError + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +# Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse. +def create_urllib3_response(status, headers=None, body=b'{}'): + """Create a proper mock response that simulates urllib3's HTTPResponse""" + mock_response = MagicMock() + mock_response.status = status + mock_response.headers = headers or {} + mock_response.msg = headers or {} # For urllib3~=1.0 compatibility + mock_response.data = body + mock_response.read.return_value = body + mock_response.get_redirect_location.return_value = False + mock_response.closed = False + mock_response.isclosed.return_value = False + return mock_response + +@pytest.mark.usefixtures("caplog") +class TestTelemetryClientRetries: + """ + Test suite for verifying the retry mechanism of the TelemetryClient. + This suite patches the low-level urllib3 connection to correctly + trigger and test the retry logic configured in the requests adapter. + """ + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, caplog): + caplog.set_level(logging.DEBUG) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, total_retries=3): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=total_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': total_retries} + ) + adapter = client._session.adapters.get("https://") + adapter.max_retries = retry_policy + return client, adapter + + def wait_for_async_request(self, timeout=2.0): + """Wait for async telemetry request to complete""" + start_time = time.time() + while time.time() - start_time < timeout: + if TelemetryClientFactory._executor and TelemetryClientFactory._executor._threads: + # Wait a bit more for threads to complete + time.sleep(0.1) + else: + break + time.sleep(0.1) # Extra buffer for completion + + def test_success_no_retry(self): + client, _ = self.get_client("session-success") + params = DriverConnectionParameters( + http_path="test-path", + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url="test.databricks.com", port=443), + auth_mech=AuthMech.PAT + ) + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) + + client.export_initial_telemetry_log(params, "test-agent") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_retry_on_503_then_succeeds(self): + client, _ = self.get_client("session-retry-once") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.side_effect = [ + create_urllib3_response(503), + create_urllib3_response(200), + ] + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert mock_get_conn.return_value.getresponse.call_count == 2 + + def test_respects_retry_after_header(self, caplog): + client, _ = self.get_client("session-retry-after") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.side_effect = [ + create_urllib3_response(429, headers={'Retry-After': '1'}), # Use integer seconds to avoid parsing issues + create_urllib3_response(200) + ] + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # Check that the request was retried (should be 2 calls: initial + 1 retry) + assert mock_get_conn.return_value.getresponse.call_count == 2 + assert "Retrying after" in caplog.text + + def test_exceeds_retry_count_limit(self, caplog): + client, _ = self.get_client("session-exceed-limit", total_retries=3) + expected_call_count = 4 + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(503) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert mock_get_conn.return_value.getresponse.call_count == expected_call_count + assert "Telemetry request failed with exception" in caplog.text + assert "Max retries exceeded" in caplog.text + + def test_no_retry_on_401_unauthorized(self, caplog): + """Test that 401 responses are not retried (per retry policy)""" + client, _ = self.get_client("session-401") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(401) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # 401 should not be retried based on the retry policy + mock_get_conn.return_value.getresponse.assert_called_once() + assert "Telemetry request failed with status code: 401" in caplog.text + + def test_retries_on_400_bad_request(self, caplog): + """Test that 400 responses are retried (this is the current behavior for telemetry)""" + client, _ = self.get_client("session-400") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(400) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER) + expected_call_count = 4 # total + 1 (initial + 3 retries) + assert mock_get_conn.return_value.getresponse.call_count == expected_call_count + assert "Telemetry request failed with exception" in caplog.text + assert "Max retries exceeded" in caplog.text + + def test_no_retry_on_403_forbidden(self, caplog): + """Test that 403 responses are not retried (per retry policy)""" + client, _ = self.get_client("session-403") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(403) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # 403 should not be retried based on the retry policy + mock_get_conn.return_value.getresponse.assert_called_once() + assert "Telemetry request failed with status code: 403" in caplog.text + + def test_retry_policy_command_type_is_set_to_other(self): + client, adapter = self.get_client("session-command-type") + + original_send = adapter.send + @wraps(original_send) + def wrapper(request, **kwargs): + assert adapter.max_retries.command_type == CommandType.OTHER + return original_send(request, **kwargs) + + with patch.object(adapter, 'send', side_effect=wrapper, autospec=True), \ + patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert adapter.send.call_count == 1 \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..8a8f974c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup): client._flush.assert_called_once() assert len(client._events_batch) == 10 - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server with authentication.""" client = telemetry_client_setup["client"] @@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == client._session.post assert kwargs["timeout"] == 10 assert "Authorization" in kwargs["headers"] assert kwargs["headers"]["Authorization"] == "Bearer test-token" - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server without authentication.""" host_url = telemetry_client_setup["host_url"] @@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == unauthenticated_client._session.post assert kwargs["timeout"] == 10 assert "Authorization" not in kwargs["headers"] # No auth header assert kwargs["headers"]["Accept"] == "application/json" @@ -331,6 +331,25 @@ class TestBaseClient(BaseTelemetryClient): with pytest.raises(TypeError): TestBaseClient() # Can't instantiate abstract class + def test_telemetry_http_adapter_configuration(self, telemetry_client_setup): + """Test that TelemetryHTTPAdapter is properly configured with correct retry parameters.""" + from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter + from databricks.sql.auth.retry import DatabricksRetryPolicy + + client = telemetry_client_setup["client"] + + # Verify that the session has the TelemetryHTTPAdapter mounted + adapter = client._session.adapters.get("https://") + assert isinstance(adapter, TelemetryHTTPAdapter) + assert isinstance(adapter.max_retries, DatabricksRetryPolicy) + + # Verify that the retry policy has the correct static configuration + retry_policy = adapter.max_retries + assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN + assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX + assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT + assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION + class TestTelemetryHelper: """Tests for the TelemetryHelper class."""