Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,24 +287,28 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
# Minor optimization to avoid making a copy if the user passes
# in a {{ method.input.ident }}.
# There's no risk of modifying the input as we've already verified
# there are no flattened fields.
if not isinstance(request, {{ method.input.ident }}):
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,15 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():


{% for method in service.methods.values() -%}
def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = {{ method.input.ident }}()
request = request_type()
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == request
assert args[0] == {{ method.input.ident }}()
{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -275,6 +275,11 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% endfor %}
{% endif %}


def test_{{ method.name|snake_case }}_from_dict():
test_{{ method.name|snake_case }}(request_type=dict)


{% if method.field_headers and not method.client_streaming %}
def test_{{ method.name|snake_case }}_field_headers():
client = {{ service.client_name }}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% if method.flattened_fields -%}
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
has_flattened_params = any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}])
if request is not None and has_flattened_params:
raise ValueError('If the `request` argument is set, then none of '
'the individual field arguments should be set.')

Expand All @@ -297,24 +298,29 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
# Minor optimization to avoid making a copy if the user passes
# in a {{ method.input.ident }}.
# There's no risk of modifying the input as we've already verified
# there are no flattened fields.
if not isinstance(request, {{ method.input.ident }}):
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():


{% for method in service.methods.values() -%}
def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = {{ method.input.ident }}()
request = request_type()
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == request
assert args[0] == {{ method.input.ident }}()
{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -374,6 +374,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% endif %}


def test_{{ method.name|snake_case }}_from_dict():
test_{{ method.name|snake_case }}(request_type=dict)


@pytest.mark.asyncio
async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio'):
client = {{ service.async_client_name }}(
Expand Down