From cecff757cb166eaac3205b00624eabcc8b87dbb5 Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Wed, 8 Jul 2020 13:58:18 -0700 Subject: [PATCH 1/4] Fix oneof detection Oneof detection and assignment to fields is tricky. This patch fixes detection of oneof fields, fixes uses in generated clients and tweaks generated tests to use them correctly. --- gapic/schema/api.py | 13 ++++- gapic/schema/wrappers.py | 22 ++++++++ .../%name_%version/%sub/test_%service.py.j2 | 27 +++++++-- tests/unit/schema/wrappers/test_message.py | 20 +++++++ tests/unit/schema/wrappers/test_method.py | 56 +++++++++++++++++++ 5 files changed, 132 insertions(+), 6 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index df3e1daa8e..6bc5ef6f93 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -26,6 +26,7 @@ from google.api_core import exceptions # type: ignore from google.longrunning import operations_pb2 # type: ignore from google.protobuf import descriptor_pb2 +from google.protobuf.json_format import MessageToDict import grpc # type: ignore @@ -614,8 +615,18 @@ def _get_fields(self, # first) and this will be None. This case is addressed in the # `_load_message` method. answer: Dict[str, wrappers.Field] = collections.OrderedDict() + + def oneof_p(field_pb): + # This is the _only_ way I have found to determine whether + # a FieldDescriptor's oneof_index is 0 or unset. + # It's frustrating, misdocumented, and it feels like there should + # be a better solution by digging through the field or its class, + # but at this point I've just given up. + # Protobuf has won this round. + return "oneofIndex" in MessageToDict(field_pb) + for i, field_pb in enumerate(field_pbs): - is_oneof = oneofs and field_pb.oneof_index > 0 + is_oneof = oneofs and oneof_p(field_pb) oneof_name = nth( (oneofs or {}).keys(), field_pb.oneof_index diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 1061620378..bbeeec679b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -239,6 +239,15 @@ def __hash__(self): # Identity is sufficiently unambiguous. return hash(self.ident) + def oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + @utils.cached_property def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: answer = tuple( @@ -583,6 +592,15 @@ def client_output(self): def client_output_async(self): return self._client_output(enable_asyncio=True) + def flattened_oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.flattened_fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + def _client_output(self, enable_asyncio: bool): """Return the output from the client layer. @@ -685,6 +703,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: return answer + @utils.cached_property + def flattened_field_to_key(self): + return {field.name: key for key, field in self.flattened_fields.items()} + @utils.cached_property def legacy_flattened_fields(self) -> Mapping[str, Field]: """Return the legacy flattening interface: top level fields only, diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 4f30579121..322094d226 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -288,9 +288,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): call.return_value = iter([{{ method.output.ident }}()]) {% else -%} call.return_value = {{ method.output.ident }}( - {%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %} + {%- for field in method.output.fields.values() | rejectattr('message')%}{% if not field.oneof or field.proto3_optional %} {{ field.name }}={{ field.mock_value }}, {% endif %}{%- endfor %} + {#- This is a hack to only pick one field #} + {%- for oneof_fields in method.output.oneof_fields().values() %} + {% with field = oneof_fields[0] %} + {{ field.name }}={{ field.mock_value }}, + {%- endwith %} + {%- endfor %} ) {% endif -%} {% if method.client_streaming %} @@ -567,9 +573,15 @@ def test_{{ method.name|snake_case }}_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - {% for key, field in method.flattened_fields.items() -%} + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} assert args[0].{{ key }} == {{ field.mock_value }} - {% endfor %} + {% endif %}{% endfor %} + {%- for oneofs in method.flattened_oneof_fields().values() %} + {%- with field = oneofs[-1] %} + assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }} + {%- endwith %} + {%- endfor %} + def test_{{ method.name|snake_case }}_flattened_error(): @@ -640,9 +652,14 @@ async def test_{{ method.name|snake_case }}_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - {% for key, field in method.flattened_fields.items() -%} + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} assert args[0].{{ key }} == {{ field.mock_value }} - {% endfor %} + {% endif %}{% endfor %} + {%- for oneofs in method.flattened_oneof_fields().values() %} + {%- with field = oneofs[-1] %} + assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }} + {%- endwith %} + {%- endfor %} @pytest.mark.asyncio diff --git a/tests/unit/schema/wrappers/test_message.py b/tests/unit/schema/wrappers/test_message.py index 7ae95d0299..325bca9dad 100644 --- a/tests/unit/schema/wrappers/test_message.py +++ b/tests/unit/schema/wrappers/test_message.py @@ -235,3 +235,23 @@ def test_field_map(): entry_field = make_field('foos', message=entry_msg, repeated=True) assert entry_msg.map assert entry_field.map + + +def test_oneof_fields(): + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) + color = make_field(name="color", type=5) + request = make_message( + name="CreateMolluscReuqest", + fields=( + mass_kg, + mass_lbs, + length_m, + length_f, + color, + ), + ) + actual_oneofs = request.oneof_fields() + expected_oneofs = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index c0102402c2..61dfedfca6 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -364,3 +364,59 @@ def test_method_legacy_flattened_fields(): ]) assert method.legacy_flattened_fields == expected + + +def test_flattened_oneof_fields(): + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) + + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) + + color = make_field(name="color", type=5) + mantle = make_field( + name="mantle", + message=make_message( + name="Mantle", + fields=( + make_field(name="color", type=5), + mass_kg, + mass_lbs, + ), + ), + ) + request = make_message( + name="CreateMolluscReuqest", + fields=( + length_m, + length_f, + color, + mantle, + ), + ) + method = make_method( + name="CreateMollusc", + input_message=request, + signatures=[ + "length_m,", + "length_f,", + "mantle.mass_kg,", + "mantle.mass_lbs,", + "color", + ] + ) + + expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} + actual = method.flattened_oneof_fields() + assert expected == actual + + # Check this method too becasue the setup is a lot of work. + expected = { + "color": "color", + "length_m": "length_m", + "length_f": "length_f", + "mass_kg": "mantle.mass_kg", + "mass_lbs": "mantle.mass_lbs", + } + actual = method.flattened_field_to_key + assert expected == actual From b6baa7a2dfa36c32f955e902851c484282a718cf Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Thu, 9 Jul 2020 14:42:05 -0700 Subject: [PATCH 2/4] Whitespace --- tests/unit/schema/wrappers/test_method.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 61dfedfca6..f10bb078cd 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -369,10 +369,10 @@ def test_method_legacy_flattened_fields(): def test_flattened_oneof_fields(): mass_kg = make_field(name="mass_kg", oneof="mass", type=5) mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) - + length_m = make_field(name="length_m", oneof="length", type=5) length_f = make_field(name="length_f", oneof="length", type=5) - + color = make_field(name="color", type=5) mantle = make_field( name="mantle", From 66e8e9c4705df60abb38cecfd382bdc73473c5e9 Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Thu, 9 Jul 2020 14:44:01 -0700 Subject: [PATCH 3/4] Style --- tests/unit/schema/wrappers/test_message.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/schema/wrappers/test_message.py b/tests/unit/schema/wrappers/test_message.py index 325bca9dad..7d8cca169a 100644 --- a/tests/unit/schema/wrappers/test_message.py +++ b/tests/unit/schema/wrappers/test_message.py @@ -254,4 +254,7 @@ def test_oneof_fields(): ), ) actual_oneofs = request.oneof_fields() - expected_oneofs = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} + expected_oneofs = { + "mass": [mass_kg, mass_lbs], + "length": [length_m, length_f], + } From 039ecbaa6263e62f7eb902118a4b25c8ecc079c9 Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Thu, 9 Jul 2020 17:26:30 -0700 Subject: [PATCH 4/4] Correct field checking --- gapic/schema/api.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 6bc5ef6f93..3c79d8f7cd 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -26,7 +26,6 @@ from google.api_core import exceptions # type: ignore from google.longrunning import operations_pb2 # type: ignore from google.protobuf import descriptor_pb2 -from google.protobuf.json_format import MessageToDict import grpc # type: ignore @@ -615,18 +614,8 @@ def _get_fields(self, # first) and this will be None. This case is addressed in the # `_load_message` method. answer: Dict[str, wrappers.Field] = collections.OrderedDict() - - def oneof_p(field_pb): - # This is the _only_ way I have found to determine whether - # a FieldDescriptor's oneof_index is 0 or unset. - # It's frustrating, misdocumented, and it feels like there should - # be a better solution by digging through the field or its class, - # but at this point I've just given up. - # Protobuf has won this round. - return "oneofIndex" in MessageToDict(field_pb) - for i, field_pb in enumerate(field_pbs): - is_oneof = oneofs and oneof_p(field_pb) + is_oneof = oneofs and field_pb.HasField('oneof_index') oneof_name = nth( (oneofs or {}).keys(), field_pb.oneof_index