diff --git a/gapic-generator-fork b/gapic-generator-fork index b26cda7d1..12c2e8b0c 160000 --- a/gapic-generator-fork +++ b/gapic-generator-fork @@ -1 +1 @@ -Subproject commit b26cda7d163d6e0d45c9684f328ca32fb49b799a +Subproject commit 12c2e8b0c14509521317c7915522ae3f13414258 diff --git a/google/cloud/bigtable/_channel_pooling/dynamic_pooled_channel.py b/google/cloud/bigtable/_channel_pooling/dynamic_pooled_channel.py new file mode 100644 index 000000000..cc91b4a99 --- /dev/null +++ b/google/cloud/bigtable/_channel_pooling/dynamic_pooled_channel.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Callable, Coroutine + +import asyncio +from dataclasses import dataclass + +from grpc.experimental import aio # type: ignore + +from .pooled_channel import PooledChannel +from .pooled_channel import StaticPoolOptions +from .tracked_channel import TrackedChannel + +from google.cloud.bigtable._channel_pooling.wrapped_channel import _BackgroundTaskMixin + + +@dataclass +class DynamicPoolOptions: + # starting channel count + start_size: int = 3 + # maximum channels to keep in the pool + max_channels: int = 10 + # minimum channels in pool + min_channels: int = 1 + # if rpcs exceed this number, pool may expand + max_rpcs_per_channel: int = 100 + # if rpcs exceed this number, pool may shrink + min_rpcs_per_channel: int = 50 + # how many channels to add/remove in a single resize event + max_resize_delta: int = 2 + # how many seconds to wait between resize attempts + pool_refresh_interval: float = 60.0 + + +class DynamicPooledChannel(PooledChannel, _BackgroundTaskMixin): + def __init__( + self, + *args, + create_channel_fn: Callable[..., aio.Channel] | None = None, + pool_options: StaticPoolOptions | DynamicPoolOptions | None = None, + warm_channel_fn: Callable[[aio.Channel], Coroutine[Any, Any, Any]] + | None = None, + on_remove: Callable[[aio.Channel], Coroutine[Any, Any, Any]] | None = None, + **kwargs, + ): + if create_channel_fn is None: + raise ValueError("create_channel_fn is required") + if isinstance(pool_options, StaticPoolOptions): + raise ValueError( + "DynamicPooledChannel cannot be initialized with StaticPoolOptions" + ) + self._pool: list[TrackedChannel] = [] + self._pool_options = pool_options or DynamicPoolOptions() + # create the pool + PooledChannel.__init__( + self, + # create options for starting pool + pool_options=StaticPoolOptions(pool_size=self._pool_options.start_size), + # all channels must be TrackChannels + create_channel_fn=lambda: TrackedChannel(create_channel_fn(*args, **kwargs)), # type: ignore + ) + # register callbacks + self._on_remove = on_remove + self._warm_channel = warm_channel_fn + # start background resize task + self._background_task: asyncio.Task[None] | None = None + self.start_background_task() + + def _background_coroutine(self) -> Coroutine[Any, Any, None]: + return self._resize_routine(interval=self._pool_options.pool_refresh_interval) + + @property + def _task_description(self) -> str: + return "Automatic channel pool resizing" + + async def _resize_routine(self, interval: float = 60): + close_tasks: list[asyncio.Task[None]] = [] + while True: + await asyncio.sleep(60) + added, removed = self._attempt_resize() + # warm up new channels immediately + if self._warm_channel: + for channel in added: + await self._warm_channel(channel) + # clear completed tasks from list + close_tasks = [t for t in close_tasks if not t.done()] + # add new tasks to close unneeded channels in the background + if self._on_remove: + for channel in removed: + close_routine = self._on_remove(channel) + close_tasks.append(asyncio.create_task(close_routine)) + + def _attempt_resize(self) -> tuple[list[TrackedChannel], list[TrackedChannel]]: + """ + Called periodically to resize the number of channels based on + the number of active RPCs + """ + added_list, removed_list = [], [] + # estimate the peak rpcs since last resize + # peak finds max active value for each channel since last check + estimated_peak = sum( + [channel.get_and_reset_max_active_rpcs() for channel in self._pool] + ) + # find the minimum number of channels to serve the peak + min_channels = estimated_peak // self._pool_options.max_rpcs_per_channel + # find the maxiumum channels we'd want to serve the peak + max_channels = estimated_peak // max(self._pool_options.min_rpcs_per_channel, 1) + # clamp the number of channels to the min and max + min_channels = max(min_channels, self.options.min_channels) + max_channels = min(max_channels, self.options.max_channels) + # Only resize the pool when thresholds are crossed + current_size = len(self._pool) + if current_size < min_channels or current_size > max_channels: + # try to aim for the middle of the bound, but limit rate of change. + tentative_target = (max_channels + min_channels) // 2 + delta = tentative_target - current_size + dampened_delta = min( + max(delta, -self.options.max_resize_delta), + self.options.max_resize_delta, + ) + dampened_target = current_size + dampened_delta + if dampened_target > current_size: + added_list = [self.create_channel() for _ in range(dampened_delta)] + self._pool.extend(added_list) + elif dampened_target < current_size: + # reset the next_idx if needed + if self._next_idx >= dampened_target: + self._next_idx = 0 + # trim pool to the right size + self._pool, removed_list = ( + self._pool[:dampened_target], + self._pool[dampened_target:], + ) + return added_list, removed_list + + async def __aenter__(self): + await _BackgroundTaskMixin.__aenter__(self) + await PooledChannel.__aenter__(self) + return self + + async def close(self, grace=None): + await _BackgroundTaskMixin.close(self, grace) + await PooledChannel.close(self, grace) + + async def __aexit__(self, *args, **kwargs): + await _BackgroundTaskMixin.__aexit__(self, *args, **kwargs) + await PooledChannel.__aexit__(self, *args, **kwargs) diff --git a/google/cloud/bigtable/_channel_pooling/pooled_channel.py b/google/cloud/bigtable/_channel_pooling/pooled_channel.py new file mode 100644 index 000000000..ea0eb8794 --- /dev/null +++ b/google/cloud/bigtable/_channel_pooling/pooled_channel.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import ( + Callable, +) +import asyncio +from dataclasses import dataclass +from functools import partial + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedUnaryUnaryMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedUnaryStreamMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedStreamUnaryMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedStreamStreamMultiCallable, +) + + +@dataclass +class StaticPoolOptions: + pool_size: int = 3 + + +class PooledChannel(aio.Channel): + def __init__( + self, + *args, + create_channel_fn: Callable[..., aio.Channel] | None = None, + pool_options: StaticPoolOptions | None = None, + **kwargs, + ): + if create_channel_fn is None: + raise ValueError("create_channel_fn is required") + self._pool: list[aio.Channel] = [] + self._next_idx = 0 + self._create_channel: Callable[[], aio.Channel] = partial( + create_channel_fn, *args, **kwargs + ) + pool_options = pool_options or StaticPoolOptions() + for i in range(pool_options.pool_size): + self._pool.append(self._create_channel()) + + def next_channel(self) -> aio.Channel: + next_idx = self._next_idx if self._next_idx < len(self._pool) else 0 + channel = self._pool[next_idx] + self._next_idx = (next_idx + 1) % len(self._pool) + return channel + + def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: + return WrappedUnaryUnaryMultiCallable( + lambda *call_args, **call_kwargs: self.next_channel().unary_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: + return WrappedUnaryStreamMultiCallable( + lambda *call_args, **call_kwargs: self.next_channel().unary_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: + return WrappedStreamUnaryMultiCallable( + lambda *call_args, **call_kwargs: self.next_channel().stream_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: + return WrappedStreamStreamMultiCallable( + lambda *call_args, **call_kwargs: self.next_channel().stream_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + async def close(self, grace=None): + close_fns = [channel.close(grace=grace) for channel in self._pool] + await asyncio.gather(*close_fns) + + async def channel_ready(self): + ready_fns = [channel.channel_ready() for channel in self._pool] + await asyncio.gather(*ready_fns) + + async def __aenter__(self): + for channel in self._pool: + await channel.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + for channel in self._pool: + await channel.__aexit__(exc_type, exc_val, exc_tb) + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError("undefined for pool of channels") + + async def wait_for_state_change(self, last_observed_state): + raise NotImplementedError("undefined for pool of channels") + + def index_of(self, channel) -> int: + try: + return self._pool.index(channel) + except ValueError: + return -1 + + @property + def channels(self) -> list[aio.Channel]: + return self._pool + + def __getitem__(self, item: int) -> aio.Channel: + return self._pool[item] diff --git a/google/cloud/bigtable/_channel_pooling/refreshable_channel.py b/google/cloud/bigtable/_channel_pooling/refreshable_channel.py new file mode 100644 index 000000000..e4b04f377 --- /dev/null +++ b/google/cloud/bigtable/_channel_pooling/refreshable_channel.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import Any, Callable, Coroutine + +import asyncio +import random +from time import monotonic +from functools import partial +from grpc.experimental import aio # type: ignore + +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + _WrappedChannel, + _BackgroundTaskMixin, +) + + +class RefreshableChannel(_WrappedChannel, _BackgroundTaskMixin): + """ + A Channel that refreshes itself periodically. + """ + + def __init__( + self, + *args, + create_channel_fn: Callable[..., aio.Channel] | None = None, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + warm_channel_fn: Callable[[aio.Channel], Coroutine[Any, Any, Any]] + | None = None, + on_replace: Callable[[aio.Channel], Coroutine[Any, Any, Any]] | None = None, + **kwargs, + ): + if create_channel_fn is None: + raise ValueError("create_channel_fn is required") + self._create_channel: Callable[[], aio.Channel] = partial( + create_channel_fn, *args, **kwargs + ) + self._warm_channel = warm_channel_fn + self._on_replace = on_replace + self._channel = self._create_channel() + self._refresh_interval_min = refresh_interval_min + self._refresh_interval_max = refresh_interval_max + self._background_task: asyncio.Task[None] | None = None + self.start_background_task() + + def _background_coroutine(self) -> Coroutine[Any, Any, None]: + return self._manage_channel_lifecycle( + self._refresh_interval_min, self._refresh_interval_max + ) + + @property + def _task_description(self) -> str: + return "Background channel refresh" + + async def _manage_channel_lifecycle( + self, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + ) -> None: + """ + Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + if self._warm_channel: + await self._warm_channel(self._channel) + next_sleep = random.uniform(refresh_interval_min, refresh_interval_max) + while True: + # let channel run for `sleep_time` seconds, then remove it from pool + await asyncio.sleep(next_sleep) + # cycle channel out of use, with long grace window before closure + start_timestamp = monotonic() + new_channel = self._create_channel() + if self._warm_channel: + await self._warm_channel(new_channel) + await new_channel.channel_ready() + old_channel, self._channel = self._channel, new_channel + if self._on_replace: + await self._on_replace(old_channel) + # find new sleep time based on how long the refresh process took + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (monotonic() - start_timestamp) + + async def __aenter__(self): + await _BackgroundTaskMixin.__aenter__(self) + await _WrappedChannel.__aenter__(self) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await _BackgroundTaskMixin.__aexit__(self, exc_type, exc_val, exc_tb) + return await _WrappedChannel.__aexit__(self, exc_type, exc_val, exc_tb) + + async def close(self, grace=None): + await _BackgroundTaskMixin.close(self, grace) + return await _WrappedChannel.close(self, grace) diff --git a/google/cloud/bigtable/_channel_pooling/tracked_channel.py b/google/cloud/bigtable/_channel_pooling/tracked_channel.py new file mode 100644 index 000000000..69386f8cc --- /dev/null +++ b/google/cloud/bigtable/_channel_pooling/tracked_channel.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial +from grpc.experimental import aio # type: ignore +from google.api_core.grpc_helpers_async import _WrappedUnaryResponseMixin +from google.api_core.grpc_helpers_async import _WrappedStreamResponseMixin +from google.api_core.grpc_helpers_async import _WrappedStreamRequestMixin + +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedUnaryUnaryMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedUnaryStreamMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedStreamUnaryMultiCallable, +) +from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + WrappedStreamStreamMultiCallable, +) + +from google.cloud.bigtable._channel_pooling.wrapped_channel import _WrappedChannel + + +class _TrackedUnaryResponseMixin(_WrappedUnaryResponseMixin): + def __init__(self, call_fn, channel, *args, **kwargs): + super().__init__() + self._call: aio.UnaryUnaryCall | aio.StreamUnaryCall = call_fn(*args, **kwargs) + self._channel = channel + + def __await__(self): + with self._channel.track_rpc(): + response = yield from self._call.__await__() + return response + + def __getattr__(self, attr): + return getattr(self._call, attr) + + +class _TrackedStreamResponseMixin(_WrappedStreamResponseMixin): + def __init__(self, call_fn, channel, *args, **kwargs): + super().__init__() + self._call: aio.UnaryStreamCall | aio.StreamStreamCall = call_fn( + *args, **kwargs + ) + self._channel = channel + + async def read(self): + with self._channel.track_rpc(): + return await self._call.read() + + async def _wrapped_aiter(self): + with self._channel.track_rpc(): + async for item in self._call: + yield item + + def __getattr__(self, attr): + return getattr(self._call, attr) + + +class TrackedUnaryUnaryCall(_TrackedUnaryResponseMixin, aio.UnaryUnaryCall): + pass + + +class TrackedUnaryStreamCall(_TrackedStreamResponseMixin, aio.UnaryStreamCall): + pass + + +class TrackedStreamUnaryCall( + _TrackedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall +): + pass + + +class TrackedStreamStreamCall( + _TrackedStreamResponseMixin, _WrappedStreamRequestMixin, aio.StreamStreamCall +): + pass + + +class TrackedChannel(_WrappedChannel): + """ + A Channel that tracks the number of active RPCs + """ + + def __init__(self, channel: aio.Channel): + super().__init__(channel) + self.active_rpcs = 0 + self.max_active_rpcs = 0 + + @contextmanager + def track_rpc(self): + self.active_rpcs += 1 + self.max_active_rpcs = max(self.max_active_rpcs, self.active_rpcs) + try: + yield + finally: + self.active_rpcs -= 1 + + def get_and_reset_max_active_rpcs(self) -> int: + current_max, self.max_active_rpcs = self.max_active_rpcs, self.active_rpcs + return current_max + + def unary_unary(self, *args, **kwargs): + multicallable = self._channel.unary_unary(*args, **kwargs) + tracked_multicallable = partial(TrackedUnaryUnaryCall, multicallable, self) + return WrappedUnaryUnaryMultiCallable(tracked_multicallable) + + def unary_stream(self, *args, **kwargs): + multicallable = self._channel.unary_stream(*args, **kwargs) + tracked_multicallable = partial(TrackedUnaryStreamCall, multicallable, self) + return WrappedUnaryStreamMultiCallable(tracked_multicallable) + + def stream_unary(self, *args, **kwargs): + multicallable = self._channel.stream_unary(*args, **kwargs) + tracked_multicallable = partial(TrackedStreamUnaryCall, multicallable, self) + return WrappedStreamUnaryMultiCallable(tracked_multicallable) + + def stream_stream(self, *args, **kwargs): + multicallable = self._channel.stream_stream(*args, **kwargs) + tracked_multicallable = partial(TrackedStreamStreamCall, multicallable, self) + return WrappedStreamStreamMultiCallable(tracked_multicallable) diff --git a/google/cloud/bigtable/_channel_pooling/wrapped_channel.py b/google/cloud/bigtable/_channel_pooling/wrapped_channel.py new file mode 100644 index 000000000..bc0a3cc45 --- /dev/null +++ b/google/cloud/bigtable/_channel_pooling/wrapped_channel.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import Callable +from abc import ABC, abstractmethod + +import asyncio +import warnings +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + + +class _WrappedChannel(aio.Channel): + """ + A wrapper around a gRPC channel. All methods are passed + through to the underlying channel. + """ + + def __init__(self, channel: aio.Channel): + self._channel = channel + + def unary_unary(self, *args, **kwargs): + return self._channel.unary_unary(*args, **kwargs) + + def unary_stream(self, *args, **kwargs): + return self._channel.unary_stream(*args, **kwargs) + + def stream_unary(self, *args, **kwargs): + return self._channel.stream_unary(*args, **kwargs) + + def stream_stream(self, *args, **kwargs): + return self._channel.stream_stream(*args, **kwargs) + + async def close(self, grace=None): + return await self._channel.close(grace=grace) + + async def channel_ready(self): + return await self._channel.channel_ready() + + async def __aenter__(self): + await self._channel.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self._channel.__aexit__(exc_type, exc_val, exc_tb) + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + return self._channel.get_state(try_to_connect=try_to_connect) + + async def wait_for_state_change(self, last_observed_state): + return await self._channel.wait_for_state_change(last_observed_state) + + @property + def wrapped_channel(self): + return self._channel + + +class _BackgroundTaskMixin(ABC): + """ + A mixin that provides methods to manage a background task that + is run throughout the lifetime of the object. + """ + + def __init__(self): + self._background_task: asyncio.Task[None] | None = None + + def background_task_is_active(self) -> bool: + """ + returns True if the background task is currently running + """ + return self._background_task is not None and not self._background_task.done() + + @abstractmethod + def _background_coroutine(self): + """ + To be implemented by subclasses. Returns the coroutine that will + be run in the background throughout the lifetime of the channel. + """ + pass + + @property + def _task_description(self) -> str: + """ + Describe what the background task does. + String will be displayed along with error message to describe + the consequences when the task can not be started. + + Example: "Automatic channel pool resizing" + """ + return "Background task" + + def start_background_task(self): + """ + Start background task to manage channel lifecycle. If background + task is already running, do nothing. If run outside of an asyncio + event loop, print a warning and do nothing. + """ + if self.background_task_is_active(): + return + try: + asyncio.get_running_loop() + self._background_task = asyncio.create_task(self._background_coroutine()) + except RuntimeError: + warnings.warn( + f"No event loop detected. {self._task_description} is disabled " + "and must be started manually in an asyncio event loop.", + RuntimeWarning, + stacklevel=2, + ) + self._refresh_task = None + + async def __aenter__(self): + self.start_background_task() + + async def close(self, grace=None): + if self._background_task: + self._background_task.cancel() + try: + await self._background_task + except asyncio.CancelledError: + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + +class _WrappedMultiCallable: + """ + Wrapper class that implements the grpc MultiCallable interface. + Allows generic functions that return calls to pass checks for + MultiCallable objects. + """ + + def __init__(self, call_factory: Callable[..., aio.Call]): + self._call_factory = call_factory + + def __call__(self, *args, **kwargs) -> aio.Call: + return self._call_factory(*args, **kwargs) + + +class WrappedUnaryUnaryMultiCallable( + _WrappedMultiCallable, aio.UnaryUnaryMultiCallable +): + pass + + +class WrappedUnaryStreamMultiCallable( + _WrappedMultiCallable, aio.UnaryStreamMultiCallable +): + pass + + +class WrappedStreamUnaryMultiCallable( + _WrappedMultiCallable, aio.StreamUnaryMultiCallable +): + pass + + +class WrappedStreamStreamMultiCallable( + _WrappedMultiCallable, aio.StreamStreamMultiCallable +): + pass diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index 3921d6640..bda0035ea 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -18,31 +18,25 @@ from typing import ( cast, Any, + Coroutine, Optional, Set, Callable, - Coroutine, TYPE_CHECKING, ) import asyncio -import grpc -import time +from grpc.experimental import aio # type: ignore +from functools import partial import warnings -import sys -import random -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) from google.cloud.client import ClientWithProject from google.api_core.exceptions import GoogleAPICallError from google.api_core import retry_async as retries from google.api_core import exceptions as core_exceptions -from google.cloud.bigtable._read_rows import _ReadRowsOperation + +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO import google.auth.credentials import google.auth._default @@ -50,10 +44,28 @@ from google.cloud.bigtable.row import Row from google.cloud.bigtable.read_rows_query import ReadRowsQuery from google.cloud.bigtable.iterators import ReadRowsIterator +from google.cloud.bigtable._read_rows import _ReadRowsOperation from google.cloud.bigtable.mutations import Mutation, RowMutationEntry from google.cloud.bigtable._mutate_rows import _MutateRowsOperation from google.cloud.bigtable._helpers import _make_metadata from google.cloud.bigtable._helpers import _convert_retry_deadline +from google.cloud.bigtable._channel_pooling.dynamic_pooled_channel import ( + DynamicPooledChannel, +) +from google.cloud.bigtable._channel_pooling.dynamic_pooled_channel import ( + DynamicPoolOptions, +) +from google.cloud.bigtable._channel_pooling.refreshable_channel import ( + RefreshableChannel, +) +from google.cloud.bigtable._channel_pooling.tracked_channel import ( + TrackedChannel, +) +from google.cloud.bigtable._channel_pooling.pooled_channel import PooledChannel +from google.cloud.bigtable._channel_pooling.pooled_channel import StaticPoolOptions +from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcAsyncIOTransport, +) if TYPE_CHECKING: from google.cloud.bigtable.mutations_batcher import MutationsBatcher @@ -67,11 +79,11 @@ def __init__( self, *, project: str | None = None, - pool_size: int = 3, credentials: google.auth.credentials.Credentials | None = None, client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, + channel_pool_options: DynamicPoolOptions | StaticPoolOptions | None = None, ): """ Create a client instance for the Bigtable Data API @@ -82,8 +94,6 @@ def __init__( project: the project which the client acts on behalf of. If not passed, falls back to the default inferred from the environment. - pool_size: The number of grpc channels to maintain - in the internal channel pool. credentials: Thehe OAuth2 Credentials to use for this client. If not passed (and if no ``_http`` object is @@ -96,10 +106,33 @@ def __init__( - RuntimeError if called outside of an async context (no running event loop) - ValueError if pool_size is less than 1 """ - # set up transport in registry - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport + + async def destroy_channel_gracefully(channel: aio.Channel): + await asyncio.sleep(10) + await channel.close(grace=600) + + # set up channel pool + create_refreshable_channel = partial( + RefreshableChannel, + create_channel_fn=BigtableGrpcAsyncIOTransport.create_channel, + on_replace=destroy_channel_gracefully, + warm_channel_fn=self._ping_and_warm_instances, + ) + if channel_pool_options is None: + channel_pool_options = DynamicPoolOptions() + if isinstance(channel_pool_options, StaticPoolOptions): + create_pool_channel = partial( + PooledChannel, + create_channel_fn=create_refreshable_channel, + pool_options=channel_pool_options, + ) + else: + create_pool_channel = partial( + DynamicPooledChannel, + create_channel_fn=create_refreshable_channel, + pool_options=channel_pool_options, + on_remove=destroy_channel_gracefully, + ) # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = client_info.gapic_version @@ -116,64 +149,67 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( - transport=transport_str, - credentials=credentials, - client_options=client_options, - client_info=client_info, - ) - self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) + # warnings will be raised by pool and channel if run outside of async context + with warnings.catch_warnings(): + # filter to show a single warning, instead of one for each channel + warnings.simplefilter("module", RuntimeWarning) + self._gapic_client = BigtableAsyncClient( + credentials=credentials, + client_options=client_options, + client_info=client_info, + transport=partial( + BigtableGrpcAsyncIOTransport, channel=create_pool_channel + ), + ) + transport = cast(BigtableGrpcAsyncIOTransport, self._gapic_client.transport) + self._pool = cast(PooledChannel, transport.grpc_channel) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[str] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[str, Set[int]] = {} - # attempt to start background tasks - self._channel_init_time = time.time() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + # raise warning if not started in async context try: - self.start_background_channel_refresh() + asyncio.get_running_loop() except RuntimeError: warnings.warn( - f"{self.__class__.__name__} should be started in an " - "asyncio event loop. Channel refresh will not be started", + f"{self.__class__.__name__} should be initialized in an " + "asyncio event loop. " + "Run start_pool_background_tasks() in an async " + "context to start grpc channel lifecycle management manually.", RuntimeWarning, stacklevel=2, ) - def start_background_channel_refresh(self) -> None: - """ - Starts a background task to ping and warm each channel in the pool - Raises: - - RuntimeError if not called in an asyncio event loop - """ - if not self._channel_refresh_tasks: - # raise RuntimeError if there is no event loop - asyncio.get_running_loop() - for channel_idx in range(self.transport.pool_size): - refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) - if sys.version_info >= (3, 8): - # task names supported in Python 3.8+ - refresh_task.set_name( - f"{self.__class__.__name__} channel refresh {channel_idx}" - ) - self._channel_refresh_tasks.append(refresh_task) + async def start_pool_background_tasks(self): + """ + If client was initialized outside of an async context, async background + tasks will not be started automatically. This method can be called to + start the background tasks manually. + """ + # start dynamic pool task + if isinstance(self._pool, DynamicPooledChannel): + self._pool.start_background_task() + # dynamic pooling wraps refreshable channel in TrackedChannel + refresh_channels = [ + channel.wrapped_channel + for channel in self._pool.channels + if isinstance(channel, TrackedChannel) + ] + else: + refresh_channels = self._pool.channels + # start refreshable channel tasks + for channel in refresh_channels: + channel.start_background_task() - async def close(self, timeout: float = 2.0): + async def close(self): """ Cancel all background tasks """ - for task in self._channel_refresh_tasks: - task.cancel() - group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) - await asyncio.wait_for(group, timeout=timeout) - await self.transport.close() - self._channel_refresh_tasks = [] + await self._gapic_client.transport.close() async def _ping_and_warm_instances( - self, channel: grpc.aio.Channel + self, channel: aio.Channel, instance_id: str | None = None ) -> list[GoogleAPICallError | None]: """ Prepares the backend for requests on a channel @@ -185,68 +221,21 @@ async def _ping_and_warm_instances( Returns: - sequence of results or exceptions from the ping requests """ + instance_list = ( + [instance_id] if instance_id is not None else self._active_instances + ) ping_rpc = channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarmChannel" ) - tasks = [ping_rpc({"name": n}) for n in self._active_instances] + tasks = [ping_rpc({"name": n}) for n in instance_list] return await asyncio.gather(*tasks, return_exceptions=True) - async def _manage_channel( - self, - channel_idx: int, - refresh_interval_min: float = 60 * 35, - refresh_interval_max: float = 60 * 45, - grace_period: float = 60 * 10, - ) -> None: - """ - Background coroutine that periodically refreshes and warms a grpc channel - - The backend will automatically close channels after 60 minutes, so - `refresh_interval` + `grace_period` should be < 60 minutes - - Runs continuously until the client is closed - - Args: - channel_idx: index of the channel in the transport's channel pool - refresh_interval_min: minimum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - refresh_interval_max: maximum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ - first_refresh = self._channel_init_time + random.uniform( - refresh_interval_min, refresh_interval_max - ) - next_sleep = max(first_refresh - time.time(), 0) - if next_sleep > 0: - # warm the current channel immediately - channel = self.transport.channels[channel_idx] - await self._ping_and_warm_instances(channel) - # continuously refresh the channel every `refresh_interval` seconds - while True: - await asyncio.sleep(next_sleep) - # prepare new channel for use - new_channel = self.transport.grpc_channel._create_channel() - await self._ping_and_warm_instances(new_channel) - # cycle channel out of use, with long grace window before closure - start_timestamp = time.time() - await self.transport.replace_channel( - channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel - ) - # subtract the time spent waiting for the channel to be replaced - next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.time() - start_timestamp) - async def _register_instance(self, instance_id: str, owner: Table) -> None: """ Registers an instance with the client, and warms the channel pool for the instance The client will periodically refresh grpc channel pool used to make requests, and new channels will be warmed for each registered instance - Channels will not be refreshed unless at least one instance is registered Args: - instance_id: id of the instance to register. @@ -258,14 +247,9 @@ async def _register_instance(self, instance_id: str, owner: Table) -> None: self._instance_owners.setdefault(instance_name, set()).add(id(owner)) if instance_name not in self._active_instances: self._active_instances.add(instance_name) - if self._channel_refresh_tasks: - # refresh tasks already running - # call ping and warm on all existing channels - for channel in self.transport.channels: - await self._ping_and_warm_instances(channel) - else: - # refresh tasks aren't active. start them as background tasks - self.start_background_channel_refresh() + # call ping and warm on all existing channels + for channel in self._pool.channels: + await self._ping_and_warm_instances(channel, instance_name) async def _remove_instance_registration( self, instance_id: str, owner: Table @@ -324,7 +308,8 @@ def get_table( ) async def __aenter__(self): - self.start_background_channel_refresh() + # ensure wrapped grpc background tasks are running + await self.start_pool_background_tasks() return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -437,7 +422,6 @@ async def read_rows_stream( - GoogleAPIError: raised if the request encounters an unrecoverable error - IdleTimeout: if iterator was abandoned """ - operation_timeout = operation_timeout or self.default_operation_timeout per_request_timeout = per_request_timeout or self.default_per_request_timeout diff --git a/google/cloud/bigtable_v2/__init__.py b/google/cloud/bigtable_v2/__init__.py index 342718dea..ee3bd8c0c 100644 --- a/google/cloud/bigtable_v2/__init__.py +++ b/google/cloud/bigtable_v2/__init__.py @@ -31,6 +31,7 @@ from .types.bigtable import MutateRowsResponse from .types.bigtable import PingAndWarmRequest from .types.bigtable import PingAndWarmResponse +from .types.bigtable import RateLimitInfo from .types.bigtable import ReadChangeStreamRequest from .types.bigtable import ReadChangeStreamResponse from .types.bigtable import ReadModifyWriteRowRequest @@ -54,6 +55,7 @@ from .types.data import StreamPartition from .types.data import TimestampRange from .types.data import ValueRange +from .types.feature_flags import FeatureFlags from .types.request_stats import FullReadStatsView from .types.request_stats import ReadIterationStats from .types.request_stats import RequestLatencyStats @@ -69,6 +71,7 @@ "Column", "ColumnRange", "Family", + "FeatureFlags", "FullReadStatsView", "GenerateInitialChangeStreamPartitionsRequest", "GenerateInitialChangeStreamPartitionsResponse", @@ -79,6 +82,7 @@ "Mutation", "PingAndWarmRequest", "PingAndWarmResponse", + "RateLimitInfo", "ReadChangeStreamRequest", "ReadChangeStreamResponse", "ReadIterationStats", diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py index 3465569b3..c9e259fcc 100644 --- a/google/cloud/bigtable_v2/services/bigtable/async_client.py +++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -169,7 +170,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, BigtableTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, BigtableTransport, Callable[..., BigtableTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -181,10 +184,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.BigtableTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It + transport (Optional[Union[str,BigtableTransport,Callable[..., BigtableTransport]]]): + The transport to use, or a callable that generates one with the + set of initialization arguments. + If set to None, a transport is chosen automatically. + client_options (Optional[ClientOptions]): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index 60622509a..68e949763 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -53,7 +54,6 @@ from .transports.base import BigtableTransport, DEFAULT_CLIENT_INFO from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport -from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .transports.rest import BigtableRestTransport @@ -68,7 +68,6 @@ class BigtableClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport - _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( @@ -367,7 +366,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, BigtableTransport]] = None, + transport: Optional[ + Union[str, BigtableTransport, Callable[..., BigtableTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -379,9 +380,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, BigtableTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,BigtableTransport,Callable[..., BigtableTransport]]]): + The transport to use, or a callable that generates one with the + set of initialization arguments. + If set to None, a transport is chosen automatically. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the @@ -450,8 +452,12 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(transport) - self._transport = Transport( + transport_init = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else transport + ) + self._transport = transport_init( credentials=credentials, credentials_file=client_options.credentials_file, host=api_endpoint, diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py index e8796bb8c..1b03919f6 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py @@ -19,7 +19,6 @@ from .base import BigtableTransport from .grpc import BigtableGrpcTransport from .grpc_asyncio import BigtableGrpcAsyncIOTransport -from .pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .rest import BigtableRestTransport from .rest import BigtableRestInterceptor @@ -28,14 +27,12 @@ _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport -_transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport __all__ = ( "BigtableTransport", "BigtableGrpcTransport", "BigtableGrpcAsyncIOTransport", - "PooledBigtableGrpcAsyncIOTransport", "BigtableRestTransport", "BigtableRestInterceptor", ) diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py index b9e073e8a..11f4b71a2 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py @@ -51,7 +51,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,8 +77,10 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a callable + that generates one with the set of initialization arguments. + If set to None, a channel is created automatically. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -118,7 +120,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -159,7 +161,8 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py index 8bf02ce77..aa59dabf4 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py @@ -96,7 +96,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -123,8 +123,10 @@ def __init__( scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a callable + that generates one with the set of initialization arguments. + If set to None, a channel is created automatically. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -164,7 +166,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -204,7 +206,8 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py deleted file mode 100644 index 372e5796d..000000000 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ /dev/null @@ -1,426 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import asyncio -import warnings -from functools import partialmethod -from functools import partial -from typing import ( - Awaitable, - Callable, - Dict, - Optional, - Sequence, - Tuple, - Union, - List, - Type, -) - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.bigtable_v2.types import bigtable -from .base import BigtableTransport, DEFAULT_CLIENT_INFO -from .grpc_asyncio import BigtableGrpcAsyncIOTransport - - -class PooledMultiCallable: - def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): - self._init_args = args - self._init_kwargs = kwargs - self.next_channel_fn = channel_pool.next_channel - - -class PooledUnaryUnaryMultiCallable(PooledMultiCallable, aio.UnaryUnaryMultiCallable): - def __call__(self, *args, **kwargs) -> aio.UnaryUnaryCall: - return self.next_channel_fn().unary_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledUnaryStreamMultiCallable(PooledMultiCallable, aio.UnaryStreamMultiCallable): - def __call__(self, *args, **kwargs) -> aio.UnaryStreamCall: - return self.next_channel_fn().unary_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamUnaryMultiCallable(PooledMultiCallable, aio.StreamUnaryMultiCallable): - def __call__(self, *args, **kwargs) -> aio.StreamUnaryCall: - return self.next_channel_fn().stream_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamStreamMultiCallable( - PooledMultiCallable, aio.StreamStreamMultiCallable -): - def __call__(self, *args, **kwargs) -> aio.StreamStreamCall: - return self.next_channel_fn().stream_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledChannel(aio.Channel): - def __init__( - self, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - quota_project_id: Optional[str] = None, - default_scopes: Optional[Sequence[str]] = None, - scopes: Optional[Sequence[str]] = None, - default_host: Optional[str] = None, - insecure: bool = False, - **kwargs, - ): - self._pool: List[aio.Channel] = [] - self._next_idx = 0 - if insecure: - self._create_channel = partial(aio.insecure_channel, host) - else: - self._create_channel = partial( - grpc_helpers_async.create_channel, - target=host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=default_scopes, - scopes=scopes, - default_host=default_host, - **kwargs, - ) - for i in range(pool_size): - self._pool.append(self._create_channel()) - - def next_channel(self) -> aio.Channel: - channel = self._pool[self._next_idx] - self._next_idx = (self._next_idx + 1) % len(self._pool) - return channel - - def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: - return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) - - def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: - return PooledUnaryStreamMultiCallable(self, *args, **kwargs) - - def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: - return PooledStreamUnaryMultiCallable(self, *args, **kwargs) - - def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: - return PooledStreamStreamMultiCallable(self, *args, **kwargs) - - async def close(self, grace=None): - close_fns = [channel.close(grace=grace) for channel in self._pool] - return await asyncio.gather(*close_fns) - - async def channel_ready(self): - ready_fns = [channel.channel_ready() for channel in self._pool] - return asyncio.gather(*ready_fns) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: - raise NotImplementedError() - - async def wait_for_state_change(self, last_observed_state): - raise NotImplementedError() - - async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None - ) -> aio.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one - new_channel(grpc.aio.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - """ - if channel_idx >= len(self._pool) or channel_idx < 0: - raise ValueError( - f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" - ) - if new_channel is None: - new_channel = self._create_channel() - old_channel = self._pool[channel_idx] - self._pool[channel_idx] = new_channel - await asyncio.sleep(swap_sleep) - await old_channel.close(grace=grace) - return new_channel - - -class PooledBigtableGrpcAsyncIOTransport(BigtableGrpcAsyncIOTransport): - """Pooled gRPC AsyncIO backend transport for Bigtable. - - Service for reading from and writing to existing Bigtable - tables. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - - This class allows channel pooling, so multiple channels can be used concurrently - when making requests. Channels are rotated in a round-robin fashion. - """ - - @classmethod - def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcAsyncIOTransport"]: - """ - Creates a new class with a fixed channel pool size. - - A fixed channel pool makes compatibility with other transports easier, - as the initializer signature is the same. - """ - - class PooledTransportFixed(cls): - __init__ = partialmethod(cls.__init__, pool_size=pool_size) - - PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" - PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ - return PooledTransportFixed - - @classmethod - def create_channel( - cls, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a PooledChannel object, representing a pool of gRPC AsyncIO channels - Args: - pool_size (int): The number of channels in the pool. - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - PooledChannel: a channel pool object - """ - - return PooledChannel( - pool_size, - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs, - ) - - def __init__( - self, - *, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - pool_size (int): the number of grpc channels to maintain in a pool - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - ValueError: if ``pool_size`` <= 0 - """ - if pool_size <= 0: - raise ValueError(f"invalid pool_size: {pool_size}") - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - BigtableTransport.__init__( - self, - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - self._quota_project_id = quota_project_id - self._grpc_channel = type(self).create_channel( - pool_size, - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=self._quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def pool_size(self) -> int: - """The number of grpc channels in the pool.""" - return len(self._grpc_channel._pool) - - @property - def channels(self) -> List[grpc.Channel]: - """Acccess the internal list of grpc channels.""" - return self._grpc_channel._pool - - async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None - ) -> aio.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one - new_channel(grpc.aio.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - """ - return await self._grpc_channel.replace_channel( - channel_idx, grace, swap_sleep, new_channel - ) - - -__all__ = ("PooledBigtableGrpcAsyncIOTransport",) diff --git a/google/cloud/bigtable_v2/types/__init__.py b/google/cloud/bigtable_v2/types/__init__.py index bb2533e33..9f15efaf5 100644 --- a/google/cloud/bigtable_v2/types/__init__.py +++ b/google/cloud/bigtable_v2/types/__init__.py @@ -24,6 +24,7 @@ MutateRowsResponse, PingAndWarmRequest, PingAndWarmResponse, + RateLimitInfo, ReadChangeStreamRequest, ReadChangeStreamResponse, ReadModifyWriteRowRequest, @@ -50,6 +51,9 @@ TimestampRange, ValueRange, ) +from .feature_flags import ( + FeatureFlags, +) from .request_stats import ( FullReadStatsView, ReadIterationStats, @@ -71,6 +75,7 @@ "MutateRowsResponse", "PingAndWarmRequest", "PingAndWarmResponse", + "RateLimitInfo", "ReadChangeStreamRequest", "ReadChangeStreamResponse", "ReadModifyWriteRowRequest", @@ -94,6 +99,7 @@ "StreamPartition", "TimestampRange", "ValueRange", + "FeatureFlags", "FullReadStatsView", "ReadIterationStats", "RequestLatencyStats", diff --git a/google/cloud/bigtable_v2/types/bigtable.py b/google/cloud/bigtable_v2/types/bigtable.py index ea97588c2..13f6ac0db 100644 --- a/google/cloud/bigtable_v2/types/bigtable.py +++ b/google/cloud/bigtable_v2/types/bigtable.py @@ -38,6 +38,7 @@ "MutateRowResponse", "MutateRowsRequest", "MutateRowsResponse", + "RateLimitInfo", "CheckAndMutateRowRequest", "CheckAndMutateRowResponse", "PingAndWarmRequest", @@ -61,8 +62,9 @@ class ReadRowsRequest(proto.Message): Values are of the form ``projects//instances//tables/``. app_profile_id (str): - This value specifies routing for replication. This API only - accepts the empty value of app_profile_id. + This value specifies routing for replication. + If not specified, the "default" application + profile will be used. rows (google.cloud.bigtable_v2.types.RowSet): The row keys and/or ranges to read sequentially. If not specified, reads from all @@ -469,10 +471,19 @@ class Entry(proto.Message): class MutateRowsResponse(proto.Message): r"""Response message for BigtableService.MutateRows. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: entries (MutableSequence[google.cloud.bigtable_v2.types.MutateRowsResponse.Entry]): One or more results for Entries from the batch request. + rate_limit_info (google.cloud.bigtable_v2.types.RateLimitInfo): + Information about how client should limit the + rate (QPS). Primirily used by supported official + Cloud Bigtable clients. If unset, the rate limit + info is not provided by the server. + + This field is a member of `oneof`_ ``_rate_limit_info``. """ class Entry(proto.Message): @@ -506,6 +517,50 @@ class Entry(proto.Message): number=1, message=Entry, ) + rate_limit_info: "RateLimitInfo" = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message="RateLimitInfo", + ) + + +class RateLimitInfo(proto.Message): + r"""Information about how client should adjust the load to + Bigtable. + + Attributes: + period (google.protobuf.duration_pb2.Duration): + Time that clients should wait before + adjusting the target rate again. If clients + adjust rate too frequently, the impact of the + previous adjustment may not have been taken into + account and may over-throttle or under-throttle. + If clients adjust rate too slowly, they will not + be responsive to load changes on server side, + and may over-throttle or under-throttle. + factor (float): + If it has been at least one ``period`` since the last load + adjustment, the client should multiply the current load by + this value to get the new target load. For example, if the + current load is 100 and ``factor`` is 0.8, the new target + load should be 80. After adjusting, the client should ignore + ``factor`` until another ``period`` has passed. + + The client can measure its load using any unit that's + comparable over time For example, QPS can be used as long as + each request involves a similar amount of work. + """ + + period: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=1, + message=duration_pb2.Duration, + ) + factor: float = proto.Field( + proto.DOUBLE, + number=2, + ) class CheckAndMutateRowRequest(proto.Message): diff --git a/google/cloud/bigtable_v2/types/feature_flags.py b/google/cloud/bigtable_v2/types/feature_flags.py new file mode 100644 index 000000000..1b5f76e24 --- /dev/null +++ b/google/cloud/bigtable_v2/types/feature_flags.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.bigtable.v2", + manifest={ + "FeatureFlags", + }, +) + + +class FeatureFlags(proto.Message): + r"""Feature flags supported by a client. This is intended to be sent as + part of request metadata to assure the server that certain behaviors + are safe to enable. This proto is meant to be serialized and + websafe-base64 encoded under the ``bigtable-features`` metadata key. + The value will remain constant for the lifetime of a client and due + to HTTP2's HPACK compression, the request overhead will be tiny. + This is an internal implementation detail and should not be used by + endusers directly. + + Attributes: + mutate_rows_rate_limit (bool): + Notify the server that the client enables + batch write flow control by requesting + RateLimitInfo from MutateRowsResponse. + """ + + mutate_rows_rate_limit: bool = proto.Field( + proto.BOOL, + number=3, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/tests/unit/_channel_pooling/__init__.py b/tests/unit/_channel_pooling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/_channel_pooling/test_dynamic_pooled_channel.py b/tests/unit/_channel_pooling/test_dynamic_pooled_channel.py new file mode 100644 index 000000000..70b5b5b2b --- /dev/null +++ b/tests/unit/_channel_pooling/test_dynamic_pooled_channel.py @@ -0,0 +1,152 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import pytest + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + +from .test_pooled_channel import TestPooledChannel +from .test_wrapped_channel import TestBackgroundTaskMixin + + +class TestDynamicPooledChannel(TestPooledChannel, TestBackgroundTaskMixin): + + def _get_target(self): + from google.cloud.bigtable._channel_pooling.dynamic_pooled_channel import DynamicPooledChannel + + return DynamicPooledChannel + + def _make_one(self, *args, init_background_task=False, async_mock=True, mock_track_wrapper=True, **kwargs): + import warnings + from google.cloud.bigtable._channel_pooling.dynamic_pooled_channel import DynamicPoolOptions + mock_type = AsyncMock if async_mock else mock.Mock + kwargs.setdefault("create_channel_fn", lambda *args, **kwargs: mock_type()) + pool_size = kwargs.pop("pool_size", 3) + kwargs.setdefault("pool_options", DynamicPoolOptions(start_size=pool_size)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + if init_background_task: + instance = self._get_target()(*args, **kwargs) + else: + with mock.patch.object(self._get_target(), "start_background_task"): + instance = self._get_target()(*args, **kwargs) + if mock_track_wrapper: + # replace tracked channels with mocks + instance._create_channel = kwargs["create_channel_fn"] + for i in range(len(instance._pool)): + instance._pool[i] = mock_type() + return instance + + def test_ctor(self): + """ + test that constuctor sets starting values + """ + from google.cloud.bigtable._channel_pooling.tracked_channel import TrackedChannel + channel_fn = mock.Mock() + channel_fn.side_effect = lambda *args, **kwargs: mock.Mock() + expected_pool_options = mock.Mock() + expected_pool_size = 12 + expected_pool_options.start_size = expected_pool_size + extra_args = ["a", "b"] + extra_kwargs = {"c": "d"} + warm_fn = AsyncMock() + on_remove = AsyncMock() + instance = self._make_one( + *extra_args, + create_channel_fn=channel_fn, + pool_options=expected_pool_options, + warm_channel_fn=warm_fn, + on_remove=on_remove, + mock_track_wrapper=False, + **extra_kwargs, + ) + assert len(instance.channels) == expected_pool_size + assert channel_fn.call_count == expected_pool_size + assert channel_fn.call_args == mock.call(*extra_args, **extra_kwargs) + # create function should create a tracked channel + assert isinstance(instance._create_channel(), TrackedChannel) + # no args in outer function + assert instance._create_channel.args == tuple() + assert instance._create_channel.keywords == {} + # ensure each channel is unique + assert len(set(instance.channels)) == expected_pool_size + # callbacks are set up + assert instance._on_remove == on_remove + assert instance._warm_channel == warm_fn + + def test_ctor_defaults(self): + """ + test with minimal arguments + """ + from google.cloud.bigtable._channel_pooling.tracked_channel import TrackedChannel + from google.cloud.bigtable._channel_pooling.dynamic_pooled_channel import DynamicPoolOptions + channel_fn = mock.Mock() + channel_fn.side_effect = lambda *args, **kwargs: mock.Mock() + instance = self._make_one( + create_channel_fn=channel_fn, + mock_track_wrapper=False, + ) + assert len(instance.channels) == 3 + assert channel_fn.call_count == 3 + assert channel_fn.call_args == mock.call() + # create function should create a tracked channel + assert isinstance(instance._create_channel(), TrackedChannel) + # no args in outer function + assert instance._create_channel.args == tuple() + assert instance._create_channel.keywords == {} + # ensure each channel is unique + assert len(set(instance.channels)) == 3 + # callbacks are empty + assert instance._on_remove is None + assert instance._warm_channel is None + # DynamicPoolOptions created automatically + assert isinstance(instance._pool_options, DynamicPoolOptions) + + @pytest.mark.asyncio + async def test_resize_routine(self): + pass + + @pytest.mark.asyncio + @pytest.mark.parametrize("start_size,max_rpcs_per_channel,min_rpcs_per_channel,max_channels,min_channels,max_delta,usages,new_size", [ + (1, 10, 0, 10, 1, 1, [0, 0, 0], 1), + ]) + async def test_attempt_resize( + self, + start_size, + max_rpcs_per_channel, + min_rpcs_per_channel, + max_channels, + min_channels, + max_delta, + usages, + new_size, + ): + """ + test different resize scenarios + """ + pass + + @pytest.mark.asyncio + async def test_resize_reset_next_idx(self): + """ + if the pools shrinks below next_idx, next_idx should be set to 0 + """ + pass diff --git a/tests/unit/_channel_pooling/test_pooled_channel.py b/tests/unit/_channel_pooling/test_pooled_channel.py new file mode 100644 index 000000000..7eed26674 --- /dev/null +++ b/tests/unit/_channel_pooling/test_pooled_channel.py @@ -0,0 +1,332 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import pytest + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + + +class TestPooledChannel: + def _make_one_with_channel_mock(self, *args, async_mock=True, **kwargs): + channel = AsyncMock() if async_mock else mock.Mock() + return ( + self._make_one(*args, create_channel_fn=lambda: channel, **kwargs), + channel, + ) + + def _get_target(self): + from google.cloud.bigtable._channel_pooling.pooled_channel import PooledChannel + + return PooledChannel + + def _make_one(self, *args, async_mock=True, **kwargs): + from google.cloud.bigtable._channel_pooling.pooled_channel import ( + StaticPoolOptions, + ) + + mock_type = AsyncMock if async_mock else mock.Mock + kwargs.setdefault("create_channel_fn", lambda *args, **kwargs: mock_type()) + pool_size = kwargs.pop("pool_size", 3) + kwargs.setdefault("pool_options", StaticPoolOptions(pool_size=pool_size)) + return self._get_target()(*args, **kwargs) + + def test_ctor(self): + """ + test that constuctor sets starting values + """ + channel_fn = mock.Mock() + channel_fn.side_effect = lambda *args, **kwargs: mock.Mock() + expected_pool_options = mock.Mock() + expected_pool_size = 12 + expected_pool_options.pool_size = expected_pool_size + extra_args = ["a", "b"] + extra_kwargs = {"c": "d"} + instance = self._make_one( + *extra_args, + create_channel_fn=channel_fn, + pool_options=expected_pool_options, + **extra_kwargs, + ) + assert instance._create_channel.func == channel_fn + assert instance._create_channel.args == tuple(extra_args) + assert instance._create_channel.keywords == extra_kwargs + assert len(instance.channels) == expected_pool_size + assert channel_fn.call_count == expected_pool_size + assert channel_fn.call_args == mock.call(*extra_args, **extra_kwargs) + # ensure each channel is unique + assert len(set(instance.channels)) == expected_pool_size + + def test_ctor_defaults(self): + """ + test with minimal arguments + """ + channel_fn = mock.Mock() + channel_fn.side_effect = lambda *args, **kwargs: mock.Mock() + instance = self._make_one( + create_channel_fn=channel_fn, + ) + assert instance._create_channel.func == channel_fn + assert instance._create_channel.args == tuple() + assert instance._create_channel.keywords == {} + assert len(instance.channels) == 3 # default size + assert channel_fn.call_count == 3 + assert channel_fn.call_args == mock.call() + # ensure each channel is unique + assert len(set(instance.channels)) == 3 + + def test_ctor_no_create_fn(self): + """ + test that constuctor raises error if no create_channel_fn is provided + """ + with pytest.raises(ValueError) as exc: + self._get_target()() + assert "create_channel_fn" in str(exc.value) + + @pytest.mark.parametrize("pool_size", [1, 2, 7, 10, 100]) + @pytest.mark.asyncio + async def test_next_channel(self, pool_size): + """ + next_channel should rotate between channels + """ + async with self._make_one(pool_size=pool_size) as instance: + # ensure each channel is unique + assert len(set(instance.channels)) == pool_size + # make sure next_channel loops through all channels as expected + instance._next_idx = 0 + expected_results = [ + instance.channels[i % pool_size] for i in range(pool_size * 2) + ] + for idx, expected_channel in enumerate(expected_results): + expected_next_idx = idx % pool_size + assert instance._next_idx == expected_next_idx + assert instance.next_channel() is expected_channel + # next_idx should be updated + assert instance._next_idx == (expected_next_idx + 1) % pool_size + + def test___getitem__(self): + """ + should be able to index on pool directly + """ + instance = self._make_one(pool_size=3) + assert instance[1] is instance.channels[1] + + @pytest.mark.asyncio + async def test_next_channel_unexpected_next(self): + """ + if _next_idx ends up out of bounds, it should be reset to 0 + """ + pool_size = 3 + async with self._make_one(pool_size=pool_size) as instance: + assert instance._next_idx == 0 + instance.next_channel() + assert instance._next_idx == 1 + # try to put it out of bounds + instance._next_idx = 100 + assert instance._next_idx == 100 + instance.next_channel() + assert instance._next_idx == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_unary", "stream_unary"]) + async def test_unary_call_api_passthrough(self, method_name): + """ + rpc call methods should use underlying channel calls + """ + mock_rpc = AsyncMock() + call_mock = mock_rpc.call() + callable_mock = lambda: call_mock # noqa: E731 + instance, channel = self._make_one_with_channel_mock( + async_mock=False, pool_size=1 + ) + with mock.patch.object(instance, "next_channel") as next_channel_mock: + next_channel_mock.return_value = channel + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = callable_mock + # call rpc to get Multicallable + arg_mock = mock.Mock() + found_callable = wrapper_method(arg_mock) + # assert that response was passed through + found_call = found_callable() + await found_call + # assert that wrapped channel method was called + assert mock_rpc.call.await_count == 1 + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == [arg_mock] + assert channel_method.call_count == 1 + assert next_channel_mock.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_stream", "stream_stream"]) + async def test_stream_call_api_passthrough(self, method_name): + """ + rpc call methods should use underlying channel calls + """ + expected_result = mock.Mock() + + async def mock_stream(): + yield expected_result + + instance, channel = self._make_one_with_channel_mock( + async_mock=False, pool_size=1 + ) + with mock.patch.object(instance, "next_channel") as next_channel_mock: + next_channel_mock.return_value = channel + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = lambda: mock_stream() + # call rpc to get Multicallable + arg_mock = mock.Mock() + found_callable = wrapper_method(arg_mock) + # assert that response was passed through + found_call = found_callable() + results = [item async for item in found_call] + assert results == [expected_result] + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == [arg_mock] + assert channel_method.call_count == 1 + assert next_channel_mock.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pool_size,calls,expected_call_count", + [ + (1, 1, [1]), + (1, 2, [2]), + (2, 1, [1, 0]), + (4, 4, [1, 1, 1, 1]), + (4, 5, [2, 1, 1, 1]), + (4, 6, [2, 2, 1, 1]), + (4, 7, [2, 2, 2, 1]), + (4, 0, [0, 0, 0, 0]), + ], + ) + async def test_rpc_call_rotates(self, pool_size, calls, expected_call_count): + """ + rpc call methods should rotate between channels in the pool + """ + for method_name in [ + "unary_unary", + "stream_unary", + "unary_stream", + "stream_stream", + ]: + instance = self._make_one(async_mock=False, pool_size=pool_size) + wrapper_method = getattr(instance, method_name) + for call_num in range(calls): + found_callable = wrapper_method() + found_callable() + assert len(expected_call_count) == len(instance.channels) + for channel_num, expected_channel_call_count in enumerate( + expected_call_count + ): + channel_method = getattr(instance[channel_num], method_name) + assert channel_method.call_count == expected_channel_call_count + + @pytest.mark.asyncio + async def test_unimplented_state_methods(self): + """ + get_state and wait_for_state should raise NotImplementedError, because + behavior is not defined for a pool of channels + """ + instance = self._make_one(async_mock=False, pool_size=1) + with pytest.raises(NotImplementedError): + instance.get_state() + with pytest.raises(NotImplementedError): + await instance.wait_for_state_change(mock.Mock()) + + @pytest.mark.parametrize( + "method_name,arg_num", + [ + ("close", 1), + ("__aenter__", 0), + ("__aexit__", 3), + ], + ) + @pytest.mark.asyncio + async def test_async_api_passthrough(self, method_name, arg_num): + """ + Wrapper should respond to full grpc Channel API, and pass through + resonses to all wrapped channels + """ + pool_size = 5 + instance = self._make_one(async_mock=True, pool_size=pool_size) + wrapper_method = getattr(instance, method_name) + # make function call + args = [mock.Mock() for _ in range(arg_num)] + await wrapper_method(*args) + # assert that wrapped channel method was called for each channel + for channel in instance.channels: + channel_method = getattr(channel, method_name) + assert channel_method.call_count == 1 + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == args + + @pytest.mark.asyncio + async def test_channel_ready(self): + """ + channel ready should block until all channels are ready + """ + pool_size = 5 + instance = self._make_one(async_mock=True, pool_size=pool_size) + await instance.channel_ready() + for channel in instance.channels: + assert channel.channel_ready.call_count == 1 + assert channel.channel_ready.await_count == 1 + + @pytest.mark.asyncio + async def test_context_manager(self): + """ + entering and exit should call enter and exit on all channels + """ + channel_list = [] + async with self._make_one() as instance: + for channel in instance.channels: + assert channel.__aenter__.call_count == 1 + assert channel.__aenter__.await_count == 1 + assert channel.__aexit__.call_count == 0 + channel_list = instance.channels + for channel in channel_list: + assert channel.__aexit__.call_count == 1 + assert channel.__aexit__.await_count == 1 + + def test_index_of(self): + """ + index_of should return the index for each channel, or -1 if not in pool + """ + pool_size = 5 + instance = self._make_one(async_mock=True, pool_size=pool_size) + for channel_num, channel in enumerate(instance.channels): + found_idx = instance.index_of(channel) + assert found_idx == channel_num + # test for channels not in list + fake_channel = mock.Mock() + found_idx = instance.index_of(fake_channel) + assert found_idx == -1 diff --git a/tests/unit/_channel_pooling/test_refreshable_channel.py b/tests/unit/_channel_pooling/test_refreshable_channel.py new file mode 100644 index 000000000..f77ed9305 --- /dev/null +++ b/tests/unit/_channel_pooling/test_refreshable_channel.py @@ -0,0 +1,231 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import pytest +import asyncio + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + +from .test_wrapped_channel import TestWrappedChannel +from .test_wrapped_channel import TestBackgroundTaskMixin + + +class TestRefreshableChannel(TestWrappedChannel, TestBackgroundTaskMixin): + def _make_one_with_channel_mock(self, *args, async_mock=True, **kwargs): + channel = AsyncMock() if async_mock else mock.Mock() + create_fn = mock.Mock() + create_fn.return_value = channel + return ( + self._make_one(*args, create_channel_fn=create_fn, **kwargs), + channel, + ) + + def _get_target(self): + from google.cloud.bigtable._channel_pooling.refreshable_channel import ( + RefreshableChannel, + ) + + return RefreshableChannel + + def _make_one(self, *args, init_background_task=False, **kwargs): + import warnings + + kwargs.setdefault("create_channel_fn", lambda *args, **kwargs: AsyncMock()) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + if init_background_task: + return self._get_target()(*args, **kwargs) + else: + with mock.patch.object(self._get_target(), "start_background_task"): + return self._get_target()(*args, **kwargs) + + def test_ctor(self): + """ + test that constuctor sets starting values + """ + expected_channel = mock.Mock() + channel_fn = mock.Mock() + channel_fn.return_value = expected_channel + warm_fn = lambda: AsyncMock() # noqa: E731 + replace_fn = lambda: AsyncMock() # noqa: E731 + min_refresh = 4 + max_refresh = 5 + extra_args = ["a", "b"] + extra_kwargs = {"c": "d"} + with mock.patch.object( + self._get_target(), "start_background_task" + ) as start_background_task_mock: + instance = self._make_one( + *extra_args, + init_background_task=True, + create_channel_fn=channel_fn, + refresh_interval_min=min_refresh, + refresh_interval_max=max_refresh, + warm_channel_fn=warm_fn, + on_replace=replace_fn, + **extra_kwargs, + ) + assert instance._create_channel.func == channel_fn + assert instance._create_channel.args == tuple(extra_args) + assert instance._create_channel.keywords == extra_kwargs + assert instance._refresh_interval_min == min_refresh + assert instance._refresh_interval_max == max_refresh + assert instance._warm_channel == warm_fn + assert instance._on_replace == replace_fn + assert instance._background_task is None + assert instance._channel == expected_channel + assert start_background_task_mock.call_count == 1 + assert channel_fn.call_count == 1 + assert channel_fn.call_args == mock.call(*extra_args, **extra_kwargs) + + def test_ctor_defaults(self): + """ + test with minimal arguments + """ + expected_channel = mock.Mock() + channel_fn = lambda: expected_channel # noqa: E731 + with mock.patch.object( + self._get_target(), "start_background_task" + ) as start_background_task_mock: + instance = self._make_one( + create_channel_fn=channel_fn, init_background_task=True + ) + assert instance._create_channel.func == channel_fn + assert instance._create_channel.args == tuple() + assert instance._create_channel.keywords == {} + assert instance._refresh_interval_min == 60 * 35 + assert instance._refresh_interval_max == 60 * 45 + assert instance._warm_channel is None + assert instance._on_replace is None + assert instance._background_task is None + assert instance._channel == expected_channel + assert start_background_task_mock.call_count == 1 + + def test_ctor_no_create_fn(self): + """ + test that constuctor raises error if no create_channel_fn is provided + """ + with pytest.raises(ValueError) as exc: + self._get_target()() + assert "create_channel_fn" in str(exc.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [ + (None, 1, 60 * 35), + (10, 10, 100), + (10, 1, 10), + (1, 10, 10), + (30, 10, 300), + ], + ) + async def test__manage_channel_sleeps( + self, refresh_interval, num_cycles, expected_sleep + ): + """ + Ensure manage_channel_lifecycle sleeps for the correct amount of time in between refreshes + """ + import time + import random + + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time: + time.return_value = 0 + with mock.patch.object(asyncio, "sleep") as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + try: + instance, _ = self._make_one_with_channel_mock() + args = ( + (refresh_interval, refresh_interval) + if refresh_interval + else tuple() + ) + await instance._manage_channel_lifecycle(*args) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + + @pytest.mark.asyncio + async def test__manage_channel_random(self): + """ + Should use random to add noise to sleep times + """ + import random + + with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + instance = self._make_one_with_channel_mock()[0] + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + await instance._manage_channel_lifecycle(min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 3 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.asyncio + async def test__manage_channel_callbacks(self): + """ + Should call warm_channel_fn when creating a new channel, + and on_replace when replacing a channel + """ + instance, orig_channel = self._make_one_with_channel_mock() + new_channel = AsyncMock() + instance._warm_channel = AsyncMock() + instance._on_replace = AsyncMock() + instance._create_channel = lambda: new_channel + new_channel = AsyncMock() + with mock.patch.object(asyncio, "sleep", AsyncMock()) as sleep: + # break out after second sleep + sleep.side_effect = [None, asyncio.CancelledError] + try: + await instance._manage_channel_lifecycle() + except asyncio.CancelledError: + pass + assert instance._channel == new_channel + # should call warm_channel_fn on old channel at start, then new channel after replacement + assert instance._warm_channel.call_count == 2 + assert instance._warm_channel.call_args_list[0][0][0] == orig_channel + assert instance._warm_channel.call_args_list[1][0][0] == new_channel + # should only call on_replace on old channel after replacement + assert instance._on_replace.call_count == 1 + assert instance._on_replace.call_args_list[0][0][0] == orig_channel diff --git a/tests/unit/_channel_pooling/test_tracked_channel.py b/tests/unit/_channel_pooling/test_tracked_channel.py new file mode 100644 index 000000000..941a50259 --- /dev/null +++ b/tests/unit/_channel_pooling/test_tracked_channel.py @@ -0,0 +1,153 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import pytest +import asyncio + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + +from .test_wrapped_channel import TestWrappedChannel + + +class TestTrackedChannel(TestWrappedChannel): + def _make_one_with_channel_mock(self, *args, async_mock=True, **kwargs): + channel = AsyncMock() if async_mock else mock.Mock() + return self._make_one(channel, *args, **kwargs), channel + + def _get_target(self): + from google.cloud.bigtable._channel_pooling.tracked_channel import ( + TrackedChannel, + ) + + return TrackedChannel + + def _make_one(self, *args, **kwargs): + return self._get_target()(*args, **kwargs) + + def test_ctor(self): + expected_channel = mock.Mock() + instance = self._make_one(expected_channel) + assert instance._channel is expected_channel + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 0 + + def test_track_rpc(self): + """ + counts should be incremented when track_rpc is used + """ + instance = self._make_one(mock.Mock()) + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 0 + # start a single rpc + with instance.track_rpc(): + assert instance.active_rpcs == 1 + assert instance.max_active_rpcs == 1 + # start a second rpc + with instance.track_rpc(): + assert instance.active_rpcs == 2 + assert instance.max_active_rpcs == 2 + # end the second rpc + assert instance.active_rpcs == 1 + assert instance.max_active_rpcs == 2 + # end the first rpc + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 2 + + def test_get_and_reset_max_active_rpcs(self): + instance = self._make_one(mock.Mock()) + expected_max = 11 + instance.max_active_rpcs = expected_max + found_max = instance.get_and_reset_max_active_rpcs() + assert found_max == expected_max + assert instance.max_active_rpcs == 0 + + @pytest.mark.parametrize( + "method_name", ["unary_unary", "unary_stream", "stream_unary", "stream_stream"] + ) + def test_calls_wrapped(self, method_name): + from google.cloud.bigtable._channel_pooling.tracked_channel import ( + _TrackedStreamResponseMixin, + _TrackedUnaryResponseMixin, + ) + from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + _WrappedMultiCallable, + ) + + instance, channel = self._make_one_with_channel_mock(async_mock=False) + found_multicallable = getattr(instance, method_name)() + assert isinstance(found_multicallable, _WrappedMultiCallable) + found_call = found_multicallable() + assert isinstance( + found_call, (_TrackedUnaryResponseMixin, _TrackedStreamResponseMixin) + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_unary", "stream_unary"]) + async def test_unary_calls_tracked(self, method_name): + """ + unary_unary calls should update track count when called + """ + import time + + instance, channel = self._make_one_with_channel_mock(async_mock=False) + getattr( + channel, method_name + ).return_value = lambda *args, **kwargs: asyncio.sleep(0.01) + found_callable = getattr(instance, method_name)() + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 0 + await found_callable() + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 1 + # try with multiple calls at once + start_time = time.monotonic() + await asyncio.gather(*[found_callable(None, None) for _ in range(10)]) + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 10 + # make sure rpcs ran in parallel + assert time.monotonic() - start_time < 0.2 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_stream", "stream_stream"]) + async def test_stream_calls_tracked(self, method_name): + """ + stream calls should update track count when called + """ + + async def mock_stream(): + for i in range(3): + yield i + + instance, channel = self._make_one_with_channel_mock(async_mock=False) + getattr(channel, method_name).return_value = lambda: mock_stream() + found_multicallable = getattr(instance, method_name)() + found_stream = found_multicallable() + assert instance.active_rpcs == 0 + assert instance.max_active_rpcs == 0 + async for _ in found_stream: + assert instance.active_rpcs == 1 + assert instance.max_active_rpcs >= 1 + new_stream = found_multicallable() + assert instance.active_rpcs == 1 + assert instance.max_active_rpcs >= 1 + async for _ in new_stream: + assert instance.active_rpcs == 2 + assert instance.max_active_rpcs == 2 diff --git a/tests/unit/_channel_pooling/test_wrapped_channel.py b/tests/unit/_channel_pooling/test_wrapped_channel.py new file mode 100644 index 000000000..fd19adb4b --- /dev/null +++ b/tests/unit/_channel_pooling/test_wrapped_channel.py @@ -0,0 +1,345 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import pytest + +from grpc.experimental import aio # type: ignore + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + + +class TestWrappedChannel: + """ + WrappedChannel should delegate all methods to the wrapped channel + """ + + def _make_one_with_channel_mock(self, *args, async_mock=True, **kwargs): + channel = AsyncMock() if async_mock else mock.Mock() + return self._make_one(channel), channel + + def _make_one(self, *args, **kwargs): + from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + _WrappedChannel, + ) + + return _WrappedChannel(*args, **kwargs) + + def test_ctor(self): + channel = mock.Mock() + instance = self._make_one(channel) + assert instance._channel is channel + + def test_is_channel_instance(self): + """ + should pass isinstance check for aio.Channel + """ + instance = self._make_one(mock.Mock()) + assert isinstance(instance, aio.Channel) + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_unary", "stream_unary"]) + async def test_unary_call_api_passthrough(self, method_name): + """ + rpc call methods should use underlying channel calls + """ + mock_rpc = AsyncMock() + call_mock = mock_rpc.call() + callable_mock = lambda: call_mock # noqa: E731 + instance, channel = self._make_one_with_channel_mock(async_mock=False) + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = callable_mock + # call rpc to get Multicallable + arg_mock = mock.Mock() + found_callable = wrapper_method(arg_mock) + # assert that response was passed through + found_call = found_callable() + await found_call + # assert that wrapped channel method was called + assert mock_rpc.call.await_count == 1 + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == [arg_mock] + assert channel_method.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method_name", ["unary_stream", "stream_stream"]) + async def test_stream_call_api_passthrough(self, method_name): + """ + rpc call methods should use underlying channel calls + """ + expected_result = mock.Mock() + + async def mock_stream(): + yield expected_result + + instance, channel = self._make_one_with_channel_mock(async_mock=False) + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = lambda: mock_stream() + # call rpc to get Multicallable + arg_mock = mock.Mock() + found_callable = wrapper_method(arg_mock) + # assert that response was passed through + found_call = found_callable() + results = [item async for item in found_call] + assert results == [expected_result] + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == [arg_mock] + assert channel_method.call_count == 1 + + @pytest.mark.parametrize("method_name", ["get_state"]) + def test_sync_api_passthrough(self, method_name): + """ + Wrapper should respond to full grpc Channel API, and pass through + resonses to wrapped channel + """ + expected_response = "response" + instance, channel = self._make_one_with_channel_mock(async_mock=False) + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = expected_response + # make function call + arg_mock = mock.Mock() + found_response = wrapper_method(arg_mock) + # assert that wrapped channel method was called + assert channel_method.call_count == 1 + # combine args and kwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == [arg_mock] + # assert that response was passed through + assert found_response == expected_response + + @pytest.mark.parametrize( + "method_name,arg_num", + [ + ("close", 1), + ("channel_ready", 0), + ("__aenter__", 0), + ("__aexit__", 3), + ("wait_for_state_change", 1), + ], + ) + @pytest.mark.asyncio + async def test_async_api_passthrough(self, method_name, arg_num): + """ + Wrapper should respond to full grpc Channel API, and pass through + resonses to wrapped channel + """ + expected_response = "response" + instance, channel = self._make_one_with_channel_mock(async_mock=True) + channel_method = getattr(channel, method_name) + wrapper_method = getattr(instance, method_name) + channel_method.return_value = expected_response + # make function call + args = [mock.Mock() for _ in range(arg_num)] + found_response = await wrapper_method(*args) + # assert that wrapped channel method was called + channel_method.assert_called_once() + channel_method.assert_awaited_once() + # combine args andkwargs + all_args = list(channel_method.call_args.args) + list( + channel_method.call_args.kwargs.values() + ) + assert all_args == args + # assert that response was passed through + if not method_name == "__aenter__": + assert found_response == expected_response + else: + # __aenter__ is special case: should return self + assert found_response is instance + + def test_wrapped_channel_property(self): + """ + Should be able to access wrapped channel + """ + instance, channel = self._make_one_with_channel_mock() + assert instance.wrapped_channel is channel + + @pytest.mark.asyncio + async def test_context_manager(self): + """ + entering and exit should call enter and exit on wrapped channel + """ + instance, channel = self._make_one_with_channel_mock() + assert channel.__aenter__.call_count == 0 + async with instance: + assert channel.__aenter__.call_count == 1 + assert channel.__aenter__.await_count == 1 + assert channel.__aexit__.call_count == 0 + assert channel.__aexit__.call_count == 1 + assert channel.__aexit__.await_count == 1 + + +class TestBackgroundTaskMixin: + def _make_one(self, *args, **kwargs): + from google.cloud.bigtable._channel_pooling.wrapped_channel import ( + _BackgroundTaskMixin, + ) + + class ConcreteBackgroundTask(_BackgroundTaskMixin): + @property + def _task_description(self): + return "Fake task" + + def _background_coroutine(self): + return self._fake_background_coroutine() + + async def _fake_background_coroutine(self): + await asyncio.sleep(0.1) + return "fake response" + + return ConcreteBackgroundTask(*args, **kwargs) + + def test_ctor(self): + """all _BackgroundTaskMixin classes should a _background_task attribute""" + instance = self._make_one() + assert hasattr(instance, "_background_task") + + @pytest.mark.asyncio + async def test_aenter_starts_task(self): + """ + Context manager should start background task + """ + instance = self._make_one() + with mock.patch.object(instance, "start_background_task") as start_mock: + async with instance: + start_mock.assert_called_once() + + @pytest.mark.asyncio + async def test_aexit_stops_task(self): + """ + Context manager should stop background task + """ + instance = self._make_one() + async with instance: + assert instance._background_task is not None + assert instance._background_task.cancelled() is False + assert instance._background_task.cancelled() is True + + @pytest.mark.asyncio + async def test_close_stops_task(self): + """Calling close directly should cancel background task""" + instance = self._make_one() + instance.start_background_task() + assert instance._background_task is not None + assert instance._background_task.cancelled() is False + await instance.close() + assert instance._background_task.cancelled() is True + + @pytest.mark.asyncio + async def test_start_background_task(self): + """test that task can be started properly""" + instance = self._make_one() + assert instance._background_task is None + instance.start_background_task() + assert instance._background_task is not None + assert instance._background_task.done() is False + assert instance._background_task.cancelled() is False + assert isinstance(instance._background_task, asyncio.Task) + await instance.close() + + @pytest.mark.asyncio + async def test_start_background_task_idempotent(self): + """Duplicate calls to start_background_task should be no-ops""" + with mock.patch("asyncio.get_running_loop") as get_loop_mock: + instance = self._make_one() + assert get_loop_mock.call_count == 0 + instance.start_background_task() + assert get_loop_mock.call_count == 1 + instance.start_background_task() + assert get_loop_mock.call_count == 1 + await instance.close() + + def test_start_background_task_sync(self): + """In sync context, should raise RuntimeWarning that routine can't be started""" + instance = self._make_one() + with pytest.warns(RuntimeWarning) as warnings: + instance.start_background_task() + assert instance._background_task is None + assert len(warnings) == 1 + assert "No event loop detected." in str(warnings[0].message) + assert instance._task_description + " is disabled" in str(warnings[0].message) + + def test__task_description(self): + """all _BackgroundTaskMixin classes should a _trask_description method""" + instance = self._make_one() + assert isinstance(instance._task_description, str) + # should start with a capital letter for proper formatting in start_background_task warning + assert instance._task_description[0].isupper() + + @pytest.mark.parametrize( + "task,is_done,expected", + [(None, None, False), (mock.Mock(), False, True), (mock.Mock(), True, False)], + ) + def test_is_active_w_mock(self, task, is_done, expected): + """ + test all possible branches in background_task_is_active with mocks + """ + instance = self._make_one() + instance._background_task = task + if is_done is not None: + instance._background_task.done.return_value = is_done + assert instance.background_task_is_active() == expected + + @pytest.mark.asyncio + async def test_is_active(self): + """ + test background_task_is_active with real task + """ + instance = self._make_one() + assert instance.background_task_is_active() is False + instance.start_background_task() + assert instance.background_task_is_active() is True + await instance.close() + assert instance.background_task_is_active() is False + + +class _WrappedMultiCallableBase: + """ + Base class for testing wrapped multicallables + """ + + pass + + +class TestWrappedUnaryUnaryMultiCallable(_WrappedMultiCallableBase): + pass + + +class TestWrappedUnaryStreamMultiCallable(_WrappedMultiCallableBase): + pass + + +class TestWrappedStreamUnaryMultiCallable(_WrappedMultiCallableBase): + pass + + +class TestWrappedStreamStreamMultiCallable(_WrappedMultiCallableBase): + pass diff --git a/tests/unit/gapic/bigtable_v2/test_bigtable.py b/tests/unit/gapic/bigtable_v2/test_bigtable.py index b1500aa48..03ba3044f 100644 --- a/tests/unit/gapic/bigtable_v2/test_bigtable.py +++ b/tests/unit/gapic/bigtable_v2/test_bigtable.py @@ -100,7 +100,6 @@ def test__get_default_mtls_endpoint(): [ (BigtableClient, "grpc"), (BigtableAsyncClient, "grpc_asyncio"), - (BigtableAsyncClient, "pooled_grpc_asyncio"), (BigtableClient, "rest"), ], ) @@ -117,7 +116,7 @@ def test_bigtable_client_from_service_account_info(client_class, transport_name) assert client.transport._host == ( "bigtable.googleapis.com:443" - if transport_name in ["grpc", "grpc_asyncio", "pooled_grpc_asyncio"] + if transport_name in ["grpc", "grpc_asyncio"] else "https://bigtable.googleapis.com" ) @@ -127,7 +126,6 @@ def test_bigtable_client_from_service_account_info(client_class, transport_name) [ (transports.BigtableGrpcTransport, "grpc"), (transports.BigtableGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.PooledBigtableGrpcAsyncIOTransport, "pooled_grpc_asyncio"), (transports.BigtableRestTransport, "rest"), ], ) @@ -154,7 +152,6 @@ def test_bigtable_client_service_account_always_use_jwt( [ (BigtableClient, "grpc"), (BigtableAsyncClient, "grpc_asyncio"), - (BigtableAsyncClient, "pooled_grpc_asyncio"), (BigtableClient, "rest"), ], ) @@ -178,7 +175,7 @@ def test_bigtable_client_from_service_account_file(client_class, transport_name) assert client.transport._host == ( "bigtable.googleapis.com:443" - if transport_name in ["grpc", "grpc_asyncio", "pooled_grpc_asyncio"] + if transport_name in ["grpc", "grpc_asyncio"] else "https://bigtable.googleapis.com" ) @@ -200,11 +197,6 @@ def test_bigtable_client_get_transport_class(): [ (BigtableClient, transports.BigtableGrpcTransport, "grpc"), (BigtableAsyncClient, transports.BigtableGrpcAsyncIOTransport, "grpc_asyncio"), - ( - BigtableAsyncClient, - transports.PooledBigtableGrpcAsyncIOTransport, - "pooled_grpc_asyncio", - ), (BigtableClient, transports.BigtableRestTransport, "rest"), ], ) @@ -340,12 +332,6 @@ def test_bigtable_client_client_options(client_class, transport_class, transport "grpc_asyncio", "true", ), - ( - BigtableAsyncClient, - transports.PooledBigtableGrpcAsyncIOTransport, - "pooled_grpc_asyncio", - "true", - ), (BigtableClient, transports.BigtableGrpcTransport, "grpc", "false"), ( BigtableAsyncClient, @@ -353,12 +339,6 @@ def test_bigtable_client_client_options(client_class, transport_class, transport "grpc_asyncio", "false", ), - ( - BigtableAsyncClient, - transports.PooledBigtableGrpcAsyncIOTransport, - "pooled_grpc_asyncio", - "false", - ), (BigtableClient, transports.BigtableRestTransport, "rest", "true"), (BigtableClient, transports.BigtableRestTransport, "rest", "false"), ], @@ -550,11 +530,6 @@ def test_bigtable_client_get_mtls_endpoint_and_cert_source(client_class): [ (BigtableClient, transports.BigtableGrpcTransport, "grpc"), (BigtableAsyncClient, transports.BigtableGrpcAsyncIOTransport, "grpc_asyncio"), - ( - BigtableAsyncClient, - transports.PooledBigtableGrpcAsyncIOTransport, - "pooled_grpc_asyncio", - ), (BigtableClient, transports.BigtableRestTransport, "rest"), ], ) @@ -591,12 +566,6 @@ def test_bigtable_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), - ( - BigtableAsyncClient, - transports.PooledBigtableGrpcAsyncIOTransport, - "pooled_grpc_asyncio", - grpc_helpers_async, - ), (BigtableClient, transports.BigtableRestTransport, "rest", None), ], ) @@ -743,35 +712,6 @@ def test_read_rows(request_type, transport: str = "grpc"): assert isinstance(message, bigtable.ReadRowsResponse) -def test_read_rows_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.read_rows(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.read_rows(request) - assert next_channel.call_count == i - - def test_read_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -991,35 +931,6 @@ def test_sample_row_keys(request_type, transport: str = "grpc"): assert isinstance(message, bigtable.SampleRowKeysResponse) -def test_sample_row_keys_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.sample_row_keys(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.sample_row_keys(request) - assert next_channel.call_count == i - - def test_sample_row_keys_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1238,35 +1149,6 @@ def test_mutate_row(request_type, transport: str = "grpc"): assert isinstance(response, bigtable.MutateRowResponse) -def test_mutate_row_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.mutate_row(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.mutate_row(request) - assert next_channel.call_count == i - - def test_mutate_row_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1530,35 +1412,6 @@ def test_mutate_rows(request_type, transport: str = "grpc"): assert isinstance(message, bigtable.MutateRowsResponse) -def test_mutate_rows_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.mutate_rows(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.mutate_rows(request) - assert next_channel.call_count == i - - def test_mutate_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1792,35 +1645,6 @@ def test_check_and_mutate_row(request_type, transport: str = "grpc"): assert response.predicate_matched is True -def test_check_and_mutate_row_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.check_and_mutate_row(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.check_and_mutate_row(request) - assert next_channel.call_count == i - - def test_check_and_mutate_row_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2198,35 +2022,6 @@ def test_ping_and_warm(request_type, transport: str = "grpc"): assert isinstance(response, bigtable.PingAndWarmResponse) -def test_ping_and_warm_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.ping_and_warm(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.ping_and_warm(request) - assert next_channel.call_count == i - - def test_ping_and_warm_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2447,35 +2242,6 @@ def test_read_modify_write_row(request_type, transport: str = "grpc"): assert isinstance(response, bigtable.ReadModifyWriteRowResponse) -def test_read_modify_write_row_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.read_modify_write_row(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.read_modify_write_row(request) - assert next_channel.call_count == i - - def test_read_modify_write_row_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2735,37 +2501,6 @@ def test_generate_initial_change_stream_partitions( ) -def test_generate_initial_change_stream_partitions_pooled_rotation( - transport: str = "pooled_grpc_asyncio", -): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.generate_initial_change_stream_partitions(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.generate_initial_change_stream_partitions(request) - assert next_channel.call_count == i - - def test_generate_initial_change_stream_partitions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3025,35 +2760,6 @@ def test_read_change_stream(request_type, transport: str = "grpc"): assert isinstance(message, bigtable.ReadChangeStreamResponse) -def test_read_change_stream_pooled_rotation(transport: str = "pooled_grpc_asyncio"): - with mock.patch.object( - transports.pooled_grpc_asyncio.PooledChannel, "next_channel" - ) as next_channel: - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = {} - - channel = client.transport._grpc_channel._pool[ - client.transport._grpc_channel._next_idx - ] - next_channel.return_value = channel - - response = client.read_change_stream(request) - - # Establish that next_channel was called - next_channel.assert_called_once() - # Establish that subsequent calls all call next_channel - starting_idx = client.transport._grpc_channel._next_idx - for i in range(2, 10): - response = client.read_change_stream(request) - assert next_channel.call_count == i - - def test_read_change_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5957,7 +5663,6 @@ def test_transport_get_channel(): [ transports.BigtableGrpcTransport, transports.BigtableGrpcAsyncIOTransport, - transports.PooledBigtableGrpcAsyncIOTransport, transports.BigtableRestTransport, ], ) @@ -6105,7 +5810,6 @@ def test_bigtable_auth_adc(): [ transports.BigtableGrpcTransport, transports.BigtableGrpcAsyncIOTransport, - transports.PooledBigtableGrpcAsyncIOTransport, ], ) def test_bigtable_transport_auth_adc(transport_class): @@ -6133,7 +5837,6 @@ def test_bigtable_transport_auth_adc(transport_class): [ transports.BigtableGrpcTransport, transports.BigtableGrpcAsyncIOTransport, - transports.PooledBigtableGrpcAsyncIOTransport, transports.BigtableRestTransport, ], ) @@ -6236,61 +5939,6 @@ def test_bigtable_grpc_transport_client_cert_source_for_mtls(transport_class): ) -@pytest.mark.parametrize( - "transport_class", [transports.PooledBigtableGrpcAsyncIOTransport] -) -def test_bigtable_pooled_grpc_transport_client_cert_source_for_mtls(transport_class): - cred = ga_credentials.AnonymousCredentials() - - # test with invalid pool size - with pytest.raises(ValueError): - transport_class( - host="squid.clam.whelk", - credentials=cred, - pool_size=0, - ) - - # Check ssl_channel_credentials is used if provided. - for pool_num in range(1, 5): - with mock.patch.object( - transport_class, "create_channel" - ) as mock_create_channel: - mock_ssl_channel_creds = mock.Mock() - transport_class( - host="squid.clam.whelk", - credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, - pool_size=pool_num, - ) - mock_create_channel.assert_called_with( - pool_num, - "squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_channel_creds, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert mock_create_channel.call_count == 1 - - # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls - # is used. - with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): - with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: - transport_class( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, - ) - expected_cert, expected_key = client_cert_source_callback() - mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key - ) - - def test_bigtable_http_transport_client_cert_source_for_mtls(): cred = ga_credentials.AnonymousCredentials() with mock.patch( @@ -6307,7 +5955,6 @@ def test_bigtable_http_transport_client_cert_source_for_mtls(): [ "grpc", "grpc_asyncio", - "pooled_grpc_asyncio", "rest", ], ) @@ -6321,7 +5968,7 @@ def test_bigtable_host_no_port(transport_name): ) assert client.transport._host == ( "bigtable.googleapis.com:443" - if transport_name in ["grpc", "grpc_asyncio", "pooled_grpc_asyncio"] + if transport_name in ["grpc", "grpc_asyncio"] else "https://bigtable.googleapis.com" ) @@ -6331,7 +5978,6 @@ def test_bigtable_host_no_port(transport_name): [ "grpc", "grpc_asyncio", - "pooled_grpc_asyncio", "rest", ], ) @@ -6345,7 +5991,7 @@ def test_bigtable_host_with_port(transport_name): ) assert client.transport._host == ( "bigtable.googleapis.com:8000" - if transport_name in ["grpc", "grpc_asyncio", "pooled_grpc_asyncio"] + if transport_name in ["grpc", "grpc_asyncio"] else "https://bigtable.googleapis.com:8000" ) @@ -6701,24 +6347,6 @@ async def test_transport_close_async(): async with client: close.assert_not_called() close.assert_called_once() - close.assert_awaited() - - -@pytest.mark.asyncio -async def test_pooled_transport_close_async(): - client = BigtableAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="pooled_grpc_asyncio", - ) - num_channels = len(client.transport._grpc_channel._pool) - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" - ) as close: - async with client: - close.assert_not_called() - close.assert_called() - assert close.call_count == num_channels - close.assert_awaited() def test_transport_close(): @@ -6785,128 +6413,3 @@ def test_api_key_credentials(client_class, transport_class): always_use_jwt_access=True, api_audience=None, ) - - -@pytest.mark.asyncio -async def test_pooled_transport_replace_default(): - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="pooled_grpc_asyncio", - ) - num_channels = len(client.transport._grpc_channel._pool) - for replace_idx in range(num_channels): - prev_pool = [channel for channel in client.transport._grpc_channel._pool] - grace_period = 4 - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" - ) as close: - await client.transport.replace_channel(replace_idx, grace=grace_period) - close.assert_called_once() - close.assert_awaited() - close.assert_called_with(grace=grace_period) - assert isinstance( - client.transport._grpc_channel._pool[replace_idx], grpc.aio.Channel - ) - # only the specified channel should be replaced - for i in range(num_channels): - if i == replace_idx: - assert client.transport._grpc_channel._pool[i] != prev_pool[i] - else: - assert client.transport._grpc_channel._pool[i] == prev_pool[i] - with pytest.raises(ValueError): - await client.transport.replace_channel(num_channels + 1) - with pytest.raises(ValueError): - await client.transport.replace_channel(-1) - - -@pytest.mark.asyncio -async def test_pooled_transport_replace_explicit(): - client = BigtableClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="pooled_grpc_asyncio", - ) - num_channels = len(client.transport._grpc_channel._pool) - for replace_idx in range(num_channels): - prev_pool = [channel for channel in client.transport._grpc_channel._pool] - grace_period = 0 - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" - ) as close: - new_channel = grpc.aio.insecure_channel("localhost:8080") - await client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once() - close.assert_awaited() - close.assert_called_with(grace=grace_period) - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - # only the specified channel should be replaced - for i in range(num_channels): - if i == replace_idx: - assert client.transport._grpc_channel._pool[i] != prev_pool[i] - else: - assert client.transport._grpc_channel._pool[i] == prev_pool[i] - - -def test_pooled_transport_next_channel(): - num_channels = 10 - transport = transports.PooledBigtableGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), - pool_size=num_channels, - ) - assert len(transport._grpc_channel._pool) == num_channels - transport._grpc_channel._next_idx = 0 - # rotate through all channels multiple times - num_cycles = 4 - for _ in range(num_cycles): - for i in range(num_channels - 1): - assert transport._grpc_channel._next_idx == i - got_channel = transport._grpc_channel.next_channel() - assert got_channel == transport._grpc_channel._pool[i] - assert transport._grpc_channel._next_idx == (i + 1) - # test wrap around - assert transport._grpc_channel._next_idx == num_channels - 1 - got_channel = transport._grpc_channel.next_channel() - assert got_channel == transport._grpc_channel._pool[num_channels - 1] - assert transport._grpc_channel._next_idx == 0 - - -def test_pooled_transport_pool_unique_channels(): - num_channels = 50 - - transport = transports.PooledBigtableGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), - pool_size=num_channels, - ) - channel_list = [channel for channel in transport._grpc_channel._pool] - channel_set = set(channel_list) - assert len(channel_list) == num_channels - assert len(channel_set) == num_channels - for channel in channel_list: - assert isinstance(channel, grpc.aio.Channel) - - -def test_pooled_transport_pool_creation(): - # channels should be created with the specified options - num_channels = 50 - creds = ga_credentials.AnonymousCredentials() - scopes = ["test1", "test2"] - quota_project_id = "test3" - host = "testhost:8080" - with mock.patch( - "google.api_core.grpc_helpers_async.create_channel" - ) as create_channel: - transport = transports.PooledBigtableGrpcAsyncIOTransport( - credentials=creds, - pool_size=num_channels, - scopes=scopes, - quota_project_id=quota_project_id, - host=host, - ) - assert create_channel.call_count == num_channels - for i in range(num_channels): - kwargs = create_channel.call_args_list[i][1] - assert kwargs["target"] == host - assert kwargs["credentials"] == creds - assert kwargs["scopes"] == scopes - assert kwargs["quota_project_id"] == quota_project_id diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index be3703a23..01e467423 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -16,7 +16,6 @@ import grpc import asyncio import re -import sys import pytest @@ -52,20 +51,15 @@ def _make_one(self, *args, **kwargs): @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" - expected_pool_size = 11 expected_credentials = AnonymousCredentials() - client = self._make_one( + async with self._make_one( project="project-id", - pool_size=expected_pool_size, credentials=expected_credentials, - ) - await asyncio.sleep(0.1) - assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size - assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size - assert client.transport._credentials == expected_credentials - await client.close() + ) as client: + await asyncio.sleep(0.1) + assert client.project == expected_project + assert not client._active_instances + assert client._gapic_client.transport._credentials == expected_credentials @pytest.mark.asyncio async def test_ctor_super_inits(self): @@ -76,11 +70,9 @@ async def test_ctor_super_inits(self): from google.api_core import client_options as client_options_lib project = "project-id" - pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - transport_str = f"pooled_grpc_asyncio_{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( @@ -90,7 +82,6 @@ async def test_ctor_super_inits(self): try: self._make_one( project=project, - pool_size=pool_size, credentials=credentials, client_options=options_parsed, ) @@ -99,7 +90,6 @@ async def test_ctor_super_inits(self): # test gapic superclass init was called assert bigtable_client_init.call_count == 1 kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed # test mixin superclass init was called @@ -115,7 +105,6 @@ async def test_ctor_dict_options(self): BigtableAsyncClient, ) from google.api_core.client_options import ClientOptions - from google.cloud.bigtable.client import BigtableDataClient client_options = {"api_endpoint": "foo.bar:1234"} with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: @@ -128,12 +117,6 @@ async def test_ctor_dict_options(self): called_options = kwargs["client_options"] assert called_options.api_endpoint == "foo.bar:1234" assert isinstance(called_options, ClientOptions) - with mock.patch.object( - BigtableDataClient, "start_background_channel_refresh" - ) as start_background_refresh: - client = self._make_one(client_options=client_options) - start_background_refresh.assert_called_once() - await client.close() @pytest.mark.asyncio async def test_veneer_grpc_headers(self): @@ -156,131 +139,12 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() - @pytest.mark.asyncio - async def test_channel_pool_creation(self): - pool_size = 14 - with mock.patch( - "google.api_core.grpc_helpers_async.create_channel" - ) as create_channel: - create_channel.return_value = AsyncMock() - client = self._make_one(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - await client.close() - # channels should be unique - client = self._make_one(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - await client.close() - - @pytest.mark.asyncio - async def test_channel_pool_rotation(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel, - ) - - pool_size = 7 - - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - # calling an rpc `pool_size` times should use a different channel each time - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - await client.close() - - @pytest.mark.asyncio - async def test_channel_pool_replace(self): - with mock.patch.object(asyncio, "sleep"): - pool_size = 7 - client = self._make_one(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" - ) as close: - new_channel = grpc.aio.insecure_channel("localhost:8080") - await client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - await client.close() - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_start_background_channel_refresh_sync(self): - # should raise RuntimeError if called in a sync context - client = self._make_one(project="project-id") - with pytest.raises(RuntimeError): - client.start_background_channel_refresh() - - @pytest.mark.asyncio - async def test_start_background_channel_refresh_tasks_exist(self): - # if tasks exist, should do nothing - client = self._make_one(project="project-id") - with mock.patch.object(asyncio, "create_task") as create_task: - client.start_background_channel_refresh() - create_task.assert_not_called() - await client.close() - - @pytest.mark.asyncio - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - async def test_start_background_channel_refresh(self, pool_size): - # should create background tasks for each channel - client = self._make_one(project="project-id", pool_size=pool_size) - ping_and_warm = AsyncMock() - client._ping_and_warm_instances = ping_and_warm - client.start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - assert isinstance(task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) - await client.close() - - @pytest.mark.asyncio - @pytest.mark.skipif( - sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" - ) - async def test_start_background_channel_refresh_tasks_names(self): - # if tasks exist, should do nothing - pool_size = 3 - client = self._make_one(project="project-id", pool_size=pool_size) - for i in range(pool_size): - name = client._channel_refresh_tasks[i].get_name() - assert str(i) in name - assert "BigtableDataClient channel refresh " in name - await client.close() - @pytest.mark.asyncio async def test__ping_and_warm_instances(self): # test with no instances with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - client = self._make_one(project="project-id", pool_size=1) - channel = client.transport._grpc_channel._pool[0] + client = self._make_one(project="project-id") + channel = client._pool[0] await client._ping_and_warm_instances(channel) gather.assert_called_once() gather.assert_awaited_once() @@ -304,253 +168,66 @@ async def test__ping_and_warm_instances(self): call._request["name"] = client._active_instances[idx] await client.close() - @pytest.mark.asyncio - @pytest.mark.parametrize( - "refresh_interval, wait_time, expected_sleep", - [ - (0, 0, 0), - (0, 1, 0), - (10, 0, 10), - (10, 5, 5), - (10, 10, 0), - (10, 15, 0), - ], - ) - async def test__manage_channel_first_sleep( - self, refresh_interval, wait_time, expected_sleep - ): - # first sleep time should be `refresh_interval` seconds after client init - import time - - with mock.patch.object(time, "time") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: - sleep.side_effect = asyncio.CancelledError - try: - client = self._make_one(project="project-id") - client._channel_init_time = -wait_time - await client._manage_channel(0, refresh_interval, refresh_interval) - except asyncio.CancelledError: - pass - sleep.assert_called_once() - call_time = sleep.call_args[0][0] - assert ( - abs(call_time - expected_sleep) < 0.1 - ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" - await client.close() - - @pytest.mark.asyncio - async def test__manage_channel_ping_and_warm(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - - # should ping an warm all new channels, and old channels if sleeping - client = self._make_one(project="project-id") - new_channel = grpc.aio.insecure_channel("localhost:8080") - with mock.patch.object(asyncio, "sleep"): - create_channel = mock.Mock() - create_channel.return_value = new_channel - client.transport.grpc_channel._create_channel = create_channel - with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" - ) as replace_channel: - replace_channel.side_effect = asyncio.CancelledError - # should ping and warm old channel then new if sleep > 0 - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_and_warm: - try: - channel_idx = 2 - old_channel = client.transport._grpc_channel._pool[channel_idx] - await client._manage_channel(channel_idx, 10) - except asyncio.CancelledError: - pass - assert ping_and_warm.call_count == 2 - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - # should ping and warm instantly new channel only if not sleeping - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_and_warm: - try: - await client._manage_channel(0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) - await client.close() - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "refresh_interval, num_cycles, expected_sleep", - [ - (None, 1, 60 * 35), - (10, 10, 100), - (10, 1, 10), - ], - ) - async def test__manage_channel_sleeps( - self, refresh_interval, num_cycles, expected_sleep - ): - # make sure that sleeps work as expected - import time - import random - - channel_idx = 1 - with mock.patch.object(random, "uniform") as uniform: - uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: - sleep.side_effect = [None for i in range(num_cycles - 1)] + [ - asyncio.CancelledError - ] - try: - client = self._make_one(project="project-id") - if refresh_interval is not None: - await client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - await client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - assert ( - abs(total_sleep - expected_sleep) < 0.1 - ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" - await client.close() - - @pytest.mark.asyncio - async def test__manage_channel_random(self): - import random - - with mock.patch.object(asyncio, "sleep") as sleep: - with mock.patch.object(random, "uniform") as uniform: - uniform.return_value = 0 - try: - uniform.side_effect = asyncio.CancelledError - client = self._make_one(project="project-id", pool_size=1) - except asyncio.CancelledError: - uniform.side_effect = None - uniform.reset_mock() - sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] - try: - await client._manage_channel(0, min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 2 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val - - @pytest.mark.asyncio - @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) - async def test__manage_channel_refresh(self, num_cycles): - # make sure that channels are properly refreshed - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.api_core import grpc_helpers_async - - expected_grace = 9 - expected_refresh = 0.5 - channel_idx = 1 - new_channel = grpc.aio.insecure_channel("localhost:8080") - - with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" - ) as replace_channel: - with mock.patch.object(asyncio, "sleep") as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - grpc_helpers_async, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - client = self._make_one(project="project-id") - create_channel.reset_mock() - try: - await client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - args, kwargs = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - await client.close() - @pytest.mark.asyncio @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__register_instance(self): - # create the client without calling start_background_channel_refresh - with mock.patch.object(asyncio, "get_running_loop") as get_event_loop: - get_event_loop.side_effect = RuntimeError("no event loop") - client = self._make_one(project="project-id") - assert not client._channel_refresh_tasks - # first call should start background refresh + """ + _register instance should add instance to _active_instances + """ + client = self._make_one(project="project-id") + # first call should register instance assert client._active_instances == set() await client._register_instance("instance-1", mock.Mock()) assert len(client._active_instances) == 1 assert client._active_instances == {"projects/project-id/instances/instance-1"} - assert client._channel_refresh_tasks - # next call should not - with mock.patch.object( - type(self._make_one()), "start_background_channel_refresh" - ) as refresh_mock: - await client._register_instance("instance-2", mock.Mock()) - assert len(client._active_instances) == 2 - assert client._active_instances == { - "projects/project-id/instances/instance-1", - "projects/project-id/instances/instance-2", - } - refresh_mock.assert_not_called() + # duplicates should not + await client._register_instance("instance-1", mock.Mock()) + assert len(client._active_instances) == 1 + # test with new instance + await client._register_instance("instance-2", mock.Mock()) + assert len(client._active_instances) == 2 + assert client._active_instances == { + "projects/project-id/instances/instance-1", + "projects/project-id/instances/instance-2", + } + + @pytest.mark.asyncio + async def test__register_instance_owners(self): + """ + _register instance should add callers to _instance_owners + """ + client = self._make_one(project="project-id") + owner1 = mock.Mock() + await client._register_instance("instance-1", owner1) + full_name = "projects/project-id/instances/instance-1" + assert client._instance_owners[full_name] == {id(owner1)} + # duplicates should have no affect + await client._register_instance("instance-1", owner1) + assert client._instance_owners[full_name] == {id(owner1)} + # should support multiple owners + owner2 = mock.Mock() + await client._register_instance("instance-1", owner2) + assert client._instance_owners[full_name] == {id(owner1), id(owner2)} @pytest.mark.asyncio @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__register_instance_ping_and_warm(self): - # should ping and warm each new instance - pool_size = 7 - with mock.patch.object(asyncio, "get_running_loop") as get_event_loop: - get_event_loop.side_effect = RuntimeError("no event loop") - client = self._make_one(project="project-id", pool_size=pool_size) - # first call should start background refresh - assert not client._channel_refresh_tasks - await client._register_instance("instance-1", mock.Mock()) - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client._channel_refresh_tasks) == pool_size - assert not client._active_instances - # next calls should trigger ping and warm - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_mock: - # new instance should trigger ping and warm - await client._register_instance("instance-2", mock.Mock()) - assert ping_mock.call_count == pool_size - await client._register_instance("instance-3", mock.Mock()) - assert ping_mock.call_count == pool_size * 2 - # duplcate instances should not trigger ping and warm - await client._register_instance("instance-3", mock.Mock()) - assert ping_mock.call_count == pool_size * 2 - await client.close() + """ + Calls to _register_instance should call _ping_and_warm_instances + on each active channel + """ + async with self._make_one(project="project-id") as client: + channel_list = client._pool.channels + with mock.patch.object( + client, "_ping_and_warm_instances", AsyncMock() + ) as ping_mock: + await client._register_instance("instance-1", mock.Mock()) + assert ping_mock.call_count == len(channel_list) + i = 0 + for args, _ in ping_mock.call_args_list: + assert args[0] == channel_list[i] + assert "instance-1" in args[1] + i += 1 @pytest.mark.asyncio async def test__remove_instance_registration(self): @@ -696,55 +373,19 @@ async def test_get_table_context_manager(self): assert table.instance_name in client._active_instances assert close_mock.call_count == 1 - @pytest.mark.asyncio - async def test_multiple_pool_sizes(self): - # should be able to create multiple clients with different pool sizes without issue - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_one(project="project-id", pool_size=pool_size) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - await client.close() - await client_duplicate.close() - @pytest.mark.asyncio async def test_close(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, + from google.cloud.bigtable_v2.services.bigtable.transports.grpc_asyncio import ( + BigtableGrpcAsyncIOTransport, ) - pool_size = 7 - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() + client = self._make_one(project="project-id") with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + BigtableGrpcAsyncIOTransport, "close", AsyncMock() ) as close_mock: await client.close() close_mock.assert_called_once() close_mock.assert_awaited() - for task in tasks_list: - assert task.done() - assert task.cancelled() - assert client._channel_refresh_tasks == [] - - @pytest.mark.asyncio - async def test_close_with_timeout(self): - pool_size = 7 - expected_timeout = 19 - client = self._make_one(project="project-id", pool_size=pool_size) - tasks = list(client._channel_refresh_tasks) - with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: - await client.close(timeout=expected_timeout) - wait_for_mock.assert_called_once() - wait_for_mock.assert_awaited() - assert wait_for_mock.call_args[1]["timeout"] == expected_timeout - client._channel_refresh_tasks = tasks - await client.close() @pytest.mark.asyncio async def test_context_manager(self): @@ -754,8 +395,6 @@ async def test_context_manager(self): async with self._make_one(project="project-id") as client: true_close = client.close() client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() assert client.project == "project-id" assert client._active_instances == set() close_mock.assert_not_called() @@ -770,13 +409,19 @@ def test_client_ctor_sync(self): with pytest.warns(RuntimeWarning) as warnings: client = BigtableDataClient(project="project-id") - expected_warning = [w for w in warnings if "client.py" in w.filename] - assert len(expected_warning) == 1 - assert "BigtableDataClient should be started in an asyncio event loop." in str( - expected_warning[0].message + assert client.project == "project-id" + # should get warnings from pool, refreshable channel, and client + assert len(warnings) == 3 + assert any( + ["refreshable_channel" in str(warning.filename) for warning in warnings] + ) + assert any( + ["dynamic_pooled_channel" in str(warning.filename) for warning in warnings] + ) + assert any(["client" in str(warning.filename) for warning in warnings]) + assert all( + ["asyncio event loop" in str(warning.message) for warning in warnings] ) - assert client.project == "project-id" - assert client._channel_refresh_tasks == [] class TestTable: