diff --git a/mypy_protobuf/main.py b/mypy_protobuf/main.py index 47fe3e7e..78a3cde3 100644 --- a/mypy_protobuf/main.py +++ b/mypy_protobuf/main.py @@ -781,6 +781,7 @@ def write_grpc_services( self, services: Iterable[d.ServiceDescriptorProto], scl_prefix: SourceCodeLocation, + async_only: bool, ) -> None: wl = self._write_line for i, service in enumerate(services): @@ -797,21 +798,25 @@ def write_grpc_services( with self._indent(): if self._write_comments(scl): wl("") - # To support casting into FooAsyncStub, allow both Channel and aio.Channel here. - channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]" + # To support casting into FooAsyncStub, allow both Channel and aio.Channel here, + # but only if we are generating both sync and async stubs. + channel = self._import("grpc.aio", "Channel") + if not async_only: + channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {channel}]" wl("def __init__(self, channel: {}) -> None: ...", channel) - self.write_grpc_stub_methods(service, scl) + self.write_grpc_stub_methods(service, scl, is_async=async_only) - # The (fake) async stub client - wl( - "class {}AsyncStub:", - service.name, - ) - with self._indent(): - if self._write_comments(scl): - wl("") - # No __init__ since this isn't a real class (yet), and requires manual casting to work. - self.write_grpc_stub_methods(service, scl, is_async=True) + if not async_only: + # The (fake) async stub client + wl( + "class {}AsyncStub:", + service.name, + ) + with self._indent(): + if self._write_comments(scl): + wl("") + # No __init__ since this isn't a real class (yet), and requires manual casting to work. + self.write_grpc_stub_methods(service, scl, is_async=True) # The service definition interface wl( @@ -1009,6 +1014,7 @@ def generate_mypy_stubs( def generate_mypy_grpc_stubs( descriptors: Descriptors, response: plugin_pb2.CodeGeneratorResponse, + async_only: bool, quiet: bool, readable_stubs: bool, relax_strict_optional_primitives: bool, @@ -1022,7 +1028,7 @@ def generate_mypy_grpc_stubs( grpc=True, ) pkg_writer.write_grpc_async_hacks() - pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]) + pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER], async_only) assert name == fd.name assert fd.name.endswith(".proto") @@ -1079,6 +1085,7 @@ def grpc() -> None: generate_mypy_grpc_stubs( Descriptors(request), response, + "async_only" in request.parameter, "quiet" in request.parameter, "readable_stubs" in request.parameter, "relax_strict_optional_primitives" in request.parameter,