-
Notifications
You must be signed in to change notification settings - Fork 112
Add retry mechanism to telemetry requests #617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: telemetry
Are you sure you want to change the base?
Changes from all commits
4b6e331
1a9c253
41bdfe7
a92821b
e684cc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,8 @@ | |
DatabricksOAuthProvider, | ||
ExternalAuthProvider, | ||
) | ||
from requests.adapters import HTTPAdapter | ||
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType | ||
import sys | ||
import platform | ||
import uuid | ||
|
@@ -31,6 +33,24 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TelemetryHTTPAdapter(HTTPAdapter): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's see if we can skip this adapter |
||
""" | ||
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. | ||
This ensures the retry timer is started and the command type is set correctly, | ||
allowing the policy to manage its state for the duration of the request retries. | ||
""" | ||
|
||
def send(self, request, **kwargs): | ||
# The DatabricksRetryPolicy needs state set before the first attempt. | ||
if isinstance(self.max_retries, DatabricksRetryPolicy): | ||
# Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER | ||
# to signal this to the retry policy, bypassing stricter rules for commands | ||
# like ExecuteStatement. | ||
self.max_retries.command_type = CommandType.OTHER | ||
self.max_retries.start_retry_timer() | ||
return super().send(request, **kwargs) | ||
|
||
|
||
class TelemetryHelper: | ||
"""Helper class for getting telemetry related information.""" | ||
|
||
|
@@ -146,6 +166,11 @@ class TelemetryClient(BaseTelemetryClient): | |
It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. | ||
""" | ||
|
||
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 | ||
TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds | ||
TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds | ||
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's align these values across drivers or have these consistent across the python driver |
||
|
||
# Telemetry endpoint paths | ||
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" | ||
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" | ||
|
@@ -170,6 +195,18 @@ def __init__( | |
self._host_url = host_url | ||
self._executor = executor | ||
|
||
self._telemetry_retry_policy = DatabricksRetryPolicy( | ||
delay_min=self.TELEMETRY_RETRY_DELAY_MIN, | ||
delay_max=self.TELEMETRY_RETRY_DELAY_MAX, | ||
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, | ||
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, | ||
delay_default=1.0, # Not directly used by telemetry, but required by constructor | ||
force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes | ||
) | ||
self._session = requests.Session() | ||
adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) | ||
self._session.mount("https://", adapter) | ||
|
||
def _export_event(self, event): | ||
"""Add an event to the batch queue and flush if batch is full""" | ||
logger.debug("Exporting event for connection %s", self._session_id_hex) | ||
|
@@ -215,7 +252,7 @@ def _send_telemetry(self, events): | |
try: | ||
logger.debug("Submitting telemetry request to thread pool") | ||
future = self._executor.submit( | ||
requests.post, | ||
self._session.post, | ||
url, | ||
data=json.dumps(request), | ||
headers=headers, | ||
|
@@ -303,6 +340,7 @@ def close(self): | |
"""Flush remaining events before closing""" | ||
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) | ||
self._flush() | ||
self._session.close() | ||
|
||
|
||
class TelemetryClientFactory: | ||
|
@@ -402,7 +440,7 @@ def get_telemetry_client(session_id_hex): | |
if session_id_hex in TelemetryClientFactory._clients: | ||
return TelemetryClientFactory._clients[session_id_hex] | ||
else: | ||
logger.error( | ||
logger.debug( | ||
"Telemetry client not initialized for connection %s", | ||
session_id_hex, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# tests/e2e/test_telemetry_retry.py | ||
|
||
import pytest | ||
import logging | ||
from unittest.mock import patch, MagicMock | ||
from functools import wraps | ||
import time | ||
from concurrent.futures import Future | ||
|
||
# Imports for the code being tested | ||
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory | ||
from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType | ||
from databricks.sql.telemetry.models.enums import AuthMech | ||
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType | ||
|
||
# Imports for mocking the network layer correctly | ||
from urllib3.connectionpool import HTTPSConnectionPool | ||
from urllib3.exceptions import MaxRetryError | ||
from requests.exceptions import ConnectionError as RequestsConnectionError | ||
|
||
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' | ||
|
||
# Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse. | ||
def create_urllib3_response(status, headers=None, body=b'{}'): | ||
"""Create a proper mock response that simulates urllib3's HTTPResponse""" | ||
mock_response = MagicMock() | ||
mock_response.status = status | ||
mock_response.headers = headers or {} | ||
mock_response.msg = headers or {} # For urllib3~=1.0 compatibility | ||
mock_response.data = body | ||
mock_response.read.return_value = body | ||
mock_response.get_redirect_location.return_value = False | ||
mock_response.closed = False | ||
mock_response.isclosed.return_value = False | ||
return mock_response | ||
|
||
@pytest.mark.usefixtures("caplog") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need caplog? |
||
class TestTelemetryClientRetries: | ||
""" | ||
Test suite for verifying the retry mechanism of the TelemetryClient. | ||
This suite patches the low-level urllib3 connection to correctly | ||
trigger and test the retry logic configured in the requests adapter. | ||
""" | ||
|
||
@pytest.fixture(autouse=True) | ||
def setup_and_teardown(self, caplog): | ||
caplog.set_level(logging.DEBUG) | ||
TelemetryClientFactory._initialized = False | ||
TelemetryClientFactory._clients = {} | ||
TelemetryClientFactory._executor = None | ||
yield | ||
if TelemetryClientFactory._executor: | ||
TelemetryClientFactory._executor.shutdown(wait=True) | ||
TelemetryClientFactory._initialized = False | ||
TelemetryClientFactory._clients = {} | ||
TelemetryClientFactory._executor = None | ||
|
||
def get_client(self, session_id, total_retries=3): | ||
TelemetryClientFactory.initialize_telemetry_client( | ||
telemetry_enabled=True, | ||
session_id_hex=session_id, | ||
auth_provider=None, | ||
host_url="test.databricks.com", | ||
) | ||
client = TelemetryClientFactory.get_telemetry_client(session_id) | ||
|
||
retry_policy = DatabricksRetryPolicy( | ||
delay_min=0.01, | ||
delay_max=0.02, | ||
stop_after_attempts_duration=2.0, | ||
stop_after_attempts_count=total_retries, | ||
delay_default=0.1, | ||
force_dangerous_codes=[], | ||
urllib3_kwargs={'total': total_retries} | ||
) | ||
adapter = client._session.adapters.get("https://") | ||
adapter.max_retries = retry_policy | ||
return client, adapter | ||
|
||
def wait_for_async_request(self, timeout=2.0): | ||
"""Wait for async telemetry request to complete""" | ||
start_time = time.time() | ||
while time.time() - start_time < timeout: | ||
if TelemetryClientFactory._executor and TelemetryClientFactory._executor._threads: | ||
# Wait a bit more for threads to complete | ||
time.sleep(0.1) | ||
else: | ||
break | ||
time.sleep(0.1) # Extra buffer for completion | ||
|
||
def test_success_no_retry(self): | ||
client, _ = self.get_client("session-success") | ||
params = DriverConnectionParameters( | ||
http_path="test-path", | ||
mode=DatabricksClientType.THRIFT, | ||
host_info=HostDetails(host_url="test.databricks.com", port=443), | ||
auth_mech=AuthMech.PAT | ||
) | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) | ||
|
||
client.export_initial_telemetry_log(params, "test-agent") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
mock_get_conn.return_value.getresponse.assert_called_once() | ||
|
||
def test_retry_on_503_then_succeeds(self): | ||
client, _ = self.get_client("session-retry-once") | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.side_effect = [ | ||
create_urllib3_response(503), | ||
create_urllib3_response(200), | ||
] | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
assert mock_get_conn.return_value.getresponse.call_count == 2 | ||
|
||
def test_respects_retry_after_header(self, caplog): | ||
client, _ = self.get_client("session-retry-after") | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.side_effect = [ | ||
create_urllib3_response(429, headers={'Retry-After': '1'}), # Use integer seconds to avoid parsing issues | ||
create_urllib3_response(200) | ||
] | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
# Check that the request was retried (should be 2 calls: initial + 1 retry) | ||
assert mock_get_conn.return_value.getresponse.call_count == 2 | ||
assert "Retrying after" in caplog.text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we collapse these tests a bit i.e. test 503, 429, retry-after header, retry count etc in a single test (first response returns 503, second does 429 etc), appreciate the detailed testing but i would like it to be more readable and maintainable in the long term |
||
|
||
def test_exceeds_retry_count_limit(self, caplog): | ||
client, _ = self.get_client("session-exceed-limit", total_retries=3) | ||
expected_call_count = 4 | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(503) | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
assert mock_get_conn.return_value.getresponse.call_count == expected_call_count | ||
assert "Telemetry request failed with exception" in caplog.text | ||
assert "Max retries exceeded" in caplog.text | ||
|
||
def test_no_retry_on_401_unauthorized(self, caplog): | ||
"""Test that 401 responses are not retried (per retry policy)""" | ||
client, _ = self.get_client("session-401") | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(401) | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
# 401 should not be retried based on the retry policy | ||
mock_get_conn.return_value.getresponse.assert_called_once() | ||
assert "Telemetry request failed with status code: 401" in caplog.text | ||
|
||
def test_retries_on_400_bad_request(self, caplog): | ||
"""Test that 400 responses are retried (this is the current behavior for telemetry)""" | ||
client, _ = self.get_client("session-400") | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(400) | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
# Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 400 is being retried? are you sure? |
||
expected_call_count = 4 # total + 1 (initial + 3 retries) | ||
assert mock_get_conn.return_value.getresponse.call_count == expected_call_count | ||
assert "Telemetry request failed with exception" in caplog.text | ||
assert "Max retries exceeded" in caplog.text | ||
|
||
def test_no_retry_on_403_forbidden(self, caplog): | ||
"""Test that 403 responses are not retried (per retry policy)""" | ||
client, _ = self.get_client("session-403") | ||
with patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(403) | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
# 403 should not be retried based on the retry policy | ||
mock_get_conn.return_value.getresponse.assert_called_once() | ||
assert "Telemetry request failed with status code: 403" in caplog.text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would like to collapse 4xx error codes into a single test with something like parametrized tests |
||
|
||
def test_retry_policy_command_type_is_set_to_other(self): | ||
client, adapter = self.get_client("session-command-type") | ||
|
||
original_send = adapter.send | ||
@wraps(original_send) | ||
def wrapper(request, **kwargs): | ||
assert adapter.max_retries.command_type == CommandType.OTHER | ||
return original_send(request, **kwargs) | ||
|
||
with patch.object(adapter, 'send', side_effect=wrapper, autospec=True), \ | ||
patch(PATCH_TARGET) as mock_get_conn: | ||
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) | ||
|
||
client.export_failure_log("TestError", "Test message") | ||
self.wait_for_async_request() | ||
TelemetryClientFactory.close(client._session_id_hex) | ||
|
||
assert adapter.send.call_count == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change?