From a8081289fadbb99948461eb1f1bd5b802ec739c6 Mon Sep 17 00:00:00 2001 From: Mira Leung Date: Tue, 11 May 2021 14:42:13 -0700 Subject: [PATCH 1/4] feat: support protobuf method deprecation option --- .../%sub/services/%service/client.py.j2 | 3 + gapic/schema/wrappers.py | 1974 ++++++++--------- .../%sub/services/%service/client.py.j2 | 7 +- test_utils/test_utils.py | 632 +++--- tests/unit/schema/wrappers/test_method.py | 831 ++++--- 5 files changed, 1711 insertions(+), 1736 deletions(-) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 index 3e02216712..bb9425bbfa 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 @@ -334,6 +334,9 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): {{ method.client_output.meta.doc|rst(width=72, indent=16, source_format='rst') }} {% endif %} """ + {% if method.is_deprecated %} + warnings.warn("{{ method.name|snake_case }} is deprecated", warnings.DeprecationWarning) + {% endif %} {% if not method.client_streaming %} # Create or coerce a protobuf request object. {% if method.flattened_fields %} diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 962407aa9c..fe7f0f8117 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Module containing wrapper classes around meta-descriptors. This module contains dataclasses which wrap the descriptor protos @@ -31,13 +30,13 @@ import dataclasses import re from itertools import chain -from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, - ClassVar, Optional, Sequence, Set, Tuple, Union) -from google.api import annotations_pb2 # type: ignore +from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, ClassVar, + Optional, Sequence, Set, Tuple, Union) +from google.api import annotations_pb2 # type: ignore from google.api import client_pb2 from google.api import field_behavior_pb2 from google.api import resource_pb2 -from google.api_core import exceptions # type: ignore +from google.api_core import exceptions # type: ignore from google.protobuf import descriptor_pb2 # type: ignore from google.protobuf.json_format import MessageToDict # type: ignore @@ -47,344 +46,329 @@ @dataclasses.dataclass(frozen=True) class Field: - """Description of a field.""" - field_pb: descriptor_pb2.FieldDescriptorProto - message: Optional['MessageType'] = None - enum: Optional['EnumType'] = None - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, + """Description of a field.""" + field_pb: descriptor_pb2.FieldDescriptorProto + message: Optional['MessageType'] = None + enum: Optional['EnumType'] = None + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + oneof: Optional[str] = None + + def __getattr__(self, name): + return getattr(self.field_pb, name) + + def __hash__(self): + # The only sense in which it is meaningful to say a field is equal to + # another field is if they are the same, i.e. they live in the same + # message type under the same moniker, i.e. they have the same id. + return id(self) + + @property + def name(self) -> str: + """Used to prevent collisions with python keywords""" + name = self.field_pb.name + return name + '_' if name in utils.RESERVED_NAMES else name + + @utils.cached_property + def ident(self) -> metadata.FieldIdentifier: + """Return the identifier to be used in templates.""" + return metadata.FieldIdentifier( + ident=self.type.ident, + repeated=self.repeated, ) - oneof: Optional[str] = None - - def __getattr__(self, name): - return getattr(self.field_pb, name) - - def __hash__(self): - # The only sense in which it is meaningful to say a field is equal to - # another field is if they are the same, i.e. they live in the same - # message type under the same moniker, i.e. they have the same id. - return id(self) - - @property - def name(self) -> str: - """Used to prevent collisions with python keywords""" - name = self.field_pb.name - return name + "_" if name in utils.RESERVED_NAMES else name - - @utils.cached_property - def ident(self) -> metadata.FieldIdentifier: - """Return the identifier to be used in templates.""" - return metadata.FieldIdentifier( - ident=self.type.ident, - repeated=self.repeated, - ) - - @property - def is_primitive(self) -> bool: - """Return True if the field is a primitive, False otherwise.""" - return isinstance(self.type, PrimitiveType) - - @property - def map(self) -> bool: - """Return True if this field is a map, False otherwise.""" - return bool(self.repeated and self.message and self.message.map) - - @utils.cached_property - def mock_value(self) -> str: - visited_fields: Set["Field"] = set() - stack = [self] - answer = "{}" - while stack: - expr = stack.pop() - answer = answer.format(expr.inner_mock(stack, visited_fields)) - - return answer - - def inner_mock(self, stack, visited_fields): - """Return a repr of a valid, usually truthy mock value.""" - # For primitives, send a truthy value computed from the - # field name. - answer = 'None' - if isinstance(self.type, PrimitiveType): - if self.type.python_type == bool: - answer = 'True' - elif self.type.python_type == str: - answer = f"'{self.name}_value'" - elif self.type.python_type == bytes: - answer = f"b'{self.name}_blob'" - elif self.type.python_type == int: - answer = f'{sum([ord(i) for i in self.name])}' - elif self.type.python_type == float: - answer = f'0.{sum([ord(i) for i in self.name])}' - else: # Impossible; skip coverage checks. - raise TypeError('Unrecognized PrimitiveType. This should ' - 'never happen; please file an issue.') - - # If this is an enum, select the first truthy value (or the zero - # value if nothing else exists). - if isinstance(self.type, EnumType): - # Note: The slightly-goofy [:2][-1] lets us gracefully fall - # back to index 0 if there is only one element. - mock_value = self.type.values[:2][-1] - answer = f'{self.type.ident}.{mock_value.name}' - - # If this is another message, set one value on the message. - if ( - not self.map # Maps are handled separately - and isinstance(self.type, MessageType) - and len(self.type.fields) - # Nested message types need to terminate eventually - and self not in visited_fields - ): - sub = next(iter(self.type.fields.values())) - stack.append(sub) - visited_fields.add(self) - # Don't do the recursive rendering here, just set up - # where the nested value should go with the double {}. - answer = f'{self.type.ident}({sub.name}={{}})' - - if self.map: - # Maps are a special case beacuse they're represented internally as - # a list of a generated type with two fields: 'key' and 'value'. - answer = '{{{}: {}}}'.format( - self.type.fields["key"].mock_value, - self.type.fields["value"].mock_value, - ) - elif self.repeated: - # If this is a repeated field, then the mock answer should - # be a list. - answer = f'[{answer}]' - - # Done; return the mock value. - return answer - - @property - def proto_type(self) -> str: - """Return the proto type constant to be used in templates.""" - return cast(str, descriptor_pb2.FieldDescriptorProto.Type.Name( - self.field_pb.type, - ))[len('TYPE_'):] - - @property - def repeated(self) -> bool: - """Return True if this is a repeated field, False otherwise. + + @property + def is_primitive(self) -> bool: + """Return True if the field is a primitive, False otherwise.""" + return isinstance(self.type, PrimitiveType) + + @property + def map(self) -> bool: + """Return True if this field is a map, False otherwise.""" + return bool(self.repeated and self.message and self.message.map) + + @utils.cached_property + def mock_value(self) -> str: + visited_fields: Set['Field'] = set() + stack = [self] + answer = '{}' + while stack: + expr = stack.pop() + answer = answer.format(expr.inner_mock(stack, visited_fields)) + + return answer + + def inner_mock(self, stack, visited_fields): + """Return a repr of a valid, usually truthy mock value.""" + # For primitives, send a truthy value computed from the + # field name. + answer = 'None' + if isinstance(self.type, PrimitiveType): + if self.type.python_type == bool: + answer = 'True' + elif self.type.python_type == str: + answer = f"'{self.name}_value'" + elif self.type.python_type == bytes: + answer = f"b'{self.name}_blob'" + elif self.type.python_type == int: + answer = f'{sum([ord(i) for i in self.name])}' + elif self.type.python_type == float: + answer = f'0.{sum([ord(i) for i in self.name])}' + else: # Impossible; skip coverage checks. + raise TypeError('Unrecognized PrimitiveType. This should ' + 'never happen; please file an issue.') + + # If this is an enum, select the first truthy value (or the zero + # value if nothing else exists). + if isinstance(self.type, EnumType): + # Note: The slightly-goofy [:2][-1] lets us gracefully fall + # back to index 0 if there is only one element. + mock_value = self.type.values[:2][-1] + answer = f'{self.type.ident}.{mock_value.name}' + + # If this is another message, set one value on the message. + if (not self.map # Maps are handled separately + and isinstance(self.type, MessageType) and len(self.type.fields) + # Nested message types need to terminate eventually + and self not in visited_fields): + sub = next(iter(self.type.fields.values())) + stack.append(sub) + visited_fields.add(self) + # Don't do the recursive rendering here, just set up + # where the nested value should go with the double {}. + answer = f'{self.type.ident}({sub.name}={{}})' + + if self.map: + # Maps are a special case beacuse they're represented internally as + # a list of a generated type with two fields: 'key' and 'value'. + answer = '{{{}: {}}}'.format( + self.type.fields['key'].mock_value, + self.type.fields['value'].mock_value, + ) + elif self.repeated: + # If this is a repeated field, then the mock answer should + # be a list. + answer = f'[{answer}]' + + # Done; return the mock value. + return answer + + @property + def proto_type(self) -> str: + """Return the proto type constant to be used in templates.""" + return cast( + str, descriptor_pb2.FieldDescriptorProto.Type.Name( + self.field_pb.type,))[len('TYPE_'):] + + @property + def repeated(self) -> bool: + """Return True if this is a repeated field, False otherwise. Returns: bool: Whether this field is repeated. """ - return self.label == \ - descriptor_pb2.FieldDescriptorProto.Label.Value( - 'LABEL_REPEATED') # type: ignore + return self.label == \ + descriptor_pb2.FieldDescriptorProto.Label.Value( + 'LABEL_REPEATED') # type: ignore - @property - def required(self) -> bool: - """Return True if this is a required field, False otherwise. + @property + def required(self) -> bool: + """Return True if this is a required field, False otherwise. Returns: bool: Whether this field is required. """ - return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in - self.options.Extensions[field_behavior_pb2.field_behavior]) - - @utils.cached_property - def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: - """Return the type of this field.""" - # If this is a message or enum, return the appropriate thing. - if self.type_name and self.message: - return self.message - if self.type_name and self.enum: - return self.enum - - # This is a primitive. Return the corresponding Python type. - # The enum values used here are defined in: - # Repository: https://github.com/google/protobuf/ - # Path: src/google/protobuf/descriptor.proto - # - # The values are used here because the code would be excessively - # verbose otherwise, and this is guaranteed never to change. - # - # 10, 11, and 14 are intentionally missing. They correspond to - # group (unused), message (covered above), and enum (covered above). - if self.field_pb.type in (1, 2): - return PrimitiveType.build(float) - if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): - return PrimitiveType.build(int) - if self.field_pb.type == 8: - return PrimitiveType.build(bool) - if self.field_pb.type == 9: - return PrimitiveType.build(str) - if self.field_pb.type == 12: - return PrimitiveType.build(bytes) - - # This should never happen. - raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. ' - 'This code should not be reachable; please file a bug.') - - def with_context( - self, - *, - collisions: FrozenSet[str], - visited_messages: FrozenSet["MessageType"], - ) -> 'Field': - """Return a derivative of this field with the provided context. + return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') + in self.options.Extensions[field_behavior_pb2.field_behavior]) + + @utils.cached_property + def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: + """Return the type of this field.""" + # If this is a message or enum, return the appropriate thing. + if self.type_name and self.message: + return self.message + if self.type_name and self.enum: + return self.enum + + # This is a primitive. Return the corresponding Python type. + # The enum values used here are defined in: + # Repository: https://github.com/google/protobuf/ + # Path: src/google/protobuf/descriptor.proto + # + # The values are used here because the code would be excessively + # verbose otherwise, and this is guaranteed never to change. + # + # 10, 11, and 14 are intentionally missing. They correspond to + # group (unused), message (covered above), and enum (covered above). + if self.field_pb.type in (1, 2): + return PrimitiveType.build(float) + if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): + return PrimitiveType.build(int) + if self.field_pb.type == 8: + return PrimitiveType.build(bool) + if self.field_pb.type == 9: + return PrimitiveType.build(str) + if self.field_pb.type == 12: + return PrimitiveType.build(bytes) + + # This should never happen. + raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. ' + 'This code should not be reachable; please file a bug.') + + def with_context( + self, + *, + collisions: FrozenSet[str], + visited_messages: FrozenSet['MessageType'], + ) -> 'Field': + """Return a derivative of this field with the provided context. This method is used to address naming collisions. The returned ``Field`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( - self, - message=self.message.with_context( - collisions=collisions, - skip_fields=self.message in visited_messages, - visited_messages=visited_messages, - ) if self.message else None, - enum=self.enum.with_context(collisions=collisions) - if self.enum else None, - meta=self.meta.with_context(collisions=collisions), - ) + return dataclasses.replace( + self, + message=self.message.with_context( + collisions=collisions, + skip_fields=self.message in visited_messages, + visited_messages=visited_messages, + ) if self.message else None, + enum=self.enum.with_context( + collisions=collisions) if self.enum else None, + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class Oneof: - """Description of a field.""" - oneof_pb: descriptor_pb2.OneofDescriptorProto + """Description of a field.""" + oneof_pb: descriptor_pb2.OneofDescriptorProto - def __getattr__(self, name): - return getattr(self.oneof_pb, name) + def __getattr__(self, name): + return getattr(self.oneof_pb, name) @dataclasses.dataclass(frozen=True) class MessageType: - """Description of a message (defined with the ``message`` keyword).""" - # Class attributes - PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}') - - # Instance attributes - message_pb: descriptor_pb2.DescriptorProto - fields: Mapping[str, Field] - nested_enums: Mapping[str, 'EnumType'] - nested_messages: Mapping[str, 'MessageType'] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, + """Description of a message (defined with the ``message`` keyword).""" + # Class attributes + PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}') + + # Instance attributes + message_pb: descriptor_pb2.DescriptorProto + fields: Mapping[str, Field] + nested_enums: Mapping[str, 'EnumType'] + nested_messages: Mapping[str, 'MessageType'] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + oneofs: Optional[Mapping[str, 'Oneof']] = None + + def __getattr__(self, name): + return getattr(self.message_pb, name) + + def __hash__(self): + # Identity is sufficiently unambiguous. + return hash(self.ident) + + def oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + + @utils.cached_property + def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + answer = tuple(field.type + for field in self.fields.values() + if field.message or field.enum) + + return answer + + @utils.cached_property + def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + """Return all composite fields used in this proto's messages.""" + types: Set[Union['MessageType', 'EnumType']] = set() + + stack = [iter(self.fields.values())] + while stack: + fields_iter = stack.pop() + for field in fields_iter: + if field.message and field.type not in types: + stack.append(iter(field.message.fields.values())) + if not field.is_primitive: + types.add(field.type) + + return tuple(types) + + @utils.cached_property + def recursive_resource_fields(self) -> FrozenSet[Field]: + all_fields = chain( + self.fields.values(), + (field for t in self.recursive_field_types + if isinstance(t, MessageType) for field in t.fields.values()), ) - oneofs: Optional[Mapping[str, 'Oneof']] = None - - def __getattr__(self, name): - return getattr(self.message_pb, name) - - def __hash__(self): - # Identity is sufficiently unambiguous. - return hash(self.ident) - - def oneof_fields(self, include_optional=False): - oneof_fields = collections.defaultdict(list) - for field in self.fields.values(): - # Only include proto3 optional oneofs if explicitly looked for. - if field.oneof and not field.proto3_optional or include_optional: - oneof_fields[field.oneof].append(field) - - return oneof_fields - - @utils.cached_property - def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - answer = tuple( - field.type - for field in self.fields.values() - if field.message or field.enum - ) - - return answer - - @utils.cached_property - def recursive_field_types(self) -> Sequence[ - Union['MessageType', 'EnumType'] - ]: - """Return all composite fields used in this proto's messages.""" - types: Set[Union['MessageType', 'EnumType']] = set() - - stack = [iter(self.fields.values())] - while stack: - fields_iter = stack.pop() - for field in fields_iter: - if field.message and field.type not in types: - stack.append(iter(field.message.fields.values())) - if not field.is_primitive: - types.add(field.type) - - return tuple(types) - - @utils.cached_property - def recursive_resource_fields(self) -> FrozenSet[Field]: - all_fields = chain( - self.fields.values(), - (field - for t in self.recursive_field_types if isinstance(t, MessageType) - for field in t.fields.values()), - ) - return frozenset( - f - for f in all_fields - if (f.options.Extensions[resource_pb2.resource_reference].type or - f.options.Extensions[resource_pb2.resource_reference].child_type) - ) - - @property - def map(self) -> bool: - """Return True if the given message is a map, False otherwise.""" - return self.message_pb.options.map_entry - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - @property - def resource_path(self) -> Optional[str]: - """If this message describes a resource, return the path to the resource. - If there are multiple paths, returns the first one.""" - return next( - iter(self.options.Extensions[resource_pb2.resource].pattern), - None - ) - - @property - def resource_type(self) -> Optional[str]: - resource = self.options.Extensions[resource_pb2.resource] - return resource.type[resource.type.find('/') + 1:] if resource else None - - @property - def resource_path_args(self) -> Sequence[str]: - return self.PATH_ARG_RE.findall(self.resource_path or '') - - @utils.cached_property - def path_regex_str(self) -> str: - # The indirection here is a little confusing: - # we're using the resource path template as the base of a regex, - # with each resource ID segment being captured by a regex. - # E.g., the path schema - # kingdoms/{kingdom}/phyla/{phylum} - # becomes the regex - # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ - parsing_regex_str = ( - "^" + - self.PATH_ARG_RE.sub( - # We can't just use (?P[^/]+) because segments may be - # separated by delimiters other than '/'. - # Multiple delimiter characters within one schema are allowed, - # e.g. - # as/{a}-{b}/cs/{c}%{d}_{e} - # This is discouraged but permitted by AIP4231 - lambda m: "(?P<{name}>.+?)".format(name=m.groups()[0]), - self.resource_path or '' - ) + - "$" - ) - return parsing_regex_str - - def get_field(self, *field_path: str, - collisions: FrozenSet[str] = frozenset()) -> Field: - """Return a field arbitrarily deep in this message's structure. + return frozenset( + f for f in all_fields + if (f.options.Extensions[resource_pb2.resource_reference].type or + f.options.Extensions[resource_pb2.resource_reference].child_type)) + + @property + def map(self) -> bool: + """Return True if the given message is a map, False otherwise.""" + return self.message_pb.options.map_entry + + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address + + @property + def resource_path(self) -> Optional[str]: + """If this message describes a resource, return the path to the resource. + + If there are multiple paths, returns the first one. + """ + return next( + iter(self.options.Extensions[resource_pb2.resource].pattern), None) + + @property + def resource_type(self) -> Optional[str]: + resource = self.options.Extensions[resource_pb2.resource] + return resource.type[resource.type.find('/') + 1:] if resource else None + + @property + def resource_path_args(self) -> Sequence[str]: + return self.PATH_ARG_RE.findall(self.resource_path or '') + + @utils.cached_property + def path_regex_str(self) -> str: + # The indirection here is a little confusing: + # we're using the resource path template as the base of a regex, + # with each resource ID segment being captured by a regex. + # E.g., the path schema + # kingdoms/{kingdom}/phyla/{phylum} + # becomes the regex + # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ + parsing_regex_str = ( + '^' + self.PATH_ARG_RE.sub( + # We can't just use (?P[^/]+) because segments may be + # separated by delimiters other than '/'. + # Multiple delimiter characters within one schema are allowed, + # e.g. + # as/{a}-{b}/cs/{c}%{d}_{e} + # This is discouraged but permitted by AIP4231 + lambda m: '(?P<{name}>.+?)'.format(name=m.groups()[0]), + self.resource_path or '') + '$') + return parsing_regex_str + + def get_field( + self, *field_path: str, + collisions: FrozenSet[str] = frozenset()) -> Field: + """Return a field arbitrarily deep in this message's structure. This method recursively traverses the message tree to return the requested inner-field. @@ -402,55 +386,55 @@ def get_field(self, *field_path: str, KeyError: If a repeated field is used in the non-terminal position in the path. """ - # If collisions are not explicitly specified, retrieve them - # from this message's address. - # This ensures that calls to `get_field` will return a field with - # the same context, regardless of the number of levels through the - # chain (in order to avoid infinite recursion on circular references, - # we only shallowly bind message references held by fields; this - # binds deeply in the one spot where that might be a problem). - collisions = collisions or self.meta.address.collisions - - # Get the first field in the path. - first_field = field_path[0] - cursor = self.fields[first_field + - ('_' if first_field in utils.RESERVED_NAMES else '')] - - # Base case: If this is the last field in the path, return it outright. - if len(field_path) == 1: - return cursor.with_context( - collisions=collisions, - visited_messages=frozenset({self}), - ) - - # Sanity check: If cursor is a repeated field, then raise an exception. - # Repeated fields are only permitted in the terminal position. - if cursor.repeated: - raise KeyError( - f'The {cursor.name} field is repeated; unable to use ' - '`get_field` to retrieve its children.\n' - 'This exception usually indicates that a ' - 'google.api.method_signature annotation uses a repeated field ' - 'in the fields list in a position other than the end.', - ) - - # Sanity check: If this cursor has no message, there is a problem. - if not cursor.message: - raise KeyError( - f'Field {".".join(field_path)} could not be resolved from ' - f'{cursor.name}.', - ) - - # Recursion case: Pass the remainder of the path to the sub-field's - # message. - return cursor.message.get_field(*field_path[1:], collisions=collisions) - - def with_context(self, *, - collisions: FrozenSet[str], - skip_fields: bool = False, - visited_messages: FrozenSet["MessageType"] = frozenset(), - ) -> 'MessageType': - """Return a derivative of this message with the provided context. + # If collisions are not explicitly specified, retrieve them + # from this message's address. + # This ensures that calls to `get_field` will return a field with + # the same context, regardless of the number of levels through the + # chain (in order to avoid infinite recursion on circular references, + # we only shallowly bind message references held by fields; this + # binds deeply in the one spot where that might be a problem). + collisions = collisions or self.meta.address.collisions + + # Get the first field in the path. + first_field = field_path[0] + cursor = self.fields[first_field + + ('_' if first_field in utils.RESERVED_NAMES else '')] + + # Base case: If this is the last field in the path, return it outright. + if len(field_path) == 1: + return cursor.with_context( + collisions=collisions, + visited_messages=frozenset({self}), + ) + + # Sanity check: If cursor is a repeated field, then raise an exception. + # Repeated fields are only permitted in the terminal position. + if cursor.repeated: + raise KeyError( + f'The {cursor.name} field is repeated; unable to use ' + '`get_field` to retrieve its children.\n' + 'This exception usually indicates that a ' + 'google.api.method_signature annotation uses a repeated field ' + 'in the fields list in a position other than the end.',) + + # Sanity check: If this cursor has no message, there is a problem. + if not cursor.message: + raise KeyError( + f'Field {".".join(field_path)} could not be resolved from ' + f'{cursor.name}.',) + + # Recursion case: Pass the remainder of the path to the sub-field's + # message. + return cursor.message.get_field(*field_path[1:], collisions=collisions) + + def with_context( + self, + *, + collisions: FrozenSet[str], + skip_fields: bool = False, + visited_messages: FrozenSet['MessageType'] = frozenset(), + ) -> 'MessageType': + """Return a derivative of this message with the provided context. This method is used to address naming collisions. The returned ``MessageType`` object aliases module names to avoid naming collisions @@ -460,226 +444,218 @@ def with_context(self, *, underlying fields. This provides for an "exit" in the case of circular references. """ - visited_messages = visited_messages | {self} - return dataclasses.replace( - self, - fields={ - k: v.with_context( - collisions=collisions, - visited_messages=visited_messages - ) for k, v in self.fields.items() - } if not skip_fields else self.fields, - nested_enums={ - k: v.with_context(collisions=collisions) - for k, v in self.nested_enums.items() - }, - nested_messages={ - k: v.with_context( - collisions=collisions, - skip_fields=skip_fields, - visited_messages=visited_messages, - ) - for k, v in self.nested_messages.items() - }, - meta=self.meta.with_context(collisions=collisions), - ) + visited_messages = visited_messages | {self} + return dataclasses.replace( + self, + fields={ + k: v.with_context( + collisions=collisions, visited_messages=visited_messages) + for k, v in self.fields.items() + } if not skip_fields else self.fields, + nested_enums={ + k: v.with_context(collisions=collisions) + for k, v in self.nested_enums.items() + }, + nested_messages={ + k: v.with_context( + collisions=collisions, + skip_fields=skip_fields, + visited_messages=visited_messages, + ) for k, v in self.nested_messages.items() + }, + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class EnumValueType: - """Description of an enum value.""" - enum_value_pb: descriptor_pb2.EnumValueDescriptorProto - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, - ) + """Description of an enum value.""" + enum_value_pb: descriptor_pb2.EnumValueDescriptorProto + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) - def __getattr__(self, name): - return getattr(self.enum_value_pb, name) + def __getattr__(self, name): + return getattr(self.enum_value_pb, name) @dataclasses.dataclass(frozen=True) class EnumType: - """Description of an enum (defined with the ``enum`` keyword.)""" - enum_pb: descriptor_pb2.EnumDescriptorProto - values: List[EnumValueType] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, - ) - - def __hash__(self): - # Identity is sufficiently unambiguous. - return hash(self.ident) - - def __getattr__(self, name): - return getattr(self.enum_pb, name) - - @property - def resource_path(self) -> Optional[str]: - # This is a minor duck-typing workaround for the resource_messages - # property in the Service class: we need to check fields recursively - # to see if they're resources, and recursive_field_types includes enums - return None - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': - """Return a derivative of this enum with the provided context. + """Description of an enum (defined with the ``enum`` keyword.)""" + enum_pb: descriptor_pb2.EnumDescriptorProto + values: List[EnumValueType] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + def __hash__(self): + # Identity is sufficiently unambiguous. + return hash(self.ident) + + def __getattr__(self, name): + return getattr(self.enum_pb, name) + + @property + def resource_path(self) -> Optional[str]: + # This is a minor duck-typing workaround for the resource_messages + # property in the Service class: we need to check fields recursively + # to see if they're resources, and recursive_field_types includes enums + return None + + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address + + def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': + """Return a derivative of this enum with the provided context. This method is used to address naming collisions. The returned ``EnumType`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( - self, - meta=self.meta.with_context(collisions=collisions), - ) if collisions else self + return dataclasses.replace( + self, + meta=self.meta.with_context(collisions=collisions), + ) if collisions else self - @property - def options_dict(self) -> Dict: - """Return the EnumOptions (if present) as a dict. + @property + def options_dict(self) -> Dict: + """Return the EnumOptions (if present) as a dict. This is a hack to support a pythonic structure representation for the generator templates. """ - return MessageToDict( - self.enum_pb.options, - preserving_proto_field_name=True - ) + return MessageToDict(self.enum_pb.options, preserving_proto_field_name=True) @dataclasses.dataclass(frozen=True) class PythonType: - """Wrapper class for Python types. + """Wrapper class for Python types. This exists for interface consistency, so that methods like :meth:`Field.type` can return an object and the caller can be confident that a ``name`` property will be present. """ - meta: metadata.Metadata + meta: metadata.Metadata - def __eq__(self, other): - return self.meta == other.meta + def __eq__(self, other): + return self.meta == other.meta - def __ne__(self, other): - return not self == other + def __ne__(self, other): + return not self == other - @utils.cached_property - def ident(self) -> metadata.Address: - """Return the identifier to be used in templates.""" - return self.meta.address + @utils.cached_property + def ident(self) -> metadata.Address: + """Return the identifier to be used in templates.""" + return self.meta.address - @property - def name(self) -> str: - return self.ident.name + @property + def name(self) -> str: + return self.ident.name - @property - def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - return tuple() + @property + def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + return tuple() @dataclasses.dataclass(frozen=True) class PrimitiveType(PythonType): - """A representation of a Python primitive type.""" - python_type: Optional[type] + """A representation of a Python primitive type.""" + python_type: Optional[type] - @classmethod - def build(cls, primitive_type: Optional[type]): - """Return a PrimitiveType object for the given Python primitive type. + @classmethod + def build(cls, primitive_type: Optional[type]): + """Return a PrimitiveType object for the given Python primitive type. Args: - primitive_type (cls): A Python primitive type, such as - :class:`int` or :class:`str`. Despite not being a type, - ``None`` is also accepted here. + primitive_type (cls): A Python primitive type, such as :class:`int` + or :class:`str`. Despite not being a type, ``None`` is also + accepted here. Returns: ~.PrimitiveType: The instantiated PrimitiveType object. """ - # Primitives have no import, and no module to reference, so the - # address just uses the name of the class (e.g. "int", "str"). - return cls(meta=metadata.Metadata(address=metadata.Address( - name='None' if primitive_type is None else primitive_type.__name__, - )), python_type=primitive_type) - - def __eq__(self, other): - # If we are sent the actual Python type (not the PrimitiveType object), - # claim to be equal to that. - if not hasattr(other, 'meta'): - return self.python_type is other - return super().__eq__(other) + # Primitives have no import, and no module to reference, so the + # address just uses the name of the class (e.g. "int", "str"). + return cls( + meta=metadata.Metadata( + address=metadata.Address( + name='None' if primitive_type is None else primitive_type + .__name__,)), + python_type=primitive_type) + + def __eq__(self, other): + # If we are sent the actual Python type (not the PrimitiveType object), + # claim to be equal to that. + if not hasattr(other, 'meta'): + return self.python_type is other + return super().__eq__(other) @dataclasses.dataclass(frozen=True) class OperationInfo: - """Representation of long-running operation info.""" - response_type: MessageType - metadata_type: MessageType + """Representation of long-running operation info.""" + response_type: MessageType + metadata_type: MessageType - def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': - """Return a derivative of this OperationInfo with the provided context. + def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': + """Return a derivative of this OperationInfo with the provided context. This method is used to address naming collisions. The returned - ``OperationInfo`` object aliases module names to avoid naming collisions + ``OperationInfo`` object aliases module names to avoid naming + collisions in the file being written. """ - return dataclasses.replace( - self, - response_type=self.response_type.with_context( - collisions=collisions - ), - metadata_type=self.metadata_type.with_context( - collisions=collisions - ), - ) + return dataclasses.replace( + self, + response_type=self.response_type.with_context(collisions=collisions), + metadata_type=self.metadata_type.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class RetryInfo: - """Representation of the method's retry behavior.""" - max_attempts: int - initial_backoff: float - max_backoff: float - backoff_multiplier: float - retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError] + """Representation of the method's retry behavior.""" + max_attempts: int + initial_backoff: float + max_backoff: float + backoff_multiplier: float + retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError] @dataclasses.dataclass(frozen=True) class Method: - """Description of a method (defined with the ``rpc`` keyword).""" - method_pb: descriptor_pb2.MethodDescriptorProto - input: MessageType - output: MessageType - lro: Optional[OperationInfo] = dataclasses.field(default=None) - retry: Optional[RetryInfo] = dataclasses.field(default=None) - timeout: Optional[float] = None - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, - ) - - def __getattr__(self, name): - return getattr(self.method_pb, name) - - @utils.cached_property - def client_output(self): - return self._client_output(enable_asyncio=False) - - @utils.cached_property - def client_output_async(self): - return self._client_output(enable_asyncio=True) - - def flattened_oneof_fields(self, include_optional=False): - oneof_fields = collections.defaultdict(list) - for field in self.flattened_fields.values(): - # Only include proto3 optional oneofs if explicitly looked for. - if field.oneof and not field.proto3_optional or include_optional: - oneof_fields[field.oneof].append(field) - - return oneof_fields - - def _client_output(self, enable_asyncio: bool): - """Return the output from the client layer. + """Description of a method (defined with the ``rpc`` keyword).""" + method_pb: descriptor_pb2.MethodDescriptorProto + input: MessageType + output: MessageType + lro: Optional[OperationInfo] = dataclasses.field(default=None) + retry: Optional[RetryInfo] = dataclasses.field(default=None) + timeout: Optional[float] = None + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + def __getattr__(self, name): + return getattr(self.method_pb, name) + + @utils.cached_property + def client_output(self): + return self._client_output(enable_asyncio=False) + + @utils.cached_property + def client_output_async(self): + return self._client_output(enable_asyncio=True) + + def flattened_oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.flattened_fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + + def _client_output(self, enable_asyncio: bool): + """Return the output from the client layer. This takes into account transformations made by the outer GAPIC client to transform the output from the transport. @@ -688,534 +664,530 @@ def _client_output(self, enable_asyncio: bool): Union[~.MessageType, ~.PythonType]: A description of the return type. """ - # Void messages ultimately return None. - if self.void: - return PrimitiveType.build(None) - - # If this method is an LRO, return a PythonType instance representing - # that. - if self.lro: - return PythonType(meta=metadata.Metadata( - address=metadata.Address( - name='AsyncOperation' if enable_asyncio else 'Operation', - module='operation_async' if enable_asyncio else 'operation', - package=('google', 'api_core'), - collisions=self.lro.response_type.ident.collisions, - ), - documentation=utils.doc( - 'An object representing a long-running operation. \n\n' - 'The result type for the operation will be ' - ':class:`{ident}` {doc}'.format( - doc=self.lro.response_type.meta.doc, - ident=self.lro.response_type.ident.sphinx, - ), - ), - )) - - # If this method is paginated, return that method's pager class. - if self.paged_result_field: - return PythonType(meta=metadata.Metadata( - address=metadata.Address( - name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager', - package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( - 'services', - utils.to_snake_case(self.ident.parent[-1]), - ), - module='pagers', - collisions=self.input.ident.collisions, - ), - documentation=utils.doc( - f'{self.output.meta.doc}\n\n' - 'Iterating over this object will yield results and ' - 'resolve additional pages automatically.', - ), - )) - - # Return the usual output. - return self.output - - # TODO(yon-mg): remove or rewrite: don't think it performs as intended - # e.g. doesn't work with basic case of gRPC transcoding - @property - def field_headers(self) -> Sequence[str]: - """Return the field headers defined for this method.""" - http = self.options.Extensions[annotations_pb2.http] - - pattern = re.compile(r'\{([a-z][\w\d_.]+)=') - - potential_verbs = [ - http.get, - http.put, - http.post, - http.delete, - http.patch, - http.custom.path, - ] - - return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) - - @property - def http_opt(self) -> Optional[Dict[str, str]]: - """Return the http option for this method. + # Void messages ultimately return None. + if self.void: + return PrimitiveType.build(None) + + # If this method is an LRO, return a PythonType instance representing + # that. + if self.lro: + return PythonType( + meta=metadata.Metadata( + address=metadata.Address( + name='AsyncOperation' if enable_asyncio else 'Operation', + module='operation_async' if enable_asyncio else 'operation', + package=('google', 'api_core'), + collisions=self.lro.response_type.ident.collisions, + ), + documentation=utils.doc( + 'An object representing a long-running operation. \n\n' + 'The result type for the operation will be ' + ':class:`{ident}` {doc}'.format( + doc=self.lro.response_type.meta.doc, + ident=self.lro.response_type.ident.sphinx, + ),), + )) + + # If this method is paginated, return that method's pager class. + if self.paged_result_field: + return PythonType( + meta=metadata.Metadata( + address=metadata.Address( + name=f'{self.name}AsyncPager' + if enable_asyncio else f'{self.name}Pager', + package=self.ident.api_naming.module_namespace + + (self.ident.api_naming.versioned_module_name,) + + self.ident.subpackage + ( + 'services', + utils.to_snake_case(self.ident.parent[-1]), + ), + module='pagers', + collisions=self.input.ident.collisions, + ), + documentation=utils.doc( + f'{self.output.meta.doc}\n\n' + 'Iterating over this object will yield results and ' + 'resolve additional pages automatically.',), + )) + + # Return the usual output. + return self.output + + @property + def is_deprecated(self) -> bool: + """Returns true if the method is deprecated, false otherwise.""" + return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') + + # TODO(yon-mg): remove or rewrite: don't think it performs as intended + # e.g. doesn't work with basic case of gRPC transcoding + @property + def field_headers(self) -> Sequence[str]: + """Return the field headers defined for this method.""" + http = self.options.Extensions[annotations_pb2.http] + + pattern = re.compile(r'\{([a-z][\w\d_.]+)=') + + potential_verbs = [ + http.get, + http.put, + http.post, + http.delete, + http.patch, + http.custom.path, + ] + + return next( + (tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) + + @property + def http_opt(self) -> Optional[Dict[str, str]]: + """Return the http option for this method. e.g. {'verb': 'post' 'url': '/some/path' 'body': '*'} """ - http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] - http = self.options.Extensions[annotations_pb2.http].ListFields() - - if len(http) < 1: - return None - - http_method = http[0] - answer: Dict[str, str] = { - 'verb': http_method[0].name, - 'url': http_method[1], - } - if len(http) > 1: - body_spec = http[1] - answer[body_spec[0].name] = body_spec[1] - - # TODO(yon-mg): handle nested fields & fields past body i.e. 'additional bindings' - # TODO(yon-mg): enums for http verbs? - return answer - - @property - def path_params(self) -> Sequence[str]: - """Return the path parameters found in the http annotation path template""" - # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) - if self.http_opt is None: - return [] - - pattern = r'\{(\w+)\}' - return re.findall(pattern, self.http_opt['url']) - - @property - def query_params(self) -> Set[str]: - """Return query parameters for API call as determined by http annotation and grpc transcoding""" - # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) - # TODO(yon-mg): remove this method and move logic to generated client - if self.http_opt is None: - return set() - - params = set(self.path_params) - body = self.http_opt.get('body') - if body: - params.add(body) - - return set(self.input.fields) - params - - # TODO(yon-mg): refactor as there may be more than one method signature - @utils.cached_property - def flattened_fields(self) -> Mapping[str, Field]: - """Return the signature defined for this method.""" - cross_pkg_request = self.input.ident.package != self.ident.package - - def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: - for f in sig.split(','): - if not f: - # Special case for an empty signature - continue - name = f.strip() - field = self.input.get_field(*name.split('.')) - name += '_' if field.field_pb.name in utils.RESERVED_NAMES else '' - if cross_pkg_request and not field.is_primitive: - # This is not a proto-plus wrapped message type, - # and setting a non-primitive field directly is verboten. - continue - - yield name, field - - signatures = self.options.Extensions[client_pb2.method_signature] - answer: Dict[str, Field] = collections.OrderedDict( - name_and_field - for sig in signatures - for name_and_field in filter_fields(sig) - ) - - return answer - - @utils.cached_property - def flattened_field_to_key(self): - return {field.name: key for key, field in self.flattened_fields.items()} - - @utils.cached_property - def legacy_flattened_fields(self) -> Mapping[str, Field]: - """Return the legacy flattening interface: top level fields only, - required fields first""" - required, optional = utils.partition(lambda f: f.required, - self.input.fields.values()) - return collections.OrderedDict((f.name, f) - for f in chain(required, optional)) - - @property - def grpc_stub_type(self) -> str: - """Return the type of gRPC stub to use.""" - return '{client}_{server}'.format( - client='stream' if self.client_streaming else 'unary', - server='stream' if self.server_streaming else 'unary', - ) - - # TODO(yon-mg): figure out why idempotent is reliant on http annotation - @utils.cached_property - def idempotent(self) -> bool: - """Return True if we know this method is idempotent, False otherwise. + http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] + http = self.options.Extensions[annotations_pb2.http].ListFields() + + if len(http) < 1: + return None + + http_method = http[0] + answer: Dict[str, str] = { + 'verb': http_method[0].name, + 'url': http_method[1], + } + if len(http) > 1: + body_spec = http[1] + answer[body_spec[0].name] = body_spec[1] + + # TODO(yon-mg): handle nested fields & fields past body i.e. 'additional bindings' + # TODO(yon-mg): enums for http verbs? + return answer + + @property + def path_params(self) -> Sequence[str]: + """Return the path parameters found in the http annotation path template""" + # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) + if self.http_opt is None: + return [] + + pattern = r'\{(\w+)\}' + return re.findall(pattern, self.http_opt['url']) + + @property + def query_params(self) -> Set[str]: + """Return query parameters for API call as determined by http annotation and grpc transcoding""" + # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) + # TODO(yon-mg): remove this method and move logic to generated client + if self.http_opt is None: + return set() + + params = set(self.path_params) + body = self.http_opt.get('body') + if body: + params.add(body) + + return set(self.input.fields) - params + + # TODO(yon-mg): refactor as there may be more than one method signature + @utils.cached_property + def flattened_fields(self) -> Mapping[str, Field]: + """Return the signature defined for this method.""" + cross_pkg_request = self.input.ident.package != self.ident.package + + def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: + for f in sig.split(','): + if not f: + # Special case for an empty signature + continue + name = f.strip() + field = self.input.get_field(*name.split('.')) + name += '_' if field.field_pb.name in utils.RESERVED_NAMES else '' + if cross_pkg_request and not field.is_primitive: + # This is not a proto-plus wrapped message type, + # and setting a non-primitive field directly is verboten. + continue + + yield name, field + + signatures = self.options.Extensions[client_pb2.method_signature] + answer: Dict[str, Field] = collections.OrderedDict( + name_and_field for sig in signatures + for name_and_field in filter_fields(sig)) + + return answer + + @utils.cached_property + def flattened_field_to_key(self): + return {field.name: key for key, field in self.flattened_fields.items()} + + @utils.cached_property + def legacy_flattened_fields(self) -> Mapping[str, Field]: + """Return the legacy flattening interface: top level fields only, + + required fields first + """ + required, optional = utils.partition(lambda f: f.required, + self.input.fields.values()) + return collections.OrderedDict( + (f.name, f) for f in chain(required, optional)) + + @property + def grpc_stub_type(self) -> str: + """Return the type of gRPC stub to use.""" + return '{client}_{server}'.format( + client='stream' if self.client_streaming else 'unary', + server='stream' if self.server_streaming else 'unary', + ) + + # TODO(yon-mg): figure out why idempotent is reliant on http annotation + @utils.cached_property + def idempotent(self) -> bool: + """Return True if we know this method is idempotent, False otherwise. Note: We are intentionally conservative here. It is far less bad to falsely believe an idempotent method is non-idempotent than the converse. """ - return bool(self.options.Extensions[annotations_pb2.http].get) - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - @utils.cached_property - def paged_result_field(self) -> Optional[Field]: - """Return the response pagination field if the method is paginated.""" - # If the request field lacks any of the expected pagination fields, - # then the method is not paginated. - - # The request must have page_token and next_page_token as they keep track of pages - for source, source_type, name in ((self.input, str, 'page_token'), - (self.output, str, 'next_page_token')): - field = source.fields.get(name, None) - if not field or field.type != source_type: - return None - - # The request must have max_results or page_size - page_fields = (self.input.fields.get('max_results', None), - self.input.fields.get('page_size', None)) - page_field_size = next( - (field for field in page_fields if field), None) - if not page_field_size or page_field_size.type != int: - return None - - # Return the first repeated field. - for field in self.output.fields.values(): - if field.repeated: - return field - - # We found no repeated fields. Return None. + return bool(self.options.Extensions[annotations_pb2.http].get) + + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address + + @utils.cached_property + def paged_result_field(self) -> Optional[Field]: + """Return the response pagination field if the method is paginated.""" + # If the request field lacks any of the expected pagination fields, + # then the method is not paginated. + + # The request must have page_token and next_page_token as they keep track of pages + for source, source_type, name in ((self.input, str, 'page_token'), + (self.output, str, 'next_page_token')): + field = source.fields.get(name, None) + if not field or field.type != source_type: return None - @utils.cached_property - def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: - return self._ref_types(True) - - @utils.cached_property - def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: - return self._ref_types(False) - - def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: - """Return types referenced by this method.""" - # Begin with the input (request) and output (response) messages. - answer: List[Union[MessageType, EnumType]] = [self.input] - types: Iterable[Union[MessageType, EnumType]] = ( - self.input.recursive_field_types if recursive - else ( - f.type - for f in self.flattened_fields.values() - if f.message or f.enum - ) - ) - answer.extend(types) - - if not self.void: - answer.append(self.client_output) - answer.extend(self.client_output.field_types) - answer.append(self.client_output_async) - answer.extend(self.client_output_async.field_types) - - # If this method has LRO, it is possible (albeit unlikely) that - # the LRO messages reside in a different module. - if self.lro: - answer.append(self.lro.response_type) - answer.append(self.lro.metadata_type) - - # If this message paginates its responses, it is possible - # that the individual result messages reside in a different module. - if self.paged_result_field and self.paged_result_field.message: - answer.append(self.paged_result_field.message) - - # Done; return the answer. - return tuple(answer) - - @property - def void(self) -> bool: - """Return True if this method has no return value, False otherwise.""" - return self.output.ident.proto == 'google.protobuf.Empty' - - def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': - """Return a derivative of this method with the provided context. + # The request must have max_results or page_size + page_fields = (self.input.fields.get('max_results', None), + self.input.fields.get('page_size', None)) + page_field_size = next((field for field in page_fields if field), None) + if not page_field_size or page_field_size.type != int: + return None + + # Return the first repeated field. + for field in self.output.fields.values(): + if field.repeated: + return field + + # We found no repeated fields. Return None. + return None + + @utils.cached_property + def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: + return self._ref_types(True) + + @utils.cached_property + def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: + return self._ref_types(False) + + def _ref_types(self, + recursive: bool) -> Sequence[Union[MessageType, EnumType]]: + """Return types referenced by this method.""" + # Begin with the input (request) and output (response) messages. + answer: List[Union[MessageType, EnumType]] = [self.input] + types: Iterable[Union[MessageType, EnumType]] = ( + self.input.recursive_field_types if recursive else + (f.type for f in self.flattened_fields.values() if f.message or f.enum)) + answer.extend(types) + + if not self.void: + answer.append(self.client_output) + answer.extend(self.client_output.field_types) + answer.append(self.client_output_async) + answer.extend(self.client_output_async.field_types) + + # If this method has LRO, it is possible (albeit unlikely) that + # the LRO messages reside in a different module. + if self.lro: + answer.append(self.lro.response_type) + answer.append(self.lro.metadata_type) + + # If this message paginates its responses, it is possible + # that the individual result messages reside in a different module. + if self.paged_result_field and self.paged_result_field.message: + answer.append(self.paged_result_field.message) + + # Done; return the answer. + return tuple(answer) + + @property + def void(self) -> bool: + """Return True if this method has no return value, False otherwise.""" + return self.output.ident.proto == 'google.protobuf.Empty' + + def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': + """Return a derivative of this method with the provided context. This method is used to address naming collisions. The returned ``Method`` object aliases module names to avoid naming collisions in the file being written. """ - maybe_lro = None - if self.lro: - maybe_lro = self.lro.with_context( - collisions=collisions - ) if collisions else self.lro - - return dataclasses.replace( - self, - lro=maybe_lro, - input=self.input.with_context(collisions=collisions), - output=self.output.with_context(collisions=collisions), - meta=self.meta.with_context(collisions=collisions), - ) + maybe_lro = None + if self.lro: + maybe_lro = self.lro.with_context( + collisions=collisions) if collisions else self.lro + + return dataclasses.replace( + self, + lro=maybe_lro, + input=self.input.with_context(collisions=collisions), + output=self.output.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class CommonResource: - type_name: str - pattern: str - - @classmethod - def build(cls, resource: resource_pb2.ResourceDescriptor): - return cls( - type_name=resource.type, - pattern=next(iter(resource.pattern)) - ) - - @utils.cached_property - def message_type(self): - message_pb = descriptor_pb2.DescriptorProto() - res_pb = message_pb.options.Extensions[resource_pb2.resource] - res_pb.type = self.type_name - res_pb.pattern.append(self.pattern) - - return MessageType( - message_pb=message_pb, - fields={}, - nested_enums={}, - nested_messages={}, - ) + type_name: str + pattern: str + + @classmethod + def build(cls, resource: resource_pb2.ResourceDescriptor): + return cls(type_name=resource.type, pattern=next(iter(resource.pattern))) + + @utils.cached_property + def message_type(self): + message_pb = descriptor_pb2.DescriptorProto() + res_pb = message_pb.options.Extensions[resource_pb2.resource] + res_pb.type = self.type_name + res_pb.pattern.append(self.pattern) + + return MessageType( + message_pb=message_pb, + fields={}, + nested_enums={}, + nested_messages={}, + ) @dataclasses.dataclass(frozen=True) class Service: - """Description of a service (defined with the ``service`` keyword).""" - service_pb: descriptor_pb2.ServiceDescriptorProto - methods: Mapping[str, Method] - # N.B.: visible_resources is intended to be a read-only view - # whose backing store is owned by the API. - # This is represented by a types.MappingProxyType instance. - visible_resources: Mapping[str, MessageType] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata, - ) - - common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( - default={ - "cloudresourcemanager.googleapis.com/Project": CommonResource( - "cloudresourcemanager.googleapis.com/Project", - "projects/{project}", - ), - "cloudresourcemanager.googleapis.com/Organization": CommonResource( - "cloudresourcemanager.googleapis.com/Organization", - "organizations/{organization}", - ), - "cloudresourcemanager.googleapis.com/Folder": CommonResource( - "cloudresourcemanager.googleapis.com/Folder", - "folders/{folder}", - ), - "cloudbilling.googleapis.com/BillingAccount": CommonResource( - "cloudbilling.googleapis.com/BillingAccount", - "billingAccounts/{billing_account}", - ), - "locations.googleapis.com/Location": CommonResource( - "locations.googleapis.com/Location", - "projects/{project}/locations/{location}", - ), - }, - init=False, - compare=False, - ) - - def __getattr__(self, name): - return getattr(self.service_pb, name) - - @property - def client_name(self) -> str: - """Returns the name of the generated client class""" - return self.name + "Client" - - @property - def async_client_name(self) -> str: - """Returns the name of the generated AsyncIO client class""" - return self.name + "AsyncClient" - - @property - def transport_name(self): - return self.name + "Transport" - - @property - def grpc_transport_name(self): - return self.name + "GrpcTransport" - - @property - def grpc_asyncio_transport_name(self): - return self.name + "GrpcAsyncIOTransport" - - @property - def rest_transport_name(self): - return self.name + "RestTransport" - - @property - def has_lro(self) -> bool: - """Return whether the service has a long-running method.""" - return any([m.lro for m in self.methods.values()]) - - @property - def has_pagers(self) -> bool: - """Return whether the service has paged methods.""" - return any(m.paged_result_field for m in self.methods.values()) - - @property - def host(self) -> str: - """Return the hostname for this service, if specified. + """Description of a service (defined with the ``service`` keyword).""" + service_pb: descriptor_pb2.ServiceDescriptorProto + methods: Mapping[str, Method] + # N.B.: visible_resources is intended to be a read-only view + # whose backing store is owned by the API. + # This is represented by a types.MappingProxyType instance. + visible_resources: Mapping[str, MessageType] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( + default={ + 'cloudresourcemanager.googleapis.com/Project': + CommonResource( + 'cloudresourcemanager.googleapis.com/Project', + 'projects/{project}', + ), + 'cloudresourcemanager.googleapis.com/Organization': + CommonResource( + 'cloudresourcemanager.googleapis.com/Organization', + 'organizations/{organization}', + ), + 'cloudresourcemanager.googleapis.com/Folder': + CommonResource( + 'cloudresourcemanager.googleapis.com/Folder', + 'folders/{folder}', + ), + 'cloudbilling.googleapis.com/BillingAccount': + CommonResource( + 'cloudbilling.googleapis.com/BillingAccount', + 'billingAccounts/{billing_account}', + ), + 'locations.googleapis.com/Location': + CommonResource( + 'locations.googleapis.com/Location', + 'projects/{project}/locations/{location}', + ), + }, + init=False, + compare=False, + ) + + def __getattr__(self, name): + return getattr(self.service_pb, name) + + @property + def client_name(self) -> str: + """Returns the name of the generated client class""" + return self.name + 'Client' + + @property + def async_client_name(self) -> str: + """Returns the name of the generated AsyncIO client class""" + return self.name + 'AsyncClient' + + @property + def transport_name(self): + return self.name + 'Transport' + + @property + def grpc_transport_name(self): + return self.name + 'GrpcTransport' + + @property + def grpc_asyncio_transport_name(self): + return self.name + 'GrpcAsyncIOTransport' + + @property + def rest_transport_name(self): + return self.name + 'RestTransport' + + @property + def has_lro(self) -> bool: + """Return whether the service has a long-running method.""" + return any([m.lro for m in self.methods.values()]) + + @property + def has_pagers(self) -> bool: + """Return whether the service has paged methods.""" + return any(m.paged_result_field for m in self.methods.values()) + + @property + def host(self) -> str: + """Return the hostname for this service, if specified. Returns: str: The hostname, with no protocol and no trailing ``/``. """ - if self.options.Extensions[client_pb2.default_host]: - return self.options.Extensions[client_pb2.default_host] - return '' + if self.options.Extensions[client_pb2.default_host]: + return self.options.Extensions[client_pb2.default_host] + return '' - @property - def shortname(self) -> str: - """Return the API short name. DRIFT uses this to identify + @property + def shortname(self) -> str: + """Return the API short name. + + DRIFT uses this to identify APIs. Returns: str: The api shortname. """ - # Get the shortname from the host - # Real APIs are expected to have format: - # "{api_shortname}.googleapis.com" - return self.host.split(".")[0] + # Get the shortname from the host + # Real APIs are expected to have format: + # "{api_shortname}.googleapis.com" + return self.host.split('.')[0] - @property - def oauth_scopes(self) -> Sequence[str]: - """Return a sequence of oauth scopes, if applicable. + @property + def oauth_scopes(self) -> Sequence[str]: + """Return a sequence of oauth scopes, if applicable. Returns: Sequence[str]: A sequence of OAuth scopes. """ - # Return the OAuth scopes, split on comma. - return tuple( - i.strip() - for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') - if i - ) + # Return the OAuth scopes, split on comma. + return tuple( + i.strip() + for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') + if i) - @property - def module_name(self) -> str: - """Return the appropriate module name for this service. + @property + def module_name(self) -> str: + """Return the appropriate module name for this service. Returns: str: The service name, in snake case. """ - return utils.to_snake_case(self.name) + return utils.to_snake_case(self.name) - @utils.cached_property - def names(self) -> FrozenSet[str]: - """Return a set of names used in this service. + @utils.cached_property + def names(self) -> FrozenSet[str]: + """Return a set of names used in this service. This is used for detecting naming collisions in the module names used for imports. """ - # Put together a set of the service and method names. - answer = {self.name, self.client_name, self.async_client_name} - answer.update( - utils.to_snake_case(i.name) for i in self.methods.values() - ) - - # Identify any import module names where the same module name is used - # from distinct packages. - modules: Dict[str, Set[str]] = collections.defaultdict(set) - for m in self.methods.values(): - for t in m.ref_types: - modules[t.ident.module].add(t.ident.package) - - answer.update( - module_name - for module_name, packages in modules.items() - if len(packages) > 1 - ) - - # Done; return the answer. - return frozenset(answer) - - @utils.cached_property - def resource_messages(self) -> FrozenSet[MessageType]: - """Returns all the resource message types used in all - request and response fields in the service.""" - def gen_resources(message): - if message.resource_path: - yield message - - for type_ in message.recursive_field_types: - if type_.resource_path: - yield type_ - - def gen_indirect_resources_used(message): - for field in message.recursive_resource_fields: - resource = field.options.Extensions[ - resource_pb2.resource_reference] - resource_type = resource.type or resource.child_type - # The resource may not be visible if the resource type is one of - # the common_resources (see the class var in class definition) - # or if it's something unhelpful like '*'. - resource = self.visible_resources.get(resource_type) - if resource: - yield resource - - return frozenset( - msg - for method in self.methods.values() - for msg in chain( - gen_resources(method.input), - gen_resources( - method.lro.response_type if method.lro else method.output - ), - gen_indirect_resources_used(method.input), - gen_indirect_resources_used( - method.lro.response_type if method.lro else method.output - ), - ) - ) - - @utils.cached_property - def any_client_streaming(self) -> bool: - return any(m.client_streaming for m in self.methods.values()) - - @utils.cached_property - def any_server_streaming(self) -> bool: - return any(m.server_streaming for m in self.methods.values()) - - def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': - """Return a derivative of this service with the provided context. + # Put together a set of the service and method names. + answer = {self.name, self.client_name, self.async_client_name} + answer.update(utils.to_snake_case(i.name) for i in self.methods.values()) + + # Identify any import module names where the same module name is used + # from distinct packages. + modules: Dict[str, Set[str]] = collections.defaultdict(set) + for m in self.methods.values(): + for t in m.ref_types: + modules[t.ident.module].add(t.ident.package) + + answer.update(module_name for module_name, packages in modules.items() + if len(packages) > 1) + + # Done; return the answer. + return frozenset(answer) + + @utils.cached_property + def resource_messages(self) -> FrozenSet[MessageType]: + """Returns all the resource message types used in all + + request and response fields in the service. + """ + + def gen_resources(message): + if message.resource_path: + yield message + + for type_ in message.recursive_field_types: + if type_.resource_path: + yield type_ + + def gen_indirect_resources_used(message): + for field in message.recursive_resource_fields: + resource = field.options.Extensions[resource_pb2.resource_reference] + resource_type = resource.type or resource.child_type + # The resource may not be visible if the resource type is one of + # the common_resources (see the class var in class definition) + # or if it's something unhelpful like '*'. + resource = self.visible_resources.get(resource_type) + if resource: + yield resource + + return frozenset(msg for method in self.methods.values() for msg in chain( + gen_resources(method.input), + gen_resources( + method.lro.response_type if method.lro else method.output), + gen_indirect_resources_used(method.input), + gen_indirect_resources_used( + method.lro.response_type if method.lro else method.output), + )) + + @utils.cached_property + def any_client_streaming(self) -> bool: + return any(m.client_streaming for m in self.methods.values()) + + @utils.cached_property + def any_server_streaming(self) -> bool: + return any(m.server_streaming for m in self.methods.values()) + + def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': + """Return a derivative of this service with the provided context. This method is used to address naming collisions. The returned ``Service`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( - self, - methods={ - k: v.with_context( - # A method's flattened fields create additional names - # that may conflict with module imports. - collisions=collisions | frozenset(v.flattened_fields.keys())) - for k, v in self.methods.items() - }, - meta=self.meta.with_context(collisions=collisions), - ) + return dataclasses.replace( + self, + methods={ + k: v.with_context( + # A method's flattened fields create additional names + # that may conflict with module imports. + collisions=collisions | frozenset(v.flattened_fields.keys())) + for k, v in self.methods.items() + }, + meta=self.meta.with_context(collisions=collisions), + ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 281913acd3..9db0f092d1 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -359,6 +359,9 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): {{ method.client_output.meta.doc|rst(width=72, indent=16, source_format="rst") }} {% endif %} """ + {% if method.is_deprecated %} + warnings.warn("{{ method.name|snake_case }} is deprecated", warnings.DeprecationWarning) + {% endif %} {% if not method.client_streaming %} # Create or coerce a protobuf request object. {% if method.flattened_fields %} @@ -476,9 +479,9 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): metadata: Sequence[Tuple[str, str]] = (), ) -> policy_pb2.Policy: r"""Sets the IAM access control policy on the specified function. - + Replaces any existing policy. - + Args: request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): The request object. Request message for `SetIamPolicy` diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index 2aafab454a..d444c56bf6 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -25,28 +25,27 @@ def make_service( - name: str = "Placeholder", - host: str = "", + name: str = 'Placeholder', + host: str = '', methods: typing.Tuple[wrappers.Method] = (), scopes: typing.Tuple[str] = (), - visible_resources: typing.Optional[ - typing.Mapping[str, wrappers.CommonResource] - ] = None, + visible_resources: typing.Optional[typing.Mapping[ + str, wrappers.CommonResource]] = None, ) -> wrappers.Service: - visible_resources = visible_resources or {} - # Define a service descriptor, and set a host and oauth scopes if - # appropriate. - service_pb = desc.ServiceDescriptorProto(name=name) - if host: - service_pb.options.Extensions[client_pb2.default_host] = host - service_pb.options.Extensions[client_pb2.oauth_scopes] = ','.join(scopes) - - # Return a service object to test. - return wrappers.Service( - service_pb=service_pb, - methods={m.name: m for m in methods}, - visible_resources=visible_resources, - ) + visible_resources = visible_resources or {} + # Define a service descriptor, and set a host and oauth scopes if + # appropriate. + service_pb = desc.ServiceDescriptorProto(name=name) + if host: + service_pb.options.Extensions[client_pb2.default_host] = host + service_pb.options.Extensions[client_pb2.oauth_scopes] = ','.join(scopes) + + # Return a service object to test. + return wrappers.Service( + service_pb=service_pb, + methods={m.name: m for m in methods}, + visible_resources=visible_resources, + ) # FIXME (lukesneeringer): This test method is convoluted and it makes these @@ -56,192 +55,196 @@ def make_service_with_method_options( http_rule: http_pb2.HttpRule = None, method_signature: str = '', in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - visible_resources: typing.Optional[typing.Mapping[str, wrappers.CommonResource]] = None, + visible_resources: typing.Optional[typing.Mapping[ + str, wrappers.CommonResource]] = None, ) -> wrappers.Service: - # Declare a method with options enabled for long-running operations and - # field headers. - method = get_method( - 'DoBigThing', - 'foo.bar.ThingRequest', - 'google.longrunning.operations_pb2.Operation', - lro_response_type='foo.baz.ThingResponse', - lro_metadata_type='foo.qux.ThingMetadata', - in_fields=in_fields, - http_rule=http_rule, - method_signature=method_signature, - ) - - # Define a service descriptor. - service_pb = desc.ServiceDescriptorProto(name='ThingDoer') - - # Return a service object to test. - return wrappers.Service( - service_pb=service_pb, - methods={method.name: method}, - visible_resources=visible_resources or {}, - ) - - -def get_method(name: str, - in_type: str, - out_type: str, - lro_response_type: str = '', - lro_metadata_type: str = '', *, - in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - http_rule: http_pb2.HttpRule = None, - method_signature: str = '', - ) -> wrappers.Method: - input_ = get_message(in_type, fields=in_fields) - output = get_message(out_type) - lro = None - - # Define a method descriptor. Set the field headers if appropriate. - method_pb = desc.MethodDescriptorProto( - name=name, - input_type=input_.ident.proto, - output_type=output.ident.proto, - ) - if lro_response_type: - lro = wrappers.OperationInfo( - response_type=get_message(lro_response_type), - metadata_type=get_message(lro_metadata_type), - ) - if http_rule: - ext_key = annotations_pb2.http - method_pb.options.Extensions[ext_key].MergeFrom(http_rule) - if method_signature: - ext_key = client_pb2.method_signature - method_pb.options.Extensions[ext_key].append(method_signature) - - return wrappers.Method( - method_pb=method_pb, - input=input_, - output=output, - lro=lro, - meta=input_.meta, - ) - - -def get_message(dot_path: str, *, - fields: typing.Tuple[desc.FieldDescriptorProto] = (), - ) -> wrappers.MessageType: - # Pass explicit None through (for lro_metadata). - if dot_path is None: - return None - - # Note: The `dot_path` here is distinct from the canonical proto path - # because it includes the module, which the proto path does not. - # - # So, if trying to test the DescriptorProto message here, the path - # would be google.protobuf.descriptor.DescriptorProto (whereas the proto - # path is just google.protobuf.DescriptorProto). - pieces = dot_path.split('.') - pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] - - return wrappers.MessageType( - fields={i.name: wrappers.Field( - field_pb=i, - enum=get_enum(i.type_name) if i.type_name else None, - ) for i in fields}, - nested_messages={}, - nested_enums={}, - message_pb=desc.DescriptorProto(name=name, field=fields), - meta=metadata.Metadata(address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), - ) - - -def make_method( - name: str, - input_message: wrappers.MessageType = None, - output_message: wrappers.MessageType = None, - package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', - module: str = 'baz', - http_rule: http_pb2.HttpRule = None, - signatures: typing.Sequence[str] = (), - **kwargs + # Declare a method with options enabled for long-running operations and + # field headers. + method = get_method( + 'DoBigThing', + 'foo.bar.ThingRequest', + 'google.longrunning.operations_pb2.Operation', + lro_response_type='foo.baz.ThingResponse', + lro_metadata_type='foo.qux.ThingMetadata', + in_fields=in_fields, + http_rule=http_rule, + method_signature=method_signature, + ) + + # Define a service descriptor. + service_pb = desc.ServiceDescriptorProto(name='ThingDoer') + + # Return a service object to test. + return wrappers.Service( + service_pb=service_pb, + methods={method.name: method}, + visible_resources=visible_resources or {}, + ) + + +def get_method( + name: str, + in_type: str, + out_type: str, + lro_response_type: str = '', + lro_metadata_type: str = '', + *, + in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), + http_rule: http_pb2.HttpRule = None, + method_signature: str = '', ) -> wrappers.Method: - # Use default input and output messages if they are not provided. - input_message = input_message or make_message('MethodInput') - output_message = output_message or make_message('MethodOutput') - - # Create the method pb2. - method_pb = desc.MethodDescriptorProto( - name=name, - input_type=str(input_message.meta.address), - output_type=str(output_message.meta.address), - **kwargs - ) - - # If there is an HTTP rule, process it. - if http_rule: - ext_key = annotations_pb2.http - method_pb.options.Extensions[ext_key].MergeFrom(http_rule) - - # If there are signatures, include them. - for sig in signatures: - ext_key = client_pb2.method_signature - method_pb.options.Extensions[ext_key].append(sig) - - if isinstance(package, str): - package = tuple(package.split('.')) - - # Instantiate the wrapper class. - return wrappers.Method( - method_pb=method_pb, - input=input_message, - output=output_message, - meta=metadata.Metadata(address=metadata.Address( - name=name, - package=package, - module=module, - parent=(f'{name}Service',), - )), - ) - - -def make_field( - name: str = 'my_field', - number: int = 1, - repeated: bool = False, - message: wrappers.MessageType = None, - enum: wrappers.EnumType = None, - meta: metadata.Metadata = None, - oneof: str = None, - **kwargs -) -> wrappers.Field: - T = desc.FieldDescriptorProto.Type - - if message: - kwargs.setdefault('type_name', str(message.meta.address)) - kwargs['type'] = 'TYPE_MESSAGE' - elif enum: - kwargs.setdefault('type_name', str(enum.meta.address)) - kwargs['type'] = 'TYPE_ENUM' - else: - kwargs.setdefault('type', T.Value('TYPE_BOOL')) - - if isinstance(kwargs['type'], str): - kwargs['type'] = T.Value(kwargs['type']) - - label = kwargs.pop('label', 3 if repeated else 1) - field_pb = desc.FieldDescriptorProto( - name=name, - label=label, - number=number, - **kwargs - ) - - return wrappers.Field( - field_pb=field_pb, - enum=enum, - message=message, - meta=meta or metadata.Metadata(), - oneof=oneof, + input_ = get_message(in_type, fields=in_fields) + output = get_message(out_type) + lro = None + + # Define a method descriptor. Set the field headers if appropriate. + method_pb = desc.MethodDescriptorProto( + name=name, + input_type=input_.ident.proto, + output_type=output.ident.proto, + ) + if lro_response_type: + lro = wrappers.OperationInfo( + response_type=get_message(lro_response_type), + metadata_type=get_message(lro_metadata_type), ) + if http_rule: + ext_key = annotations_pb2.http + method_pb.options.Extensions[ext_key].MergeFrom(http_rule) + if method_signature: + ext_key = client_pb2.method_signature + method_pb.options.Extensions[ext_key].append(method_signature) + + return wrappers.Method( + method_pb=method_pb, + input=input_, + output=output, + lro=lro, + meta=input_.meta, + ) + + +def get_message( + dot_path: str, + *, + fields: typing.Tuple[desc.FieldDescriptorProto] = (), +) -> wrappers.MessageType: + # Pass explicit None through (for lro_metadata). + if dot_path is None: + return None + + # Note: The `dot_path` here is distinct from the canonical proto path + # because it includes the module, which the proto path does not. + # + # So, if trying to test the DescriptorProto message here, the path + # would be google.protobuf.descriptor.DescriptorProto (whereas the proto + # path is just google.protobuf.DescriptorProto). + pieces = dot_path.split('.') + pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] + + return wrappers.MessageType( + fields={ + i.name: wrappers.Field( + field_pb=i, + enum=get_enum(i.type_name) if i.type_name else None, + ) for i in fields + }, + nested_messages={}, + nested_enums={}, + message_pb=desc.DescriptorProto(name=name, field=fields), + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), + ) + + +def make_method(name: str, + input_message: wrappers.MessageType = None, + output_message: wrappers.MessageType = None, + package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', + module: str = 'baz', + http_rule: http_pb2.HttpRule = None, + signatures: typing.Sequence[str] = (), + is_deprecated: bool = False, + **kwargs) -> wrappers.Method: + # Use default input and output messages if they are not provided. + input_message = input_message or make_message('MethodInput') + output_message = output_message or make_message('MethodOutput') + + # Create the method pb2. + method_pb = desc.MethodDescriptorProto( + name=name, + input_type=str(input_message.meta.address), + output_type=str(output_message.meta.address), + **kwargs) + + # If there is an HTTP rule, process it. + if http_rule: + ext_key = annotations_pb2.http + method_pb.options.Extensions[ext_key].MergeFrom(http_rule) + + # If there are signatures, include them. + for sig in signatures: + ext_key = client_pb2.method_signature + method_pb.options.Extensions[ext_key].append(sig) + + if isinstance(package, str): + package = tuple(package.split('.')) + + if is_deprecated: + method_pb.options.deprecated = True + + # Instantiate the wrapper class. + return wrappers.Method( + method_pb=method_pb, + input=input_message, + output=output_message, + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=package, + module=module, + parent=(f'{name}Service',), + )), + ) + + +def make_field(name: str = 'my_field', + number: int = 1, + repeated: bool = False, + message: wrappers.MessageType = None, + enum: wrappers.EnumType = None, + meta: metadata.Metadata = None, + oneof: str = None, + **kwargs) -> wrappers.Field: + T = desc.FieldDescriptorProto.Type + + if message: + kwargs.setdefault('type_name', str(message.meta.address)) + kwargs['type'] = 'TYPE_MESSAGE' + elif enum: + kwargs.setdefault('type_name', str(enum.meta.address)) + kwargs['type'] = 'TYPE_ENUM' + else: + kwargs.setdefault('type', T.Value('TYPE_BOOL')) + + if isinstance(kwargs['type'], str): + kwargs['type'] = T.Value(kwargs['type']) + + label = kwargs.pop('label', 3 if repeated else 1) + field_pb = desc.FieldDescriptorProto( + name=name, label=label, number=number, **kwargs) + + return wrappers.Field( + field_pb=field_pb, + enum=enum, + message=message, + meta=meta or metadata.Metadata(), + oneof=oneof, + ) def make_message( @@ -252,36 +255,38 @@ def make_message( meta: metadata.Metadata = None, options: desc.MethodOptions = None, ) -> wrappers.MessageType: - message_pb = desc.DescriptorProto( - name=name, - field=[i.field_pb for i in fields], - options=options, - ) - return wrappers.MessageType( - message_pb=message_pb, - fields=collections.OrderedDict((i.name, i) for i in fields), - nested_messages={}, - nested_enums={}, - meta=meta or metadata.Metadata(address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), - ) + message_pb = desc.DescriptorProto( + name=name, + field=[i.field_pb for i in fields], + options=options, + ) + return wrappers.MessageType( + message_pb=message_pb, + fields=collections.OrderedDict((i.name, i) for i in fields), + nested_messages={}, + nested_enums={}, + meta=meta or metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), + ) def get_enum(dot_path: str) -> wrappers.EnumType: - pieces = dot_path.split('.') - pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] - return wrappers.EnumType( - enum_pb=desc.EnumDescriptorProto(name=name), - meta=metadata.Metadata(address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), - values=[], - ) + pieces = dot_path.split('.') + pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] + return wrappers.EnumType( + enum_pb=desc.EnumDescriptorProto(name=name), + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), + values=[], + ) def make_enum( @@ -292,102 +297,101 @@ def make_enum( meta: metadata.Metadata = None, options: desc.EnumOptions = None, ) -> wrappers.EnumType: - enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=i[0], number=i[1]) - for i in values - ] - enum_pb = desc.EnumDescriptorProto( - name=name, - value=enum_value_pbs, - options=options, - ) - return wrappers.EnumType( - enum_pb=enum_pb, - values=[wrappers.EnumValueType(enum_value_pb=evpb) - for evpb in enum_value_pbs], - meta=meta or metadata.Metadata(address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), - ) + enum_value_pbs = [ + desc.EnumValueDescriptorProto(name=i[0], number=i[1]) for i in values + ] + enum_pb = desc.EnumDescriptorProto( + name=name, + value=enum_value_pbs, + options=options, + ) + return wrappers.EnumType( + enum_pb=enum_pb, + values=[ + wrappers.EnumValueType(enum_value_pb=evpb) for evpb in enum_value_pbs + ], + meta=meta or metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), + ) def make_naming(**kwargs) -> naming.Naming: - kwargs.setdefault('name', 'Hatstand') - kwargs.setdefault('namespace', ('Google', 'Cloud')) - kwargs.setdefault('version', 'v1') - kwargs.setdefault('product_name', 'Hatstand') - return naming.NewNaming(**kwargs) + kwargs.setdefault('name', 'Hatstand') + kwargs.setdefault('namespace', ('Google', 'Cloud')) + kwargs.setdefault('version', 'v1') + kwargs.setdefault('product_name', 'Hatstand') + return naming.NewNaming(**kwargs) -def make_enum_pb2( +def make_enum_pb2(name: str, *values: typing.Sequence[str], + **kwargs) -> desc.EnumDescriptorProto: + enum_value_pbs = [ + desc.EnumValueDescriptorProto(name=n, number=i) + for i, n in enumerate(values) + ] + enum_pb = desc.EnumDescriptorProto(name=name, value=enum_value_pbs, **kwargs) + return enum_pb + + +def make_message_pb2(name: str, + fields: tuple = (), + oneof_decl: tuple = (), + **kwargs) -> desc.DescriptorProto: + return desc.DescriptorProto( + name=name, field=fields, oneof_decl=oneof_decl, **kwargs) + + +def make_field_pb2( name: str, - *values: typing.Sequence[str], - **kwargs -) -> desc.EnumDescriptorProto: - enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=n, number=i) - for i, n in enumerate(values) - ] - enum_pb = desc.EnumDescriptorProto(name=name, value=enum_value_pbs, **kwargs) - return enum_pb - - -def make_message_pb2( - name: str, - fields: tuple = (), - oneof_decl: tuple = (), - **kwargs -) -> desc.DescriptorProto: - return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs) - - -def make_field_pb2(name: str, number: int, - type: int = 11, # 11 == message - type_name: str = None, - oneof_index: int = None - ) -> desc.FieldDescriptorProto: - return desc.FieldDescriptorProto( - name=name, - number=number, - type=type, - type_name=type_name, - oneof_index=oneof_index, - ) + number: int, + type: int = 11, # 11 == message + type_name: str = None, + oneof_index: int = None) -> desc.FieldDescriptorProto: + return desc.FieldDescriptorProto( + name=name, + number=number, + type=type, + type_name=type_name, + oneof_index=oneof_index, + ) + def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto: - return desc.OneofDescriptorProto( - name=name, - ) + return desc.OneofDescriptorProto(name=name,) -def make_file_pb2(name: str = 'my_proto.proto', package: str = 'example.v1', *, - messages: typing.Sequence[desc.DescriptorProto] = (), - enums: typing.Sequence[desc.EnumDescriptorProto] = (), - services: typing.Sequence[desc.ServiceDescriptorProto] = (), - locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), - ) -> desc.FileDescriptorProto: - return desc.FileDescriptorProto( - name=name, - package=package, - message_type=messages, - enum_type=enums, - service=services, - source_code_info=desc.SourceCodeInfo(location=locations), - ) +def make_file_pb2( + name: str = 'my_proto.proto', + package: str = 'example.v1', + *, + messages: typing.Sequence[desc.DescriptorProto] = (), + enums: typing.Sequence[desc.EnumDescriptorProto] = (), + services: typing.Sequence[desc.ServiceDescriptorProto] = (), + locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), +) -> desc.FileDescriptorProto: + return desc.FileDescriptorProto( + name=name, + package=package, + message_type=messages, + enum_type=enums, + service=services, + source_code_info=desc.SourceCodeInfo(location=locations), + ) def make_doc_meta( - *, - leading: str = '', - trailing: str = '', - detached: typing.List[str] = [], + *, + leading: str = '', + trailing: str = '', + detached: typing.List[str] = [], ) -> desc.SourceCodeInfo.Location: - return metadata.Metadata( - documentation=desc.SourceCodeInfo.Location( - leading_comments=leading, - trailing_comments=trailing, - leading_detached_comments=detached, - ), - ) + return metadata.Metadata( + documentation=desc.SourceCodeInfo.Location( + leading_comments=leading, + trailing_comments=trailing, + leading_detached_comments=detached, + ),) diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 2162effbbb..6a47bd42f7 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -31,528 +31,521 @@ def test_method_types(): - input_msg = make_message(name='Input', module='baz') - output_msg = make_message(name='Output', module='baz') - method = make_method('DoSomething', input_msg, output_msg, - package='foo.bar', module='bacon') - assert method.name == 'DoSomething' - assert method.input.name == 'Input' - assert method.output.name == 'Output' + input_msg = make_message(name='Input', module='baz') + output_msg = make_message(name='Output', module='baz') + method = make_method( + 'DoSomething', input_msg, output_msg, package='foo.bar', module='bacon') + assert method.name == 'DoSomething' + assert method.input.name == 'Input' + assert method.output.name == 'Output' def test_method_void(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.void + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.void def test_method_not_void(): - not_empty = make_message(name='OutputMessage', package='foo.bar.v1') - method = make_method('Meh', output_message=not_empty) - assert not method.void + not_empty = make_message(name='OutputMessage', package='foo.bar.v1') + method = make_method('Meh', output_message=not_empty) + assert not method.void + + +def test_method_deprecated(): + method = make_method('DeprecatedMethod', is_deprecated=True) + assert method.is_deprecated def test_method_client_output(): - output = make_message(name='Input', module='baz') - method = make_method('DoStuff', output_message=output) - assert method.client_output is method.output + output = make_message(name='Input', module='baz') + method = make_method('DoStuff', output_message=output) + assert method.client_output is method.output def test_method_client_output_empty(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.client_output == wrappers.PrimitiveType.build(None) + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.client_output == wrappers.PrimitiveType.build(None) def test_method_client_output_paged(): - paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - parent = make_field(name='parent', type=9) # str - page_size = make_field(name='page_size', type=5) # int - page_token = make_field(name='page_token', type=9) # str - - input_msg = make_message(name='ListFoosRequest', fields=( - parent, - page_size, - page_token, - )) - output_msg = make_message(name='ListFoosResponse', fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListFoosPager' - - max_results = make_field(name='max_results', type=5) # int - input_msg = make_message(name='ListFoosRequest', fields=( - parent, - max_results, - page_token, - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListFoosPager' + paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int + page_token = make_field(name='page_token', type=9) # str + + input_msg = make_message( + name='ListFoosRequest', fields=( + parent, + page_size, + page_token, + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' + + max_results = make_field(name='max_results', type=5) # int + input_msg = make_message( + name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' def test_method_client_output_async_empty(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.client_output_async == wrappers.PrimitiveType.build(None) + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.client_output_async == wrappers.PrimitiveType.build(None) def test_method_paged_result_field_not_first(): - paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - input_msg = make_message(name='ListFoosRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message(name='ListFoosResponse', fields=( - make_field(name='next_page_token', type=9), # str - paged, - )) - method = make_method('ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged + paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + input_msg = make_message( + name='ListFoosRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + make_field(name='next_page_token', type=9), # str + paged, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged def test_method_paged_result_field_no_page_field(): - input_msg = make_message(name='ListFoosRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message(name='ListFoosResponse', fields=( - make_field(name='foos', message=make_message('Foo'), repeated=False), - make_field(name='next_page_token', type=9), # str - )) - method = make_method('ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field is None - - method = make_method( - name='Foo', - input_message=make_message( - name='FooRequest', - fields=(make_field(name='page_token', type=9),) # str - ), - output_message=make_message( - name='FooResponse', - fields=(make_field(name='next_page_token', type=9),) # str - ) - ) - assert method.paged_result_field is None + input_msg = make_message( + name='ListFoosRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + make_field(name='foos', message=make_message('Foo'), repeated=False), + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field is None + + method = make_method( + name='Foo', + input_message=make_message( + name='FooRequest', + fields=(make_field(name='page_token', type=9),) # str + ), + output_message=make_message( + name='FooResponse', + fields=(make_field(name='next_page_token', type=9),) # str + )) + assert method.paged_result_field is None def test_method_paged_result_ref_types(): - input_msg = make_message( - name='ListSquidsRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - ), - module='squid', - ) - mollusc_msg = make_message('Mollusc', module='mollusc') - output_msg = make_message( - name='ListMolluscsResponse', - fields=( - make_field(name='molluscs', message=mollusc_msg, repeated=True), - make_field(name='next_page_token', type=9) # str - ), - module='mollusc' - ) - method = make_method( - 'ListSquids', - input_message=input_msg, - output_message=output_msg, - module='squid' - ) - - ref_type_names = {t.name for t in method.ref_types} - assert ref_type_names == { - 'ListSquidsRequest', - 'ListSquidsPager', - 'ListSquidsAsyncPager', - 'Mollusc', - } + input_msg = make_message( + name='ListSquidsRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + ), + module='squid', + ) + mollusc_msg = make_message('Mollusc', module='mollusc') + output_msg = make_message( + name='ListMolluscsResponse', + fields=( + make_field(name='molluscs', message=mollusc_msg, repeated=True), + make_field(name='next_page_token', type=9) # str + ), + module='mollusc') + method = make_method( + 'ListSquids', + input_message=input_msg, + output_message=output_msg, + module='squid') + + ref_type_names = {t.name for t in method.ref_types} + assert ref_type_names == { + 'ListSquidsRequest', + 'ListSquidsPager', + 'ListSquidsAsyncPager', + 'Mollusc', + } def test_flattened_ref_types(): - method = make_method( - 'IdentifyMollusc', - input_message=make_message( - 'IdentifyMolluscRequest', - fields=( - make_field( - 'cephalopod', - message=make_message( - 'Cephalopod', - fields=( - make_field('mass_kg', type='TYPE_INT32'), - make_field( - 'squid', - number=2, - message=make_message('Squid'), - ), - make_field( - 'clam', - number=3, - message=make_message('Clam'), - ), - ), - ), - ), - make_field( - 'stratum', - enum=make_enum( - 'Stratum', - ) - ), - ), - ), - signatures=('cephalopod.squid,stratum',), - output_message=make_message('Mollusc'), - ) - - expected_flat_ref_type_names = { - 'IdentifyMolluscRequest', - 'Squid', - 'Stratum', - 'Mollusc', - } - actual_flat_ref_type_names = {t.name for t in method.flat_ref_types} - assert expected_flat_ref_type_names == actual_flat_ref_type_names + method = make_method( + 'IdentifyMollusc', + input_message=make_message( + 'IdentifyMolluscRequest', + fields=( + make_field( + 'cephalopod', + message=make_message( + 'Cephalopod', + fields=( + make_field('mass_kg', type='TYPE_INT32'), + make_field( + 'squid', + number=2, + message=make_message('Squid'), + ), + make_field( + 'clam', + number=3, + message=make_message('Clam'), + ), + ), + ), + ), + make_field('stratum', enum=make_enum('Stratum',)), + ), + ), + signatures=('cephalopod.squid,stratum',), + output_message=make_message('Mollusc'), + ) + + expected_flat_ref_type_names = { + 'IdentifyMolluscRequest', + 'Squid', + 'Stratum', + 'Mollusc', + } + actual_flat_ref_type_names = {t.name for t in method.flat_ref_types} + assert expected_flat_ref_type_names == actual_flat_ref_type_names def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) # str - input_msg = make_message( - name='ListSquidsRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - ), - ) - output_msg = make_message(name='ListFoosResponse', fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListSquids', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListSquidsPager' + paged = make_field(name='squids', type=9, repeated=True) # str + input_msg = make_message( + name='ListSquidsRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + ), + ) + output_msg = make_message( + name='ListFoosResponse', + fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListSquids', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListSquidsPager' def test_method_field_headers_none(): - method = make_method('DoSomething') - assert isinstance(method.field_headers, collections.abc.Sequence) + method = make_method('DoSomething') + assert isinstance(method.field_headers, collections.abc.Sequence) def test_method_field_headers_present(): - verbs = [ - 'get', - 'put', - 'post', - 'delete', - 'patch', - ] + verbs = [ + 'get', + 'put', + 'post', + 'delete', + 'patch', + ] - for v in verbs: - rule = http_pb2.HttpRule(**{v: '/v1/{parent=projects/*}/topics'}) - method = make_method('DoSomething', http_rule=rule) - assert method.field_headers == ('parent',) + for v in verbs: + rule = http_pb2.HttpRule(**{v: '/v1/{parent=projects/*}/topics'}) + method = make_method('DoSomething', http_rule=rule) + assert method.field_headers == ('parent',) def test_method_http_opt(): - http_rule = http_pb2.HttpRule( - post='/v1/{parent=projects/*}/topics', - body='*' - ) - method = make_method('DoSomething', http_rule=http_rule) - assert method.http_opt == { - 'verb': 'post', - 'url': '/v1/{parent=projects/*}/topics', - 'body': '*' - } + http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics', body='*') + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_opt == { + 'verb': 'post', + 'url': '/v1/{parent=projects/*}/topics', + 'body': '*' + } + + # TODO(yon-mg) to test: grpc transcoding, # correct handling of path/query params # correct handling of body & additional binding def test_method_http_opt_no_body(): - http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.http_opt == { - 'verb': 'post', - 'url': '/v1/{parent=projects/*}/topics' - } + http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_opt == { + 'verb': 'post', + 'url': '/v1/{parent=projects/*}/topics' + } def test_method_http_opt_no_http_rule(): - method = make_method('DoSomething') - assert method.http_opt == None + method = make_method('DoSomething') + assert method.http_opt == None def test_method_path_params(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.path_params == ['project'] + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.path_params == ['project'] def test_method_path_params_no_http_rule(): - method = make_method('DoSomething') - assert method.path_params == [] + method = make_method('DoSomething') + assert method.path_params == [] def test_method_query_params(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule( - post='/v1/{project}/topics', - body='address' - ) - input_message = make_message( - 'MethodInput', - fields=( - make_field('region'), - make_field('project'), - make_field('address') - ) - ) - method = make_method('DoSomething', http_rule=http_rule, - input_message=input_message) - assert method.query_params == {'region'} + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics', body='address') + input_message = make_message( + 'MethodInput', + fields=(make_field('region'), make_field('project'), + make_field('address'))) + method = make_method( + 'DoSomething', http_rule=http_rule, input_message=input_message) + assert method.query_params == {'region'} def test_method_query_params_no_body(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') - input_message = make_message( - 'MethodInput', - fields=( - make_field('region'), - make_field('project'), - ) - ) - method = make_method('DoSomething', http_rule=http_rule, - input_message=input_message) - assert method.query_params == {'region'} + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') + input_message = make_message( + 'MethodInput', fields=( + make_field('region'), + make_field('project'), + )) + method = make_method( + 'DoSomething', http_rule=http_rule, input_message=input_message) + assert method.query_params == {'region'} def test_method_query_params_no_http_rule(): - method = make_method('DoSomething') - assert method.query_params == set() + method = make_method('DoSomething') + assert method.query_params == set() def test_method_idempotent_yes(): - http_rule = http_pb2.HttpRule(get='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.idempotent is True + http_rule = http_pb2.HttpRule(get='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.idempotent is True def test_method_idempotent_no(): - http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.idempotent is False + http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.idempotent is False def test_method_idempotent_no_http_rule(): - method = make_method('DoSomething') - assert method.idempotent is False + method = make_method('DoSomething') + assert method.idempotent is False def test_method_unary_unary(): - method = make_method('F', client_streaming=False, server_streaming=False) - assert method.grpc_stub_type == 'unary_unary' + method = make_method('F', client_streaming=False, server_streaming=False) + assert method.grpc_stub_type == 'unary_unary' def test_method_unary_stream(): - method = make_method('F', client_streaming=False, server_streaming=True) - assert method.grpc_stub_type == 'unary_stream' + method = make_method('F', client_streaming=False, server_streaming=True) + assert method.grpc_stub_type == 'unary_stream' def test_method_stream_unary(): - method = make_method('F', client_streaming=True, server_streaming=False) - assert method.grpc_stub_type == 'stream_unary' + method = make_method('F', client_streaming=True, server_streaming=False) + assert method.grpc_stub_type == 'stream_unary' def test_method_stream_stream(): - method = make_method('F', client_streaming=True, server_streaming=True) - assert method.grpc_stub_type == 'stream_stream' + method = make_method('F', client_streaming=True, server_streaming=True) + assert method.grpc_stub_type == 'stream_stream' def test_method_flattened_fields(): - a = make_field('a', type=5) # int - b = make_field('b', type=5) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('a,b',)) - assert len(method.flattened_fields) == 2 - assert 'a' in method.flattened_fields - assert 'b' in method.flattened_fields + a = make_field('a', type=5) # int + b = make_field('b', type=5) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('a,b',)) + assert len(method.flattened_fields) == 2 + assert 'a' in method.flattened_fields + assert 'b' in method.flattened_fields def test_method_flattened_fields_empty_sig(): - a = make_field('a', type=5) # int - b = make_field('b', type=5) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('',)) - assert len(method.flattened_fields) == 0 + a = make_field('a', type=5) # int + b = make_field('b', type=5) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('',)) + assert len(method.flattened_fields) == 0 def test_method_flattened_fields_different_package_non_primitive(): - # This test verifies that method flattening handles a special case where: - # * the method's request message type lives in a different package and - # * a field in the method_signature is a non-primitive. - # - # If the message is defined in a different package it is not guaranteed to - # be a proto-plus wrapped type, which puts restrictions on assigning - # directly to its fields, which complicates request construction. - # The easiest solution in this case is to just prohibit these fields - # in the method flattening. - message = make_message('Mantle', - package="mollusc.cephalopod.v1", module="squid") - mantle = make_field('mantle', type=11, type_name='Mantle', - message=message, meta=message.meta) - arms_count = make_field('arms_count', type=5, meta=message.meta) - input_message = make_message( - 'Squid', fields=(mantle, arms_count), - package=".".join(message.meta.address.package), - module=message.meta.address.module - ) - method = make_method('PutSquid', input_message=input_message, - package="remote.package.v1", module="module", signatures=("mantle,arms_count",)) - assert set(method.flattened_fields) == {'arms_count'} + # This test verifies that method flattening handles a special case where: + # * the method's request message type lives in a different package and + # * a field in the method_signature is a non-primitive. + # + # If the message is defined in a different package it is not guaranteed to + # be a proto-plus wrapped type, which puts restrictions on assigning + # directly to its fields, which complicates request construction. + # The easiest solution in this case is to just prohibit these fields + # in the method flattening. + message = make_message( + 'Mantle', package='mollusc.cephalopod.v1', module='squid') + mantle = make_field( + 'mantle', type=11, type_name='Mantle', message=message, meta=message.meta) + arms_count = make_field('arms_count', type=5, meta=message.meta) + input_message = make_message( + 'Squid', + fields=(mantle, arms_count), + package='.'.join(message.meta.address.package), + module=message.meta.address.module) + method = make_method( + 'PutSquid', + input_message=input_message, + package='remote.package.v1', + module='module', + signatures=('mantle,arms_count',)) + assert set(method.flattened_fields) == {'arms_count'} def test_method_include_flattened_message_fields(): - a = make_field('a', type=5) - b = make_field('b', type=11, type_name='Eggs', - message=make_message('Eggs')) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('a,b',)) - assert len(method.flattened_fields) == 2 + a = make_field('a', type=5) + b = make_field('b', type=11, type_name='Eggs', message=make_message('Eggs')) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('a,b',)) + assert len(method.flattened_fields) == 2 def test_method_legacy_flattened_fields(): - required_options = descriptor_pb2.FieldOptions() - required_options.Extensions[field_behavior_pb2.field_behavior].append( - field_behavior_pb2.FieldBehavior.Value("REQUIRED")) - - # Cephalopods are required. - squid = make_field(name="squid", options=required_options) - octopus = make_field( - name="octopus", - message=make_message( - name="Octopus", - fields=[make_field(name="mass", options=required_options)] - ), - options=required_options) - - # Bivalves are optional. - clam = make_field(name="clam") - oyster = make_field( - name="oyster", - message=make_message( - name="Oyster", - fields=[make_field(name="has_pearl")] - ) - ) - - # Interleave required and optional fields to make sure - # that, in the legacy flattening, required fields are always first. - request = make_message("request", fields=[squid, clam, octopus, oyster]) - - method = make_method( - name="CreateMolluscs", - input_message=request, - # Signatures should be ignored. - signatures=[ - "squid,octopus.mass", - "squid,octopus,oyster.has_pearl" - ] - ) - - # Use an ordered dict because ordering is important: - # required fields should come first. - expected = collections.OrderedDict([ - ("squid", squid), - ("octopus", octopus), - ("clam", clam), - ("oyster", oyster) - ]) - - assert method.legacy_flattened_fields == expected + required_options = descriptor_pb2.FieldOptions() + required_options.Extensions[field_behavior_pb2.field_behavior].append( + field_behavior_pb2.FieldBehavior.Value('REQUIRED')) + + # Cephalopods are required. + squid = make_field(name='squid', options=required_options) + octopus = make_field( + name='octopus', + message=make_message( + name='Octopus', + fields=[make_field(name='mass', options=required_options)]), + options=required_options) + + # Bivalves are optional. + clam = make_field(name='clam') + oyster = make_field( + name='oyster', + message=make_message( + name='Oyster', fields=[make_field(name='has_pearl')])) + + # Interleave required and optional fields to make sure + # that, in the legacy flattening, required fields are always first. + request = make_message('request', fields=[squid, clam, octopus, oyster]) + + method = make_method( + name='CreateMolluscs', + input_message=request, + # Signatures should be ignored. + signatures=['squid,octopus.mass', 'squid,octopus,oyster.has_pearl']) + + # Use an ordered dict because ordering is important: + # required fields should come first. + expected = collections.OrderedDict([('squid', squid), ('octopus', octopus), + ('clam', clam), ('oyster', oyster)]) + + assert method.legacy_flattened_fields == expected def test_flattened_oneof_fields(): - mass_kg = make_field(name="mass_kg", oneof="mass", type=5) - mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) - - length_m = make_field(name="length_m", oneof="length", type=5) - length_f = make_field(name="length_f", oneof="length", type=5) - - color = make_field(name="color", type=5) - mantle = make_field( - name="mantle", - message=make_message( - name="Mantle", - fields=( - make_field(name="color", type=5), - mass_kg, - mass_lbs, - ), - ), - ) - request = make_message( - name="CreateMolluscReuqest", - fields=( - length_m, - length_f, - color, - mantle, - ), - ) - method = make_method( - name="CreateMollusc", - input_message=request, - signatures=[ - "length_m,", - "length_f,", - "mantle.mass_kg,", - "mantle.mass_lbs,", - "color", - ] - ) - - expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} - actual = method.flattened_oneof_fields() - assert expected == actual - - # Check this method too becasue the setup is a lot of work. - expected = { - "color": "color", - "length_m": "length_m", - "length_f": "length_f", - "mass_kg": "mantle.mass_kg", - "mass_lbs": "mantle.mass_lbs", - } - actual = method.flattened_field_to_key - assert expected == actual + mass_kg = make_field(name='mass_kg', oneof='mass', type=5) + mass_lbs = make_field(name='mass_lbs', oneof='mass', type=5) + + length_m = make_field(name='length_m', oneof='length', type=5) + length_f = make_field(name='length_f', oneof='length', type=5) + + color = make_field(name='color', type=5) + mantle = make_field( + name='mantle', + message=make_message( + name='Mantle', + fields=( + make_field(name='color', type=5), + mass_kg, + mass_lbs, + ), + ), + ) + request = make_message( + name='CreateMolluscReuqest', + fields=( + length_m, + length_f, + color, + mantle, + ), + ) + method = make_method( + name='CreateMollusc', + input_message=request, + signatures=[ + 'length_m,', + 'length_f,', + 'mantle.mass_kg,', + 'mantle.mass_lbs,', + 'color', + ]) + + expected = {'mass': [mass_kg, mass_lbs], 'length': [length_m, length_f]} + actual = method.flattened_oneof_fields() + assert expected == actual + + # Check this method too becasue the setup is a lot of work. + expected = { + 'color': 'color', + 'length_m': 'length_m', + 'length_f': 'length_f', + 'mass_kg': 'mantle.mass_kg', + 'mass_lbs': 'mantle.mass_lbs', + } + actual = method.flattened_field_to_key + assert expected == actual From 96a2cad5235decacc1503cca32591def6e1fc11a Mon Sep 17 00:00:00 2001 From: Mira Leung Date: Tue, 11 May 2021 14:44:56 -0700 Subject: [PATCH 2/4] fix: linter fixes --- gapic/schema/wrappers.py | 2165 +++++++++++---------- test_utils/test_utils.py | 497 ++--- tests/unit/schema/wrappers/test_method.py | 817 ++++---- 3 files changed, 1743 insertions(+), 1736 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index fe7f0f8117..19403c442f 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -46,1148 +46,1151 @@ @dataclasses.dataclass(frozen=True) class Field: - """Description of a field.""" - field_pb: descriptor_pb2.FieldDescriptorProto - message: Optional['MessageType'] = None - enum: Optional['EnumType'] = None - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) - oneof: Optional[str] = None - - def __getattr__(self, name): - return getattr(self.field_pb, name) - - def __hash__(self): - # The only sense in which it is meaningful to say a field is equal to - # another field is if they are the same, i.e. they live in the same - # message type under the same moniker, i.e. they have the same id. - return id(self) - - @property - def name(self) -> str: - """Used to prevent collisions with python keywords""" - name = self.field_pb.name - return name + '_' if name in utils.RESERVED_NAMES else name - - @utils.cached_property - def ident(self) -> metadata.FieldIdentifier: - """Return the identifier to be used in templates.""" - return metadata.FieldIdentifier( - ident=self.type.ident, - repeated=self.repeated, - ) - - @property - def is_primitive(self) -> bool: - """Return True if the field is a primitive, False otherwise.""" - return isinstance(self.type, PrimitiveType) - - @property - def map(self) -> bool: - """Return True if this field is a map, False otherwise.""" - return bool(self.repeated and self.message and self.message.map) - - @utils.cached_property - def mock_value(self) -> str: - visited_fields: Set['Field'] = set() - stack = [self] - answer = '{}' - while stack: - expr = stack.pop() - answer = answer.format(expr.inner_mock(stack, visited_fields)) - - return answer - - def inner_mock(self, stack, visited_fields): - """Return a repr of a valid, usually truthy mock value.""" - # For primitives, send a truthy value computed from the - # field name. - answer = 'None' - if isinstance(self.type, PrimitiveType): - if self.type.python_type == bool: - answer = 'True' - elif self.type.python_type == str: - answer = f"'{self.name}_value'" - elif self.type.python_type == bytes: - answer = f"b'{self.name}_blob'" - elif self.type.python_type == int: - answer = f'{sum([ord(i) for i in self.name])}' - elif self.type.python_type == float: - answer = f'0.{sum([ord(i) for i in self.name])}' - else: # Impossible; skip coverage checks. - raise TypeError('Unrecognized PrimitiveType. This should ' - 'never happen; please file an issue.') - - # If this is an enum, select the first truthy value (or the zero - # value if nothing else exists). - if isinstance(self.type, EnumType): - # Note: The slightly-goofy [:2][-1] lets us gracefully fall - # back to index 0 if there is only one element. - mock_value = self.type.values[:2][-1] - answer = f'{self.type.ident}.{mock_value.name}' - - # If this is another message, set one value on the message. - if (not self.map # Maps are handled separately - and isinstance(self.type, MessageType) and len(self.type.fields) - # Nested message types need to terminate eventually - and self not in visited_fields): - sub = next(iter(self.type.fields.values())) - stack.append(sub) - visited_fields.add(self) - # Don't do the recursive rendering here, just set up - # where the nested value should go with the double {}. - answer = f'{self.type.ident}({sub.name}={{}})' - - if self.map: - # Maps are a special case beacuse they're represented internally as - # a list of a generated type with two fields: 'key' and 'value'. - answer = '{{{}: {}}}'.format( - self.type.fields['key'].mock_value, - self.type.fields['value'].mock_value, - ) - elif self.repeated: - # If this is a repeated field, then the mock answer should - # be a list. - answer = f'[{answer}]' - - # Done; return the mock value. - return answer - - @property - def proto_type(self) -> str: - """Return the proto type constant to be used in templates.""" - return cast( - str, descriptor_pb2.FieldDescriptorProto.Type.Name( - self.field_pb.type,))[len('TYPE_'):] - - @property - def repeated(self) -> bool: - """Return True if this is a repeated field, False otherwise. - - Returns: - bool: Whether this field is repeated. - """ - return self.label == \ - descriptor_pb2.FieldDescriptorProto.Label.Value( - 'LABEL_REPEATED') # type: ignore - - @property - def required(self) -> bool: - """Return True if this is a required field, False otherwise. - - Returns: - bool: Whether this field is required. - """ - return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') - in self.options.Extensions[field_behavior_pb2.field_behavior]) - - @utils.cached_property - def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: - """Return the type of this field.""" - # If this is a message or enum, return the appropriate thing. - if self.type_name and self.message: - return self.message - if self.type_name and self.enum: - return self.enum - - # This is a primitive. Return the corresponding Python type. - # The enum values used here are defined in: - # Repository: https://github.com/google/protobuf/ - # Path: src/google/protobuf/descriptor.proto - # - # The values are used here because the code would be excessively - # verbose otherwise, and this is guaranteed never to change. - # - # 10, 11, and 14 are intentionally missing. They correspond to - # group (unused), message (covered above), and enum (covered above). - if self.field_pb.type in (1, 2): - return PrimitiveType.build(float) - if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): - return PrimitiveType.build(int) - if self.field_pb.type == 8: - return PrimitiveType.build(bool) - if self.field_pb.type == 9: - return PrimitiveType.build(str) - if self.field_pb.type == 12: - return PrimitiveType.build(bytes) - - # This should never happen. - raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. ' - 'This code should not be reachable; please file a bug.') - - def with_context( - self, - *, - collisions: FrozenSet[str], - visited_messages: FrozenSet['MessageType'], - ) -> 'Field': - """Return a derivative of this field with the provided context. - - This method is used to address naming collisions. The returned - ``Field`` object aliases module names to avoid naming collisions - in the file being written. - """ - return dataclasses.replace( + """Description of a field.""" + field_pb: descriptor_pb2.FieldDescriptorProto + message: Optional['MessageType'] = None + enum: Optional['EnumType'] = None + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + oneof: Optional[str] = None + + def __getattr__(self, name): + return getattr(self.field_pb, name) + + def __hash__(self): + # The only sense in which it is meaningful to say a field is equal to + # another field is if they are the same, i.e. they live in the same + # message type under the same moniker, i.e. they have the same id. + return id(self) + + @property + def name(self) -> str: + """Used to prevent collisions with python keywords""" + name = self.field_pb.name + return name + '_' if name in utils.RESERVED_NAMES else name + + @utils.cached_property + def ident(self) -> metadata.FieldIdentifier: + """Return the identifier to be used in templates.""" + return metadata.FieldIdentifier( + ident=self.type.ident, + repeated=self.repeated, + ) + + @property + def is_primitive(self) -> bool: + """Return True if the field is a primitive, False otherwise.""" + return isinstance(self.type, PrimitiveType) + + @property + def map(self) -> bool: + """Return True if this field is a map, False otherwise.""" + return bool(self.repeated and self.message and self.message.map) + + @utils.cached_property + def mock_value(self) -> str: + visited_fields: Set['Field'] = set() + stack = [self] + answer = '{}' + while stack: + expr = stack.pop() + answer = answer.format(expr.inner_mock(stack, visited_fields)) + + return answer + + def inner_mock(self, stack, visited_fields): + """Return a repr of a valid, usually truthy mock value.""" + # For primitives, send a truthy value computed from the + # field name. + answer = 'None' + if isinstance(self.type, PrimitiveType): + if self.type.python_type == bool: + answer = 'True' + elif self.type.python_type == str: + answer = f"'{self.name}_value'" + elif self.type.python_type == bytes: + answer = f"b'{self.name}_blob'" + elif self.type.python_type == int: + answer = f'{sum([ord(i) for i in self.name])}' + elif self.type.python_type == float: + answer = f'0.{sum([ord(i) for i in self.name])}' + else: # Impossible; skip coverage checks. + raise TypeError('Unrecognized PrimitiveType. This should ' + 'never happen; please file an issue.') + + # If this is an enum, select the first truthy value (or the zero + # value if nothing else exists). + if isinstance(self.type, EnumType): + # Note: The slightly-goofy [:2][-1] lets us gracefully fall + # back to index 0 if there is only one element. + mock_value = self.type.values[:2][-1] + answer = f'{self.type.ident}.{mock_value.name}' + + # If this is another message, set one value on the message. + if (not self.map # Maps are handled separately + and isinstance(self.type, MessageType) and len(self.type.fields) + # Nested message types need to terminate eventually + and self not in visited_fields): + sub = next(iter(self.type.fields.values())) + stack.append(sub) + visited_fields.add(self) + # Don't do the recursive rendering here, just set up + # where the nested value should go with the double {}. + answer = f'{self.type.ident}({sub.name}={{}})' + + if self.map: + # Maps are a special case beacuse they're represented internally as + # a list of a generated type with two fields: 'key' and 'value'. + answer = '{{{}: {}}}'.format( + self.type.fields['key'].mock_value, + self.type.fields['value'].mock_value, + ) + elif self.repeated: + # If this is a repeated field, then the mock answer should + # be a list. + answer = f'[{answer}]' + + # Done; return the mock value. + return answer + + @property + def proto_type(self) -> str: + """Return the proto type constant to be used in templates.""" + return cast( + str, descriptor_pb2.FieldDescriptorProto.Type.Name( + self.field_pb.type,))[len('TYPE_'):] + + @property + def repeated(self) -> bool: + """Return True if this is a repeated field, False otherwise. + + Returns: + bool: Whether this field is repeated. + """ + return self.label == \ + descriptor_pb2.FieldDescriptorProto.Label.Value( + 'LABEL_REPEATED') # type: ignore + + @property + def required(self) -> bool: + """Return True if this is a required field, False otherwise. + + Returns: + bool: Whether this field is required. + """ + return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') + in self.options.Extensions[field_behavior_pb2.field_behavior]) + + @utils.cached_property + def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: + """Return the type of this field.""" + # If this is a message or enum, return the appropriate thing. + if self.type_name and self.message: + return self.message + if self.type_name and self.enum: + return self.enum + + # This is a primitive. Return the corresponding Python type. + # The enum values used here are defined in: + # Repository: https://github.com/google/protobuf/ + # Path: src/google/protobuf/descriptor.proto + # + # The values are used here because the code would be excessively + # verbose otherwise, and this is guaranteed never to change. + # + # 10, 11, and 14 are intentionally missing. They correspond to + # group (unused), message (covered above), and enum (covered above). + if self.field_pb.type in (1, 2): + return PrimitiveType.build(float) + if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): + return PrimitiveType.build(int) + if self.field_pb.type == 8: + return PrimitiveType.build(bool) + if self.field_pb.type == 9: + return PrimitiveType.build(str) + if self.field_pb.type == 12: + return PrimitiveType.build(bytes) + + # This should never happen. + raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. ' + 'This code should not be reachable; please file a bug.') + + def with_context( self, - message=self.message.with_context( - collisions=collisions, - skip_fields=self.message in visited_messages, - visited_messages=visited_messages, - ) if self.message else None, - enum=self.enum.with_context( - collisions=collisions) if self.enum else None, - meta=self.meta.with_context(collisions=collisions), - ) + *, + collisions: FrozenSet[str], + visited_messages: FrozenSet['MessageType'], + ) -> 'Field': + """Return a derivative of this field with the provided context. + + This method is used to address naming collisions. The returned + ``Field`` object aliases module names to avoid naming collisions + in the file being written. + """ + return dataclasses.replace( + self, + message=self.message.with_context( + collisions=collisions, + skip_fields=self.message in visited_messages, + visited_messages=visited_messages, + ) if self.message else None, + enum=self.enum.with_context( + collisions=collisions) if self.enum else None, + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class Oneof: - """Description of a field.""" - oneof_pb: descriptor_pb2.OneofDescriptorProto + """Description of a field.""" + oneof_pb: descriptor_pb2.OneofDescriptorProto - def __getattr__(self, name): - return getattr(self.oneof_pb, name) + def __getattr__(self, name): + return getattr(self.oneof_pb, name) @dataclasses.dataclass(frozen=True) class MessageType: - """Description of a message (defined with the ``message`` keyword).""" - # Class attributes - PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}') - - # Instance attributes - message_pb: descriptor_pb2.DescriptorProto - fields: Mapping[str, Field] - nested_enums: Mapping[str, 'EnumType'] - nested_messages: Mapping[str, 'MessageType'] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) - oneofs: Optional[Mapping[str, 'Oneof']] = None - - def __getattr__(self, name): - return getattr(self.message_pb, name) - - def __hash__(self): - # Identity is sufficiently unambiguous. - return hash(self.ident) - - def oneof_fields(self, include_optional=False): - oneof_fields = collections.defaultdict(list) - for field in self.fields.values(): - # Only include proto3 optional oneofs if explicitly looked for. - if field.oneof and not field.proto3_optional or include_optional: - oneof_fields[field.oneof].append(field) - - return oneof_fields - - @utils.cached_property - def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - answer = tuple(field.type - for field in self.fields.values() - if field.message or field.enum) - - return answer - - @utils.cached_property - def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - """Return all composite fields used in this proto's messages.""" - types: Set[Union['MessageType', 'EnumType']] = set() - - stack = [iter(self.fields.values())] - while stack: - fields_iter = stack.pop() - for field in fields_iter: - if field.message and field.type not in types: - stack.append(iter(field.message.fields.values())) - if not field.is_primitive: - types.add(field.type) - - return tuple(types) - - @utils.cached_property - def recursive_resource_fields(self) -> FrozenSet[Field]: - all_fields = chain( - self.fields.values(), - (field for t in self.recursive_field_types - if isinstance(t, MessageType) for field in t.fields.values()), - ) - return frozenset( - f for f in all_fields - if (f.options.Extensions[resource_pb2.resource_reference].type or - f.options.Extensions[resource_pb2.resource_reference].child_type)) - - @property - def map(self) -> bool: - """Return True if the given message is a map, False otherwise.""" - return self.message_pb.options.map_entry - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - @property - def resource_path(self) -> Optional[str]: - """If this message describes a resource, return the path to the resource. - - If there are multiple paths, returns the first one. - """ - return next( - iter(self.options.Extensions[resource_pb2.resource].pattern), None) - - @property - def resource_type(self) -> Optional[str]: - resource = self.options.Extensions[resource_pb2.resource] - return resource.type[resource.type.find('/') + 1:] if resource else None - - @property - def resource_path_args(self) -> Sequence[str]: - return self.PATH_ARG_RE.findall(self.resource_path or '') - - @utils.cached_property - def path_regex_str(self) -> str: - # The indirection here is a little confusing: - # we're using the resource path template as the base of a regex, - # with each resource ID segment being captured by a regex. - # E.g., the path schema - # kingdoms/{kingdom}/phyla/{phylum} - # becomes the regex - # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ - parsing_regex_str = ( - '^' + self.PATH_ARG_RE.sub( - # We can't just use (?P[^/]+) because segments may be - # separated by delimiters other than '/'. - # Multiple delimiter characters within one schema are allowed, - # e.g. - # as/{a}-{b}/cs/{c}%{d}_{e} - # This is discouraged but permitted by AIP4231 - lambda m: '(?P<{name}>.+?)'.format(name=m.groups()[0]), - self.resource_path or '') + '$') - return parsing_regex_str - - def get_field( - self, *field_path: str, - collisions: FrozenSet[str] = frozenset()) -> Field: - """Return a field arbitrarily deep in this message's structure. - - This method recursively traverses the message tree to return the - requested inner-field. - - Traversing through repeated fields is not supported; a repeated field - may be specified if and only if it is the last field in the path. - - Args: - field_path (Sequence[str]): The field path. - - Returns: - ~.Field: A field object. - - Raises: - KeyError: If a repeated field is used in the non-terminal position - in the path. - """ - # If collisions are not explicitly specified, retrieve them - # from this message's address. - # This ensures that calls to `get_field` will return a field with - # the same context, regardless of the number of levels through the - # chain (in order to avoid infinite recursion on circular references, - # we only shallowly bind message references held by fields; this - # binds deeply in the one spot where that might be a problem). - collisions = collisions or self.meta.address.collisions - - # Get the first field in the path. - first_field = field_path[0] - cursor = self.fields[first_field + - ('_' if first_field in utils.RESERVED_NAMES else '')] - - # Base case: If this is the last field in the path, return it outright. - if len(field_path) == 1: - return cursor.with_context( - collisions=collisions, - visited_messages=frozenset({self}), - ) - - # Sanity check: If cursor is a repeated field, then raise an exception. - # Repeated fields are only permitted in the terminal position. - if cursor.repeated: - raise KeyError( - f'The {cursor.name} field is repeated; unable to use ' - '`get_field` to retrieve its children.\n' - 'This exception usually indicates that a ' - 'google.api.method_signature annotation uses a repeated field ' - 'in the fields list in a position other than the end.',) - - # Sanity check: If this cursor has no message, there is a problem. - if not cursor.message: - raise KeyError( - f'Field {".".join(field_path)} could not be resolved from ' - f'{cursor.name}.',) - - # Recursion case: Pass the remainder of the path to the sub-field's - # message. - return cursor.message.get_field(*field_path[1:], collisions=collisions) - - def with_context( - self, - *, - collisions: FrozenSet[str], - skip_fields: bool = False, - visited_messages: FrozenSet['MessageType'] = frozenset(), - ) -> 'MessageType': - """Return a derivative of this message with the provided context. - - This method is used to address naming collisions. The returned - ``MessageType`` object aliases module names to avoid naming collisions - in the file being written. - - The ``skip_fields`` argument will omit applying the context to the - underlying fields. This provides for an "exit" in the case of circular - references. + """Description of a message (defined with the ``message`` keyword).""" + # Class attributes + PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}') + + # Instance attributes + message_pb: descriptor_pb2.DescriptorProto + fields: Mapping[str, Field] + nested_enums: Mapping[str, 'EnumType'] + nested_messages: Mapping[str, 'MessageType'] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + oneofs: Optional[Mapping[str, 'Oneof']] = None + + def __getattr__(self, name): + return getattr(self.message_pb, name) + + def __hash__(self): + # Identity is sufficiently unambiguous. + return hash(self.ident) + + def oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + + @utils.cached_property + def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + answer = tuple(field.type + for field in self.fields.values() + if field.message or field.enum) + + return answer + + @utils.cached_property + def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + """Return all composite fields used in this proto's messages.""" + types: Set[Union['MessageType', 'EnumType']] = set() + + stack = [iter(self.fields.values())] + while stack: + fields_iter = stack.pop() + for field in fields_iter: + if field.message and field.type not in types: + stack.append(iter(field.message.fields.values())) + if not field.is_primitive: + types.add(field.type) + + return tuple(types) + + @utils.cached_property + def recursive_resource_fields(self) -> FrozenSet[Field]: + all_fields = chain( + self.fields.values(), + (field for t in self.recursive_field_types + if isinstance(t, MessageType) for field in t.fields.values()), + ) + return frozenset( + f for f in all_fields + if (f.options.Extensions[resource_pb2.resource_reference].type or + f.options.Extensions[resource_pb2.resource_reference].child_type)) + + @property + def map(self) -> bool: + """Return True if the given message is a map, False otherwise.""" + return self.message_pb.options.map_entry + + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address + + @property + def resource_path(self) -> Optional[str]: + """If this message describes a resource, return the path to the resource. + + If there are multiple paths, returns the first one. """ - visited_messages = visited_messages | {self} - return dataclasses.replace( - self, - fields={ - k: v.with_context( - collisions=collisions, visited_messages=visited_messages) - for k, v in self.fields.items() - } if not skip_fields else self.fields, - nested_enums={ - k: v.with_context(collisions=collisions) - for k, v in self.nested_enums.items() - }, - nested_messages={ - k: v.with_context( + return next( + iter(self.options.Extensions[resource_pb2.resource].pattern), None) + + @property + def resource_type(self) -> Optional[str]: + resource = self.options.Extensions[resource_pb2.resource] + return resource.type[resource.type.find('/') + 1:] if resource else None + + @property + def resource_path_args(self) -> Sequence[str]: + return self.PATH_ARG_RE.findall(self.resource_path or '') + + @utils.cached_property + def path_regex_str(self) -> str: + # The indirection here is a little confusing: + # we're using the resource path template as the base of a regex, + # with each resource ID segment being captured by a regex. + # E.g., the path schema + # kingdoms/{kingdom}/phyla/{phylum} + # becomes the regex + # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ + parsing_regex_str = ( + '^' + self.PATH_ARG_RE.sub( + # We can't just use (?P[^/]+) because segments may be + # separated by delimiters other than '/'. + # Multiple delimiter characters within one schema are allowed, + # e.g. + # as/{a}-{b}/cs/{c}%{d}_{e} + # This is discouraged but permitted by AIP4231 + lambda m: '(?P<{name}>.+?)'.format(name=m.groups()[0]), + self.resource_path or '') + '$') + return parsing_regex_str + + def get_field( + self, *field_path: str, + collisions: FrozenSet[str] = frozenset()) -> Field: + """Return a field arbitrarily deep in this message's structure. + + This method recursively traverses the message tree to return the + requested inner-field. + + Traversing through repeated fields is not supported; a repeated field + may be specified if and only if it is the last field in the path. + + Args: + field_path (Sequence[str]): The field path. + + Returns: + ~.Field: A field object. + + Raises: + KeyError: If a repeated field is used in the non-terminal position + in the path. + """ + # If collisions are not explicitly specified, retrieve them + # from this message's address. + # This ensures that calls to `get_field` will return a field with + # the same context, regardless of the number of levels through the + # chain (in order to avoid infinite recursion on circular references, + # we only shallowly bind message references held by fields; this + # binds deeply in the one spot where that might be a problem). + collisions = collisions or self.meta.address.collisions + + # Get the first field in the path. + first_field = field_path[0] + cursor = self.fields[first_field + + ('_' if first_field in utils.RESERVED_NAMES else '')] + + # Base case: If this is the last field in the path, return it outright. + if len(field_path) == 1: + return cursor.with_context( collisions=collisions, - skip_fields=skip_fields, - visited_messages=visited_messages, - ) for k, v in self.nested_messages.items() - }, - meta=self.meta.with_context(collisions=collisions), - ) + visited_messages=frozenset({self}), + ) + + # Sanity check: If cursor is a repeated field, then raise an exception. + # Repeated fields are only permitted in the terminal position. + if cursor.repeated: + raise KeyError( + f'The {cursor.name} field is repeated; unable to use ' + '`get_field` to retrieve its children.\n' + 'This exception usually indicates that a ' + 'google.api.method_signature annotation uses a repeated field ' + 'in the fields list in a position other than the end.',) + + # Sanity check: If this cursor has no message, there is a problem. + if not cursor.message: + raise KeyError( + f'Field {".".join(field_path)} could not be resolved from ' + f'{cursor.name}.',) + + # Recursion case: Pass the remainder of the path to the sub-field's + # message. + return cursor.message.get_field(*field_path[1:], collisions=collisions) + + def with_context( + self, + *, + collisions: FrozenSet[str], + skip_fields: bool = False, + visited_messages: FrozenSet['MessageType'] = frozenset(), + ) -> 'MessageType': + """Return a derivative of this message with the provided context. + + This method is used to address naming collisions. The returned + ``MessageType`` object aliases module names to avoid naming collisions + in the file being written. + + The ``skip_fields`` argument will omit applying the context to the + underlying fields. This provides for an "exit" in the case of circular + references. + """ + visited_messages = visited_messages | {self} + return dataclasses.replace( + self, + fields={ + k: v.with_context( + collisions=collisions, visited_messages=visited_messages) + for k, v in self.fields.items() + } if not skip_fields else self.fields, + nested_enums={ + k: v.with_context(collisions=collisions) + for k, v in self.nested_enums.items() + }, + nested_messages={ + k: v.with_context( + collisions=collisions, + skip_fields=skip_fields, + visited_messages=visited_messages, + ) for k, v in self.nested_messages.items() + }, + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class EnumValueType: - """Description of an enum value.""" - enum_value_pb: descriptor_pb2.EnumValueDescriptorProto - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + """Description of an enum value.""" + enum_value_pb: descriptor_pb2.EnumValueDescriptorProto + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) - def __getattr__(self, name): - return getattr(self.enum_value_pb, name) + def __getattr__(self, name): + return getattr(self.enum_value_pb, name) @dataclasses.dataclass(frozen=True) class EnumType: - """Description of an enum (defined with the ``enum`` keyword.)""" - enum_pb: descriptor_pb2.EnumDescriptorProto - values: List[EnumValueType] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) - - def __hash__(self): - # Identity is sufficiently unambiguous. - return hash(self.ident) - - def __getattr__(self, name): - return getattr(self.enum_pb, name) - - @property - def resource_path(self) -> Optional[str]: - # This is a minor duck-typing workaround for the resource_messages - # property in the Service class: we need to check fields recursively - # to see if they're resources, and recursive_field_types includes enums - return None - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': - """Return a derivative of this enum with the provided context. - - This method is used to address naming collisions. The returned - ``EnumType`` object aliases module names to avoid naming collisions in - the file being written. - """ - return dataclasses.replace( - self, - meta=self.meta.with_context(collisions=collisions), - ) if collisions else self + """Description of an enum (defined with the ``enum`` keyword.)""" + enum_pb: descriptor_pb2.EnumDescriptorProto + values: List[EnumValueType] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + def __hash__(self): + # Identity is sufficiently unambiguous. + return hash(self.ident) + + def __getattr__(self, name): + return getattr(self.enum_pb, name) + + @property + def resource_path(self) -> Optional[str]: + # This is a minor duck-typing workaround for the resource_messages + # property in the Service class: we need to check fields recursively + # to see if they're resources, and recursive_field_types includes enums + return None - @property - def options_dict(self) -> Dict: - """Return the EnumOptions (if present) as a dict. + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address - This is a hack to support a pythonic structure representation for - the generator templates. - """ - return MessageToDict(self.enum_pb.options, preserving_proto_field_name=True) + def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': + """Return a derivative of this enum with the provided context. + + This method is used to address naming collisions. The returned + ``EnumType`` object aliases module names to avoid naming collisions in + the file being written. + """ + return dataclasses.replace( + self, + meta=self.meta.with_context(collisions=collisions), + ) if collisions else self + + @property + def options_dict(self) -> Dict: + """Return the EnumOptions (if present) as a dict. + + This is a hack to support a pythonic structure representation for + the generator templates. + """ + return MessageToDict(self.enum_pb.options, preserving_proto_field_name=True) @dataclasses.dataclass(frozen=True) class PythonType: - """Wrapper class for Python types. + """Wrapper class for Python types. - This exists for interface consistency, so that methods like - :meth:`Field.type` can return an object and the caller can be confident - that a ``name`` property will be present. - """ - meta: metadata.Metadata + This exists for interface consistency, so that methods like + :meth:`Field.type` can return an object and the caller can be confident + that a ``name`` property will be present. + """ + meta: metadata.Metadata - def __eq__(self, other): - return self.meta == other.meta + def __eq__(self, other): + return self.meta == other.meta - def __ne__(self, other): - return not self == other + def __ne__(self, other): + return not self == other - @utils.cached_property - def ident(self) -> metadata.Address: - """Return the identifier to be used in templates.""" - return self.meta.address + @utils.cached_property + def ident(self) -> metadata.Address: + """Return the identifier to be used in templates.""" + return self.meta.address - @property - def name(self) -> str: - return self.ident.name + @property + def name(self) -> str: + return self.ident.name - @property - def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - return tuple() + @property + def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + return tuple() @dataclasses.dataclass(frozen=True) class PrimitiveType(PythonType): - """A representation of a Python primitive type.""" - python_type: Optional[type] - - @classmethod - def build(cls, primitive_type: Optional[type]): - """Return a PrimitiveType object for the given Python primitive type. - - Args: - primitive_type (cls): A Python primitive type, such as :class:`int` - or :class:`str`. Despite not being a type, ``None`` is also - accepted here. - - Returns: - ~.PrimitiveType: The instantiated PrimitiveType object. - """ - # Primitives have no import, and no module to reference, so the - # address just uses the name of the class (e.g. "int", "str"). - return cls( - meta=metadata.Metadata( - address=metadata.Address( - name='None' if primitive_type is None else primitive_type - .__name__,)), - python_type=primitive_type) - - def __eq__(self, other): - # If we are sent the actual Python type (not the PrimitiveType object), - # claim to be equal to that. - if not hasattr(other, 'meta'): - return self.python_type is other - return super().__eq__(other) + """A representation of a Python primitive type.""" + python_type: Optional[type] + + @classmethod + def build(cls, primitive_type: Optional[type]): + """Return a PrimitiveType object for the given Python primitive type. + + Args: + primitive_type (cls): A Python primitive type, such as :class:`int` + or :class:`str`. Despite not being a type, ``None`` is also + accepted here. + + Returns: + ~.PrimitiveType: The instantiated PrimitiveType object. + """ + # Primitives have no import, and no module to reference, so the + # address just uses the name of the class (e.g. "int", "str"). + return cls( + meta=metadata.Metadata( + address=metadata.Address( + name='None' if primitive_type is None else primitive_type + .__name__,)), + python_type=primitive_type) + + def __eq__(self, other): + # If we are sent the actual Python type (not the PrimitiveType object), + # claim to be equal to that. + if not hasattr(other, 'meta'): + return self.python_type is other + return super().__eq__(other) @dataclasses.dataclass(frozen=True) class OperationInfo: - """Representation of long-running operation info.""" - response_type: MessageType - metadata_type: MessageType - - def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': - """Return a derivative of this OperationInfo with the provided context. - - This method is used to address naming collisions. The returned - ``OperationInfo`` object aliases module names to avoid naming - collisions - in the file being written. - """ - return dataclasses.replace( - self, - response_type=self.response_type.with_context(collisions=collisions), - metadata_type=self.metadata_type.with_context(collisions=collisions), - ) + """Representation of long-running operation info.""" + response_type: MessageType + metadata_type: MessageType + + def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': + """Return a derivative of this OperationInfo with the provided context. + + This method is used to address naming collisions. The returned + ``OperationInfo`` object aliases module names to avoid naming + collisions + in the file being written. + """ + return dataclasses.replace( + self, + response_type=self.response_type.with_context( + collisions=collisions), + metadata_type=self.metadata_type.with_context( + collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class RetryInfo: - """Representation of the method's retry behavior.""" - max_attempts: int - initial_backoff: float - max_backoff: float - backoff_multiplier: float - retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError] + """Representation of the method's retry behavior.""" + max_attempts: int + initial_backoff: float + max_backoff: float + backoff_multiplier: float + retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError] @dataclasses.dataclass(frozen=True) class Method: - """Description of a method (defined with the ``rpc`` keyword).""" - method_pb: descriptor_pb2.MethodDescriptorProto - input: MessageType - output: MessageType - lro: Optional[OperationInfo] = dataclasses.field(default=None) - retry: Optional[RetryInfo] = dataclasses.field(default=None) - timeout: Optional[float] = None - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) - - def __getattr__(self, name): - return getattr(self.method_pb, name) - - @utils.cached_property - def client_output(self): - return self._client_output(enable_asyncio=False) - - @utils.cached_property - def client_output_async(self): - return self._client_output(enable_asyncio=True) - - def flattened_oneof_fields(self, include_optional=False): - oneof_fields = collections.defaultdict(list) - for field in self.flattened_fields.values(): - # Only include proto3 optional oneofs if explicitly looked for. - if field.oneof and not field.proto3_optional or include_optional: - oneof_fields[field.oneof].append(field) - - return oneof_fields - - def _client_output(self, enable_asyncio: bool): - """Return the output from the client layer. - - This takes into account transformations made by the outer GAPIC - client to transform the output from the transport. - - Returns: - Union[~.MessageType, ~.PythonType]: - A description of the return type. - """ - # Void messages ultimately return None. - if self.void: - return PrimitiveType.build(None) - - # If this method is an LRO, return a PythonType instance representing - # that. - if self.lro: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name='AsyncOperation' if enable_asyncio else 'Operation', - module='operation_async' if enable_asyncio else 'operation', - package=('google', 'api_core'), - collisions=self.lro.response_type.ident.collisions, - ), - documentation=utils.doc( - 'An object representing a long-running operation. \n\n' - 'The result type for the operation will be ' - ':class:`{ident}` {doc}'.format( - doc=self.lro.response_type.meta.doc, - ident=self.lro.response_type.ident.sphinx, - ),), - )) - - # If this method is paginated, return that method's pager class. - if self.paged_result_field: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name=f'{self.name}AsyncPager' - if enable_asyncio else f'{self.name}Pager', - package=self.ident.api_naming.module_namespace + - (self.ident.api_naming.versioned_module_name,) + - self.ident.subpackage + ( - 'services', - utils.to_snake_case(self.ident.parent[-1]), - ), - module='pagers', - collisions=self.input.ident.collisions, - ), - documentation=utils.doc( - f'{self.output.meta.doc}\n\n' - 'Iterating over this object will yield results and ' - 'resolve additional pages automatically.',), - )) - - # Return the usual output. - return self.output - - @property - def is_deprecated(self) -> bool: - """Returns true if the method is deprecated, false otherwise.""" - return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') - - # TODO(yon-mg): remove or rewrite: don't think it performs as intended - # e.g. doesn't work with basic case of gRPC transcoding - @property - def field_headers(self) -> Sequence[str]: - """Return the field headers defined for this method.""" - http = self.options.Extensions[annotations_pb2.http] - - pattern = re.compile(r'\{([a-z][\w\d_.]+)=') - - potential_verbs = [ - http.get, - http.put, - http.post, - http.delete, - http.patch, - http.custom.path, - ] - - return next( - (tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) - - @property - def http_opt(self) -> Optional[Dict[str, str]]: - """Return the http option for this method. - - e.g. {'verb': 'post' - 'url': '/some/path' - 'body': '*'} - - """ - http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] - http = self.options.Extensions[annotations_pb2.http].ListFields() - - if len(http) < 1: - return None - - http_method = http[0] - answer: Dict[str, str] = { - 'verb': http_method[0].name, - 'url': http_method[1], - } - if len(http) > 1: - body_spec = http[1] - answer[body_spec[0].name] = body_spec[1] - - # TODO(yon-mg): handle nested fields & fields past body i.e. 'additional bindings' - # TODO(yon-mg): enums for http verbs? - return answer - - @property - def path_params(self) -> Sequence[str]: - """Return the path parameters found in the http annotation path template""" - # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) - if self.http_opt is None: - return [] - - pattern = r'\{(\w+)\}' - return re.findall(pattern, self.http_opt['url']) - - @property - def query_params(self) -> Set[str]: - """Return query parameters for API call as determined by http annotation and grpc transcoding""" - # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) - # TODO(yon-mg): remove this method and move logic to generated client - if self.http_opt is None: - return set() - - params = set(self.path_params) - body = self.http_opt.get('body') - if body: - params.add(body) - - return set(self.input.fields) - params - - # TODO(yon-mg): refactor as there may be more than one method signature - @utils.cached_property - def flattened_fields(self) -> Mapping[str, Field]: - """Return the signature defined for this method.""" - cross_pkg_request = self.input.ident.package != self.ident.package - - def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: - for f in sig.split(','): - if not f: - # Special case for an empty signature - continue - name = f.strip() - field = self.input.get_field(*name.split('.')) - name += '_' if field.field_pb.name in utils.RESERVED_NAMES else '' - if cross_pkg_request and not field.is_primitive: - # This is not a proto-plus wrapped message type, - # and setting a non-primitive field directly is verboten. - continue - - yield name, field - - signatures = self.options.Extensions[client_pb2.method_signature] - answer: Dict[str, Field] = collections.OrderedDict( - name_and_field for sig in signatures - for name_and_field in filter_fields(sig)) - - return answer - - @utils.cached_property - def flattened_field_to_key(self): - return {field.name: key for key, field in self.flattened_fields.items()} - - @utils.cached_property - def legacy_flattened_fields(self) -> Mapping[str, Field]: - """Return the legacy flattening interface: top level fields only, - - required fields first - """ - required, optional = utils.partition(lambda f: f.required, - self.input.fields.values()) - return collections.OrderedDict( - (f.name, f) for f in chain(required, optional)) - - @property - def grpc_stub_type(self) -> str: - """Return the type of gRPC stub to use.""" - return '{client}_{server}'.format( - client='stream' if self.client_streaming else 'unary', - server='stream' if self.server_streaming else 'unary', - ) - - # TODO(yon-mg): figure out why idempotent is reliant on http annotation - @utils.cached_property - def idempotent(self) -> bool: - """Return True if we know this method is idempotent, False otherwise. - - Note: We are intentionally conservative here. It is far less bad - to falsely believe an idempotent method is non-idempotent than - the converse. + """Description of a method (defined with the ``rpc`` keyword).""" + method_pb: descriptor_pb2.MethodDescriptorProto + input: MessageType + output: MessageType + lro: Optional[OperationInfo] = dataclasses.field(default=None) + retry: Optional[RetryInfo] = dataclasses.field(default=None) + timeout: Optional[float] = None + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + def __getattr__(self, name): + return getattr(self.method_pb, name) + + @utils.cached_property + def client_output(self): + return self._client_output(enable_asyncio=False) + + @utils.cached_property + def client_output_async(self): + return self._client_output(enable_asyncio=True) + + def flattened_oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.flattened_fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + + def _client_output(self, enable_asyncio: bool): + """Return the output from the client layer. + + This takes into account transformations made by the outer GAPIC + client to transform the output from the transport. + + Returns: + Union[~.MessageType, ~.PythonType]: + A description of the return type. + """ + # Void messages ultimately return None. + if self.void: + return PrimitiveType.build(None) + + # If this method is an LRO, return a PythonType instance representing + # that. + if self.lro: + return PythonType( + meta=metadata.Metadata( + address=metadata.Address( + name='AsyncOperation' if enable_asyncio else 'Operation', + module='operation_async' if enable_asyncio else 'operation', + package=('google', 'api_core'), + collisions=self.lro.response_type.ident.collisions, + ), + documentation=utils.doc( + 'An object representing a long-running operation. \n\n' + 'The result type for the operation will be ' + ':class:`{ident}` {doc}'.format( + doc=self.lro.response_type.meta.doc, + ident=self.lro.response_type.ident.sphinx, + ),), + )) + + # If this method is paginated, return that method's pager class. + if self.paged_result_field: + return PythonType( + meta=metadata.Metadata( + address=metadata.Address( + name=f'{self.name}AsyncPager' + if enable_asyncio else f'{self.name}Pager', + package=self.ident.api_naming.module_namespace + + (self.ident.api_naming.versioned_module_name,) + + self.ident.subpackage + ( + 'services', + utils.to_snake_case(self.ident.parent[-1]), + ), + module='pagers', + collisions=self.input.ident.collisions, + ), + documentation=utils.doc( + f'{self.output.meta.doc}\n\n' + 'Iterating over this object will yield results and ' + 'resolve additional pages automatically.',), + )) + + # Return the usual output. + return self.output + + @property + def is_deprecated(self) -> bool: + """Returns true if the method is deprecated, false otherwise.""" + return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') + + # TODO(yon-mg): remove or rewrite: don't think it performs as intended + # e.g. doesn't work with basic case of gRPC transcoding + @property + def field_headers(self) -> Sequence[str]: + """Return the field headers defined for this method.""" + http = self.options.Extensions[annotations_pb2.http] + + pattern = re.compile(r'\{([a-z][\w\d_.]+)=') + + potential_verbs = [ + http.get, + http.put, + http.post, + http.delete, + http.patch, + http.custom.path, + ] + + return next( + (tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) + + @property + def http_opt(self) -> Optional[Dict[str, str]]: + """Return the http option for this method. + + e.g. {'verb': 'post' + 'url': '/some/path' + 'body': '*'} + + """ + http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] + http = self.options.Extensions[annotations_pb2.http].ListFields() + + if len(http) < 1: + return None + + http_method = http[0] + answer: Dict[str, str] = { + 'verb': http_method[0].name, + 'url': http_method[1], + } + if len(http) > 1: + body_spec = http[1] + answer[body_spec[0].name] = body_spec[1] + + # TODO(yon-mg): handle nested fields & fields past body i.e. 'additional bindings' + # TODO(yon-mg): enums for http verbs? + return answer + + @property + def path_params(self) -> Sequence[str]: + """Return the path parameters found in the http annotation path template""" + # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) + if self.http_opt is None: + return [] + + pattern = r'\{(\w+)\}' + return re.findall(pattern, self.http_opt['url']) + + @property + def query_params(self) -> Set[str]: + """Return query parameters for API call as determined by http annotation and grpc transcoding""" + # TODO(yon-mg): fully implement grpc transcoding (currently only handles basic case) + # TODO(yon-mg): remove this method and move logic to generated client + if self.http_opt is None: + return set() + + params = set(self.path_params) + body = self.http_opt.get('body') + if body: + params.add(body) + + return set(self.input.fields) - params + + # TODO(yon-mg): refactor as there may be more than one method signature + @utils.cached_property + def flattened_fields(self) -> Mapping[str, Field]: + """Return the signature defined for this method.""" + cross_pkg_request = self.input.ident.package != self.ident.package + + def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: + for f in sig.split(','): + if not f: + # Special case for an empty signature + continue + name = f.strip() + field = self.input.get_field(*name.split('.')) + name += '_' if field.field_pb.name in utils.RESERVED_NAMES else '' + if cross_pkg_request and not field.is_primitive: + # This is not a proto-plus wrapped message type, + # and setting a non-primitive field directly is verboten. + continue + + yield name, field + + signatures = self.options.Extensions[client_pb2.method_signature] + answer: Dict[str, Field] = collections.OrderedDict( + name_and_field for sig in signatures + for name_and_field in filter_fields(sig)) + + return answer + + @utils.cached_property + def flattened_field_to_key(self): + return {field.name: key for key, field in self.flattened_fields.items()} + + @utils.cached_property + def legacy_flattened_fields(self) -> Mapping[str, Field]: + """Return the legacy flattening interface: top level fields only, + + required fields first """ - return bool(self.options.Extensions[annotations_pb2.http].get) - - @property - def ident(self) -> metadata.Address: - """Return the identifier data to be used in templates.""" - return self.meta.address - - @utils.cached_property - def paged_result_field(self) -> Optional[Field]: - """Return the response pagination field if the method is paginated.""" - # If the request field lacks any of the expected pagination fields, - # then the method is not paginated. - - # The request must have page_token and next_page_token as they keep track of pages - for source, source_type, name in ((self.input, str, 'page_token'), - (self.output, str, 'next_page_token')): - field = source.fields.get(name, None) - if not field or field.type != source_type: + required, optional = utils.partition(lambda f: f.required, + self.input.fields.values()) + return collections.OrderedDict( + (f.name, f) for f in chain(required, optional)) + + @property + def grpc_stub_type(self) -> str: + """Return the type of gRPC stub to use.""" + return '{client}_{server}'.format( + client='stream' if self.client_streaming else 'unary', + server='stream' if self.server_streaming else 'unary', + ) + + # TODO(yon-mg): figure out why idempotent is reliant on http annotation + @utils.cached_property + def idempotent(self) -> bool: + """Return True if we know this method is idempotent, False otherwise. + + Note: We are intentionally conservative here. It is far less bad + to falsely believe an idempotent method is non-idempotent than + the converse. + """ + return bool(self.options.Extensions[annotations_pb2.http].get) + + @property + def ident(self) -> metadata.Address: + """Return the identifier data to be used in templates.""" + return self.meta.address + + @utils.cached_property + def paged_result_field(self) -> Optional[Field]: + """Return the response pagination field if the method is paginated.""" + # If the request field lacks any of the expected pagination fields, + # then the method is not paginated. + + # The request must have page_token and next_page_token as they keep track of pages + for source, source_type, name in ((self.input, str, 'page_token'), + (self.output, str, 'next_page_token')): + field = source.fields.get(name, None) + if not field or field.type != source_type: + return None + + # The request must have max_results or page_size + page_fields = (self.input.fields.get('max_results', None), + self.input.fields.get('page_size', None)) + page_field_size = next((field for field in page_fields if field), None) + if not page_field_size or page_field_size.type != int: + return None + + # Return the first repeated field. + for field in self.output.fields.values(): + if field.repeated: + return field + + # We found no repeated fields. Return None. return None - # The request must have max_results or page_size - page_fields = (self.input.fields.get('max_results', None), - self.input.fields.get('page_size', None)) - page_field_size = next((field for field in page_fields if field), None) - if not page_field_size or page_field_size.type != int: - return None - - # Return the first repeated field. - for field in self.output.fields.values(): - if field.repeated: - return field - - # We found no repeated fields. Return None. - return None - - @utils.cached_property - def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: - return self._ref_types(True) - - @utils.cached_property - def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: - return self._ref_types(False) - - def _ref_types(self, - recursive: bool) -> Sequence[Union[MessageType, EnumType]]: - """Return types referenced by this method.""" - # Begin with the input (request) and output (response) messages. - answer: List[Union[MessageType, EnumType]] = [self.input] - types: Iterable[Union[MessageType, EnumType]] = ( - self.input.recursive_field_types if recursive else - (f.type for f in self.flattened_fields.values() if f.message or f.enum)) - answer.extend(types) - - if not self.void: - answer.append(self.client_output) - answer.extend(self.client_output.field_types) - answer.append(self.client_output_async) - answer.extend(self.client_output_async.field_types) - - # If this method has LRO, it is possible (albeit unlikely) that - # the LRO messages reside in a different module. - if self.lro: - answer.append(self.lro.response_type) - answer.append(self.lro.metadata_type) - - # If this message paginates its responses, it is possible - # that the individual result messages reside in a different module. - if self.paged_result_field and self.paged_result_field.message: - answer.append(self.paged_result_field.message) - - # Done; return the answer. - return tuple(answer) - - @property - def void(self) -> bool: - """Return True if this method has no return value, False otherwise.""" - return self.output.ident.proto == 'google.protobuf.Empty' - - def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': - """Return a derivative of this method with the provided context. - - This method is used to address naming collisions. The returned - ``Method`` object aliases module names to avoid naming collisions - in the file being written. - """ - maybe_lro = None - if self.lro: - maybe_lro = self.lro.with_context( - collisions=collisions) if collisions else self.lro - - return dataclasses.replace( - self, - lro=maybe_lro, - input=self.input.with_context(collisions=collisions), - output=self.output.with_context(collisions=collisions), - meta=self.meta.with_context(collisions=collisions), - ) + @utils.cached_property + def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: + return self._ref_types(True) + + @utils.cached_property + def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: + return self._ref_types(False) + + def _ref_types(self, + recursive: bool) -> Sequence[Union[MessageType, EnumType]]: + """Return types referenced by this method.""" + # Begin with the input (request) and output (response) messages. + answer: List[Union[MessageType, EnumType]] = [self.input] + types: Iterable[Union[MessageType, EnumType]] = ( + self.input.recursive_field_types if recursive else + (f.type for f in self.flattened_fields.values() if f.message or f.enum)) + answer.extend(types) + + if not self.void: + answer.append(self.client_output) + answer.extend(self.client_output.field_types) + answer.append(self.client_output_async) + answer.extend(self.client_output_async.field_types) + + # If this method has LRO, it is possible (albeit unlikely) that + # the LRO messages reside in a different module. + if self.lro: + answer.append(self.lro.response_type) + answer.append(self.lro.metadata_type) + + # If this message paginates its responses, it is possible + # that the individual result messages reside in a different module. + if self.paged_result_field and self.paged_result_field.message: + answer.append(self.paged_result_field.message) + + # Done; return the answer. + return tuple(answer) + + @property + def void(self) -> bool: + """Return True if this method has no return value, False otherwise.""" + return self.output.ident.proto == 'google.protobuf.Empty' + + def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': + """Return a derivative of this method with the provided context. + + This method is used to address naming collisions. The returned + ``Method`` object aliases module names to avoid naming collisions + in the file being written. + """ + maybe_lro = None + if self.lro: + maybe_lro = self.lro.with_context( + collisions=collisions) if collisions else self.lro + + return dataclasses.replace( + self, + lro=maybe_lro, + input=self.input.with_context(collisions=collisions), + output=self.output.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions), + ) @dataclasses.dataclass(frozen=True) class CommonResource: - type_name: str - pattern: str - - @classmethod - def build(cls, resource: resource_pb2.ResourceDescriptor): - return cls(type_name=resource.type, pattern=next(iter(resource.pattern))) - - @utils.cached_property - def message_type(self): - message_pb = descriptor_pb2.DescriptorProto() - res_pb = message_pb.options.Extensions[resource_pb2.resource] - res_pb.type = self.type_name - res_pb.pattern.append(self.pattern) - - return MessageType( - message_pb=message_pb, - fields={}, - nested_enums={}, - nested_messages={}, - ) + type_name: str + pattern: str + @classmethod + def build(cls, resource: resource_pb2.ResourceDescriptor): + return cls(type_name=resource.type, pattern=next(iter(resource.pattern))) -@dataclasses.dataclass(frozen=True) -class Service: - """Description of a service (defined with the ``service`` keyword).""" - service_pb: descriptor_pb2.ServiceDescriptorProto - methods: Mapping[str, Method] - # N.B.: visible_resources is intended to be a read-only view - # whose backing store is owned by the API. - # This is represented by a types.MappingProxyType instance. - visible_resources: Mapping[str, MessageType] - meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) - - common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( - default={ - 'cloudresourcemanager.googleapis.com/Project': - CommonResource( - 'cloudresourcemanager.googleapis.com/Project', - 'projects/{project}', - ), - 'cloudresourcemanager.googleapis.com/Organization': - CommonResource( - 'cloudresourcemanager.googleapis.com/Organization', - 'organizations/{organization}', - ), - 'cloudresourcemanager.googleapis.com/Folder': - CommonResource( - 'cloudresourcemanager.googleapis.com/Folder', - 'folders/{folder}', - ), - 'cloudbilling.googleapis.com/BillingAccount': - CommonResource( - 'cloudbilling.googleapis.com/BillingAccount', - 'billingAccounts/{billing_account}', - ), - 'locations.googleapis.com/Location': - CommonResource( - 'locations.googleapis.com/Location', - 'projects/{project}/locations/{location}', - ), - }, - init=False, - compare=False, - ) - - def __getattr__(self, name): - return getattr(self.service_pb, name) - - @property - def client_name(self) -> str: - """Returns the name of the generated client class""" - return self.name + 'Client' - - @property - def async_client_name(self) -> str: - """Returns the name of the generated AsyncIO client class""" - return self.name + 'AsyncClient' - - @property - def transport_name(self): - return self.name + 'Transport' - - @property - def grpc_transport_name(self): - return self.name + 'GrpcTransport' - - @property - def grpc_asyncio_transport_name(self): - return self.name + 'GrpcAsyncIOTransport' - - @property - def rest_transport_name(self): - return self.name + 'RestTransport' - - @property - def has_lro(self) -> bool: - """Return whether the service has a long-running method.""" - return any([m.lro for m in self.methods.values()]) - - @property - def has_pagers(self) -> bool: - """Return whether the service has paged methods.""" - return any(m.paged_result_field for m in self.methods.values()) - - @property - def host(self) -> str: - """Return the hostname for this service, if specified. - - Returns: - str: The hostname, with no protocol and no trailing ``/``. - """ - if self.options.Extensions[client_pb2.default_host]: - return self.options.Extensions[client_pb2.default_host] - return '' - - @property - def shortname(self) -> str: - """Return the API short name. + @utils.cached_property + def message_type(self): + message_pb = descriptor_pb2.DescriptorProto() + res_pb = message_pb.options.Extensions[resource_pb2.resource] + res_pb.type = self.type_name + res_pb.pattern.append(self.pattern) - DRIFT uses this to identify - APIs. + return MessageType( + message_pb=message_pb, + fields={}, + nested_enums={}, + nested_messages={}, + ) - Returns: - str: The api shortname. - """ - # Get the shortname from the host - # Real APIs are expected to have format: - # "{api_shortname}.googleapis.com" - return self.host.split('.')[0] - @property - def oauth_scopes(self) -> Sequence[str]: - """Return a sequence of oauth scopes, if applicable. +@dataclasses.dataclass(frozen=True) +class Service: + """Description of a service (defined with the ``service`` keyword).""" + service_pb: descriptor_pb2.ServiceDescriptorProto + methods: Mapping[str, Method] + # N.B.: visible_resources is intended to be a read-only view + # whose backing store is owned by the API. + # This is represented by a types.MappingProxyType instance. + visible_resources: Mapping[str, MessageType] + meta: metadata.Metadata = dataclasses.field( + default_factory=metadata.Metadata,) + + common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( + default={ + 'cloudresourcemanager.googleapis.com/Project': + CommonResource( + 'cloudresourcemanager.googleapis.com/Project', + 'projects/{project}', + ), + 'cloudresourcemanager.googleapis.com/Organization': + CommonResource( + 'cloudresourcemanager.googleapis.com/Organization', + 'organizations/{organization}', + ), + 'cloudresourcemanager.googleapis.com/Folder': + CommonResource( + 'cloudresourcemanager.googleapis.com/Folder', + 'folders/{folder}', + ), + 'cloudbilling.googleapis.com/BillingAccount': + CommonResource( + 'cloudbilling.googleapis.com/BillingAccount', + 'billingAccounts/{billing_account}', + ), + 'locations.googleapis.com/Location': + CommonResource( + 'locations.googleapis.com/Location', + 'projects/{project}/locations/{location}', + ), + }, + init=False, + compare=False, + ) - Returns: - Sequence[str]: A sequence of OAuth scopes. + def __getattr__(self, name): + return getattr(self.service_pb, name) + + @property + def client_name(self) -> str: + """Returns the name of the generated client class""" + return self.name + 'Client' + + @property + def async_client_name(self) -> str: + """Returns the name of the generated AsyncIO client class""" + return self.name + 'AsyncClient' + + @property + def transport_name(self): + return self.name + 'Transport' + + @property + def grpc_transport_name(self): + return self.name + 'GrpcTransport' + + @property + def grpc_asyncio_transport_name(self): + return self.name + 'GrpcAsyncIOTransport' + + @property + def rest_transport_name(self): + return self.name + 'RestTransport' + + @property + def has_lro(self) -> bool: + """Return whether the service has a long-running method.""" + return any([m.lro for m in self.methods.values()]) + + @property + def has_pagers(self) -> bool: + """Return whether the service has paged methods.""" + return any(m.paged_result_field for m in self.methods.values()) + + @property + def host(self) -> str: + """Return the hostname for this service, if specified. + + Returns: + str: The hostname, with no protocol and no trailing ``/``. + """ + if self.options.Extensions[client_pb2.default_host]: + return self.options.Extensions[client_pb2.default_host] + return '' + + @property + def shortname(self) -> str: + """Return the API short name. + + DRIFT uses this to identify + APIs. + + Returns: + str: The api shortname. + """ + # Get the shortname from the host + # Real APIs are expected to have format: + # "{api_shortname}.googleapis.com" + return self.host.split('.')[0] + + @property + def oauth_scopes(self) -> Sequence[str]: + """Return a sequence of oauth scopes, if applicable. + + Returns: + Sequence[str]: A sequence of OAuth scopes. + """ + # Return the OAuth scopes, split on comma. + return tuple( + i.strip() + for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') + if i) + + @property + def module_name(self) -> str: + """Return the appropriate module name for this service. + + Returns: + str: The service name, in snake case. + """ + return utils.to_snake_case(self.name) + + @utils.cached_property + def names(self) -> FrozenSet[str]: + """Return a set of names used in this service. + + This is used for detecting naming collisions in the module names + used for imports. + """ + # Put together a set of the service and method names. + answer = {self.name, self.client_name, self.async_client_name} + answer.update(utils.to_snake_case(i.name) + for i in self.methods.values()) + + # Identify any import module names where the same module name is used + # from distinct packages. + modules: Dict[str, Set[str]] = collections.defaultdict(set) + for m in self.methods.values(): + for t in m.ref_types: + modules[t.ident.module].add(t.ident.package) + + answer.update(module_name for module_name, packages in modules.items() + if len(packages) > 1) + + # Done; return the answer. + return frozenset(answer) + + @utils.cached_property + def resource_messages(self) -> FrozenSet[MessageType]: + """Returns all the resource message types used in all + + request and response fields in the service. """ - # Return the OAuth scopes, split on comma. - return tuple( - i.strip() - for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') - if i) - - @property - def module_name(self) -> str: - """Return the appropriate module name for this service. - - Returns: - str: The service name, in snake case. - """ - return utils.to_snake_case(self.name) - - @utils.cached_property - def names(self) -> FrozenSet[str]: - """Return a set of names used in this service. - This is used for detecting naming collisions in the module names - used for imports. - """ - # Put together a set of the service and method names. - answer = {self.name, self.client_name, self.async_client_name} - answer.update(utils.to_snake_case(i.name) for i in self.methods.values()) - - # Identify any import module names where the same module name is used - # from distinct packages. - modules: Dict[str, Set[str]] = collections.defaultdict(set) - for m in self.methods.values(): - for t in m.ref_types: - modules[t.ident.module].add(t.ident.package) - - answer.update(module_name for module_name, packages in modules.items() - if len(packages) > 1) - - # Done; return the answer. - return frozenset(answer) - - @utils.cached_property - def resource_messages(self) -> FrozenSet[MessageType]: - """Returns all the resource message types used in all - - request and response fields in the service. - """ - - def gen_resources(message): - if message.resource_path: - yield message - - for type_ in message.recursive_field_types: - if type_.resource_path: - yield type_ - - def gen_indirect_resources_used(message): - for field in message.recursive_resource_fields: - resource = field.options.Extensions[resource_pb2.resource_reference] - resource_type = resource.type or resource.child_type - # The resource may not be visible if the resource type is one of - # the common_resources (see the class var in class definition) - # or if it's something unhelpful like '*'. - resource = self.visible_resources.get(resource_type) - if resource: - yield resource - - return frozenset(msg for method in self.methods.values() for msg in chain( - gen_resources(method.input), - gen_resources( - method.lro.response_type if method.lro else method.output), - gen_indirect_resources_used(method.input), - gen_indirect_resources_used( - method.lro.response_type if method.lro else method.output), - )) - - @utils.cached_property - def any_client_streaming(self) -> bool: - return any(m.client_streaming for m in self.methods.values()) - - @utils.cached_property - def any_server_streaming(self) -> bool: - return any(m.server_streaming for m in self.methods.values()) - - def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': - """Return a derivative of this service with the provided context. - - This method is used to address naming collisions. The returned - ``Service`` object aliases module names to avoid naming collisions - in the file being written. - """ - return dataclasses.replace( - self, - methods={ - k: v.with_context( - # A method's flattened fields create additional names - # that may conflict with module imports. - collisions=collisions | frozenset(v.flattened_fields.keys())) - for k, v in self.methods.items() - }, - meta=self.meta.with_context(collisions=collisions), - ) + def gen_resources(message): + if message.resource_path: + yield message + + for type_ in message.recursive_field_types: + if type_.resource_path: + yield type_ + + def gen_indirect_resources_used(message): + for field in message.recursive_resource_fields: + resource = field.options.Extensions[resource_pb2.resource_reference] + resource_type = resource.type or resource.child_type + # The resource may not be visible if the resource type is one of + # the common_resources (see the class var in class definition) + # or if it's something unhelpful like '*'. + resource = self.visible_resources.get(resource_type) + if resource: + yield resource + + return frozenset(msg for method in self.methods.values() for msg in chain( + gen_resources(method.input), + gen_resources( + method.lro.response_type if method.lro else method.output), + gen_indirect_resources_used(method.input), + gen_indirect_resources_used( + method.lro.response_type if method.lro else method.output), + )) + + @utils.cached_property + def any_client_streaming(self) -> bool: + return any(m.client_streaming for m in self.methods.values()) + + @utils.cached_property + def any_server_streaming(self) -> bool: + return any(m.server_streaming for m in self.methods.values()) + + def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': + """Return a derivative of this service with the provided context. + + This method is used to address naming collisions. The returned + ``Service`` object aliases module names to avoid naming collisions + in the file being written. + """ + return dataclasses.replace( + self, + methods={ + k: v.with_context( + # A method's flattened fields create additional names + # that may conflict with module imports. + collisions=collisions | frozenset(v.flattened_fields.keys())) + for k, v in self.methods.items() + }, + meta=self.meta.with_context(collisions=collisions), + ) diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index d444c56bf6..69c3b7cf07 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -32,20 +32,20 @@ def make_service( visible_resources: typing.Optional[typing.Mapping[ str, wrappers.CommonResource]] = None, ) -> wrappers.Service: - visible_resources = visible_resources or {} - # Define a service descriptor, and set a host and oauth scopes if - # appropriate. - service_pb = desc.ServiceDescriptorProto(name=name) - if host: - service_pb.options.Extensions[client_pb2.default_host] = host - service_pb.options.Extensions[client_pb2.oauth_scopes] = ','.join(scopes) - - # Return a service object to test. - return wrappers.Service( - service_pb=service_pb, - methods={m.name: m for m in methods}, - visible_resources=visible_resources, - ) + visible_resources = visible_resources or {} + # Define a service descriptor, and set a host and oauth scopes if + # appropriate. + service_pb = desc.ServiceDescriptorProto(name=name) + if host: + service_pb.options.Extensions[client_pb2.default_host] = host + service_pb.options.Extensions[client_pb2.oauth_scopes] = ','.join(scopes) + + # Return a service object to test. + return wrappers.Service( + service_pb=service_pb, + methods={m.name: m for m in methods}, + visible_resources=visible_resources, + ) # FIXME (lukesneeringer): This test method is convoluted and it makes these @@ -58,28 +58,28 @@ def make_service_with_method_options( visible_resources: typing.Optional[typing.Mapping[ str, wrappers.CommonResource]] = None, ) -> wrappers.Service: - # Declare a method with options enabled for long-running operations and - # field headers. - method = get_method( - 'DoBigThing', - 'foo.bar.ThingRequest', - 'google.longrunning.operations_pb2.Operation', - lro_response_type='foo.baz.ThingResponse', - lro_metadata_type='foo.qux.ThingMetadata', - in_fields=in_fields, - http_rule=http_rule, - method_signature=method_signature, - ) - - # Define a service descriptor. - service_pb = desc.ServiceDescriptorProto(name='ThingDoer') - - # Return a service object to test. - return wrappers.Service( - service_pb=service_pb, - methods={method.name: method}, - visible_resources=visible_resources or {}, - ) + # Declare a method with options enabled for long-running operations and + # field headers. + method = get_method( + 'DoBigThing', + 'foo.bar.ThingRequest', + 'google.longrunning.operations_pb2.Operation', + lro_response_type='foo.baz.ThingResponse', + lro_metadata_type='foo.qux.ThingMetadata', + in_fields=in_fields, + http_rule=http_rule, + method_signature=method_signature, + ) + + # Define a service descriptor. + service_pb = desc.ServiceDescriptorProto(name='ThingDoer') + + # Return a service object to test. + return wrappers.Service( + service_pb=service_pb, + methods={method.name: method}, + visible_resources=visible_resources or {}, + ) def get_method( @@ -93,35 +93,35 @@ def get_method( http_rule: http_pb2.HttpRule = None, method_signature: str = '', ) -> wrappers.Method: - input_ = get_message(in_type, fields=in_fields) - output = get_message(out_type) - lro = None - - # Define a method descriptor. Set the field headers if appropriate. - method_pb = desc.MethodDescriptorProto( - name=name, - input_type=input_.ident.proto, - output_type=output.ident.proto, - ) - if lro_response_type: - lro = wrappers.OperationInfo( - response_type=get_message(lro_response_type), - metadata_type=get_message(lro_metadata_type), + input_ = get_message(in_type, fields=in_fields) + output = get_message(out_type) + lro = None + + # Define a method descriptor. Set the field headers if appropriate. + method_pb = desc.MethodDescriptorProto( + name=name, + input_type=input_.ident.proto, + output_type=output.ident.proto, + ) + if lro_response_type: + lro = wrappers.OperationInfo( + response_type=get_message(lro_response_type), + metadata_type=get_message(lro_metadata_type), + ) + if http_rule: + ext_key = annotations_pb2.http + method_pb.options.Extensions[ext_key].MergeFrom(http_rule) + if method_signature: + ext_key = client_pb2.method_signature + method_pb.options.Extensions[ext_key].append(method_signature) + + return wrappers.Method( + method_pb=method_pb, + input=input_, + output=output, + lro=lro, + meta=input_.meta, ) - if http_rule: - ext_key = annotations_pb2.http - method_pb.options.Extensions[ext_key].MergeFrom(http_rule) - if method_signature: - ext_key = client_pb2.method_signature - method_pb.options.Extensions[ext_key].append(method_signature) - - return wrappers.Method( - method_pb=method_pb, - input=input_, - output=output, - lro=lro, - meta=input_.meta, - ) def get_message( @@ -129,36 +129,36 @@ def get_message( *, fields: typing.Tuple[desc.FieldDescriptorProto] = (), ) -> wrappers.MessageType: - # Pass explicit None through (for lro_metadata). - if dot_path is None: - return None - - # Note: The `dot_path` here is distinct from the canonical proto path - # because it includes the module, which the proto path does not. - # - # So, if trying to test the DescriptorProto message here, the path - # would be google.protobuf.descriptor.DescriptorProto (whereas the proto - # path is just google.protobuf.DescriptorProto). - pieces = dot_path.split('.') - pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] - - return wrappers.MessageType( - fields={ - i.name: wrappers.Field( - field_pb=i, - enum=get_enum(i.type_name) if i.type_name else None, - ) for i in fields - }, - nested_messages={}, - nested_enums={}, - message_pb=desc.DescriptorProto(name=name, field=fields), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), - ) + # Pass explicit None through (for lro_metadata). + if dot_path is None: + return None + + # Note: The `dot_path` here is distinct from the canonical proto path + # because it includes the module, which the proto path does not. + # + # So, if trying to test the DescriptorProto message here, the path + # would be google.protobuf.descriptor.DescriptorProto (whereas the proto + # path is just google.protobuf.DescriptorProto). + pieces = dot_path.split('.') + pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] + + return wrappers.MessageType( + fields={ + i.name: wrappers.Field( + field_pb=i, + enum=get_enum(i.type_name) if i.type_name else None, + ) for i in fields + }, + nested_messages={}, + nested_enums={}, + message_pb=desc.DescriptorProto(name=name, field=fields), + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), + ) def make_method(name: str, @@ -170,46 +170,46 @@ def make_method(name: str, signatures: typing.Sequence[str] = (), is_deprecated: bool = False, **kwargs) -> wrappers.Method: - # Use default input and output messages if they are not provided. - input_message = input_message or make_message('MethodInput') - output_message = output_message or make_message('MethodOutput') - - # Create the method pb2. - method_pb = desc.MethodDescriptorProto( - name=name, - input_type=str(input_message.meta.address), - output_type=str(output_message.meta.address), - **kwargs) - - # If there is an HTTP rule, process it. - if http_rule: - ext_key = annotations_pb2.http - method_pb.options.Extensions[ext_key].MergeFrom(http_rule) - - # If there are signatures, include them. - for sig in signatures: - ext_key = client_pb2.method_signature - method_pb.options.Extensions[ext_key].append(sig) - - if isinstance(package, str): - package = tuple(package.split('.')) - - if is_deprecated: - method_pb.options.deprecated = True - - # Instantiate the wrapper class. - return wrappers.Method( - method_pb=method_pb, - input=input_message, - output=output_message, - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=package, - module=module, - parent=(f'{name}Service',), - )), - ) + # Use default input and output messages if they are not provided. + input_message = input_message or make_message('MethodInput') + output_message = output_message or make_message('MethodOutput') + + # Create the method pb2. + method_pb = desc.MethodDescriptorProto( + name=name, + input_type=str(input_message.meta.address), + output_type=str(output_message.meta.address), + **kwargs) + + # If there is an HTTP rule, process it. + if http_rule: + ext_key = annotations_pb2.http + method_pb.options.Extensions[ext_key].MergeFrom(http_rule) + + # If there are signatures, include them. + for sig in signatures: + ext_key = client_pb2.method_signature + method_pb.options.Extensions[ext_key].append(sig) + + if isinstance(package, str): + package = tuple(package.split('.')) + + if is_deprecated: + method_pb.options.deprecated = True + + # Instantiate the wrapper class. + return wrappers.Method( + method_pb=method_pb, + input=input_message, + output=output_message, + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=package, + module=module, + parent=(f'{name}Service',), + )), + ) def make_field(name: str = 'my_field', @@ -220,31 +220,31 @@ def make_field(name: str = 'my_field', meta: metadata.Metadata = None, oneof: str = None, **kwargs) -> wrappers.Field: - T = desc.FieldDescriptorProto.Type - - if message: - kwargs.setdefault('type_name', str(message.meta.address)) - kwargs['type'] = 'TYPE_MESSAGE' - elif enum: - kwargs.setdefault('type_name', str(enum.meta.address)) - kwargs['type'] = 'TYPE_ENUM' - else: - kwargs.setdefault('type', T.Value('TYPE_BOOL')) - - if isinstance(kwargs['type'], str): - kwargs['type'] = T.Value(kwargs['type']) - - label = kwargs.pop('label', 3 if repeated else 1) - field_pb = desc.FieldDescriptorProto( - name=name, label=label, number=number, **kwargs) - - return wrappers.Field( - field_pb=field_pb, - enum=enum, - message=message, - meta=meta or metadata.Metadata(), - oneof=oneof, - ) + T = desc.FieldDescriptorProto.Type + + if message: + kwargs.setdefault('type_name', str(message.meta.address)) + kwargs['type'] = 'TYPE_MESSAGE' + elif enum: + kwargs.setdefault('type_name', str(enum.meta.address)) + kwargs['type'] = 'TYPE_ENUM' + else: + kwargs.setdefault('type', T.Value('TYPE_BOOL')) + + if isinstance(kwargs['type'], str): + kwargs['type'] = T.Value(kwargs['type']) + + label = kwargs.pop('label', 3 if repeated else 1) + field_pb = desc.FieldDescriptorProto( + name=name, label=label, number=number, **kwargs) + + return wrappers.Field( + field_pb=field_pb, + enum=enum, + message=message, + meta=meta or metadata.Metadata(), + oneof=oneof, + ) def make_message( @@ -255,38 +255,38 @@ def make_message( meta: metadata.Metadata = None, options: desc.MethodOptions = None, ) -> wrappers.MessageType: - message_pb = desc.DescriptorProto( - name=name, - field=[i.field_pb for i in fields], - options=options, - ) - return wrappers.MessageType( - message_pb=message_pb, - fields=collections.OrderedDict((i.name, i) for i in fields), - nested_messages={}, - nested_enums={}, - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), - ) + message_pb = desc.DescriptorProto( + name=name, + field=[i.field_pb for i in fields], + options=options, + ) + return wrappers.MessageType( + message_pb=message_pb, + fields=collections.OrderedDict((i.name, i) for i in fields), + nested_messages={}, + nested_enums={}, + meta=meta or metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), + ) def get_enum(dot_path: str) -> wrappers.EnumType: - pieces = dot_path.split('.') - pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] - return wrappers.EnumType( - enum_pb=desc.EnumDescriptorProto(name=name), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), - values=[], - ) + pieces = dot_path.split('.') + pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] + return wrappers.EnumType( + enum_pb=desc.EnumDescriptorProto(name=name), + meta=metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), + values=[], + ) def make_enum( @@ -297,71 +297,72 @@ def make_enum( meta: metadata.Metadata = None, options: desc.EnumOptions = None, ) -> wrappers.EnumType: - enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=i[0], number=i[1]) for i in values - ] - enum_pb = desc.EnumDescriptorProto( - name=name, - value=enum_value_pbs, - options=options, - ) - return wrappers.EnumType( - enum_pb=enum_pb, - values=[ - wrappers.EnumValueType(enum_value_pb=evpb) for evpb in enum_value_pbs - ], - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), - ) + enum_value_pbs = [ + desc.EnumValueDescriptorProto(name=i[0], number=i[1]) for i in values + ] + enum_pb = desc.EnumDescriptorProto( + name=name, + value=enum_value_pbs, + options=options, + ) + return wrappers.EnumType( + enum_pb=enum_pb, + values=[ + wrappers.EnumValueType(enum_value_pb=evpb) for evpb in enum_value_pbs + ], + meta=meta or metadata.Metadata( + address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), + ) def make_naming(**kwargs) -> naming.Naming: - kwargs.setdefault('name', 'Hatstand') - kwargs.setdefault('namespace', ('Google', 'Cloud')) - kwargs.setdefault('version', 'v1') - kwargs.setdefault('product_name', 'Hatstand') - return naming.NewNaming(**kwargs) + kwargs.setdefault('name', 'Hatstand') + kwargs.setdefault('namespace', ('Google', 'Cloud')) + kwargs.setdefault('version', 'v1') + kwargs.setdefault('product_name', 'Hatstand') + return naming.NewNaming(**kwargs) def make_enum_pb2(name: str, *values: typing.Sequence[str], **kwargs) -> desc.EnumDescriptorProto: - enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=n, number=i) - for i, n in enumerate(values) - ] - enum_pb = desc.EnumDescriptorProto(name=name, value=enum_value_pbs, **kwargs) - return enum_pb + enum_value_pbs = [ + desc.EnumValueDescriptorProto(name=n, number=i) + for i, n in enumerate(values) + ] + enum_pb = desc.EnumDescriptorProto( + name=name, value=enum_value_pbs, **kwargs) + return enum_pb def make_message_pb2(name: str, fields: tuple = (), oneof_decl: tuple = (), **kwargs) -> desc.DescriptorProto: - return desc.DescriptorProto( - name=name, field=fields, oneof_decl=oneof_decl, **kwargs) + return desc.DescriptorProto( + name=name, field=fields, oneof_decl=oneof_decl, **kwargs) def make_field_pb2( - name: str, - number: int, - type: int = 11, # 11 == message - type_name: str = None, - oneof_index: int = None) -> desc.FieldDescriptorProto: - return desc.FieldDescriptorProto( - name=name, - number=number, - type=type, - type_name=type_name, - oneof_index=oneof_index, - ) + name: str, + number: int, + type: int = 11, # 11 == message + type_name: str = None, + oneof_index: int = None) -> desc.FieldDescriptorProto: + return desc.FieldDescriptorProto( + name=name, + number=number, + type=type, + type_name=type_name, + oneof_index=oneof_index, + ) def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto: - return desc.OneofDescriptorProto(name=name,) + return desc.OneofDescriptorProto(name=name,) def make_file_pb2( @@ -373,14 +374,14 @@ def make_file_pb2( services: typing.Sequence[desc.ServiceDescriptorProto] = (), locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), ) -> desc.FileDescriptorProto: - return desc.FileDescriptorProto( - name=name, - package=package, - message_type=messages, - enum_type=enums, - service=services, - source_code_info=desc.SourceCodeInfo(location=locations), - ) + return desc.FileDescriptorProto( + name=name, + package=package, + message_type=messages, + enum_type=enums, + service=services, + source_code_info=desc.SourceCodeInfo(location=locations), + ) def make_doc_meta( @@ -389,9 +390,9 @@ def make_doc_meta( trailing: str = '', detached: typing.List[str] = [], ) -> desc.SourceCodeInfo.Location: - return metadata.Metadata( - documentation=desc.SourceCodeInfo.Location( - leading_comments=leading, - trailing_comments=trailing, - leading_detached_comments=detached, - ),) + return metadata.Metadata( + documentation=desc.SourceCodeInfo.Location( + leading_comments=leading, + trailing_comments=trailing, + leading_detached_comments=detached, + ),) diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 6a47bd42f7..6168f58564 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -31,277 +31,279 @@ def test_method_types(): - input_msg = make_message(name='Input', module='baz') - output_msg = make_message(name='Output', module='baz') - method = make_method( - 'DoSomething', input_msg, output_msg, package='foo.bar', module='bacon') - assert method.name == 'DoSomething' - assert method.input.name == 'Input' - assert method.output.name == 'Output' + input_msg = make_message(name='Input', module='baz') + output_msg = make_message(name='Output', module='baz') + method = make_method( + 'DoSomething', input_msg, output_msg, package='foo.bar', module='bacon') + assert method.name == 'DoSomething' + assert method.input.name == 'Input' + assert method.output.name == 'Output' def test_method_void(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.void + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.void def test_method_not_void(): - not_empty = make_message(name='OutputMessage', package='foo.bar.v1') - method = make_method('Meh', output_message=not_empty) - assert not method.void + not_empty = make_message(name='OutputMessage', package='foo.bar.v1') + method = make_method('Meh', output_message=not_empty) + assert not method.void def test_method_deprecated(): - method = make_method('DeprecatedMethod', is_deprecated=True) - assert method.is_deprecated + method = make_method('DeprecatedMethod', is_deprecated=True) + assert method.is_deprecated def test_method_client_output(): - output = make_message(name='Input', module='baz') - method = make_method('DoStuff', output_message=output) - assert method.client_output is method.output + output = make_message(name='Input', module='baz') + method = make_method('DoStuff', output_message=output) + assert method.client_output is method.output def test_method_client_output_empty(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.client_output == wrappers.PrimitiveType.build(None) + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.client_output == wrappers.PrimitiveType.build(None) def test_method_client_output_paged(): - paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - parent = make_field(name='parent', type=9) # str - page_size = make_field(name='page_size', type=5) # int - page_token = make_field(name='page_token', type=9) # str - - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - page_size, - page_token, - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListFoosPager' - - max_results = make_field(name='max_results', type=5) # int - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - max_results, - page_token, - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListFoosPager' + paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int + page_token = make_field(name='page_token', type=9) # str + + input_msg = make_message( + name='ListFoosRequest', fields=( + parent, + page_size, + page_token, + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' + + max_results = make_field(name='max_results', type=5) # int + input_msg = make_message( + name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' def test_method_client_output_async_empty(): - empty = make_message(name='Empty', package='google.protobuf') - method = make_method('Meh', output_message=empty) - assert method.client_output_async == wrappers.PrimitiveType.build(None) + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.client_output_async == wrappers.PrimitiveType.build(None) def test_method_paged_result_field_not_first(): - paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='next_page_token', type=9), # str - paged, - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged + paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + input_msg = make_message( + name='ListFoosRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + make_field(name='next_page_token', type=9), # str + paged, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged def test_method_paged_result_field_no_page_field(): - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='foos', message=make_message('Foo'), repeated=False), - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field is None - - method = make_method( - name='Foo', - input_message=make_message( - name='FooRequest', - fields=(make_field(name='page_token', type=9),) # str - ), - output_message=make_message( - name='FooResponse', - fields=(make_field(name='next_page_token', type=9),) # str - )) - assert method.paged_result_field is None + input_msg = make_message( + name='ListFoosRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message( + name='ListFoosResponse', + fields=( + make_field(name='foos', message=make_message( + 'Foo'), repeated=False), + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field is None + + method = make_method( + name='Foo', + input_message=make_message( + name='FooRequest', + fields=(make_field(name='page_token', type=9),) # str + ), + output_message=make_message( + name='FooResponse', + fields=(make_field(name='next_page_token', type=9),) # str + )) + assert method.paged_result_field is None def test_method_paged_result_ref_types(): - input_msg = make_message( - name='ListSquidsRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - ), - module='squid', - ) - mollusc_msg = make_message('Mollusc', module='mollusc') - output_msg = make_message( - name='ListMolluscsResponse', - fields=( - make_field(name='molluscs', message=mollusc_msg, repeated=True), - make_field(name='next_page_token', type=9) # str - ), - module='mollusc') - method = make_method( - 'ListSquids', - input_message=input_msg, - output_message=output_msg, - module='squid') - - ref_type_names = {t.name for t in method.ref_types} - assert ref_type_names == { - 'ListSquidsRequest', - 'ListSquidsPager', - 'ListSquidsAsyncPager', - 'Mollusc', - } + input_msg = make_message( + name='ListSquidsRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + ), + module='squid', + ) + mollusc_msg = make_message('Mollusc', module='mollusc') + output_msg = make_message( + name='ListMolluscsResponse', + fields=( + make_field(name='molluscs', message=mollusc_msg, repeated=True), + make_field(name='next_page_token', type=9) # str + ), + module='mollusc') + method = make_method( + 'ListSquids', + input_message=input_msg, + output_message=output_msg, + module='squid') + + ref_type_names = {t.name for t in method.ref_types} + assert ref_type_names == { + 'ListSquidsRequest', + 'ListSquidsPager', + 'ListSquidsAsyncPager', + 'Mollusc', + } def test_flattened_ref_types(): - method = make_method( - 'IdentifyMollusc', - input_message=make_message( - 'IdentifyMolluscRequest', - fields=( - make_field( - 'cephalopod', - message=make_message( - 'Cephalopod', - fields=( - make_field('mass_kg', type='TYPE_INT32'), - make_field( - 'squid', - number=2, - message=make_message('Squid'), - ), - make_field( - 'clam', - number=3, - message=make_message('Clam'), - ), - ), - ), - ), - make_field('stratum', enum=make_enum('Stratum',)), - ), - ), - signatures=('cephalopod.squid,stratum',), - output_message=make_message('Mollusc'), - ) - - expected_flat_ref_type_names = { - 'IdentifyMolluscRequest', - 'Squid', - 'Stratum', - 'Mollusc', - } - actual_flat_ref_type_names = {t.name for t in method.flat_ref_types} - assert expected_flat_ref_type_names == actual_flat_ref_type_names + method = make_method( + 'IdentifyMollusc', + input_message=make_message( + 'IdentifyMolluscRequest', + fields=( + make_field( + 'cephalopod', + message=make_message( + 'Cephalopod', + fields=( + make_field('mass_kg', type='TYPE_INT32'), + make_field( + 'squid', + number=2, + message=make_message('Squid'), + ), + make_field( + 'clam', + number=3, + message=make_message('Clam'), + ), + ), + ), + ), + make_field('stratum', enum=make_enum('Stratum',)), + ), + ), + signatures=('cephalopod.squid,stratum',), + output_message=make_message('Mollusc'), + ) + + expected_flat_ref_type_names = { + 'IdentifyMolluscRequest', + 'Squid', + 'Stratum', + 'Mollusc', + } + actual_flat_ref_type_names = {t.name for t in method.flat_ref_types} + assert expected_flat_ref_type_names == actual_flat_ref_type_names def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) # str - input_msg = make_message( - name='ListSquidsRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - ), - ) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListSquids', - input_message=input_msg, - output_message=output_msg, - ) - assert method.paged_result_field == paged - assert method.client_output.ident.name == 'ListSquidsPager' + paged = make_field(name='squids', type=9, repeated=True) # str + input_msg = make_message( + name='ListSquidsRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + ), + ) + output_msg = make_message( + name='ListFoosResponse', + fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) + method = make_method( + 'ListSquids', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListSquidsPager' def test_method_field_headers_none(): - method = make_method('DoSomething') - assert isinstance(method.field_headers, collections.abc.Sequence) + method = make_method('DoSomething') + assert isinstance(method.field_headers, collections.abc.Sequence) def test_method_field_headers_present(): - verbs = [ - 'get', - 'put', - 'post', - 'delete', - 'patch', - ] + verbs = [ + 'get', + 'put', + 'post', + 'delete', + 'patch', + ] - for v in verbs: - rule = http_pb2.HttpRule(**{v: '/v1/{parent=projects/*}/topics'}) - method = make_method('DoSomething', http_rule=rule) - assert method.field_headers == ('parent',) + for v in verbs: + rule = http_pb2.HttpRule(**{v: '/v1/{parent=projects/*}/topics'}) + method = make_method('DoSomething', http_rule=rule) + assert method.field_headers == ('parent',) def test_method_http_opt(): - http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics', body='*') - method = make_method('DoSomething', http_rule=http_rule) - assert method.http_opt == { - 'verb': 'post', - 'url': '/v1/{parent=projects/*}/topics', - 'body': '*' - } + http_rule = http_pb2.HttpRule( + post='/v1/{parent=projects/*}/topics', body='*') + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_opt == { + 'verb': 'post', + 'url': '/v1/{parent=projects/*}/topics', + 'body': '*' + } # TODO(yon-mg) to test: grpc transcoding, @@ -310,242 +312,243 @@ def test_method_http_opt(): def test_method_http_opt_no_body(): - http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.http_opt == { - 'verb': 'post', - 'url': '/v1/{parent=projects/*}/topics' - } + http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_opt == { + 'verb': 'post', + 'url': '/v1/{parent=projects/*}/topics' + } def test_method_http_opt_no_http_rule(): - method = make_method('DoSomething') - assert method.http_opt == None + method = make_method('DoSomething') + assert method.http_opt == None def test_method_path_params(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.path_params == ['project'] + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.path_params == ['project'] def test_method_path_params_no_http_rule(): - method = make_method('DoSomething') - assert method.path_params == [] + method = make_method('DoSomething') + assert method.path_params == [] def test_method_query_params(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics', body='address') - input_message = make_message( - 'MethodInput', - fields=(make_field('region'), make_field('project'), - make_field('address'))) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) - assert method.query_params == {'region'} + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics', body='address') + input_message = make_message( + 'MethodInput', + fields=(make_field('region'), make_field('project'), + make_field('address'))) + method = make_method( + 'DoSomething', http_rule=http_rule, input_message=input_message) + assert method.query_params == {'region'} def test_method_query_params_no_body(): - # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') - input_message = make_message( - 'MethodInput', fields=( - make_field('region'), - make_field('project'), - )) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) - assert method.query_params == {'region'} + # tests only the basic case of grpc transcoding + http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') + input_message = make_message( + 'MethodInput', fields=( + make_field('region'), + make_field('project'), + )) + method = make_method( + 'DoSomething', http_rule=http_rule, input_message=input_message) + assert method.query_params == {'region'} def test_method_query_params_no_http_rule(): - method = make_method('DoSomething') - assert method.query_params == set() + method = make_method('DoSomething') + assert method.query_params == set() def test_method_idempotent_yes(): - http_rule = http_pb2.HttpRule(get='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.idempotent is True + http_rule = http_pb2.HttpRule(get='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.idempotent is True def test_method_idempotent_no(): - http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') - method = make_method('DoSomething', http_rule=http_rule) - assert method.idempotent is False + http_rule = http_pb2.HttpRule(post='/v1/{parent=projects/*}/topics') + method = make_method('DoSomething', http_rule=http_rule) + assert method.idempotent is False def test_method_idempotent_no_http_rule(): - method = make_method('DoSomething') - assert method.idempotent is False + method = make_method('DoSomething') + assert method.idempotent is False def test_method_unary_unary(): - method = make_method('F', client_streaming=False, server_streaming=False) - assert method.grpc_stub_type == 'unary_unary' + method = make_method('F', client_streaming=False, server_streaming=False) + assert method.grpc_stub_type == 'unary_unary' def test_method_unary_stream(): - method = make_method('F', client_streaming=False, server_streaming=True) - assert method.grpc_stub_type == 'unary_stream' + method = make_method('F', client_streaming=False, server_streaming=True) + assert method.grpc_stub_type == 'unary_stream' def test_method_stream_unary(): - method = make_method('F', client_streaming=True, server_streaming=False) - assert method.grpc_stub_type == 'stream_unary' + method = make_method('F', client_streaming=True, server_streaming=False) + assert method.grpc_stub_type == 'stream_unary' def test_method_stream_stream(): - method = make_method('F', client_streaming=True, server_streaming=True) - assert method.grpc_stub_type == 'stream_stream' + method = make_method('F', client_streaming=True, server_streaming=True) + assert method.grpc_stub_type == 'stream_stream' def test_method_flattened_fields(): - a = make_field('a', type=5) # int - b = make_field('b', type=5) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('a,b',)) - assert len(method.flattened_fields) == 2 - assert 'a' in method.flattened_fields - assert 'b' in method.flattened_fields + a = make_field('a', type=5) # int + b = make_field('b', type=5) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('a,b',)) + assert len(method.flattened_fields) == 2 + assert 'a' in method.flattened_fields + assert 'b' in method.flattened_fields def test_method_flattened_fields_empty_sig(): - a = make_field('a', type=5) # int - b = make_field('b', type=5) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('',)) - assert len(method.flattened_fields) == 0 + a = make_field('a', type=5) # int + b = make_field('b', type=5) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('',)) + assert len(method.flattened_fields) == 0 def test_method_flattened_fields_different_package_non_primitive(): - # This test verifies that method flattening handles a special case where: - # * the method's request message type lives in a different package and - # * a field in the method_signature is a non-primitive. - # - # If the message is defined in a different package it is not guaranteed to - # be a proto-plus wrapped type, which puts restrictions on assigning - # directly to its fields, which complicates request construction. - # The easiest solution in this case is to just prohibit these fields - # in the method flattening. - message = make_message( - 'Mantle', package='mollusc.cephalopod.v1', module='squid') - mantle = make_field( - 'mantle', type=11, type_name='Mantle', message=message, meta=message.meta) - arms_count = make_field('arms_count', type=5, meta=message.meta) - input_message = make_message( - 'Squid', - fields=(mantle, arms_count), - package='.'.join(message.meta.address.package), - module=message.meta.address.module) - method = make_method( - 'PutSquid', - input_message=input_message, - package='remote.package.v1', - module='module', - signatures=('mantle,arms_count',)) - assert set(method.flattened_fields) == {'arms_count'} + # This test verifies that method flattening handles a special case where: + # * the method's request message type lives in a different package and + # * a field in the method_signature is a non-primitive. + # + # If the message is defined in a different package it is not guaranteed to + # be a proto-plus wrapped type, which puts restrictions on assigning + # directly to its fields, which complicates request construction. + # The easiest solution in this case is to just prohibit these fields + # in the method flattening. + message = make_message( + 'Mantle', package='mollusc.cephalopod.v1', module='squid') + mantle = make_field( + 'mantle', type=11, type_name='Mantle', message=message, meta=message.meta) + arms_count = make_field('arms_count', type=5, meta=message.meta) + input_message = make_message( + 'Squid', + fields=(mantle, arms_count), + package='.'.join(message.meta.address.package), + module=message.meta.address.module) + method = make_method( + 'PutSquid', + input_message=input_message, + package='remote.package.v1', + module='module', + signatures=('mantle,arms_count',)) + assert set(method.flattened_fields) == {'arms_count'} def test_method_include_flattened_message_fields(): - a = make_field('a', type=5) - b = make_field('b', type=11, type_name='Eggs', message=make_message('Eggs')) - input_msg = make_message('Z', fields=(a, b)) - method = make_method('F', input_message=input_msg, signatures=('a,b',)) - assert len(method.flattened_fields) == 2 + a = make_field('a', type=5) + b = make_field('b', type=11, type_name='Eggs', + message=make_message('Eggs')) + input_msg = make_message('Z', fields=(a, b)) + method = make_method('F', input_message=input_msg, signatures=('a,b',)) + assert len(method.flattened_fields) == 2 def test_method_legacy_flattened_fields(): - required_options = descriptor_pb2.FieldOptions() - required_options.Extensions[field_behavior_pb2.field_behavior].append( - field_behavior_pb2.FieldBehavior.Value('REQUIRED')) - - # Cephalopods are required. - squid = make_field(name='squid', options=required_options) - octopus = make_field( - name='octopus', - message=make_message( - name='Octopus', - fields=[make_field(name='mass', options=required_options)]), - options=required_options) - - # Bivalves are optional. - clam = make_field(name='clam') - oyster = make_field( - name='oyster', - message=make_message( - name='Oyster', fields=[make_field(name='has_pearl')])) - - # Interleave required and optional fields to make sure - # that, in the legacy flattening, required fields are always first. - request = make_message('request', fields=[squid, clam, octopus, oyster]) - - method = make_method( - name='CreateMolluscs', - input_message=request, - # Signatures should be ignored. - signatures=['squid,octopus.mass', 'squid,octopus,oyster.has_pearl']) - - # Use an ordered dict because ordering is important: - # required fields should come first. - expected = collections.OrderedDict([('squid', squid), ('octopus', octopus), - ('clam', clam), ('oyster', oyster)]) - - assert method.legacy_flattened_fields == expected + required_options = descriptor_pb2.FieldOptions() + required_options.Extensions[field_behavior_pb2.field_behavior].append( + field_behavior_pb2.FieldBehavior.Value('REQUIRED')) + + # Cephalopods are required. + squid = make_field(name='squid', options=required_options) + octopus = make_field( + name='octopus', + message=make_message( + name='Octopus', + fields=[make_field(name='mass', options=required_options)]), + options=required_options) + + # Bivalves are optional. + clam = make_field(name='clam') + oyster = make_field( + name='oyster', + message=make_message( + name='Oyster', fields=[make_field(name='has_pearl')])) + + # Interleave required and optional fields to make sure + # that, in the legacy flattening, required fields are always first. + request = make_message('request', fields=[squid, clam, octopus, oyster]) + + method = make_method( + name='CreateMolluscs', + input_message=request, + # Signatures should be ignored. + signatures=['squid,octopus.mass', 'squid,octopus,oyster.has_pearl']) + + # Use an ordered dict because ordering is important: + # required fields should come first. + expected = collections.OrderedDict([('squid', squid), ('octopus', octopus), + ('clam', clam), ('oyster', oyster)]) + + assert method.legacy_flattened_fields == expected def test_flattened_oneof_fields(): - mass_kg = make_field(name='mass_kg', oneof='mass', type=5) - mass_lbs = make_field(name='mass_lbs', oneof='mass', type=5) - - length_m = make_field(name='length_m', oneof='length', type=5) - length_f = make_field(name='length_f', oneof='length', type=5) - - color = make_field(name='color', type=5) - mantle = make_field( - name='mantle', - message=make_message( - name='Mantle', - fields=( - make_field(name='color', type=5), - mass_kg, - mass_lbs, - ), - ), - ) - request = make_message( - name='CreateMolluscReuqest', - fields=( - length_m, - length_f, - color, - mantle, - ), - ) - method = make_method( - name='CreateMollusc', - input_message=request, - signatures=[ - 'length_m,', - 'length_f,', - 'mantle.mass_kg,', - 'mantle.mass_lbs,', - 'color', - ]) - - expected = {'mass': [mass_kg, mass_lbs], 'length': [length_m, length_f]} - actual = method.flattened_oneof_fields() - assert expected == actual - - # Check this method too becasue the setup is a lot of work. - expected = { - 'color': 'color', - 'length_m': 'length_m', - 'length_f': 'length_f', - 'mass_kg': 'mantle.mass_kg', - 'mass_lbs': 'mantle.mass_lbs', - } - actual = method.flattened_field_to_key - assert expected == actual + mass_kg = make_field(name='mass_kg', oneof='mass', type=5) + mass_lbs = make_field(name='mass_lbs', oneof='mass', type=5) + + length_m = make_field(name='length_m', oneof='length', type=5) + length_f = make_field(name='length_f', oneof='length', type=5) + + color = make_field(name='color', type=5) + mantle = make_field( + name='mantle', + message=make_message( + name='Mantle', + fields=( + make_field(name='color', type=5), + mass_kg, + mass_lbs, + ), + ), + ) + request = make_message( + name='CreateMolluscReuqest', + fields=( + length_m, + length_f, + color, + mantle, + ), + ) + method = make_method( + name='CreateMollusc', + input_message=request, + signatures=[ + 'length_m,', + 'length_f,', + 'mantle.mass_kg,', + 'mantle.mass_lbs,', + 'color', + ]) + + expected = {'mass': [mass_kg, mass_lbs], 'length': [length_m, length_f]} + actual = method.flattened_oneof_fields() + assert expected == actual + + # Check this method too becasue the setup is a lot of work. + expected = { + 'color': 'color', + 'length_m': 'length_m', + 'length_f': 'length_f', + 'mass_kg': 'mantle.mass_kg', + 'mass_lbs': 'mantle.mass_lbs', + } + actual = method.flattened_field_to_key + assert expected == actual From c5336f858ce43e155d0da0dacdf6a7151221048d Mon Sep 17 00:00:00 2001 From: Mira Leung Date: Wed, 12 May 2021 16:06:06 -0700 Subject: [PATCH 3/4] fix: style --- gapic/schema/wrappers.py | 553 ++++++++++++---------- test_utils/test_utils.py | 231 +++++---- tests/unit/schema/wrappers/test_method.py | 281 +++++------ 3 files changed, 552 insertions(+), 513 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 19403c442f..7a5d907da9 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Module containing wrapper classes around meta-descriptors. This module contains dataclasses which wrap the descriptor protos @@ -30,13 +31,13 @@ import dataclasses import re from itertools import chain -from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, ClassVar, - Optional, Sequence, Set, Tuple, Union) -from google.api import annotations_pb2 # type: ignore +from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, + ClassVar, Optional, Sequence, Set, Tuple, Union) +from google.api import annotations_pb2 # type: ignore from google.api import client_pb2 from google.api import field_behavior_pb2 from google.api import resource_pb2 -from google.api_core import exceptions # type: ignore +from google.api_core import exceptions # type: ignore from google.protobuf import descriptor_pb2 # type: ignore from google.protobuf.json_format import MessageToDict # type: ignore @@ -51,7 +52,8 @@ class Field: message: Optional['MessageType'] = None enum: Optional['EnumType'] = None meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) oneof: Optional[str] = None def __getattr__(self, name): @@ -67,7 +69,7 @@ def __hash__(self): def name(self) -> str: """Used to prevent collisions with python keywords""" name = self.field_pb.name - return name + '_' if name in utils.RESERVED_NAMES else name + return name + "_" if name in utils.RESERVED_NAMES else name @utils.cached_property def ident(self) -> metadata.FieldIdentifier: @@ -89,9 +91,9 @@ def map(self) -> bool: @utils.cached_property def mock_value(self) -> str: - visited_fields: Set['Field'] = set() + visited_fields: Set["Field"] = set() stack = [self] - answer = '{}' + answer = "{}" while stack: expr = stack.pop() answer = answer.format(expr.inner_mock(stack, visited_fields)) @@ -127,10 +129,13 @@ def inner_mock(self, stack, visited_fields): answer = f'{self.type.ident}.{mock_value.name}' # If this is another message, set one value on the message. - if (not self.map # Maps are handled separately - and isinstance(self.type, MessageType) and len(self.type.fields) - # Nested message types need to terminate eventually - and self not in visited_fields): + if ( + not self.map # Maps are handled separately + and isinstance(self.type, MessageType) + and len(self.type.fields) + # Nested message types need to terminate eventually + and self not in visited_fields + ): sub = next(iter(self.type.fields.values())) stack.append(sub) visited_fields.add(self) @@ -142,8 +147,8 @@ def inner_mock(self, stack, visited_fields): # Maps are a special case beacuse they're represented internally as # a list of a generated type with two fields: 'key' and 'value'. answer = '{{{}: {}}}'.format( - self.type.fields['key'].mock_value, - self.type.fields['value'].mock_value, + self.type.fields["key"].mock_value, + self.type.fields["value"].mock_value, ) elif self.repeated: # If this is a repeated field, then the mock answer should @@ -156,17 +161,17 @@ def inner_mock(self, stack, visited_fields): @property def proto_type(self) -> str: """Return the proto type constant to be used in templates.""" - return cast( - str, descriptor_pb2.FieldDescriptorProto.Type.Name( - self.field_pb.type,))[len('TYPE_'):] + return cast(str, descriptor_pb2.FieldDescriptorProto.Type.Name( + self.field_pb.type, + ))[len('TYPE_'):] @property def repeated(self) -> bool: """Return True if this is a repeated field, False otherwise. - Returns: - bool: Whether this field is repeated. - """ + Returns: + bool: Whether this field is repeated. + """ return self.label == \ descriptor_pb2.FieldDescriptorProto.Label.Value( 'LABEL_REPEATED') # type: ignore @@ -175,11 +180,11 @@ def repeated(self) -> bool: def required(self) -> bool: """Return True if this is a required field, False otherwise. - Returns: - bool: Whether this field is required. - """ - return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') - in self.options.Extensions[field_behavior_pb2.field_behavior]) + Returns: + bool: Whether this field is required. + """ + return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in + self.options.Extensions[field_behavior_pb2.field_behavior]) @utils.cached_property def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: @@ -216,17 +221,17 @@ def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: 'This code should not be reachable; please file a bug.') def with_context( - self, - *, - collisions: FrozenSet[str], - visited_messages: FrozenSet['MessageType'], + self, + *, + collisions: FrozenSet[str], + visited_messages: FrozenSet["MessageType"], ) -> 'Field': """Return a derivative of this field with the provided context. - This method is used to address naming collisions. The returned - ``Field`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Field`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, message=self.message.with_context( @@ -234,8 +239,8 @@ def with_context( skip_fields=self.message in visited_messages, visited_messages=visited_messages, ) if self.message else None, - enum=self.enum.with_context( - collisions=collisions) if self.enum else None, + enum=self.enum.with_context(collisions=collisions) + if self.enum else None, meta=self.meta.with_context(collisions=collisions), ) @@ -261,7 +266,8 @@ class MessageType: nested_enums: Mapping[str, 'EnumType'] nested_messages: Mapping[str, 'MessageType'] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) oneofs: Optional[Mapping[str, 'Oneof']] = None def __getattr__(self, name): @@ -282,14 +288,18 @@ def oneof_fields(self, include_optional=False): @utils.cached_property def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - answer = tuple(field.type - for field in self.fields.values() - if field.message or field.enum) + answer = tuple( + field.type + for field in self.fields.values() + if field.message or field.enum + ) return answer @utils.cached_property - def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + def recursive_field_types(self) -> Sequence[ + Union['MessageType', 'EnumType'] + ]: """Return all composite fields used in this proto's messages.""" types: Set[Union['MessageType', 'EnumType']] = set() @@ -308,13 +318,16 @@ def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: def recursive_resource_fields(self) -> FrozenSet[Field]: all_fields = chain( self.fields.values(), - (field for t in self.recursive_field_types - if isinstance(t, MessageType) for field in t.fields.values()), + (field + for t in self.recursive_field_types if isinstance(t, MessageType) + for field in t.fields.values()), ) return frozenset( - f for f in all_fields + f + for f in all_fields if (f.options.Extensions[resource_pb2.resource_reference].type or - f.options.Extensions[resource_pb2.resource_reference].child_type)) + f.options.Extensions[resource_pb2.resource_reference].child_type) + ) @property def map(self) -> bool: @@ -329,11 +342,11 @@ def ident(self) -> metadata.Address: @property def resource_path(self) -> Optional[str]: """If this message describes a resource, return the path to the resource. - - If there are multiple paths, returns the first one. - """ + If there are multiple paths, returns the first one.""" return next( - iter(self.options.Extensions[resource_pb2.resource].pattern), None) + iter(self.options.Extensions[resource_pb2.resource].pattern), + None + ) @property def resource_type(self) -> Optional[str]: @@ -354,38 +367,41 @@ def path_regex_str(self) -> str: # becomes the regex # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ parsing_regex_str = ( - '^' + self.PATH_ARG_RE.sub( + "^" + + self.PATH_ARG_RE.sub( # We can't just use (?P[^/]+) because segments may be # separated by delimiters other than '/'. # Multiple delimiter characters within one schema are allowed, # e.g. # as/{a}-{b}/cs/{c}%{d}_{e} # This is discouraged but permitted by AIP4231 - lambda m: '(?P<{name}>.+?)'.format(name=m.groups()[0]), - self.resource_path or '') + '$') + lambda m: "(?P<{name}>.+?)".format(name=m.groups()[0]), + self.resource_path or '' + ) + + "$" + ) return parsing_regex_str - def get_field( - self, *field_path: str, - collisions: FrozenSet[str] = frozenset()) -> Field: + def get_field(self, *field_path: str, + collisions: FrozenSet[str] = frozenset()) -> Field: """Return a field arbitrarily deep in this message's structure. - This method recursively traverses the message tree to return the - requested inner-field. + This method recursively traverses the message tree to return the + requested inner-field. - Traversing through repeated fields is not supported; a repeated field - may be specified if and only if it is the last field in the path. + Traversing through repeated fields is not supported; a repeated field + may be specified if and only if it is the last field in the path. - Args: - field_path (Sequence[str]): The field path. + Args: + field_path (Sequence[str]): The field path. - Returns: - ~.Field: A field object. + Returns: + ~.Field: A field object. - Raises: - KeyError: If a repeated field is used in the non-terminal position - in the path. - """ + Raises: + KeyError: If a repeated field is used in the non-terminal position + in the path. + """ # If collisions are not explicitly specified, retrieve them # from this message's address. # This ensures that calls to `get_field` will return a field with @@ -415,42 +431,43 @@ def get_field( '`get_field` to retrieve its children.\n' 'This exception usually indicates that a ' 'google.api.method_signature annotation uses a repeated field ' - 'in the fields list in a position other than the end.',) + 'in the fields list in a position other than the end.', + ) # Sanity check: If this cursor has no message, there is a problem. if not cursor.message: raise KeyError( f'Field {".".join(field_path)} could not be resolved from ' - f'{cursor.name}.',) + f'{cursor.name}.', + ) # Recursion case: Pass the remainder of the path to the sub-field's # message. return cursor.message.get_field(*field_path[1:], collisions=collisions) - def with_context( - self, - *, - collisions: FrozenSet[str], - skip_fields: bool = False, - visited_messages: FrozenSet['MessageType'] = frozenset(), - ) -> 'MessageType': + def with_context(self, *, + collisions: FrozenSet[str], + skip_fields: bool = False, + visited_messages: FrozenSet["MessageType"] = frozenset(), + ) -> 'MessageType': """Return a derivative of this message with the provided context. - This method is used to address naming collisions. The returned - ``MessageType`` object aliases module names to avoid naming collisions - in the file being written. + This method is used to address naming collisions. The returned + ``MessageType`` object aliases module names to avoid naming collisions + in the file being written. - The ``skip_fields`` argument will omit applying the context to the - underlying fields. This provides for an "exit" in the case of circular - references. - """ + The ``skip_fields`` argument will omit applying the context to the + underlying fields. This provides for an "exit" in the case of circular + references. + """ visited_messages = visited_messages | {self} return dataclasses.replace( self, fields={ k: v.with_context( - collisions=collisions, visited_messages=visited_messages) - for k, v in self.fields.items() + collisions=collisions, + visited_messages=visited_messages + ) for k, v in self.fields.items() } if not skip_fields else self.fields, nested_enums={ k: v.with_context(collisions=collisions) @@ -461,7 +478,8 @@ def with_context( collisions=collisions, skip_fields=skip_fields, visited_messages=visited_messages, - ) for k, v in self.nested_messages.items() + ) + for k, v in self.nested_messages.items() }, meta=self.meta.with_context(collisions=collisions), ) @@ -472,7 +490,8 @@ class EnumValueType: """Description of an enum value.""" enum_value_pb: descriptor_pb2.EnumValueDescriptorProto meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __getattr__(self, name): return getattr(self.enum_value_pb, name) @@ -484,7 +503,8 @@ class EnumType: enum_pb: descriptor_pb2.EnumDescriptorProto values: List[EnumValueType] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __hash__(self): # Identity is sufficiently unambiguous. @@ -508,10 +528,10 @@ def ident(self) -> metadata.Address: def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': """Return a derivative of this enum with the provided context. - This method is used to address naming collisions. The returned - ``EnumType`` object aliases module names to avoid naming collisions in - the file being written. - """ + This method is used to address naming collisions. The returned + ``EnumType`` object aliases module names to avoid naming collisions in + the file being written. + """ return dataclasses.replace( self, meta=self.meta.with_context(collisions=collisions), @@ -521,20 +541,23 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': def options_dict(self) -> Dict: """Return the EnumOptions (if present) as a dict. - This is a hack to support a pythonic structure representation for - the generator templates. - """ - return MessageToDict(self.enum_pb.options, preserving_proto_field_name=True) + This is a hack to support a pythonic structure representation for + the generator templates. + """ + return MessageToDict( + self.enum_pb.options, + preserving_proto_field_name=True + ) @dataclasses.dataclass(frozen=True) class PythonType: """Wrapper class for Python types. - This exists for interface consistency, so that methods like - :meth:`Field.type` can return an object and the caller can be confident - that a ``name`` property will be present. - """ + This exists for interface consistency, so that methods like + :meth:`Field.type` can return an object and the caller can be confident + that a ``name`` property will be present. + """ meta: metadata.Metadata def __eq__(self, other): @@ -566,22 +589,19 @@ class PrimitiveType(PythonType): def build(cls, primitive_type: Optional[type]): """Return a PrimitiveType object for the given Python primitive type. - Args: - primitive_type (cls): A Python primitive type, such as :class:`int` - or :class:`str`. Despite not being a type, ``None`` is also - accepted here. + Args: + primitive_type (cls): A Python primitive type, such as + :class:`int` or :class:`str`. Despite not being a type, + ``None`` is also accepted here. - Returns: - ~.PrimitiveType: The instantiated PrimitiveType object. - """ + Returns: + ~.PrimitiveType: The instantiated PrimitiveType object. + """ # Primitives have no import, and no module to reference, so the # address just uses the name of the class (e.g. "int", "str"). - return cls( - meta=metadata.Metadata( - address=metadata.Address( - name='None' if primitive_type is None else primitive_type - .__name__,)), - python_type=primitive_type) + return cls(meta=metadata.Metadata(address=metadata.Address( + name='None' if primitive_type is None else primitive_type.__name__, + )), python_type=primitive_type) def __eq__(self, other): # If we are sent the actual Python type (not the PrimitiveType object), @@ -600,17 +620,18 @@ class OperationInfo: def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': """Return a derivative of this OperationInfo with the provided context. - This method is used to address naming collisions. The returned - ``OperationInfo`` object aliases module names to avoid naming - collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``OperationInfo`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, response_type=self.response_type.with_context( - collisions=collisions), + collisions=collisions + ), metadata_type=self.metadata_type.with_context( - collisions=collisions), + collisions=collisions + ), ) @@ -634,7 +655,8 @@ class Method: retry: Optional[RetryInfo] = dataclasses.field(default=None) timeout: Optional[float] = None meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __getattr__(self, name): return getattr(self.method_pb, name) @@ -659,13 +681,13 @@ def flattened_oneof_fields(self, include_optional=False): def _client_output(self, enable_asyncio: bool): """Return the output from the client layer. - This takes into account transformations made by the outer GAPIC - client to transform the output from the transport. + This takes into account transformations made by the outer GAPIC + client to transform the output from the transport. - Returns: - Union[~.MessageType, ~.PythonType]: - A description of the return type. - """ + Returns: + Union[~.MessageType, ~.PythonType]: + A description of the return type. + """ # Void messages ultimately return None. if self.void: return PrimitiveType.build(None) @@ -673,44 +695,41 @@ def _client_output(self, enable_asyncio: bool): # If this method is an LRO, return a PythonType instance representing # that. if self.lro: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name='AsyncOperation' if enable_asyncio else 'Operation', - module='operation_async' if enable_asyncio else 'operation', - package=('google', 'api_core'), - collisions=self.lro.response_type.ident.collisions, + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name='AsyncOperation' if enable_asyncio else 'Operation', + module='operation_async' if enable_asyncio else 'operation', + package=('google', 'api_core'), + collisions=self.lro.response_type.ident.collisions, + ), + documentation=utils.doc( + 'An object representing a long-running operation. \n\n' + 'The result type for the operation will be ' + ':class:`{ident}` {doc}'.format( + doc=self.lro.response_type.meta.doc, + ident=self.lro.response_type.ident.sphinx, ), - documentation=utils.doc( - 'An object representing a long-running operation. \n\n' - 'The result type for the operation will be ' - ':class:`{ident}` {doc}'.format( - doc=self.lro.response_type.meta.doc, - ident=self.lro.response_type.ident.sphinx, - ),), - )) + ), + )) # If this method is paginated, return that method's pager class. if self.paged_result_field: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name=f'{self.name}AsyncPager' - if enable_asyncio else f'{self.name}Pager', - package=self.ident.api_naming.module_namespace + - (self.ident.api_naming.versioned_module_name,) + - self.ident.subpackage + ( - 'services', - utils.to_snake_case(self.ident.parent[-1]), - ), - module='pagers', - collisions=self.input.ident.collisions, + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager', + package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( + 'services', + utils.to_snake_case(self.ident.parent[-1]), ), - documentation=utils.doc( - f'{self.output.meta.doc}\n\n' - 'Iterating over this object will yield results and ' - 'resolve additional pages automatically.',), - )) + module='pagers', + collisions=self.input.ident.collisions, + ), + documentation=utils.doc( + f'{self.output.meta.doc}\n\n' + 'Iterating over this object will yield results and ' + 'resolve additional pages automatically.', + ), + )) # Return the usual output. return self.output @@ -720,6 +739,7 @@ def is_deprecated(self) -> bool: """Returns true if the method is deprecated, false otherwise.""" return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') + # TODO(yon-mg): remove or rewrite: don't think it performs as intended # e.g. doesn't work with basic case of gRPC transcoding @property @@ -738,18 +758,17 @@ def field_headers(self) -> Sequence[str]: http.custom.path, ] - return next( - (tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) + return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) @property def http_opt(self) -> Optional[Dict[str, str]]: """Return the http option for this method. - e.g. {'verb': 'post' - 'url': '/some/path' - 'body': '*'} + e.g. {'verb': 'post' + 'url': '/some/path' + 'body': '*'} - """ + """ http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] http = self.options.Extensions[annotations_pb2.http].ListFields() @@ -817,8 +836,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: signatures = self.options.Extensions[client_pb2.method_signature] answer: Dict[str, Field] = collections.OrderedDict( - name_and_field for sig in signatures - for name_and_field in filter_fields(sig)) + name_and_field + for sig in signatures + for name_and_field in filter_fields(sig) + ) return answer @@ -829,13 +850,11 @@ def flattened_field_to_key(self): @utils.cached_property def legacy_flattened_fields(self) -> Mapping[str, Field]: """Return the legacy flattening interface: top level fields only, - - required fields first - """ + required fields first""" required, optional = utils.partition(lambda f: f.required, self.input.fields.values()) - return collections.OrderedDict( - (f.name, f) for f in chain(required, optional)) + return collections.OrderedDict((f.name, f) + for f in chain(required, optional)) @property def grpc_stub_type(self) -> str: @@ -850,10 +869,10 @@ def grpc_stub_type(self) -> str: def idempotent(self) -> bool: """Return True if we know this method is idempotent, False otherwise. - Note: We are intentionally conservative here. It is far less bad - to falsely believe an idempotent method is non-idempotent than - the converse. - """ + Note: We are intentionally conservative here. It is far less bad + to falsely believe an idempotent method is non-idempotent than + the converse. + """ return bool(self.options.Extensions[annotations_pb2.http].get) @property @@ -877,7 +896,8 @@ def paged_result_field(self) -> Optional[Field]: # The request must have max_results or page_size page_fields = (self.input.fields.get('max_results', None), self.input.fields.get('page_size', None)) - page_field_size = next((field for field in page_fields if field), None) + page_field_size = next( + (field for field in page_fields if field), None) if not page_field_size or page_field_size.type != int: return None @@ -897,14 +917,18 @@ def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: return self._ref_types(False) - def _ref_types(self, - recursive: bool) -> Sequence[Union[MessageType, EnumType]]: + def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: """Return types referenced by this method.""" # Begin with the input (request) and output (response) messages. answer: List[Union[MessageType, EnumType]] = [self.input] types: Iterable[Union[MessageType, EnumType]] = ( - self.input.recursive_field_types if recursive else - (f.type for f in self.flattened_fields.values() if f.message or f.enum)) + self.input.recursive_field_types if recursive + else ( + f.type + for f in self.flattened_fields.values() + if f.message or f.enum + ) + ) answer.extend(types) if not self.void: @@ -935,14 +959,15 @@ def void(self) -> bool: def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': """Return a derivative of this method with the provided context. - This method is used to address naming collisions. The returned - ``Method`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Method`` object aliases module names to avoid naming collisions + in the file being written. + """ maybe_lro = None if self.lro: maybe_lro = self.lro.with_context( - collisions=collisions) if collisions else self.lro + collisions=collisions + ) if collisions else self.lro return dataclasses.replace( self, @@ -960,7 +985,10 @@ class CommonResource: @classmethod def build(cls, resource: resource_pb2.ResourceDescriptor): - return cls(type_name=resource.type, pattern=next(iter(resource.pattern))) + return cls( + type_name=resource.type, + pattern=next(iter(resource.pattern)) + ) @utils.cached_property def message_type(self): @@ -987,35 +1015,31 @@ class Service: # This is represented by a types.MappingProxyType instance. visible_resources: Mapping[str, MessageType] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( default={ - 'cloudresourcemanager.googleapis.com/Project': - CommonResource( - 'cloudresourcemanager.googleapis.com/Project', - 'projects/{project}', - ), - 'cloudresourcemanager.googleapis.com/Organization': - CommonResource( - 'cloudresourcemanager.googleapis.com/Organization', - 'organizations/{organization}', - ), - 'cloudresourcemanager.googleapis.com/Folder': - CommonResource( - 'cloudresourcemanager.googleapis.com/Folder', - 'folders/{folder}', - ), - 'cloudbilling.googleapis.com/BillingAccount': - CommonResource( - 'cloudbilling.googleapis.com/BillingAccount', - 'billingAccounts/{billing_account}', - ), - 'locations.googleapis.com/Location': - CommonResource( - 'locations.googleapis.com/Location', - 'projects/{project}/locations/{location}', - ), + "cloudresourcemanager.googleapis.com/Project": CommonResource( + "cloudresourcemanager.googleapis.com/Project", + "projects/{project}", + ), + "cloudresourcemanager.googleapis.com/Organization": CommonResource( + "cloudresourcemanager.googleapis.com/Organization", + "organizations/{organization}", + ), + "cloudresourcemanager.googleapis.com/Folder": CommonResource( + "cloudresourcemanager.googleapis.com/Folder", + "folders/{folder}", + ), + "cloudbilling.googleapis.com/BillingAccount": CommonResource( + "cloudbilling.googleapis.com/BillingAccount", + "billingAccounts/{billing_account}", + ), + "locations.googleapis.com/Location": CommonResource( + "locations.googleapis.com/Location", + "projects/{project}/locations/{location}", + ), }, init=False, compare=False, @@ -1027,28 +1051,28 @@ def __getattr__(self, name): @property def client_name(self) -> str: """Returns the name of the generated client class""" - return self.name + 'Client' + return self.name + "Client" @property def async_client_name(self) -> str: """Returns the name of the generated AsyncIO client class""" - return self.name + 'AsyncClient' + return self.name + "AsyncClient" @property def transport_name(self): - return self.name + 'Transport' + return self.name + "Transport" @property def grpc_transport_name(self): - return self.name + 'GrpcTransport' + return self.name + "GrpcTransport" @property def grpc_asyncio_transport_name(self): - return self.name + 'GrpcAsyncIOTransport' + return self.name + "GrpcAsyncIOTransport" @property def rest_transport_name(self): - return self.name + 'RestTransport' + return self.name + "RestTransport" @property def has_lro(self) -> bool: @@ -1064,61 +1088,61 @@ def has_pagers(self) -> bool: def host(self) -> str: """Return the hostname for this service, if specified. - Returns: - str: The hostname, with no protocol and no trailing ``/``. - """ + Returns: + str: The hostname, with no protocol and no trailing ``/``. + """ if self.options.Extensions[client_pb2.default_host]: return self.options.Extensions[client_pb2.default_host] return '' @property def shortname(self) -> str: - """Return the API short name. + """Return the API short name. DRIFT uses this to identify + APIs. - DRIFT uses this to identify - APIs. - - Returns: - str: The api shortname. - """ + Returns: + str: The api shortname. + """ # Get the shortname from the host # Real APIs are expected to have format: # "{api_shortname}.googleapis.com" - return self.host.split('.')[0] + return self.host.split(".")[0] @property def oauth_scopes(self) -> Sequence[str]: """Return a sequence of oauth scopes, if applicable. - Returns: - Sequence[str]: A sequence of OAuth scopes. - """ + Returns: + Sequence[str]: A sequence of OAuth scopes. + """ # Return the OAuth scopes, split on comma. return tuple( i.strip() for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') - if i) + if i + ) @property def module_name(self) -> str: """Return the appropriate module name for this service. - Returns: - str: The service name, in snake case. - """ + Returns: + str: The service name, in snake case. + """ return utils.to_snake_case(self.name) @utils.cached_property def names(self) -> FrozenSet[str]: """Return a set of names used in this service. - This is used for detecting naming collisions in the module names - used for imports. - """ + This is used for detecting naming collisions in the module names + used for imports. + """ # Put together a set of the service and method names. answer = {self.name, self.client_name, self.async_client_name} - answer.update(utils.to_snake_case(i.name) - for i in self.methods.values()) + answer.update( + utils.to_snake_case(i.name) for i in self.methods.values() + ) # Identify any import module names where the same module name is used # from distinct packages. @@ -1127,8 +1151,11 @@ def names(self) -> FrozenSet[str]: for t in m.ref_types: modules[t.ident.module].add(t.ident.package) - answer.update(module_name for module_name, packages in modules.items() - if len(packages) > 1) + answer.update( + module_name + for module_name, packages in modules.items() + if len(packages) > 1 + ) # Done; return the answer. return frozenset(answer) @@ -1136,10 +1163,7 @@ def names(self) -> FrozenSet[str]: @utils.cached_property def resource_messages(self) -> FrozenSet[MessageType]: """Returns all the resource message types used in all - - request and response fields in the service. - """ - + request and response fields in the service.""" def gen_resources(message): if message.resource_path: yield message @@ -1150,7 +1174,8 @@ def gen_resources(message): def gen_indirect_resources_used(message): for field in message.recursive_resource_fields: - resource = field.options.Extensions[resource_pb2.resource_reference] + resource = field.options.Extensions[ + resource_pb2.resource_reference] resource_type = resource.type or resource.child_type # The resource may not be visible if the resource type is one of # the common_resources (see the class var in class definition) @@ -1159,14 +1184,20 @@ def gen_indirect_resources_used(message): if resource: yield resource - return frozenset(msg for method in self.methods.values() for msg in chain( - gen_resources(method.input), - gen_resources( - method.lro.response_type if method.lro else method.output), - gen_indirect_resources_used(method.input), - gen_indirect_resources_used( - method.lro.response_type if method.lro else method.output), - )) + return frozenset( + msg + for method in self.methods.values() + for msg in chain( + gen_resources(method.input), + gen_resources( + method.lro.response_type if method.lro else method.output + ), + gen_indirect_resources_used(method.input), + gen_indirect_resources_used( + method.lro.response_type if method.lro else method.output + ), + ) + ) @utils.cached_property def any_client_streaming(self) -> bool: @@ -1179,10 +1210,10 @@ def any_server_streaming(self) -> bool: def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': """Return a derivative of this service with the provided context. - This method is used to address naming collisions. The returned - ``Service`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Service`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, methods={ diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index 69c3b7cf07..a499606f49 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -25,12 +25,13 @@ def make_service( - name: str = 'Placeholder', - host: str = '', + name: str = "Placeholder", + host: str = "", methods: typing.Tuple[wrappers.Method] = (), scopes: typing.Tuple[str] = (), - visible_resources: typing.Optional[typing.Mapping[ - str, wrappers.CommonResource]] = None, + visible_resources: typing.Optional[ + typing.Mapping[str, wrappers.CommonResource] + ] = None, ) -> wrappers.Service: visible_resources = visible_resources or {} # Define a service descriptor, and set a host and oauth scopes if @@ -55,8 +56,7 @@ def make_service_with_method_options( http_rule: http_pb2.HttpRule = None, method_signature: str = '', in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - visible_resources: typing.Optional[typing.Mapping[ - str, wrappers.CommonResource]] = None, + visible_resources: typing.Optional[typing.Mapping[str, wrappers.CommonResource]] = None, ) -> wrappers.Service: # Declare a method with options enabled for long-running operations and # field headers. @@ -82,17 +82,15 @@ def make_service_with_method_options( ) -def get_method( - name: str, - in_type: str, - out_type: str, - lro_response_type: str = '', - lro_metadata_type: str = '', - *, - in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - http_rule: http_pb2.HttpRule = None, - method_signature: str = '', -) -> wrappers.Method: +def get_method(name: str, + in_type: str, + out_type: str, + lro_response_type: str = '', + lro_metadata_type: str = '', *, + in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), + http_rule: http_pb2.HttpRule = None, + method_signature: str = '', + ) -> wrappers.Method: input_ = get_message(in_type, fields=in_fields) output = get_message(out_type) lro = None @@ -124,11 +122,9 @@ def get_method( ) -def get_message( - dot_path: str, - *, - fields: typing.Tuple[desc.FieldDescriptorProto] = (), -) -> wrappers.MessageType: +def get_message(dot_path: str, *, + fields: typing.Tuple[desc.FieldDescriptorProto] = (), + ) -> wrappers.MessageType: # Pass explicit None through (for lro_metadata). if dot_path is None: return None @@ -143,33 +139,32 @@ def get_message( pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] return wrappers.MessageType( - fields={ - i.name: wrappers.Field( - field_pb=i, - enum=get_enum(i.type_name) if i.type_name else None, - ) for i in fields - }, + fields={i.name: wrappers.Field( + field_pb=i, + enum=get_enum(i.type_name) if i.type_name else None, + ) for i in fields}, nested_messages={}, nested_enums={}, message_pb=desc.DescriptorProto(name=name, field=fields), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), ) -def make_method(name: str, - input_message: wrappers.MessageType = None, - output_message: wrappers.MessageType = None, - package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', - module: str = 'baz', - http_rule: http_pb2.HttpRule = None, - signatures: typing.Sequence[str] = (), - is_deprecated: bool = False, - **kwargs) -> wrappers.Method: +def make_method( + name: str, + input_message: wrappers.MessageType = None, + output_message: wrappers.MessageType = None, + package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', + module: str = 'baz', + http_rule: http_pb2.HttpRule = None, + signatures: typing.Sequence[str] = (), + is_deprecated: bool = False, + **kwargs +) -> wrappers.Method: # Use default input and output messages if they are not provided. input_message = input_message or make_message('MethodInput') output_message = output_message or make_message('MethodOutput') @@ -179,7 +174,8 @@ def make_method(name: str, name=name, input_type=str(input_message.meta.address), output_type=str(output_message.meta.address), - **kwargs) + **kwargs + ) # If there is an HTTP rule, process it. if http_rule: @@ -202,31 +198,32 @@ def make_method(name: str, method_pb=method_pb, input=input_message, output=output_message, - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=package, - module=module, - parent=(f'{name}Service',), - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=package, + module=module, + parent=(f'{name}Service',), + )), ) -def make_field(name: str = 'my_field', - number: int = 1, - repeated: bool = False, - message: wrappers.MessageType = None, - enum: wrappers.EnumType = None, - meta: metadata.Metadata = None, - oneof: str = None, - **kwargs) -> wrappers.Field: +def make_field( + name: str = 'my_field', + number: int = 1, + repeated: bool = False, + message: wrappers.MessageType = None, + enum: wrappers.EnumType = None, + meta: metadata.Metadata = None, + oneof: str = None, + **kwargs +) -> wrappers.Field: T = desc.FieldDescriptorProto.Type if message: kwargs.setdefault('type_name', str(message.meta.address)) kwargs['type'] = 'TYPE_MESSAGE' elif enum: - kwargs.setdefault('type_name', str(enum.meta.address)) + kwargs.setdefault('type_name', str(enum.meta.address)) kwargs['type'] = 'TYPE_ENUM' else: kwargs.setdefault('type', T.Value('TYPE_BOOL')) @@ -236,7 +233,11 @@ def make_field(name: str = 'my_field', label = kwargs.pop('label', 3 if repeated else 1) field_pb = desc.FieldDescriptorProto( - name=name, label=label, number=number, **kwargs) + name=name, + label=label, + number=number, + **kwargs + ) return wrappers.Field( field_pb=field_pb, @@ -265,12 +266,11 @@ def make_message( fields=collections.OrderedDict((i.name, i) for i in fields), nested_messages={}, nested_enums={}, - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), + meta=meta or metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), ) @@ -279,12 +279,11 @@ def get_enum(dot_path: str) -> wrappers.EnumType: pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] return wrappers.EnumType( enum_pb=desc.EnumDescriptorProto(name=name), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), values=[], ) @@ -298,7 +297,8 @@ def make_enum( options: desc.EnumOptions = None, ) -> wrappers.EnumType: enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=i[0], number=i[1]) for i in values + desc.EnumValueDescriptorProto(name=i[0], number=i[1]) + for i in values ] enum_pb = desc.EnumDescriptorProto( name=name, @@ -307,15 +307,13 @@ def make_enum( ) return wrappers.EnumType( enum_pb=enum_pb, - values=[ - wrappers.EnumValueType(enum_value_pb=evpb) for evpb in enum_value_pbs - ], - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), + values=[wrappers.EnumValueType(enum_value_pb=evpb) + for evpb in enum_value_pbs], + meta=meta or metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), ) @@ -327,31 +325,33 @@ def make_naming(**kwargs) -> naming.Naming: return naming.NewNaming(**kwargs) -def make_enum_pb2(name: str, *values: typing.Sequence[str], - **kwargs) -> desc.EnumDescriptorProto: +def make_enum_pb2( + name: str, + *values: typing.Sequence[str], + **kwargs +) -> desc.EnumDescriptorProto: enum_value_pbs = [ desc.EnumValueDescriptorProto(name=n, number=i) for i, n in enumerate(values) ] - enum_pb = desc.EnumDescriptorProto( - name=name, value=enum_value_pbs, **kwargs) + enum_pb = desc.EnumDescriptorProto(name=name, value=enum_value_pbs, **kwargs) return enum_pb -def make_message_pb2(name: str, - fields: tuple = (), - oneof_decl: tuple = (), - **kwargs) -> desc.DescriptorProto: - return desc.DescriptorProto( - name=name, field=fields, oneof_decl=oneof_decl, **kwargs) - - -def make_field_pb2( +def make_message_pb2( name: str, - number: int, - type: int = 11, # 11 == message - type_name: str = None, - oneof_index: int = None) -> desc.FieldDescriptorProto: + fields: tuple = (), + oneof_decl: tuple = (), + **kwargs +) -> desc.DescriptorProto: + return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs) + + +def make_field_pb2(name: str, number: int, + type: int = 11, # 11 == message + type_name: str = None, + oneof_index: int = None + ) -> desc.FieldDescriptorProto: return desc.FieldDescriptorProto( name=name, number=number, @@ -360,20 +360,18 @@ def make_field_pb2( oneof_index=oneof_index, ) - def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto: - return desc.OneofDescriptorProto(name=name,) + return desc.OneofDescriptorProto( + name=name, + ) -def make_file_pb2( - name: str = 'my_proto.proto', - package: str = 'example.v1', - *, - messages: typing.Sequence[desc.DescriptorProto] = (), - enums: typing.Sequence[desc.EnumDescriptorProto] = (), - services: typing.Sequence[desc.ServiceDescriptorProto] = (), - locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), -) -> desc.FileDescriptorProto: +def make_file_pb2(name: str = 'my_proto.proto', package: str = 'example.v1', *, + messages: typing.Sequence[desc.DescriptorProto] = (), + enums: typing.Sequence[desc.EnumDescriptorProto] = (), + services: typing.Sequence[desc.ServiceDescriptorProto] = (), + locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), + ) -> desc.FileDescriptorProto: return desc.FileDescriptorProto( name=name, package=package, @@ -385,14 +383,15 @@ def make_file_pb2( def make_doc_meta( - *, - leading: str = '', - trailing: str = '', - detached: typing.List[str] = [], + *, + leading: str = '', + trailing: str = '', + detached: typing.List[str] = [], ) -> desc.SourceCodeInfo.Location: return metadata.Metadata( documentation=desc.SourceCodeInfo.Location( leading_comments=leading, trailing_comments=trailing, leading_detached_comments=detached, - ),) + ), + ) diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 6168f58564..c13a9afb28 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -33,8 +33,8 @@ def test_method_types(): input_msg = make_message(name='Input', module='baz') output_msg = make_message(name='Output', module='baz') - method = make_method( - 'DoSomething', input_msg, output_msg, package='foo.bar', module='bacon') + method = make_method('DoSomething', input_msg, output_msg, + package='foo.bar', module='bacon') assert method.name == 'DoSomething' assert method.input.name == 'Input' assert method.output.name == 'Output' @@ -71,22 +71,19 @@ def test_method_client_output_empty(): def test_method_client_output_paged(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - parent = make_field(name='parent', type=9) # str - page_size = make_field(name='page_size', type=5) # int + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int page_token = make_field(name='page_token', type=9) # str - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - page_size, - page_token, - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + page_size, + page_token, + )) + output_msg = make_message(name='ListFoosResponse', fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) method = make_method( 'ListFoos', input_message=input_msg, @@ -96,12 +93,11 @@ def test_method_client_output_paged(): assert method.client_output.ident.name == 'ListFoosPager' max_results = make_field(name='max_results', type=5) # int - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - max_results, - page_token, - )) + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) method = make_method( 'ListFoos', input_message=input_msg, @@ -119,47 +115,36 @@ def test_method_client_output_async_empty(): def test_method_paged_result_field_not_first(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='next_page_token', type=9), # str - paged, - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + input_msg = make_message(name='ListFoosRequest', fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message(name='ListFoosResponse', fields=( + make_field(name='next_page_token', type=9), # str + paged, + )) + method = make_method('ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field == paged def test_method_paged_result_field_no_page_field(): - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='foos', message=make_message( - 'Foo'), repeated=False), - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + input_msg = make_message(name='ListFoosRequest', fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message(name='ListFoosResponse', fields=( + make_field(name='foos', message=make_message('Foo'), repeated=False), + make_field(name='next_page_token', type=9), # str + )) + method = make_method('ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field is None method = make_method( @@ -171,7 +156,8 @@ def test_method_paged_result_field_no_page_field(): output_message=make_message( name='FooResponse', fields=(make_field(name='next_page_token', type=9),) # str - )) + ) + ) assert method.paged_result_field is None @@ -179,8 +165,8 @@ def test_method_paged_result_ref_types(): input_msg = make_message( name='ListSquidsRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int make_field(name='page_token', type=9), # str ), module='squid', @@ -192,12 +178,14 @@ def test_method_paged_result_ref_types(): make_field(name='molluscs', message=mollusc_msg, repeated=True), make_field(name='next_page_token', type=9) # str ), - module='mollusc') + module='mollusc' + ) method = make_method( 'ListSquids', input_message=input_msg, output_message=output_msg, - module='squid') + module='squid' + ) ref_type_names = {t.name for t in method.ref_types} assert ref_type_names == { @@ -233,7 +221,12 @@ def test_flattened_ref_types(): ), ), ), - make_field('stratum', enum=make_enum('Stratum',)), + make_field( + 'stratum', + enum=make_enum( + 'Stratum', + ) + ), ), ), signatures=('cephalopod.squid,stratum',), @@ -251,21 +244,19 @@ def test_flattened_ref_types(): def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) # str + paged = make_field(name='squids', type=9, repeated=True) # str input_msg = make_message( name='ListSquidsRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int make_field(name='page_token', type=9), # str ), ) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) + output_msg = make_message(name='ListFoosResponse', fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) method = make_method( 'ListSquids', input_message=input_msg, @@ -297,15 +288,15 @@ def test_method_field_headers_present(): def test_method_http_opt(): http_rule = http_pb2.HttpRule( - post='/v1/{parent=projects/*}/topics', body='*') + post='/v1/{parent=projects/*}/topics', + body='*' + ) method = make_method('DoSomething', http_rule=http_rule) assert method.http_opt == { 'verb': 'post', 'url': '/v1/{parent=projects/*}/topics', 'body': '*' } - - # TODO(yon-mg) to test: grpc transcoding, # correct handling of path/query params # correct handling of body & additional binding @@ -339,13 +330,20 @@ def test_method_path_params_no_http_rule(): def test_method_query_params(): # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics', body='address') + http_rule = http_pb2.HttpRule( + post='/v1/{project}/topics', + body='address' + ) input_message = make_message( 'MethodInput', - fields=(make_field('region'), make_field('project'), - make_field('address'))) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) + fields=( + make_field('region'), + make_field('project'), + make_field('address') + ) + ) + method = make_method('DoSomething', http_rule=http_rule, + input_message=input_message) assert method.query_params == {'region'} @@ -353,12 +351,14 @@ def test_method_query_params_no_body(): # tests only the basic case of grpc transcoding http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') input_message = make_message( - 'MethodInput', fields=( + 'MethodInput', + fields=( make_field('region'), make_field('project'), - )) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) + ) + ) + method = make_method('DoSomething', http_rule=http_rule, + input_message=input_message) assert method.query_params == {'region'} @@ -432,22 +432,18 @@ def test_method_flattened_fields_different_package_non_primitive(): # directly to its fields, which complicates request construction. # The easiest solution in this case is to just prohibit these fields # in the method flattening. - message = make_message( - 'Mantle', package='mollusc.cephalopod.v1', module='squid') - mantle = make_field( - 'mantle', type=11, type_name='Mantle', message=message, meta=message.meta) + message = make_message('Mantle', + package="mollusc.cephalopod.v1", module="squid") + mantle = make_field('mantle', type=11, type_name='Mantle', + message=message, meta=message.meta) arms_count = make_field('arms_count', type=5, meta=message.meta) input_message = make_message( - 'Squid', - fields=(mantle, arms_count), - package='.'.join(message.meta.address.package), - module=message.meta.address.module) - method = make_method( - 'PutSquid', - input_message=input_message, - package='remote.package.v1', - module='module', - signatures=('mantle,arms_count',)) + 'Squid', fields=(mantle, arms_count), + package=".".join(message.meta.address.package), + module=message.meta.address.module + ) + method = make_method('PutSquid', input_message=input_message, + package="remote.package.v1", module="module", signatures=("mantle,arms_count",)) assert set(method.flattened_fields) == {'arms_count'} @@ -463,63 +459,75 @@ def test_method_include_flattened_message_fields(): def test_method_legacy_flattened_fields(): required_options = descriptor_pb2.FieldOptions() required_options.Extensions[field_behavior_pb2.field_behavior].append( - field_behavior_pb2.FieldBehavior.Value('REQUIRED')) + field_behavior_pb2.FieldBehavior.Value("REQUIRED")) # Cephalopods are required. - squid = make_field(name='squid', options=required_options) + squid = make_field(name="squid", options=required_options) octopus = make_field( - name='octopus', + name="octopus", message=make_message( - name='Octopus', - fields=[make_field(name='mass', options=required_options)]), + name="Octopus", + fields=[make_field(name="mass", options=required_options)] + ), options=required_options) # Bivalves are optional. - clam = make_field(name='clam') + clam = make_field(name="clam") oyster = make_field( - name='oyster', + name="oyster", message=make_message( - name='Oyster', fields=[make_field(name='has_pearl')])) + name="Oyster", + fields=[make_field(name="has_pearl")] + ) + ) # Interleave required and optional fields to make sure # that, in the legacy flattening, required fields are always first. - request = make_message('request', fields=[squid, clam, octopus, oyster]) + request = make_message("request", fields=[squid, clam, octopus, oyster]) method = make_method( - name='CreateMolluscs', + name="CreateMolluscs", input_message=request, # Signatures should be ignored. - signatures=['squid,octopus.mass', 'squid,octopus,oyster.has_pearl']) + signatures=[ + "squid,octopus.mass", + "squid,octopus,oyster.has_pearl" + ] + ) # Use an ordered dict because ordering is important: # required fields should come first. - expected = collections.OrderedDict([('squid', squid), ('octopus', octopus), - ('clam', clam), ('oyster', oyster)]) + expected = collections.OrderedDict([ + ("squid", squid), + ("octopus", octopus), + ("clam", clam), + ("oyster", oyster) + ]) assert method.legacy_flattened_fields == expected def test_flattened_oneof_fields(): - mass_kg = make_field(name='mass_kg', oneof='mass', type=5) - mass_lbs = make_field(name='mass_lbs', oneof='mass', type=5) + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) - length_m = make_field(name='length_m', oneof='length', type=5) - length_f = make_field(name='length_f', oneof='length', type=5) + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) - color = make_field(name='color', type=5) + color = make_field(name="color", type=5) mantle = make_field( - name='mantle', + name="mantle", message=make_message( - name='Mantle', + name="Mantle", fields=( - make_field(name='color', type=5), + make_field(name="color", type=5), mass_kg, mass_lbs, ), ), ) request = make_message( - name='CreateMolluscReuqest', + name="CreateMolluscReuqest", fields=( length_m, length_f, @@ -528,27 +536,28 @@ def test_flattened_oneof_fields(): ), ) method = make_method( - name='CreateMollusc', + name="CreateMollusc", input_message=request, signatures=[ - 'length_m,', - 'length_f,', - 'mantle.mass_kg,', - 'mantle.mass_lbs,', - 'color', - ]) - - expected = {'mass': [mass_kg, mass_lbs], 'length': [length_m, length_f]} + "length_m,", + "length_f,", + "mantle.mass_kg,", + "mantle.mass_lbs,", + "color", + ] + ) + + expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} actual = method.flattened_oneof_fields() assert expected == actual # Check this method too becasue the setup is a lot of work. expected = { - 'color': 'color', - 'length_m': 'length_m', - 'length_f': 'length_f', - 'mass_kg': 'mantle.mass_kg', - 'mass_lbs': 'mantle.mass_lbs', + "color": "color", + "length_m": "length_m", + "length_f": "length_f", + "mass_kg": "mantle.mass_kg", + "mass_lbs": "mantle.mass_lbs", } actual = method.flattened_field_to_key assert expected == actual From 69a05043844c5b02177af4496077871bb2c235a2 Mon Sep 17 00:00:00 2001 From: Mira Leung Date: Wed, 12 May 2021 23:20:59 -0700 Subject: [PATCH 4/4] fix: autopep8 linter fixes --- gapic/schema/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 7a5d907da9..249ca5b5d4 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -739,9 +739,9 @@ def is_deprecated(self) -> bool: """Returns true if the method is deprecated, false otherwise.""" return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') - # TODO(yon-mg): remove or rewrite: don't think it performs as intended # e.g. doesn't work with basic case of gRPC transcoding + @property def field_headers(self) -> Sequence[str]: """Return the field headers defined for this method."""