Skip to content

Add retry mechanism to telemetry requests #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: telemetry
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import logging

from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory

logger = logging.getLogger(__name__)

### PEP-249 Mandated ###
Expand All @@ -21,11 +19,11 @@ def __init__(
self.context = context or {}

error_name = self.__class__.__name__
if session_id_hex:
telemetry_client = TelemetryClientFactory.get_telemetry_client(
session_id_hex
)
telemetry_client.export_failure_log(error_name, self.message)

from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory

telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

telemetry_client.export_failure_log(error_name, self.message)

def __str__(self):
return self.message
Expand Down
42 changes: 40 additions & 2 deletions src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from requests.adapters import HTTPAdapter
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
import sys
import platform
import uuid
Expand All @@ -31,6 +33,24 @@
logger = logging.getLogger(__name__)


class TelemetryHTTPAdapter(HTTPAdapter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's see if we can skip this adapter

"""
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
This ensures the retry timer is started and the command type is set correctly,
allowing the policy to manage its state for the duration of the request retries.
"""

def send(self, request, **kwargs):
# The DatabricksRetryPolicy needs state set before the first attempt.
if isinstance(self.max_retries, DatabricksRetryPolicy):
# Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER
# to signal this to the retry policy, bypassing stricter rules for commands
# like ExecuteStatement.
self.max_retries.command_type = CommandType.OTHER
self.max_retries.start_retry_timer()
return super().send(request, **kwargs)


class TelemetryHelper:
"""Helper class for getting telemetry related information."""

Expand Down Expand Up @@ -146,6 +166,11 @@ class TelemetryClient(BaseTelemetryClient):
It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory.
"""

TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds
TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's align these values across drivers or have these consistent across the python driver


# Telemetry endpoint paths
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
Expand All @@ -170,6 +195,18 @@ def __init__(
self._host_url = host_url
self._executor = executor

self._telemetry_retry_policy = DatabricksRetryPolicy(
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
delay_default=1.0, # Not directly used by telemetry, but required by constructor
force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes
)
self._session = requests.Session()
adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy)
self._session.mount("https://", adapter)

def _export_event(self, event):
"""Add an event to the batch queue and flush if batch is full"""
logger.debug("Exporting event for connection %s", self._session_id_hex)
Expand Down Expand Up @@ -215,7 +252,7 @@ def _send_telemetry(self, events):
try:
logger.debug("Submitting telemetry request to thread pool")
future = self._executor.submit(
requests.post,
self._session.post,
url,
data=json.dumps(request),
headers=headers,
Expand Down Expand Up @@ -303,6 +340,7 @@ def close(self):
"""Flush remaining events before closing"""
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
self._flush()
self._session.close()


class TelemetryClientFactory:
Expand Down Expand Up @@ -402,7 +440,7 @@ def get_telemetry_client(session_id_hex):
if session_id_hex in TelemetryClientFactory._clients:
return TelemetryClientFactory._clients[session_id_hex]
else:
logger.error(
logger.debug(
"Telemetry client not initialized for connection %s",
session_id_hex,
)
Expand Down
213 changes: 213 additions & 0 deletions tests/e2e/test_telemetry_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# tests/e2e/test_telemetry_retry.py

import pytest
import logging
from unittest.mock import patch, MagicMock
from functools import wraps
import time
from concurrent.futures import Future

# Imports for the code being tested
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType
from databricks.sql.telemetry.models.enums import AuthMech
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType

# Imports for mocking the network layer correctly
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.exceptions import MaxRetryError
from requests.exceptions import ConnectionError as RequestsConnectionError

PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'

# Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse.
def create_urllib3_response(status, headers=None, body=b'{}'):
"""Create a proper mock response that simulates urllib3's HTTPResponse"""
mock_response = MagicMock()
mock_response.status = status
mock_response.headers = headers or {}
mock_response.msg = headers or {} # For urllib3~=1.0 compatibility
mock_response.data = body
mock_response.read.return_value = body
mock_response.get_redirect_location.return_value = False
mock_response.closed = False
mock_response.isclosed.return_value = False
return mock_response

@pytest.mark.usefixtures("caplog")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need caplog?

class TestTelemetryClientRetries:
"""
Test suite for verifying the retry mechanism of the TelemetryClient.
This suite patches the low-level urllib3 connection to correctly
trigger and test the retry logic configured in the requests adapter.
"""

@pytest.fixture(autouse=True)
def setup_and_teardown(self, caplog):
caplog.set_level(logging.DEBUG)
TelemetryClientFactory._initialized = False
TelemetryClientFactory._clients = {}
TelemetryClientFactory._executor = None
yield
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._initialized = False
TelemetryClientFactory._clients = {}
TelemetryClientFactory._executor = None

def get_client(self, session_id, total_retries=3):
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=session_id,
auth_provider=None,
host_url="test.databricks.com",
)
client = TelemetryClientFactory.get_telemetry_client(session_id)

retry_policy = DatabricksRetryPolicy(
delay_min=0.01,
delay_max=0.02,
stop_after_attempts_duration=2.0,
stop_after_attempts_count=total_retries,
delay_default=0.1,
force_dangerous_codes=[],
urllib3_kwargs={'total': total_retries}
)
adapter = client._session.adapters.get("https://")
adapter.max_retries = retry_policy
return client, adapter

def wait_for_async_request(self, timeout=2.0):
"""Wait for async telemetry request to complete"""
start_time = time.time()
while time.time() - start_time < timeout:
if TelemetryClientFactory._executor and TelemetryClientFactory._executor._threads:
# Wait a bit more for threads to complete
time.sleep(0.1)
else:
break
time.sleep(0.1) # Extra buffer for completion

def test_success_no_retry(self):
client, _ = self.get_client("session-success")
params = DriverConnectionParameters(
http_path="test-path",
mode=DatabricksClientType.THRIFT,
host_info=HostDetails(host_url="test.databricks.com", port=443),
auth_mech=AuthMech.PAT
)
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200)

client.export_initial_telemetry_log(params, "test-agent")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

mock_get_conn.return_value.getresponse.assert_called_once()

def test_retry_on_503_then_succeeds(self):
client, _ = self.get_client("session-retry-once")
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.side_effect = [
create_urllib3_response(503),
create_urllib3_response(200),
]

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

assert mock_get_conn.return_value.getresponse.call_count == 2

def test_respects_retry_after_header(self, caplog):
client, _ = self.get_client("session-retry-after")
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.side_effect = [
create_urllib3_response(429, headers={'Retry-After': '1'}), # Use integer seconds to avoid parsing issues
create_urllib3_response(200)
]

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

# Check that the request was retried (should be 2 calls: initial + 1 retry)
assert mock_get_conn.return_value.getresponse.call_count == 2
assert "Retrying after" in caplog.text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we collapse these tests a bit i.e. test 503, 429, retry-after header, retry count etc in a single test (first response returns 503, second does 429 etc), appreciate the detailed testing but i would like it to be more readable and maintainable in the long term


def test_exceeds_retry_count_limit(self, caplog):
client, _ = self.get_client("session-exceed-limit", total_retries=3)
expected_call_count = 4
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(503)

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

assert mock_get_conn.return_value.getresponse.call_count == expected_call_count
assert "Telemetry request failed with exception" in caplog.text
assert "Max retries exceeded" in caplog.text

def test_no_retry_on_401_unauthorized(self, caplog):
"""Test that 401 responses are not retried (per retry policy)"""
client, _ = self.get_client("session-401")
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(401)

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

# 401 should not be retried based on the retry policy
mock_get_conn.return_value.getresponse.assert_called_once()
assert "Telemetry request failed with status code: 401" in caplog.text

def test_retries_on_400_bad_request(self, caplog):
"""Test that 400 responses are retried (this is the current behavior for telemetry)"""
client, _ = self.get_client("session-400")
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(400)

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

# Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

400 is being retried? are you sure?

expected_call_count = 4 # total + 1 (initial + 3 retries)
assert mock_get_conn.return_value.getresponse.call_count == expected_call_count
assert "Telemetry request failed with exception" in caplog.text
assert "Max retries exceeded" in caplog.text

def test_no_retry_on_403_forbidden(self, caplog):
"""Test that 403 responses are not retried (per retry policy)"""
client, _ = self.get_client("session-403")
with patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(403)

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

# 403 should not be retried based on the retry policy
mock_get_conn.return_value.getresponse.assert_called_once()
assert "Telemetry request failed with status code: 403" in caplog.text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to collapse 4xx error codes into a single test with something like parametrized tests


def test_retry_policy_command_type_is_set_to_other(self):
client, adapter = self.get_client("session-command-type")

original_send = adapter.send
@wraps(original_send)
def wrapper(request, **kwargs):
assert adapter.max_retries.command_type == CommandType.OTHER
return original_send(request, **kwargs)

with patch.object(adapter, 'send', side_effect=wrapper, autospec=True), \
patch(PATCH_TARGET) as mock_get_conn:
mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200)

client.export_failure_log("TestError", "Test message")
self.wait_for_async_request()
TelemetryClientFactory.close(client._session_id_hex)

assert adapter.send.call_count == 1
27 changes: 23 additions & 4 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup):
client._flush.assert_called_once()
assert len(client._events_batch) == 10

@patch("requests.post")
@patch("requests.Session.post")
def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
"""Test sending telemetry to the server with authentication."""
client = telemetry_client_setup["client"]
Expand All @@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):

executor.submit.assert_called_once()
args, kwargs = executor.submit.call_args
assert args[0] == requests.post
assert args[0] == client._session.post
assert kwargs["timeout"] == 10
assert "Authorization" in kwargs["headers"]
assert kwargs["headers"]["Authorization"] == "Bearer test-token"

@patch("requests.post")
@patch("requests.Session.post")
def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup):
"""Test sending telemetry to the server without authentication."""
host_url = telemetry_client_setup["host_url"]
Expand All @@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup)

executor.submit.assert_called_once()
args, kwargs = executor.submit.call_args
assert args[0] == requests.post
assert args[0] == unauthenticated_client._session.post
assert kwargs["timeout"] == 10
assert "Authorization" not in kwargs["headers"] # No auth header
assert kwargs["headers"]["Accept"] == "application/json"
Expand Down Expand Up @@ -331,6 +331,25 @@ class TestBaseClient(BaseTelemetryClient):
with pytest.raises(TypeError):
TestBaseClient() # Can't instantiate abstract class

def test_telemetry_http_adapter_configuration(self, telemetry_client_setup):
"""Test that TelemetryHTTPAdapter is properly configured with correct retry parameters."""
from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter
from databricks.sql.auth.retry import DatabricksRetryPolicy

client = telemetry_client_setup["client"]

# Verify that the session has the TelemetryHTTPAdapter mounted
adapter = client._session.adapters.get("https://")
assert isinstance(adapter, TelemetryHTTPAdapter)
assert isinstance(adapter.max_retries, DatabricksRetryPolicy)

# Verify that the retry policy has the correct static configuration
retry_policy = adapter.max_retries
assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN
assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX
assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT
assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION


class TestTelemetryHelper:
"""Tests for the TelemetryHelper class."""
Expand Down
Loading