From e6c68bbbf8cd022302376c7af27ff15b26d19dab Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 25 Mar 2025 02:13:09 +0400 Subject: [PATCH 1/6] Fix cyclic dependency between clients and sub clients While the Python GC can handle those dependencies, it can cause latency spikes when .option() is used for each query, which causes many clients to be garbage collected, which can be slower in the presence of cycles. --- elasticsearch/_async/client/__init__.py | 178 ++++++++++++++---------- elasticsearch/_async/client/_base.py | 67 +++------ elasticsearch/_sync/client/__init__.py | 178 ++++++++++++++---------- elasticsearch/_sync/client/_base.py | 67 +++------ 4 files changed, 250 insertions(+), 240 deletions(-) diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index f88bb0190..27db48ca4 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -20,10 +20,12 @@ import typing as t from elastic_transport import ( + ApiResponse, AsyncTransport, BaseNode, BinaryApiResponse, HeadApiResponse, + HttpHeaders, NodeConfig, NodePool, NodeSelector, @@ -97,7 +99,7 @@ SelfType = t.TypeVar("SelfType", bound="AsyncElasticsearch") -class AsyncElasticsearch(BaseClient): +class AsyncElasticsearch: """ Elasticsearch low-level client. Provides a straightforward mapping from Python to Elasticsearch REST APIs. @@ -224,6 +226,18 @@ def __init__( ): sniff_callback = default_sniff_callback + headers = HttpHeaders() + if headers is not DEFAULT and headers is not None: + headers.update(headers) + if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] + headers["x-opaque-id"] = opaque_id + headers = resolve_auth_headers( + headers, + api_key=api_key, + basic_auth=basic_auth, + bearer_auth=bearer_auth, + ) + if _transport is None: node_configs = client_node_configs( hosts, @@ -295,72 +309,92 @@ def __init__( **transport_kwargs, ) - super().__init__(_transport) + self._base_client = BaseClient(_transport, headers=headers) # These are set per-request so are stored separately. - self._request_timeout = request_timeout - self._max_retries = max_retries - self._retry_on_timeout = retry_on_timeout + self._base_client._request_timeout = request_timeout + self._base_client._max_retries = max_retries + self._base_client._retry_on_timeout = retry_on_timeout if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - self._retry_on_status = retry_on_status + self._base_client._retry_on_status = retry_on_status else: - super().__init__(_transport) - - if headers is not DEFAULT and headers is not None: - self._headers.update(headers) - if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - self._headers["x-opaque-id"] = opaque_id - self._headers = resolve_auth_headers( - self._headers, - api_key=api_key, - basic_auth=basic_auth, - bearer_auth=bearer_auth, - ) + self._base_client = BaseClient(_transport, headers=headers) # namespaced clients for compatibility with API names - self.async_search = AsyncSearchClient(self) - self.autoscaling = AutoscalingClient(self) - self.cat = CatClient(self) - self.cluster = ClusterClient(self) - self.connector = ConnectorClient(self) - self.fleet = FleetClient(self) - self.features = FeaturesClient(self) - self.indices = IndicesClient(self) - self.inference = InferenceClient(self) - self.ingest = IngestClient(self) - self.nodes = NodesClient(self) - self.snapshot = SnapshotClient(self) - self.tasks = TasksClient(self) - - self.xpack = XPackClient(self) - self.ccr = CcrClient(self) - self.dangling_indices = DanglingIndicesClient(self) - self.enrich = EnrichClient(self) - self.eql = EqlClient(self) - self.esql = EsqlClient(self) - self.graph = GraphClient(self) - self.ilm = IlmClient(self) - self.license = LicenseClient(self) - self.logstash = LogstashClient(self) - self.migration = MigrationClient(self) - self.ml = MlClient(self) - self.monitoring = MonitoringClient(self) - self.query_rules = QueryRulesClient(self) - self.rollup = RollupClient(self) - self.search_application = SearchApplicationClient(self) - self.searchable_snapshots = SearchableSnapshotsClient(self) - self.security = SecurityClient(self) - self.slm = SlmClient(self) - self.simulate = SimulateClient(self) - self.shutdown = ShutdownClient(self) - self.sql = SqlClient(self) - self.ssl = SslClient(self) - self.synonyms = SynonymsClient(self) - self.text_structure = TextStructureClient(self) - self.transform = TransformClient(self) - self.watcher = WatcherClient(self) + self.async_search = AsyncSearchClient(self._base_client) + self.autoscaling = AutoscalingClient(self._base_client) + self.cat = CatClient(self._base_client) + self.cluster = ClusterClient(self._base_client) + self.connector = ConnectorClient(self._base_client) + self.fleet = FleetClient(self._base_client) + self.features = FeaturesClient(self._base_client) + self.indices = IndicesClient(self._base_client) + self.inference = InferenceClient(self._base_client) + self.ingest = IngestClient(self._base_client) + self.nodes = NodesClient(self._base_client) + self.snapshot = SnapshotClient(self._base_client) + self.tasks = TasksClient(self._base_client) + + self.xpack = XPackClient(self._base_client) + self.ccr = CcrClient(self._base_client) + self.dangling_indices = DanglingIndicesClient(self._base_client) + self.enrich = EnrichClient(self._base_client) + self.eql = EqlClient(self._base_client) + self.esql = EsqlClient(self._base_client) + self.graph = GraphClient(self._base_client) + self.ilm = IlmClient(self._base_client) + self.license = LicenseClient(self._base_client) + self.logstash = LogstashClient(self._base_client) + self.migration = MigrationClient(self._base_client) + self.ml = MlClient(self._base_client) + self.monitoring = MonitoringClient(self._base_client) + self.query_rules = QueryRulesClient(self._base_client) + self.rollup = RollupClient(self._base_client) + self.search_application = SearchApplicationClient(self._base_client) + self.searchable_snapshots = SearchableSnapshotsClient(self._base_client) + self.security = SecurityClient(self._base_client) + self.slm = SlmClient(self._base_client) + self.simulate = SimulateClient(self._base_client) + self.shutdown = ShutdownClient(self._base_client) + self.sql = SqlClient(self._base_client) + self.ssl = SslClient(self._base_client) + self.synonyms = SynonymsClient(self._base_client) + self.text_structure = TextStructureClient(self._base_client) + self.transform = TransformClient(self._base_client) + self.watcher = WatcherClient(self._base_client) + + @property + def transport(self) -> AsyncTransport: + return self._base_client._transport + + async def perform_request( + self, + method: str, + path: str, + *, + params: t.Optional[t.Mapping[str, t.Any]] = None, + headers: t.Optional[t.Mapping[str, str]] = None, + body: t.Optional[t.Any] = None, + endpoint_id: t.Optional[str] = None, + path_parts: t.Optional[t.Mapping[str, t.Any]] = None, + ) -> ApiResponse[t.Any]: + with self._base_client._otel.span( + method, + endpoint_id=endpoint_id, + path_parts=path_parts or {}, + ) as otel_span: + response = await self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response def __repr__(self) -> str: try: @@ -413,44 +447,44 @@ def options( resolved_headers["x-opaque-id"] = resolved_opaque_id if resolved_headers: - new_headers = self._headers.copy() + new_headers = self._base_client._headers.copy() new_headers.update(resolved_headers) - client._headers = new_headers + client._base_client._headers = new_headers else: - client._headers = self._headers.copy() + client._base_client._headers = self._headers.copy() if request_timeout is not DEFAULT: - client._request_timeout = request_timeout + client._base_client._request_timeout = request_timeout else: - client._request_timeout = self._request_timeout + client._base_client._request_timeout = self._base_client._request_timeout if ignore_status is not DEFAULT: if isinstance(ignore_status, int): ignore_status = (ignore_status,) - client._ignore_status = ignore_status + client._base_client._ignore_status = ignore_status else: - client._ignore_status = self._ignore_status + client._base_client._ignore_status = self._base_client._ignore_status if max_retries is not DEFAULT: if not isinstance(max_retries, int): raise TypeError("'max_retries' must be of type 'int'") - client._max_retries = max_retries + client._base_client._max_retries = max_retries else: - client._max_retries = self._max_retries + client._base_client._max_retries = self._base_client._max_retries if retry_on_status is not DEFAULT: if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - client._retry_on_status = retry_on_status + client._base_client._retry_on_status = retry_on_status else: - client._retry_on_status = self._retry_on_status + client._base_client._retry_on_status = self._base_client._retry_on_status if retry_on_timeout is not DEFAULT: if not isinstance(retry_on_timeout, bool): raise TypeError("'retry_on_timeout' must be of type 'bool'") - client._retry_on_timeout = retry_on_timeout + client._base_client._retry_on_timeout = retry_on_timeout else: - client._retry_on_timeout = self._retry_on_timeout + client._base_client._retry_on_timeout = self._base_client._retry_on_timeout return client diff --git a/elasticsearch/_async/client/_base.py b/elasticsearch/_async/client/_base.py index ed61f7bc4..208404c51 100644 --- a/elasticsearch/_async/client/_base.py +++ b/elasticsearch/_async/client/_base.py @@ -210,49 +210,17 @@ def _default_sniffed_node_callback( class BaseClient: - def __init__(self, _transport: AsyncTransport) -> None: + def __init__(self, _transport: AsyncTransport, headers: HttpHeaders) -> None: self._transport = _transport self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT - self._headers = HttpHeaders() + self._headers = headers self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT - self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() - @property - def transport(self) -> AsyncTransport: - return self._transport - - async def perform_request( - self, - method: str, - path: str, - *, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Any] = None, - endpoint_id: Optional[str] = None, - path_parts: Optional[Mapping[str, Any]] = None, - ) -> ApiResponse[Any]: - with self._otel.span( - method, - endpoint_id=endpoint_id, - path_parts=path_parts or {}, - ) as otel_span: - response = await self._perform_request( - method, - path, - params=params, - headers=headers, - body=body, - otel_span=otel_span, - ) - otel_span.set_elastic_cloud_metadata(response.meta.headers) - return response - async def _perform_request( self, method: str, @@ -287,7 +255,7 @@ def mimetype_header_to_compat(header: str) -> None: else: target = path - meta, resp_body = await self.transport.perform_request( + meta, resp_body = await self._transport.perform_request( method, target, headers=request_headers, @@ -376,10 +344,9 @@ def mimetype_header_to_compat(header: str) -> None: return response -class NamespacedClient(BaseClient): - def __init__(self, client: "BaseClient") -> None: - self._client = client - super().__init__(self._client.transport) +class NamespacedClient: + def __init__(self, client: BaseClient) -> None: + self._base_client = client async def perform_request( self, @@ -392,14 +359,18 @@ async def perform_request( endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: - # Use the internal clients .perform_request() implementation - # so we take advantage of their transport options. - return await self._client.perform_request( + with self._base_client._otel.span( method, - path, - params=params, - headers=headers, - body=body, endpoint_id=endpoint_id, - path_parts=path_parts, - ) + path_parts=path_parts or {}, + ) as otel_span: + response = await self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index b39cbae26..44d8309ae 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -20,9 +20,11 @@ import typing as t from elastic_transport import ( + ApiResponse, BaseNode, BinaryApiResponse, HeadApiResponse, + HttpHeaders, NodeConfig, NodePool, NodeSelector, @@ -97,7 +99,7 @@ SelfType = t.TypeVar("SelfType", bound="Elasticsearch") -class Elasticsearch(BaseClient): +class Elasticsearch: """ Elasticsearch low-level client. Provides a straightforward mapping from Python to Elasticsearch REST APIs. @@ -224,6 +226,18 @@ def __init__( ): sniff_callback = default_sniff_callback + headers = HttpHeaders() + if headers is not DEFAULT and headers is not None: + headers.update(headers) + if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] + headers["x-opaque-id"] = opaque_id + headers = resolve_auth_headers( + headers, + api_key=api_key, + basic_auth=basic_auth, + bearer_auth=bearer_auth, + ) + if _transport is None: node_configs = client_node_configs( hosts, @@ -295,72 +309,92 @@ def __init__( **transport_kwargs, ) - super().__init__(_transport) + self._base_client = BaseClient(_transport, headers=headers) # These are set per-request so are stored separately. - self._request_timeout = request_timeout - self._max_retries = max_retries - self._retry_on_timeout = retry_on_timeout + self._base_client._request_timeout = request_timeout + self._base_client._max_retries = max_retries + self._base_client._retry_on_timeout = retry_on_timeout if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - self._retry_on_status = retry_on_status + self._base_client._retry_on_status = retry_on_status else: - super().__init__(_transport) - - if headers is not DEFAULT and headers is not None: - self._headers.update(headers) - if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - self._headers["x-opaque-id"] = opaque_id - self._headers = resolve_auth_headers( - self._headers, - api_key=api_key, - basic_auth=basic_auth, - bearer_auth=bearer_auth, - ) + self._base_client = BaseClient(_transport, headers=headers) # namespaced clients for compatibility with API names - self.async_search = AsyncSearchClient(self) - self.autoscaling = AutoscalingClient(self) - self.cat = CatClient(self) - self.cluster = ClusterClient(self) - self.connector = ConnectorClient(self) - self.fleet = FleetClient(self) - self.features = FeaturesClient(self) - self.indices = IndicesClient(self) - self.inference = InferenceClient(self) - self.ingest = IngestClient(self) - self.nodes = NodesClient(self) - self.snapshot = SnapshotClient(self) - self.tasks = TasksClient(self) - - self.xpack = XPackClient(self) - self.ccr = CcrClient(self) - self.dangling_indices = DanglingIndicesClient(self) - self.enrich = EnrichClient(self) - self.eql = EqlClient(self) - self.esql = EsqlClient(self) - self.graph = GraphClient(self) - self.ilm = IlmClient(self) - self.license = LicenseClient(self) - self.logstash = LogstashClient(self) - self.migration = MigrationClient(self) - self.ml = MlClient(self) - self.monitoring = MonitoringClient(self) - self.query_rules = QueryRulesClient(self) - self.rollup = RollupClient(self) - self.search_application = SearchApplicationClient(self) - self.searchable_snapshots = SearchableSnapshotsClient(self) - self.security = SecurityClient(self) - self.slm = SlmClient(self) - self.simulate = SimulateClient(self) - self.shutdown = ShutdownClient(self) - self.sql = SqlClient(self) - self.ssl = SslClient(self) - self.synonyms = SynonymsClient(self) - self.text_structure = TextStructureClient(self) - self.transform = TransformClient(self) - self.watcher = WatcherClient(self) + self.async_search = AsyncSearchClient(self._base_client) + self.autoscaling = AutoscalingClient(self._base_client) + self.cat = CatClient(self._base_client) + self.cluster = ClusterClient(self._base_client) + self.connector = ConnectorClient(self._base_client) + self.fleet = FleetClient(self._base_client) + self.features = FeaturesClient(self._base_client) + self.indices = IndicesClient(self._base_client) + self.inference = InferenceClient(self._base_client) + self.ingest = IngestClient(self._base_client) + self.nodes = NodesClient(self._base_client) + self.snapshot = SnapshotClient(self._base_client) + self.tasks = TasksClient(self._base_client) + + self.xpack = XPackClient(self._base_client) + self.ccr = CcrClient(self._base_client) + self.dangling_indices = DanglingIndicesClient(self._base_client) + self.enrich = EnrichClient(self._base_client) + self.eql = EqlClient(self._base_client) + self.esql = EsqlClient(self._base_client) + self.graph = GraphClient(self._base_client) + self.ilm = IlmClient(self._base_client) + self.license = LicenseClient(self._base_client) + self.logstash = LogstashClient(self._base_client) + self.migration = MigrationClient(self._base_client) + self.ml = MlClient(self._base_client) + self.monitoring = MonitoringClient(self._base_client) + self.query_rules = QueryRulesClient(self._base_client) + self.rollup = RollupClient(self._base_client) + self.search_application = SearchApplicationClient(self._base_client) + self.searchable_snapshots = SearchableSnapshotsClient(self._base_client) + self.security = SecurityClient(self._base_client) + self.slm = SlmClient(self._base_client) + self.simulate = SimulateClient(self._base_client) + self.shutdown = ShutdownClient(self._base_client) + self.sql = SqlClient(self._base_client) + self.ssl = SslClient(self._base_client) + self.synonyms = SynonymsClient(self._base_client) + self.text_structure = TextStructureClient(self._base_client) + self.transform = TransformClient(self._base_client) + self.watcher = WatcherClient(self._base_client) + + @property + def transport(self) -> Transport: + return self._base_client._transport + + def perform_request( + self, + method: str, + path: str, + *, + params: t.Optional[t.Mapping[str, t.Any]] = None, + headers: t.Optional[t.Mapping[str, str]] = None, + body: t.Optional[t.Any] = None, + endpoint_id: t.Optional[str] = None, + path_parts: t.Optional[t.Mapping[str, t.Any]] = None, + ) -> ApiResponse[t.Any]: + with self._base_client._otel.span( + method, + endpoint_id=endpoint_id, + path_parts=path_parts or {}, + ) as otel_span: + response = self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response def __repr__(self) -> str: try: @@ -413,44 +447,44 @@ def options( resolved_headers["x-opaque-id"] = resolved_opaque_id if resolved_headers: - new_headers = self._headers.copy() + new_headers = self._base_client._headers.copy() new_headers.update(resolved_headers) - client._headers = new_headers + client._base_client._headers = new_headers else: - client._headers = self._headers.copy() + client._base_client._headers = self._headers.copy() if request_timeout is not DEFAULT: - client._request_timeout = request_timeout + client._base_client._request_timeout = request_timeout else: - client._request_timeout = self._request_timeout + client._base_client._request_timeout = self._base_client._request_timeout if ignore_status is not DEFAULT: if isinstance(ignore_status, int): ignore_status = (ignore_status,) - client._ignore_status = ignore_status + client._base_client._ignore_status = ignore_status else: - client._ignore_status = self._ignore_status + client._base_client._ignore_status = self._base_client._ignore_status if max_retries is not DEFAULT: if not isinstance(max_retries, int): raise TypeError("'max_retries' must be of type 'int'") - client._max_retries = max_retries + client._base_client._max_retries = max_retries else: - client._max_retries = self._max_retries + client._base_client._max_retries = self._base_client._max_retries if retry_on_status is not DEFAULT: if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - client._retry_on_status = retry_on_status + client._base_client._retry_on_status = retry_on_status else: - client._retry_on_status = self._retry_on_status + client._base_client._retry_on_status = self._base_client._retry_on_status if retry_on_timeout is not DEFAULT: if not isinstance(retry_on_timeout, bool): raise TypeError("'retry_on_timeout' must be of type 'bool'") - client._retry_on_timeout = retry_on_timeout + client._base_client._retry_on_timeout = retry_on_timeout else: - client._retry_on_timeout = self._retry_on_timeout + client._base_client._retry_on_timeout = self._base_client._retry_on_timeout return client diff --git a/elasticsearch/_sync/client/_base.py b/elasticsearch/_sync/client/_base.py index 7d4617f74..2f13290e2 100644 --- a/elasticsearch/_sync/client/_base.py +++ b/elasticsearch/_sync/client/_base.py @@ -210,49 +210,17 @@ def _default_sniffed_node_callback( class BaseClient: - def __init__(self, _transport: Transport) -> None: + def __init__(self, _transport: Transport, headers: HttpHeaders) -> None: self._transport = _transport self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT - self._headers = HttpHeaders() + self._headers = headers self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT - self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() - @property - def transport(self) -> Transport: - return self._transport - - def perform_request( - self, - method: str, - path: str, - *, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Any] = None, - endpoint_id: Optional[str] = None, - path_parts: Optional[Mapping[str, Any]] = None, - ) -> ApiResponse[Any]: - with self._otel.span( - method, - endpoint_id=endpoint_id, - path_parts=path_parts or {}, - ) as otel_span: - response = self._perform_request( - method, - path, - params=params, - headers=headers, - body=body, - otel_span=otel_span, - ) - otel_span.set_elastic_cloud_metadata(response.meta.headers) - return response - def _perform_request( self, method: str, @@ -287,7 +255,7 @@ def mimetype_header_to_compat(header: str) -> None: else: target = path - meta, resp_body = self.transport.perform_request( + meta, resp_body = self._transport.perform_request( method, target, headers=request_headers, @@ -376,10 +344,9 @@ def mimetype_header_to_compat(header: str) -> None: return response -class NamespacedClient(BaseClient): - def __init__(self, client: "BaseClient") -> None: - self._client = client - super().__init__(self._client.transport) +class NamespacedClient: + def __init__(self, client: BaseClient) -> None: + self._base_client = client def perform_request( self, @@ -392,14 +359,18 @@ def perform_request( endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: - # Use the internal clients .perform_request() implementation - # so we take advantage of their transport options. - return self._client.perform_request( + with self._base_client._otel.span( method, - path, - params=params, - headers=headers, - body=body, endpoint_id=endpoint_id, - path_parts=path_parts, - ) + path_parts=path_parts or {}, + ) as otel_span: + response = self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response From 9299e0ee238d772037d95bc58d093cb5a8328d40 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 25 Mar 2025 02:40:46 +0400 Subject: [PATCH 2/6] Fix lint --- elasticsearch/_async/client/__init__.py | 2 +- elasticsearch/_async/client/_base.py | 1 + elasticsearch/_async/helpers.py | 4 ++-- elasticsearch/_sync/client/__init__.py | 2 +- elasticsearch/_sync/client/_base.py | 1 + elasticsearch/helpers/actions.py | 10 +++++----- 6 files changed, 11 insertions(+), 9 deletions(-) diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 27db48ca4..48022cd8c 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -451,7 +451,7 @@ def options( new_headers.update(resolved_headers) client._base_client._headers = new_headers else: - client._base_client._headers = self._headers.copy() + client._base_client._headers = self._base_client._headers.copy() if request_timeout is not DEFAULT: client._base_client._request_timeout = request_timeout diff --git a/elasticsearch/_async/client/_base.py b/elasticsearch/_async/client/_base.py index 208404c51..ef1bf3065 100644 --- a/elasticsearch/_async/client/_base.py +++ b/elasticsearch/_async/client/_base.py @@ -218,6 +218,7 @@ def __init__(self, _transport: AsyncTransport, headers: HttpHeaders) -> None: self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT + self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() diff --git a/elasticsearch/_async/helpers.py b/elasticsearch/_async/helpers.py index 7acc41ecd..51b15e0ba 100644 --- a/elasticsearch/_async/helpers.py +++ b/elasticsearch/_async/helpers.py @@ -216,7 +216,7 @@ async def async_streaming_bulk( """ client = client.options() - client._client_meta = (("h", "bp"),) + client._base_client._client_meta = (("h", "bp"),) if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) @@ -429,7 +429,7 @@ def pop_transport_kwargs(kw: MutableMapping[str, Any]) -> MutableMapping[str, An client = client.options( request_timeout=request_timeout, **pop_transport_kwargs(kwargs) ) - client._client_meta = (("h", "s"),) + client._base_client._client_meta = (("h", "s"),) # Setting query={"from": ...} would make 'from' be used # as a keyword argument instead of 'from_'. We handle that here. diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index 44d8309ae..807f5b7d0 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -451,7 +451,7 @@ def options( new_headers.update(resolved_headers) client._base_client._headers = new_headers else: - client._base_client._headers = self._headers.copy() + client._base_client._headers = self._base_client._headers.copy() if request_timeout is not DEFAULT: client._base_client._request_timeout = request_timeout diff --git a/elasticsearch/_sync/client/_base.py b/elasticsearch/_sync/client/_base.py index 2f13290e2..e869e2d64 100644 --- a/elasticsearch/_sync/client/_base.py +++ b/elasticsearch/_sync/client/_base.py @@ -218,6 +218,7 @@ def __init__(self, _transport: Transport, headers: HttpHeaders) -> None: self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT + self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() diff --git a/elasticsearch/helpers/actions.py b/elasticsearch/helpers/actions.py index 25c21cdd4..f1a095c9f 100644 --- a/elasticsearch/helpers/actions.py +++ b/elasticsearch/helpers/actions.py @@ -334,7 +334,7 @@ def _process_bulk_chunk( """ Send a bulk request to elasticsearch and process the output. """ - with client._otel.use_span(otel_span): + with client._base_client._otel.use_span(otel_span): if isinstance(ignore_status, int): ignore_status = (ignore_status,) @@ -416,9 +416,9 @@ def streaming_bulk( :arg yield_ok: if set to False will skip successful documents in the output :arg ignore_status: list of HTTP status code that you want to ignore """ - with client._otel.helpers_span(span_name) as otel_span: + with client._base_client._otel.helpers_span(span_name) as otel_span: client = client.options() - client._client_meta = (("h", "bp"),) + client._base_client._client_meta = (("h", "bp"),) if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) @@ -608,7 +608,7 @@ def _setup_queues(self) -> None: ] = Queue(max(queue_size, thread_count)) self._quick_put = self._inqueue.put - with client._otel.helpers_span("helpers.parallel_bulk") as otel_span: + with client._base_client._otel.helpers_span("helpers.parallel_bulk") as otel_span: pool = BlockingPool(thread_count) try: @@ -711,7 +711,7 @@ def pop_transport_kwargs(kw: MutableMapping[str, Any]) -> Dict[str, Any]: client = client.options( request_timeout=request_timeout, **pop_transport_kwargs(kwargs) ) - client._client_meta = (("h", "s"),) + client._base_client._client_meta = (("h", "s"),) # Setting query={"from": ...} would make 'from' be used # as a keyword argument instead of 'from_'. We handle that here. From e03a3972b6dfe21f2c9788c09fb1872061f148a3 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 25 Mar 2025 02:51:44 +0400 Subject: [PATCH 3/6] Fix otel tests --- test_elasticsearch/test_otel.py | 6 +++--- test_elasticsearch/test_server/test_otel.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test_elasticsearch/test_otel.py b/test_elasticsearch/test_otel.py index 48eb9ea58..682ee1f86 100644 --- a/test_elasticsearch/test_otel.py +++ b/test_elasticsearch/test_otel.py @@ -112,11 +112,11 @@ def test_forward_otel_context_to_subthreads( ): tracer, memory_exporter = setup_tracing() es_client = Elasticsearch("http://localhost:9200") - es_client._otel = OpenTelemetry(enabled=True, tracer=tracer) + es_client._base_client._otel = OpenTelemetry(enabled=True, tracer=tracer) _call_bulk_mock.return_value = mock.Mock() actions = ({"x": i} for i in range(100)) list(helpers.parallel_bulk(es_client, actions, chunk_size=4)) # Ensures that the OTEL context has been forwarded to all chunks - assert es_client._otel.helpers_span.call_count == 1 - assert es_client._otel.use_span.call_count == 25 + assert es_client._base_client._otel.helpers_span.call_count == 1 + assert es_client._base_client._otel.use_span.call_count == 25 diff --git a/test_elasticsearch/test_server/test_otel.py b/test_elasticsearch/test_server/test_otel.py index 3f8033d7b..39da182b1 100644 --- a/test_elasticsearch/test_server/test_otel.py +++ b/test_elasticsearch/test_server/test_otel.py @@ -61,7 +61,7 @@ def test_otel_bulk(sync_client, elasticsearch_url, bulk_helper_name): # Create a new client with our tracer sync_client = sync_client.options() - sync_client._otel.tracer = tracer + sync_client._base_client._otel.tracer = tracer # "Disable" options to keep our custom tracer sync_client.options = lambda: sync_client From 428203a1ba345f8763bb8c88b4d1b67906b4c6c6 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Tue, 25 Mar 2025 03:02:43 +0400 Subject: [PATCH 4/6] Add more _base_client attributes --- test_elasticsearch/test_dsl/test_connections.py | 2 +- test_elasticsearch/test_server/test_otel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test_elasticsearch/test_dsl/test_connections.py b/test_elasticsearch/test_dsl/test_connections.py index dcaa59a98..757e088b0 100644 --- a/test_elasticsearch/test_dsl/test_connections.py +++ b/test_elasticsearch/test_dsl/test_connections.py @@ -123,7 +123,7 @@ def test_connection_has_correct_user_agent() -> None: c.create_connection("testing", hosts=["https://es.com:9200"]) assert ( c.get_connection("testing") - ._headers["user-agent"] + ._base_client._headers["user-agent"] .startswith("elasticsearch-dsl-py/") ) diff --git a/test_elasticsearch/test_server/test_otel.py b/test_elasticsearch/test_server/test_otel.py index 39da182b1..11b7642eb 100644 --- a/test_elasticsearch/test_server/test_otel.py +++ b/test_elasticsearch/test_server/test_otel.py @@ -34,7 +34,7 @@ def test_otel_end_to_end(sync_client): tracer, memory_exporter = setup_tracing() - sync_client._otel.tracer = tracer + sync_client._base_client._otel.tracer = tracer resp = sync_client.search(index="logs-*", query={"match_all": {}}) assert resp.meta.status == 200 From 84ce3bcd40237fd4cf068068cadfe32f9a66e542 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Wed, 26 Mar 2025 19:22:58 +0400 Subject: [PATCH 5/6] Fix headers handling --- elasticsearch/_async/client/__init__.py | 28 +++++++++---------- elasticsearch/_async/client/_base.py | 4 +-- elasticsearch/_sync/client/__init__.py | 28 +++++++++---------- elasticsearch/_sync/client/_base.py | 4 +-- .../test_client/test_options.py | 4 +-- 5 files changed, 32 insertions(+), 36 deletions(-) diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 48022cd8c..e43e1d6a7 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -25,7 +25,6 @@ BaseNode, BinaryApiResponse, HeadApiResponse, - HttpHeaders, NodeConfig, NodePool, NodeSelector, @@ -226,18 +225,6 @@ def __init__( ): sniff_callback = default_sniff_callback - headers = HttpHeaders() - if headers is not DEFAULT and headers is not None: - headers.update(headers) - if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - headers["x-opaque-id"] = opaque_id - headers = resolve_auth_headers( - headers, - api_key=api_key, - basic_auth=basic_auth, - bearer_auth=bearer_auth, - ) - if _transport is None: node_configs = client_node_configs( hosts, @@ -309,7 +296,7 @@ def __init__( **transport_kwargs, ) - self._base_client = BaseClient(_transport, headers=headers) + self._base_client = BaseClient(_transport) # These are set per-request so are stored separately. self._base_client._request_timeout = request_timeout @@ -320,7 +307,18 @@ def __init__( self._base_client._retry_on_status = retry_on_status else: - self._base_client = BaseClient(_transport, headers=headers) + self._base_client = BaseClient(_transport) + + if headers is not DEFAULT and headers is not None: + self._base_client._headers.update(headers) + if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] + self._base_client._headers["x-opaque-id"] = opaque_id + self._base_client._headers = resolve_auth_headers( + self._base_client._headers, + api_key=api_key, + basic_auth=basic_auth, + bearer_auth=bearer_auth, + ) # namespaced clients for compatibility with API names self.async_search = AsyncSearchClient(self._base_client) diff --git a/elasticsearch/_async/client/_base.py b/elasticsearch/_async/client/_base.py index ef1bf3065..8e6e5fbd9 100644 --- a/elasticsearch/_async/client/_base.py +++ b/elasticsearch/_async/client/_base.py @@ -210,10 +210,10 @@ def _default_sniffed_node_callback( class BaseClient: - def __init__(self, _transport: AsyncTransport, headers: HttpHeaders) -> None: + def __init__(self, _transport: AsyncTransport) -> None: self._transport = _transport self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT - self._headers = headers + self._headers = HttpHeaders() self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index 807f5b7d0..0c6db746e 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -24,7 +24,6 @@ BaseNode, BinaryApiResponse, HeadApiResponse, - HttpHeaders, NodeConfig, NodePool, NodeSelector, @@ -226,18 +225,6 @@ def __init__( ): sniff_callback = default_sniff_callback - headers = HttpHeaders() - if headers is not DEFAULT and headers is not None: - headers.update(headers) - if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - headers["x-opaque-id"] = opaque_id - headers = resolve_auth_headers( - headers, - api_key=api_key, - basic_auth=basic_auth, - bearer_auth=bearer_auth, - ) - if _transport is None: node_configs = client_node_configs( hosts, @@ -309,7 +296,7 @@ def __init__( **transport_kwargs, ) - self._base_client = BaseClient(_transport, headers=headers) + self._base_client = BaseClient(_transport) # These are set per-request so are stored separately. self._base_client._request_timeout = request_timeout @@ -320,7 +307,18 @@ def __init__( self._base_client._retry_on_status = retry_on_status else: - self._base_client = BaseClient(_transport, headers=headers) + self._base_client = BaseClient(_transport) + + if headers is not DEFAULT and headers is not None: + self._base_client._headers.update(headers) + if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] + self._base_client._headers["x-opaque-id"] = opaque_id + self._base_client._headers = resolve_auth_headers( + self._base_client._headers, + api_key=api_key, + basic_auth=basic_auth, + bearer_auth=bearer_auth, + ) # namespaced clients for compatibility with API names self.async_search = AsyncSearchClient(self._base_client) diff --git a/elasticsearch/_sync/client/_base.py b/elasticsearch/_sync/client/_base.py index e869e2d64..02542eb43 100644 --- a/elasticsearch/_sync/client/_base.py +++ b/elasticsearch/_sync/client/_base.py @@ -210,10 +210,10 @@ def _default_sniffed_node_callback( class BaseClient: - def __init__(self, _transport: Transport, headers: HttpHeaders) -> None: + def __init__(self, _transport: Transport) -> None: self._transport = _transport self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT - self._headers = headers + self._headers = HttpHeaders() self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT diff --git a/test_elasticsearch/test_client/test_options.py b/test_elasticsearch/test_client/test_options.py index c2050d186..05486cab4 100644 --- a/test_elasticsearch/test_client/test_options.py +++ b/test_elasticsearch/test_client/test_options.py @@ -290,7 +290,7 @@ def test_default_node_configs(self): headers={"key": "val"}, basic_auth=("username", "password"), ) - assert client._headers == { + assert client._base_client._headers == { "key": "val", "authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", } @@ -347,7 +347,7 @@ def test_http_headers_overrides(self): "authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", "user-agent": USER_AGENT, } - assert client._headers == {"key": "val"} + assert client._base_client._headers == {"key": "val"} def test_user_agent_override(self): client = Elasticsearch( From 214046fa78da3c345905843418d3e9fea1228720 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Wed, 26 Mar 2025 19:41:54 +0400 Subject: [PATCH 6/6] Fix user agent handling in DSL --- elasticsearch/dsl/connections.py | 10 +++++----- test_elasticsearch/test_dsl/test_connections.py | 8 ++++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/elasticsearch/dsl/connections.py b/elasticsearch/dsl/connections.py index 8acd80c6e..350a250fd 100644 --- a/elasticsearch/dsl/connections.py +++ b/elasticsearch/dsl/connections.py @@ -117,15 +117,15 @@ def get_connection(self, alias: Union[str, _T] = "default") -> _T: def _with_user_agent(self, conn: _T) -> _T: # try to inject our user agent - if hasattr(conn, "_headers"): - is_frozen = conn._headers.frozen + if hasattr(conn, "_base_client") and hasattr(conn._base_client, "_headers"): + is_frozen = conn._base_client._headers.frozen if is_frozen: - conn._headers = conn._headers.copy() - conn._headers.update( + conn._base_client._headers = conn._base_client._headers.copy() + conn._base_client._headers.update( {"user-agent": f"elasticsearch-dsl-py/{__versionstr__}"} ) if is_frozen: - conn._headers.freeze() + conn._base_client._headers.freeze() return conn diff --git a/test_elasticsearch/test_dsl/test_connections.py b/test_elasticsearch/test_dsl/test_connections.py index 757e088b0..f7825f9b0 100644 --- a/test_elasticsearch/test_dsl/test_connections.py +++ b/test_elasticsearch/test_dsl/test_connections.py @@ -130,12 +130,16 @@ def test_connection_has_correct_user_agent() -> None: my_client = Elasticsearch(hosts=["http://localhost:9200"]) my_client = my_client.options(headers={"user-agent": "my-user-agent/1.0"}) c.add_connection("default", my_client) - assert c.get_connection()._headers["user-agent"].startswith("elasticsearch-dsl-py/") + assert ( + c.get_connection() + ._base_client._headers["user-agent"] + .startswith("elasticsearch-dsl-py/") + ) my_client = Elasticsearch(hosts=["http://localhost:9200"]) assert ( c.get_connection(my_client) - ._headers["user-agent"] + ._base_client._headers["user-agent"] .startswith("elasticsearch-dsl-py/") )