Skip to content

Remove generic type from BaseApiClient #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,34 @@

- HTTP2 keep-alive is now enabled by default, with an interval of 60 seconds between pings, and a 20 second timeout for responses from the service. These values are configurable and may be updated based on specific requirements.

* The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. To convert you client:

```python
# Old
from my_service_pb2_grpc import MyServiceStub
class MyApiClient(BaseApiClient[MyServiceStub]):
def __init__(self, server_url: str, *, ...) -> None:
super().__init__(server_url, MyServiceStub, ...)
...

# New
from typing import cast
from my_service_pb2_grpc import MyServiceStub, MyServiceAsyncStub
class MyApiClient(BaseApiClient):
def __init__(self, server_url: str, *, ...) -> None:
super().__init__(server_url, connect=connect)
self._stub = cast(MyServiceAsyncStub, MyServiceStub(self.channel))
...
Comment on lines +32 to +36
Copy link
Contributor

@shsms shsms Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just ask them to pass cast(MyServiceAsyncStub, MyServiceStub(service.channel)) as an argument, and make the base client generic over MyServiceAsyncStub instead?

Copy link
Contributor Author

@llucax llucax Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because __init__ can't have MyServiceAsyncStub as a type hint, to do that we need to take the type as a generic, and we can't use an async stub as a generic because it is parsed by the interpreter, and it can't find it because it is not present in the .py file, only the .pyi file.

This is how great grpio is handling asyncio and type-hinting 😬 (there is a new experimental interface in the cooking, let's see how that turns out).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't parse them with from __future__ import annotations right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, it will parse them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because it is not an annotation in this case, is is a proper object for the interpreter. You can give it a go if you want, I already tried many different attempts and nothing seemed to work.

$ head -n 100 t.py* t2.py 
==> t.py <==
class IDoExist:
    pass

==> t.pyi <==
class IDoExist:
    pass

class IDontExist:
    pass

==> t2.py <==
from __future__ import annotations

from typing import Generic, TypeVar

from t import IDoExist, IDontExist

T = TypeVar("T")


class G(Generic[T]):
    def __init__(self, x: T):
        self.x = x


class Sub(G[IDoExist]):
    pass


class Sub2(G[IDontExist]):
    pass
$ python t2.py 
Traceback (most recent call last):
  File "/home/luca/devel/client-base/t2.py", line 5, in <module>
    from t import IDoExist, IDontExist
ImportError: cannot import name 'IDontExist' from 't' (/home/luca/devel/client-base/t.py)

Copy link
Contributor Author

@llucax llucax Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put another way, with TYPE_CHECKING:

from t import IDoExist

if TYPE_CHECKING:
    from t import IDontExist

Then the error becomes:

$ python t2.py 
Traceback (most recent call last):
  File "/home/luca/devel/client-base/t2.py", line 22, in <module>
    class Sub2(G[IDontExist]):
                 ^^^^^^^^^^
NameError: name 'IDontExist' is not defined. Did you mean: 'IDoExist'?

Maybe it is more clear to see this way. Only stuff after a : (or used in special constructs, like cast), are type hints, the rest is just code.


@property
def stub(self) -> MyServiceAsyncStub:
if self._channel is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub
```

After this, you should be able to remove a lot of `cast`s or `type: ignore` from the code when calling the stub `async` methods.

## New Features

- Added support for HTTP2 keep-alive.
Expand Down
92 changes: 56 additions & 36 deletions src/frequenz/client/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
import abc
import inspect
from collections.abc import Awaitable, Callable
from typing import Any, Generic, Self, TypeVar, overload
from typing import Any, Self, TypeVar, overload

from grpc.aio import AioRpcError, Channel

from .channel import ChannelOptions, parse_grpc_uri
from .exception import ApiClientError, ClientNotConnected

StubT = TypeVar("StubT")
"""The type of the gRPC stub."""


class BaseApiClient(abc.ABC, Generic[StubT]):
class BaseApiClient(abc.ABC):
"""A base class for API clients.

This class provides a common interface for API clients that communicate with a API
Expand All @@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]):
a class that helps sending messages from a gRPC stream to
a [Broadcast][frequenz.channels.Broadcast] channel.

Note:
Because grpcio doesn't provide proper type hints, a hack is needed to have
propepr async type hints for the stubs generated by protoc. When using
`mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
but in the `.pyi` file, so the type can be used to specify type hints, but
**not** in any other context, as the class doesn't really exist for the Python
interpreter. This include generics, and because of this, this class can't be
even parametrized using the async class, so the instantiation of the stub can't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
even parametrized using the async class, so the instantiation of the stub can't
even parameterized using the async class, so the instantiation of the stub can't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it confusing :D
I seems I was wrong 👼

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is an irregular language.

Copy link
Contributor Author

@llucax llucax Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that it says parametrized when defining past tense and past participle, but uses parameterized in the example 🤣

image

be done in the base class.

Because of this, subclasses need to create the stubs by themselves, using the
real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
the async version of the stubs.

It is recommended to define a `stub` property that returns the async stub, so
this hack is completely hidden from clients, even if they need to access the
stub for more advanced uses.

Example:
This example illustrates how to create a simple API client that connects to a
gRPC server and calls a method on a stub.

```python
from collections.abc import AsyncIterable
from typing import cast
from frequenz.client.base.client import BaseApiClient, call_stub_method
from frequenz.client.base.streaming import GrpcStreamBroadcaster
from frequenz.channels import Receiver
Expand All @@ -57,25 +73,51 @@ async def example_method(
) -> ExampleResponse:
...

def example_stream(self) -> AsyncIterable[ExampleResponse]:
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
...

class ExampleAsyncStub:
async def example_method(
self,
request: ExampleRequest # pylint: disable=unused-argument
) -> ExampleResponse:
...

def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
...
# End of generated classes

class ExampleResponseWrapper:
def __init__(self, response: ExampleResponse):
def __init__(self, response: ExampleResponse) -> None:
self.transformed_value = f"{response.float_value:.2f}"

class MyApiClient(BaseApiClient[ExampleStub]):
def __init__(self, server_url: str, *, connect: bool = True):
super().__init__(
server_url, ExampleStub, connect=connect
# Change defaults as needed
DEFAULT_CHANNEL_OPTIONS = ChannelOptions()

class MyApiClient(BaseApiClient):
def __init__(
self,
server_url: str,
*,
connect: bool = True,
channel_defaults: ChannelOptions = DEFAULT_CHANNEL_OPTIONS,
) -> None:
super().__init__(server_url, connect=connect, channel_defaults=channel_defaults)
self._stub = cast(
ExampleAsyncStub, ExampleStub(self.channel)
)
self._broadcaster = GrpcStreamBroadcaster(
"stream",
lambda: self.stub.example_stream(ExampleRequest()),
ExampleResponseWrapper,
)

@property
def stub(self) -> ExampleAsyncStub:
if self._channel is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub

async def example_method(
self, int_value: int, str_value: str
) -> ExampleResponseWrapper:
Expand Down Expand Up @@ -114,7 +156,6 @@ async def main():
def __init__(
self,
server_url: str,
create_stub: Callable[[Channel], StubT],
*,
connect: bool = True,
channel_defaults: ChannelOptions = ChannelOptions(),
Expand All @@ -123,7 +164,6 @@ def __init__(

Args:
server_url: The URL of the server to connect to.
create_stub: A function that creates a stub from a channel.
connect: Whether to connect to the server as soon as a client instance is
created. If `False`, the client will not connect to the server until
[connect()][frequenz.client.base.client.BaseApiClient.connect] is
Expand All @@ -132,10 +172,8 @@ def __init__(
the server URL.
"""
self._server_url: str = server_url
self._create_stub: Callable[[Channel], StubT] = create_stub
self._channel_defaults: ChannelOptions = channel_defaults
self._channel: Channel | None = None
self._stub: StubT | None = None
if connect:
self.connect(server_url)

Expand Down Expand Up @@ -165,22 +203,6 @@ def channel_defaults(self) -> ChannelOptions:
"""The default options for the gRPC channel."""
return self._channel_defaults

@property
def stub(self) -> StubT:
"""The underlying gRPC stub.

Warning:
This stub is provided as a last resort for advanced users. It is not
recommended to use this property directly unless you know what you are
doing and you don't care about being tied to a specific gRPC library.

Raises:
ClientNotConnected: If the client is not connected to the server.
"""
if self._stub is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub

@property
def is_connected(self) -> bool:
"""Whether the client is connected to the server."""
Expand All @@ -202,7 +224,6 @@ def connect(self, server_url: str | None = None) -> None:
elif self.is_connected:
return
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
self._stub = self._create_stub(self._channel)

async def disconnect(self) -> None:
"""Disconnect from the server.
Expand All @@ -227,7 +248,6 @@ async def __aexit__(
return None
result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb)
self._channel = None
self._stub = None
return result


Expand All @@ -240,7 +260,7 @@ async def __aexit__(

@overload
async def call_stub_method(
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand All @@ -250,7 +270,7 @@ async def call_stub_method(

@overload
async def call_stub_method(
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand All @@ -261,7 +281,7 @@ async def call_stub_method(
# We need the `noqa: DOC503` because `pydoclint` can't figure out that
# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
async def call_stub_method( # noqa: DOC503
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand Down
32 changes: 3 additions & 29 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import pytest_mock

from frequenz.client.base.channel import ChannelOptions, SslOptions
from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method
from frequenz.client.base.client import BaseApiClient, call_stub_method
from frequenz.client.base.exception import ClientNotConnected, UnknownError


def _auto_connect_name(auto_connect: bool) -> str:
return f"{auto_connect=}"


def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
def _assert_is_disconnected(client: BaseApiClient) -> None:
"""Assert that the client is disconnected."""
assert not client.is_connected

Expand All @@ -30,17 +30,9 @@ def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
assert exc.server_url == _DEFAULT_SERVER_URL
assert exc.operation == "channel"

with pytest.raises(ClientNotConnected, match=r"") as exc_info:
_ = client.stub
exc = exc_info.value
assert exc.server_url == _DEFAULT_SERVER_URL
assert exc.operation == "stub"


@dataclass(kw_only=True, frozen=True)
class _ClientMocks:
stub: mock.MagicMock
create_stub: mock.MagicMock
channel: mock.MagicMock
parse_grpc_uri: mock.MagicMock

Expand All @@ -54,10 +46,8 @@ def create_client_with_mocks(
auto_connect: bool = True,
server_url: str = _DEFAULT_SERVER_URL,
channel_defaults: ChannelOptions | None = None,
) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]:
) -> tuple[BaseApiClient, _ClientMocks]:
"""Create a BaseApiClient instance with mocks."""
mock_stub = mock.MagicMock(name="stub")
mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub)
mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel)
mock_parse_grpc_uri = mocker.patch(
"frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel
Expand All @@ -67,13 +57,10 @@ def create_client_with_mocks(
kwargs["channel_defaults"] = channel_defaults
client = BaseApiClient(
server_url=server_url,
create_stub=mock_create_stub,
connect=auto_connect,
**kwargs,
)
return client, _ClientMocks(
stub=mock_stub,
create_stub=mock_create_stub,
channel=mock_channel,
parse_grpc_uri=mock_parse_grpc_uri,
)
Expand All @@ -92,13 +79,10 @@ def test_base_api_client_init(
client.server_url, ChannelOptions()
)
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.create_stub.assert_called_once_with(mocks.channel)
else:
_assert_is_disconnected(client)
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()


def test_base_api_client_init_with_channel_defaults(
Expand All @@ -110,9 +94,7 @@ def test_base_api_client_init_with_channel_defaults(
assert client.server_url == _DEFAULT_SERVER_URL
mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults)
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.create_stub.assert_called_once_with(mocks.channel)


@pytest.mark.parametrize(
Expand All @@ -129,12 +111,10 @@ def test_base_api_client_connect(
# We want to check only what happens when we call connect, so we reset the mocks
# that were called during initialization
mocks.parse_grpc_uri.reset_mock()
mocks.create_stub.reset_mock()

client.connect(new_server_url)

assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected

same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL
Expand All @@ -148,12 +128,10 @@ def test_base_api_client_connect(
# reconnect
if auto_connect and same_url:
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()
else:
mocks.parse_grpc_uri.assert_called_once_with(
client.server_url, ChannelOptions()
)
mocks.create_stub.assert_called_once_with(mocks.channel)


async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None:
Expand All @@ -177,23 +155,19 @@ async def test_base_api_client_async_context_manager(
# We want to check only what happens when we enter the context manager, so we reset
# the mocks that were called during initialization
mocks.parse_grpc_uri.reset_mock()
mocks.create_stub.reset_mock()

async with client:
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.channel.__aexit__.assert_not_called()
# If we were previously connected, the client should not reconnect when entering
# the context manager
if auto_connect:
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()
else:
mocks.parse_grpc_uri.assert_called_once_with(
client.server_url, ChannelOptions()
)
mocks.create_stub.assert_called_once_with(mocks.channel)

mocks.channel.__aexit__.assert_called_once_with(None, None, None)
assert client.server_url == _DEFAULT_SERVER_URL
Expand Down
Loading