diff --git a/google/cloud/dataproc_v1/gapic_metadata.json b/google/cloud/dataproc_v1/gapic_metadata.json index da52acb4..f8a05276 100644 --- a/google/cloud/dataproc_v1/gapic_metadata.json +++ b/google/cloud/dataproc_v1/gapic_metadata.json @@ -66,6 +66,36 @@ ] } } + }, + "rest": { + "libraryClient": "AutoscalingPolicyServiceClient", + "rpcs": { + "CreateAutoscalingPolicy": { + "methods": [ + "create_autoscaling_policy" + ] + }, + "DeleteAutoscalingPolicy": { + "methods": [ + "delete_autoscaling_policy" + ] + }, + "GetAutoscalingPolicy": { + "methods": [ + "get_autoscaling_policy" + ] + }, + "ListAutoscalingPolicies": { + "methods": [ + "list_autoscaling_policies" + ] + }, + "UpdateAutoscalingPolicy": { + "methods": [ + "update_autoscaling_policy" + ] + } + } } } }, @@ -120,6 +150,31 @@ ] } } + }, + "rest": { + "libraryClient": "BatchControllerClient", + "rpcs": { + "CreateBatch": { + "methods": [ + "create_batch" + ] + }, + "DeleteBatch": { + "methods": [ + "delete_batch" + ] + }, + "GetBatch": { + "methods": [ + "get_batch" + ] + }, + "ListBatches": { + "methods": [ + "list_batches" + ] + } + } } } }, @@ -214,6 +269,51 @@ ] } } + }, + "rest": { + "libraryClient": "ClusterControllerClient", + "rpcs": { + "CreateCluster": { + "methods": [ + "create_cluster" + ] + }, + "DeleteCluster": { + "methods": [ + "delete_cluster" + ] + }, + "DiagnoseCluster": { + "methods": [ + "diagnose_cluster" + ] + }, + "GetCluster": { + "methods": [ + "get_cluster" + ] + }, + "ListClusters": { + "methods": [ + "list_clusters" + ] + }, + "StartCluster": { + "methods": [ + "start_cluster" + ] + }, + "StopCluster": { + "methods": [ + "stop_cluster" + ] + }, + "UpdateCluster": { + "methods": [ + "update_cluster" + ] + } + } } } }, @@ -298,6 +398,46 @@ ] } } + }, + "rest": { + "libraryClient": "JobControllerClient", + "rpcs": { + "CancelJob": { + "methods": [ + "cancel_job" + ] + }, + "DeleteJob": { + "methods": [ + "delete_job" + ] + }, + "GetJob": { + "methods": [ + "get_job" + ] + }, + "ListJobs": { + "methods": [ + "list_jobs" + ] + }, + "SubmitJob": { + "methods": [ + "submit_job" + ] + }, + "SubmitJobAsOperation": { + "methods": [ + "submit_job_as_operation" + ] + }, + "UpdateJob": { + "methods": [ + "update_job" + ] + } + } } } }, @@ -342,6 +482,26 @@ ] } } + }, + "rest": { + "libraryClient": "NodeGroupControllerClient", + "rpcs": { + "CreateNodeGroup": { + "methods": [ + "create_node_group" + ] + }, + "GetNodeGroup": { + "methods": [ + "get_node_group" + ] + }, + "ResizeNodeGroup": { + "methods": [ + "resize_node_group" + ] + } + } } } }, @@ -426,6 +586,46 @@ ] } } + }, + "rest": { + "libraryClient": "WorkflowTemplateServiceClient", + "rpcs": { + "CreateWorkflowTemplate": { + "methods": [ + "create_workflow_template" + ] + }, + "DeleteWorkflowTemplate": { + "methods": [ + "delete_workflow_template" + ] + }, + "GetWorkflowTemplate": { + "methods": [ + "get_workflow_template" + ] + }, + "InstantiateInlineWorkflowTemplate": { + "methods": [ + "instantiate_inline_workflow_template" + ] + }, + "InstantiateWorkflowTemplate": { + "methods": [ + "instantiate_workflow_template" + ] + }, + "ListWorkflowTemplates": { + "methods": [ + "list_workflow_templates" + ] + }, + "UpdateWorkflowTemplate": { + "methods": [ + "update_workflow_template" + ] + } + } } } } diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py index 0a434a38..0a423c15 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py @@ -51,6 +51,7 @@ from .transports.base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import AutoscalingPolicyServiceGrpcTransport from .transports.grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport +from .transports.rest import AutoscalingPolicyServiceRestTransport class AutoscalingPolicyServiceClientMeta(type): @@ -66,6 +67,7 @@ class AutoscalingPolicyServiceClientMeta(type): ) # type: Dict[str, Type[AutoscalingPolicyServiceTransport]] _transport_registry["grpc"] = AutoscalingPolicyServiceGrpcTransport _transport_registry["grpc_asyncio"] = AutoscalingPolicyServiceGrpcAsyncIOTransport + _transport_registry["rest"] = AutoscalingPolicyServiceRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/__init__.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/__init__.py index 2aec20e0..0672f4b2 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/__init__.py @@ -19,6 +19,8 @@ from .base import AutoscalingPolicyServiceTransport from .grpc import AutoscalingPolicyServiceGrpcTransport from .grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport +from .rest import AutoscalingPolicyServiceRestTransport +from .rest import AutoscalingPolicyServiceRestInterceptor # Compile a registry of transports. @@ -27,9 +29,12 @@ ) # type: Dict[str, Type[AutoscalingPolicyServiceTransport]] _transport_registry["grpc"] = AutoscalingPolicyServiceGrpcTransport _transport_registry["grpc_asyncio"] = AutoscalingPolicyServiceGrpcAsyncIOTransport +_transport_registry["rest"] = AutoscalingPolicyServiceRestTransport __all__ = ( "AutoscalingPolicyServiceTransport", "AutoscalingPolicyServiceGrpcTransport", "AutoscalingPolicyServiceGrpcAsyncIOTransport", + "AutoscalingPolicyServiceRestTransport", + "AutoscalingPolicyServiceRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/rest.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/rest.py new file mode 100644 index 00000000..05ef2ac7 --- /dev/null +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/rest.py @@ -0,0 +1,877 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import autoscaling_policies +from google.protobuf import empty_pb2 # type: ignore + +from .base import ( + AutoscalingPolicyServiceTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class AutoscalingPolicyServiceRestInterceptor: + """Interceptor for AutoscalingPolicyService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the AutoscalingPolicyServiceRestTransport. + + .. code-block:: python + class MyCustomAutoscalingPolicyServiceInterceptor(AutoscalingPolicyServiceRestInterceptor): + def pre_create_autoscaling_policy(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_autoscaling_policy(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_autoscaling_policy(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_autoscaling_policy(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_autoscaling_policy(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_autoscaling_policies(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_autoscaling_policies(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_autoscaling_policy(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_autoscaling_policy(self, response): + logging.log(f"Received response: {response}") + return response + + transport = AutoscalingPolicyServiceRestTransport(interceptor=MyCustomAutoscalingPolicyServiceInterceptor()) + client = AutoscalingPolicyServiceClient(transport=transport) + + + """ + + def pre_create_autoscaling_policy( + self, + request: autoscaling_policies.CreateAutoscalingPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + autoscaling_policies.CreateAutoscalingPolicyRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for create_autoscaling_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the AutoscalingPolicyService server. + """ + return request, metadata + + def post_create_autoscaling_policy( + self, response: autoscaling_policies.AutoscalingPolicy + ) -> autoscaling_policies.AutoscalingPolicy: + """Post-rpc interceptor for create_autoscaling_policy + + Override in a subclass to manipulate the response + after it is returned by the AutoscalingPolicyService server but before + it is returned to user code. + """ + return response + + def pre_delete_autoscaling_policy( + self, + request: autoscaling_policies.DeleteAutoscalingPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + autoscaling_policies.DeleteAutoscalingPolicyRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for delete_autoscaling_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the AutoscalingPolicyService server. + """ + return request, metadata + + def pre_get_autoscaling_policy( + self, + request: autoscaling_policies.GetAutoscalingPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + autoscaling_policies.GetAutoscalingPolicyRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for get_autoscaling_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the AutoscalingPolicyService server. + """ + return request, metadata + + def post_get_autoscaling_policy( + self, response: autoscaling_policies.AutoscalingPolicy + ) -> autoscaling_policies.AutoscalingPolicy: + """Post-rpc interceptor for get_autoscaling_policy + + Override in a subclass to manipulate the response + after it is returned by the AutoscalingPolicyService server but before + it is returned to user code. + """ + return response + + def pre_list_autoscaling_policies( + self, + request: autoscaling_policies.ListAutoscalingPoliciesRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + autoscaling_policies.ListAutoscalingPoliciesRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for list_autoscaling_policies + + Override in a subclass to manipulate the request or metadata + before they are sent to the AutoscalingPolicyService server. + """ + return request, metadata + + def post_list_autoscaling_policies( + self, response: autoscaling_policies.ListAutoscalingPoliciesResponse + ) -> autoscaling_policies.ListAutoscalingPoliciesResponse: + """Post-rpc interceptor for list_autoscaling_policies + + Override in a subclass to manipulate the response + after it is returned by the AutoscalingPolicyService server but before + it is returned to user code. + """ + return response + + def pre_update_autoscaling_policy( + self, + request: autoscaling_policies.UpdateAutoscalingPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + autoscaling_policies.UpdateAutoscalingPolicyRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for update_autoscaling_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the AutoscalingPolicyService server. + """ + return request, metadata + + def post_update_autoscaling_policy( + self, response: autoscaling_policies.AutoscalingPolicy + ) -> autoscaling_policies.AutoscalingPolicy: + """Post-rpc interceptor for update_autoscaling_policy + + Override in a subclass to manipulate the response + after it is returned by the AutoscalingPolicyService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class AutoscalingPolicyServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: AutoscalingPolicyServiceRestInterceptor + + +class AutoscalingPolicyServiceRestTransport(AutoscalingPolicyServiceTransport): + """REST backend transport for AutoscalingPolicyService. + + The API interface for managing autoscaling policies in the + Dataproc API. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[AutoscalingPolicyServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or AutoscalingPolicyServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CreateAutoscalingPolicy(AutoscalingPolicyServiceRestStub): + def __hash__(self): + return hash("CreateAutoscalingPolicy") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: autoscaling_policies.CreateAutoscalingPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> autoscaling_policies.AutoscalingPolicy: + r"""Call the create autoscaling policy method over HTTP. + + Args: + request (~.autoscaling_policies.CreateAutoscalingPolicyRequest): + The request object. A request to create an autoscaling + policy. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.autoscaling_policies.AutoscalingPolicy: + Describes an autoscaling policy for + Dataproc cluster autoscaler. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/locations/*}/autoscalingPolicies", + "body": "policy", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/regions/*}/autoscalingPolicies", + "body": "policy", + }, + ] + request, metadata = self._interceptor.pre_create_autoscaling_policy( + request, metadata + ) + pb_request = autoscaling_policies.CreateAutoscalingPolicyRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = autoscaling_policies.AutoscalingPolicy() + pb_resp = autoscaling_policies.AutoscalingPolicy.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_autoscaling_policy(resp) + return resp + + class _DeleteAutoscalingPolicy(AutoscalingPolicyServiceRestStub): + def __hash__(self): + return hash("DeleteAutoscalingPolicy") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: autoscaling_policies.DeleteAutoscalingPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + r"""Call the delete autoscaling policy method over HTTP. + + Args: + request (~.autoscaling_policies.DeleteAutoscalingPolicyRequest): + The request object. A request to delete an autoscaling + policy. + Autoscaling policies in use by one or + more clusters will not be deleted. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/autoscalingPolicies/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/autoscalingPolicies/*}", + }, + ] + request, metadata = self._interceptor.pre_delete_autoscaling_policy( + request, metadata + ) + pb_request = autoscaling_policies.DeleteAutoscalingPolicyRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetAutoscalingPolicy(AutoscalingPolicyServiceRestStub): + def __hash__(self): + return hash("GetAutoscalingPolicy") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: autoscaling_policies.GetAutoscalingPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> autoscaling_policies.AutoscalingPolicy: + r"""Call the get autoscaling policy method over HTTP. + + Args: + request (~.autoscaling_policies.GetAutoscalingPolicyRequest): + The request object. A request to fetch an autoscaling + policy. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.autoscaling_policies.AutoscalingPolicy: + Describes an autoscaling policy for + Dataproc cluster autoscaler. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/autoscalingPolicies/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/autoscalingPolicies/*}", + }, + ] + request, metadata = self._interceptor.pre_get_autoscaling_policy( + request, metadata + ) + pb_request = autoscaling_policies.GetAutoscalingPolicyRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = autoscaling_policies.AutoscalingPolicy() + pb_resp = autoscaling_policies.AutoscalingPolicy.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_autoscaling_policy(resp) + return resp + + class _ListAutoscalingPolicies(AutoscalingPolicyServiceRestStub): + def __hash__(self): + return hash("ListAutoscalingPolicies") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: autoscaling_policies.ListAutoscalingPoliciesRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> autoscaling_policies.ListAutoscalingPoliciesResponse: + r"""Call the list autoscaling policies method over HTTP. + + Args: + request (~.autoscaling_policies.ListAutoscalingPoliciesRequest): + The request object. A request to list autoscaling + policies in a project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.autoscaling_policies.ListAutoscalingPoliciesResponse: + A response to a request to list + autoscaling policies in a project. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{parent=projects/*/locations/*}/autoscalingPolicies", + }, + { + "method": "get", + "uri": "/v1/{parent=projects/*/regions/*}/autoscalingPolicies", + }, + ] + request, metadata = self._interceptor.pre_list_autoscaling_policies( + request, metadata + ) + pb_request = autoscaling_policies.ListAutoscalingPoliciesRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = autoscaling_policies.ListAutoscalingPoliciesResponse() + pb_resp = autoscaling_policies.ListAutoscalingPoliciesResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_autoscaling_policies(resp) + return resp + + class _UpdateAutoscalingPolicy(AutoscalingPolicyServiceRestStub): + def __hash__(self): + return hash("UpdateAutoscalingPolicy") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: autoscaling_policies.UpdateAutoscalingPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> autoscaling_policies.AutoscalingPolicy: + r"""Call the update autoscaling policy method over HTTP. + + Args: + request (~.autoscaling_policies.UpdateAutoscalingPolicyRequest): + The request object. A request to update an autoscaling + policy. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.autoscaling_policies.AutoscalingPolicy: + Describes an autoscaling policy for + Dataproc cluster autoscaler. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "put", + "uri": "/v1/{policy.name=projects/*/locations/*/autoscalingPolicies/*}", + "body": "policy", + }, + { + "method": "put", + "uri": "/v1/{policy.name=projects/*/regions/*/autoscalingPolicies/*}", + "body": "policy", + }, + ] + request, metadata = self._interceptor.pre_update_autoscaling_policy( + request, metadata + ) + pb_request = autoscaling_policies.UpdateAutoscalingPolicyRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = autoscaling_policies.AutoscalingPolicy() + pb_resp = autoscaling_policies.AutoscalingPolicy.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_autoscaling_policy(resp) + return resp + + @property + def create_autoscaling_policy( + self, + ) -> Callable[ + [autoscaling_policies.CreateAutoscalingPolicyRequest], + autoscaling_policies.AutoscalingPolicy, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateAutoscalingPolicy(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_autoscaling_policy( + self, + ) -> Callable[ + [autoscaling_policies.DeleteAutoscalingPolicyRequest], empty_pb2.Empty + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteAutoscalingPolicy(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_autoscaling_policy( + self, + ) -> Callable[ + [autoscaling_policies.GetAutoscalingPolicyRequest], + autoscaling_policies.AutoscalingPolicy, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetAutoscalingPolicy(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_autoscaling_policies( + self, + ) -> Callable[ + [autoscaling_policies.ListAutoscalingPoliciesRequest], + autoscaling_policies.ListAutoscalingPoliciesResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListAutoscalingPolicies(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_autoscaling_policy( + self, + ) -> Callable[ + [autoscaling_policies.UpdateAutoscalingPolicyRequest], + autoscaling_policies.AutoscalingPolicy, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateAutoscalingPolicy(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("AutoscalingPolicyServiceRestTransport",) diff --git a/google/cloud/dataproc_v1/services/batch_controller/client.py b/google/cloud/dataproc_v1/services/batch_controller/client.py index 05d0e9f0..6044c058 100644 --- a/google/cloud/dataproc_v1/services/batch_controller/client.py +++ b/google/cloud/dataproc_v1/services/batch_controller/client.py @@ -56,6 +56,7 @@ from .transports.base import BatchControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import BatchControllerGrpcTransport from .transports.grpc_asyncio import BatchControllerGrpcAsyncIOTransport +from .transports.rest import BatchControllerRestTransport class BatchControllerClientMeta(type): @@ -71,6 +72,7 @@ class BatchControllerClientMeta(type): ) # type: Dict[str, Type[BatchControllerTransport]] _transport_registry["grpc"] = BatchControllerGrpcTransport _transport_registry["grpc_asyncio"] = BatchControllerGrpcAsyncIOTransport + _transport_registry["rest"] = BatchControllerRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/batch_controller/transports/__init__.py b/google/cloud/dataproc_v1/services/batch_controller/transports/__init__.py index ff09d005..352e3932 100644 --- a/google/cloud/dataproc_v1/services/batch_controller/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/batch_controller/transports/__init__.py @@ -19,15 +19,20 @@ from .base import BatchControllerTransport from .grpc import BatchControllerGrpcTransport from .grpc_asyncio import BatchControllerGrpcAsyncIOTransport +from .rest import BatchControllerRestTransport +from .rest import BatchControllerRestInterceptor # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[BatchControllerTransport]] _transport_registry["grpc"] = BatchControllerGrpcTransport _transport_registry["grpc_asyncio"] = BatchControllerGrpcAsyncIOTransport +_transport_registry["rest"] = BatchControllerRestTransport __all__ = ( "BatchControllerTransport", "BatchControllerGrpcTransport", "BatchControllerGrpcAsyncIOTransport", + "BatchControllerRestTransport", + "BatchControllerRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/batch_controller/transports/rest.py b/google/cloud/dataproc_v1/services/batch_controller/transports/rest.py new file mode 100644 index 00000000..f43641b0 --- /dev/null +++ b/google/cloud/dataproc_v1/services/batch_controller/transports/rest.py @@ -0,0 +1,715 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import batches +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from .base import ( + BatchControllerTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class BatchControllerRestInterceptor: + """Interceptor for BatchController. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the BatchControllerRestTransport. + + .. code-block:: python + class MyCustomBatchControllerInterceptor(BatchControllerRestInterceptor): + def pre_create_batch(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_batch(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_batch(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_batch(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_batch(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_batches(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_batches(self, response): + logging.log(f"Received response: {response}") + return response + + transport = BatchControllerRestTransport(interceptor=MyCustomBatchControllerInterceptor()) + client = BatchControllerClient(transport=transport) + + + """ + + def pre_create_batch( + self, request: batches.CreateBatchRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[batches.CreateBatchRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for create_batch + + Override in a subclass to manipulate the request or metadata + before they are sent to the BatchController server. + """ + return request, metadata + + def post_create_batch( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for create_batch + + Override in a subclass to manipulate the response + after it is returned by the BatchController server but before + it is returned to user code. + """ + return response + + def pre_delete_batch( + self, request: batches.DeleteBatchRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[batches.DeleteBatchRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_batch + + Override in a subclass to manipulate the request or metadata + before they are sent to the BatchController server. + """ + return request, metadata + + def pre_get_batch( + self, request: batches.GetBatchRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[batches.GetBatchRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_batch + + Override in a subclass to manipulate the request or metadata + before they are sent to the BatchController server. + """ + return request, metadata + + def post_get_batch(self, response: batches.Batch) -> batches.Batch: + """Post-rpc interceptor for get_batch + + Override in a subclass to manipulate the response + after it is returned by the BatchController server but before + it is returned to user code. + """ + return response + + def pre_list_batches( + self, request: batches.ListBatchesRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[batches.ListBatchesRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_batches + + Override in a subclass to manipulate the request or metadata + before they are sent to the BatchController server. + """ + return request, metadata + + def post_list_batches( + self, response: batches.ListBatchesResponse + ) -> batches.ListBatchesResponse: + """Post-rpc interceptor for list_batches + + Override in a subclass to manipulate the response + after it is returned by the BatchController server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class BatchControllerRestStub: + _session: AuthorizedSession + _host: str + _interceptor: BatchControllerRestInterceptor + + +class BatchControllerRestTransport(BatchControllerTransport): + """REST backend transport for BatchController. + + The BatchController provides methods to manage batch + workloads. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[BatchControllerRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or BatchControllerRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.CancelOperation": [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/operations/*}:cancel", + }, + ], + "google.longrunning.Operations.DeleteOperation": [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations}", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CreateBatch(BatchControllerRestStub): + def __hash__(self): + return hash("CreateBatch") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: batches.CreateBatchRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the create batch method over HTTP. + + Args: + request (~.batches.CreateBatchRequest): + The request object. A request to create a batch workload. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/locations/*}/batches", + "body": "batch", + }, + ] + request, metadata = self._interceptor.pre_create_batch(request, metadata) + pb_request = batches.CreateBatchRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_batch(resp) + return resp + + class _DeleteBatch(BatchControllerRestStub): + def __hash__(self): + return hash("DeleteBatch") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: batches.DeleteBatchRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + r"""Call the delete batch method over HTTP. + + Args: + request (~.batches.DeleteBatchRequest): + The request object. A request to delete a batch workload. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/batches/*}", + }, + ] + request, metadata = self._interceptor.pre_delete_batch(request, metadata) + pb_request = batches.DeleteBatchRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetBatch(BatchControllerRestStub): + def __hash__(self): + return hash("GetBatch") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: batches.GetBatchRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batches.Batch: + r"""Call the get batch method over HTTP. + + Args: + request (~.batches.GetBatchRequest): + The request object. A request to get the resource + representation for a batch workload. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.batches.Batch: + A representation of a batch workload + in the service. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/batches/*}", + }, + ] + request, metadata = self._interceptor.pre_get_batch(request, metadata) + pb_request = batches.GetBatchRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = batches.Batch() + pb_resp = batches.Batch.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_batch(resp) + return resp + + class _ListBatches(BatchControllerRestStub): + def __hash__(self): + return hash("ListBatches") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: batches.ListBatchesRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batches.ListBatchesResponse: + r"""Call the list batches method over HTTP. + + Args: + request (~.batches.ListBatchesRequest): + The request object. A request to list batch workloads in + a project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.batches.ListBatchesResponse: + A list of batch workloads. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{parent=projects/*/locations/*}/batches", + }, + ] + request, metadata = self._interceptor.pre_list_batches(request, metadata) + pb_request = batches.ListBatchesRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = batches.ListBatchesResponse() + pb_resp = batches.ListBatchesResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_batches(resp) + return resp + + @property + def create_batch( + self, + ) -> Callable[[batches.CreateBatchRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateBatch(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_batch(self) -> Callable[[batches.DeleteBatchRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteBatch(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_batch(self) -> Callable[[batches.GetBatchRequest], batches.Batch]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetBatch(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_batches( + self, + ) -> Callable[[batches.ListBatchesRequest], batches.ListBatchesResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListBatches(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("BatchControllerRestTransport",) diff --git a/google/cloud/dataproc_v1/services/cluster_controller/client.py b/google/cloud/dataproc_v1/services/cluster_controller/client.py index 921a308c..73a08f8f 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/client.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/client.py @@ -56,6 +56,7 @@ from .transports.base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import ClusterControllerGrpcTransport from .transports.grpc_asyncio import ClusterControllerGrpcAsyncIOTransport +from .transports.rest import ClusterControllerRestTransport class ClusterControllerClientMeta(type): @@ -71,6 +72,7 @@ class ClusterControllerClientMeta(type): ) # type: Dict[str, Type[ClusterControllerTransport]] _transport_registry["grpc"] = ClusterControllerGrpcTransport _transport_registry["grpc_asyncio"] = ClusterControllerGrpcAsyncIOTransport + _transport_registry["rest"] = ClusterControllerRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/cluster_controller/transports/__init__.py b/google/cloud/dataproc_v1/services/cluster_controller/transports/__init__.py index 592eeaaf..896b77f6 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/transports/__init__.py @@ -19,15 +19,20 @@ from .base import ClusterControllerTransport from .grpc import ClusterControllerGrpcTransport from .grpc_asyncio import ClusterControllerGrpcAsyncIOTransport +from .rest import ClusterControllerRestTransport +from .rest import ClusterControllerRestInterceptor # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ClusterControllerTransport]] _transport_registry["grpc"] = ClusterControllerGrpcTransport _transport_registry["grpc_asyncio"] = ClusterControllerGrpcAsyncIOTransport +_transport_registry["rest"] = ClusterControllerRestTransport __all__ = ( "ClusterControllerTransport", "ClusterControllerGrpcTransport", "ClusterControllerGrpcAsyncIOTransport", + "ClusterControllerRestTransport", + "ClusterControllerRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/cluster_controller/transports/rest.py b/google/cloud/dataproc_v1/services/cluster_controller/transports/rest.py new file mode 100644 index 00000000..58586417 --- /dev/null +++ b/google/cloud/dataproc_v1/services/cluster_controller/transports/rest.py @@ -0,0 +1,1294 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import clusters +from google.longrunning import operations_pb2 # type: ignore + +from .base import ( + ClusterControllerTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class ClusterControllerRestInterceptor: + """Interceptor for ClusterController. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the ClusterControllerRestTransport. + + .. code-block:: python + class MyCustomClusterControllerInterceptor(ClusterControllerRestInterceptor): + def pre_create_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_delete_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_diagnose_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_diagnose_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_clusters(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_clusters(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_start_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_start_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_stop_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_stop_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_cluster(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_cluster(self, response): + logging.log(f"Received response: {response}") + return response + + transport = ClusterControllerRestTransport(interceptor=MyCustomClusterControllerInterceptor()) + client = ClusterControllerClient(transport=transport) + + + """ + + def pre_create_cluster( + self, + request: clusters.CreateClusterRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[clusters.CreateClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for create_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_create_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for create_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_delete_cluster( + self, + request: clusters.DeleteClusterRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[clusters.DeleteClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_delete_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for delete_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_diagnose_cluster( + self, + request: clusters.DiagnoseClusterRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[clusters.DiagnoseClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for diagnose_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_diagnose_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for diagnose_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_get_cluster( + self, request: clusters.GetClusterRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[clusters.GetClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_get_cluster(self, response: clusters.Cluster) -> clusters.Cluster: + """Post-rpc interceptor for get_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_list_clusters( + self, request: clusters.ListClustersRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[clusters.ListClustersRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_clusters + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_list_clusters( + self, response: clusters.ListClustersResponse + ) -> clusters.ListClustersResponse: + """Post-rpc interceptor for list_clusters + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_start_cluster( + self, request: clusters.StartClusterRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[clusters.StartClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for start_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_start_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for start_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_stop_cluster( + self, request: clusters.StopClusterRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[clusters.StopClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for stop_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_stop_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for stop_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + def pre_update_cluster( + self, + request: clusters.UpdateClusterRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[clusters.UpdateClusterRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for update_cluster + + Override in a subclass to manipulate the request or metadata + before they are sent to the ClusterController server. + """ + return request, metadata + + def post_update_cluster( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for update_cluster + + Override in a subclass to manipulate the response + after it is returned by the ClusterController server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class ClusterControllerRestStub: + _session: AuthorizedSession + _host: str + _interceptor: ClusterControllerRestInterceptor + + +class ClusterControllerRestTransport(ClusterControllerTransport): + """REST backend transport for ClusterController. + + The ClusterControllerService provides methods to manage + clusters of Compute Engine instances. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[ClusterControllerRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or ClusterControllerRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.CancelOperation": [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/operations/*}:cancel", + }, + ], + "google.longrunning.Operations.DeleteOperation": [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations}", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CreateCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("CreateCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.CreateClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the create cluster method over HTTP. + + Args: + request (~.clusters.CreateClusterRequest): + The request object. A request to create a cluster. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters", + "body": "cluster", + }, + ] + request, metadata = self._interceptor.pre_create_cluster(request, metadata) + pb_request = clusters.CreateClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_cluster(resp) + return resp + + class _DeleteCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("DeleteCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.DeleteClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the delete cluster method over HTTP. + + Args: + request (~.clusters.DeleteClusterRequest): + The request object. A request to delete a cluster. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}", + }, + ] + request, metadata = self._interceptor.pre_delete_cluster(request, metadata) + pb_request = clusters.DeleteClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_delete_cluster(resp) + return resp + + class _DiagnoseCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("DiagnoseCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.DiagnoseClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the diagnose cluster method over HTTP. + + Args: + request (~.clusters.DiagnoseClusterRequest): + The request object. A request to collect cluster + diagnostic information. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}:diagnose", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_diagnose_cluster( + request, metadata + ) + pb_request = clusters.DiagnoseClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_diagnose_cluster(resp) + return resp + + class _GetCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("GetCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.GetClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> clusters.Cluster: + r"""Call the get cluster method over HTTP. + + Args: + request (~.clusters.GetClusterRequest): + The request object. Request to get the resource + representation for a cluster in a + project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.clusters.Cluster: + Describes the identifying + information, config, and status of a + Dataproc cluster + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}", + }, + ] + request, metadata = self._interceptor.pre_get_cluster(request, metadata) + pb_request = clusters.GetClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = clusters.Cluster() + pb_resp = clusters.Cluster.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_cluster(resp) + return resp + + class _ListClusters(ClusterControllerRestStub): + def __hash__(self): + return hash("ListClusters") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.ListClustersRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> clusters.ListClustersResponse: + r"""Call the list clusters method over HTTP. + + Args: + request (~.clusters.ListClustersRequest): + The request object. A request to list the clusters in a + project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.clusters.ListClustersResponse: + The list of all clusters in a + project. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters", + }, + ] + request, metadata = self._interceptor.pre_list_clusters(request, metadata) + pb_request = clusters.ListClustersRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = clusters.ListClustersResponse() + pb_resp = clusters.ListClustersResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_clusters(resp) + return resp + + class _StartCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("StartCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.StartClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the start cluster method over HTTP. + + Args: + request (~.clusters.StartClusterRequest): + The request object. A request to start a cluster. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}:start", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_start_cluster(request, metadata) + pb_request = clusters.StartClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_start_cluster(resp) + return resp + + class _StopCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("StopCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.StopClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the stop cluster method over HTTP. + + Args: + request (~.clusters.StopClusterRequest): + The request object. A request to stop a cluster. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}:stop", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_stop_cluster(request, metadata) + pb_request = clusters.StopClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_stop_cluster(resp) + return resp + + class _UpdateCluster(ClusterControllerRestStub): + def __hash__(self): + return hash("UpdateCluster") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: clusters.UpdateClusterRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the update cluster method over HTTP. + + Args: + request (~.clusters.UpdateClusterRequest): + The request object. A request to update a cluster. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}", + "body": "cluster", + }, + ] + request, metadata = self._interceptor.pre_update_cluster(request, metadata) + pb_request = clusters.UpdateClusterRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_cluster(resp) + return resp + + @property + def create_cluster( + self, + ) -> Callable[[clusters.CreateClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_cluster( + self, + ) -> Callable[[clusters.DeleteClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def diagnose_cluster( + self, + ) -> Callable[[clusters.DiagnoseClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DiagnoseCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_cluster(self) -> Callable[[clusters.GetClusterRequest], clusters.Cluster]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_clusters( + self, + ) -> Callable[[clusters.ListClustersRequest], clusters.ListClustersResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListClusters(self._session, self._host, self._interceptor) # type: ignore + + @property + def start_cluster( + self, + ) -> Callable[[clusters.StartClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._StartCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def stop_cluster( + self, + ) -> Callable[[clusters.StopClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._StopCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_cluster( + self, + ) -> Callable[[clusters.UpdateClusterRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateCluster(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("ClusterControllerRestTransport",) diff --git a/google/cloud/dataproc_v1/services/job_controller/client.py b/google/cloud/dataproc_v1/services/job_controller/client.py index a5d75602..15e749c5 100644 --- a/google/cloud/dataproc_v1/services/job_controller/client.py +++ b/google/cloud/dataproc_v1/services/job_controller/client.py @@ -53,6 +53,7 @@ from .transports.base import JobControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import JobControllerGrpcTransport from .transports.grpc_asyncio import JobControllerGrpcAsyncIOTransport +from .transports.rest import JobControllerRestTransport class JobControllerClientMeta(type): @@ -66,6 +67,7 @@ class JobControllerClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[JobControllerTransport]] _transport_registry["grpc"] = JobControllerGrpcTransport _transport_registry["grpc_asyncio"] = JobControllerGrpcAsyncIOTransport + _transport_registry["rest"] = JobControllerRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/job_controller/transports/__init__.py b/google/cloud/dataproc_v1/services/job_controller/transports/__init__.py index 26539a6d..60486d0d 100644 --- a/google/cloud/dataproc_v1/services/job_controller/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/job_controller/transports/__init__.py @@ -19,15 +19,20 @@ from .base import JobControllerTransport from .grpc import JobControllerGrpcTransport from .grpc_asyncio import JobControllerGrpcAsyncIOTransport +from .rest import JobControllerRestTransport +from .rest import JobControllerRestInterceptor # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobControllerTransport]] _transport_registry["grpc"] = JobControllerGrpcTransport _transport_registry["grpc_asyncio"] = JobControllerGrpcAsyncIOTransport +_transport_registry["rest"] = JobControllerRestTransport __all__ = ( "JobControllerTransport", "JobControllerGrpcTransport", "JobControllerGrpcAsyncIOTransport", + "JobControllerRestTransport", + "JobControllerRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/job_controller/transports/rest.py b/google/cloud/dataproc_v1/services/job_controller/transports/rest.py new file mode 100644 index 00000000..0f5727cb --- /dev/null +++ b/google/cloud/dataproc_v1/services/job_controller/transports/rest.py @@ -0,0 +1,1094 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import jobs +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from .base import ( + JobControllerTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class JobControllerRestInterceptor: + """Interceptor for JobController. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the JobControllerRestTransport. + + .. code-block:: python + class MyCustomJobControllerInterceptor(JobControllerRestInterceptor): + def pre_cancel_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_cancel_job(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_job(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_jobs(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_jobs(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_submit_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_submit_job(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_submit_job_as_operation(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_submit_job_as_operation(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_job(self, response): + logging.log(f"Received response: {response}") + return response + + transport = JobControllerRestTransport(interceptor=MyCustomJobControllerInterceptor()) + client = JobControllerClient(transport=transport) + + + """ + + def pre_cancel_job( + self, request: jobs.CancelJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.CancelJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for cancel_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_cancel_job(self, response: jobs.Job) -> jobs.Job: + """Post-rpc interceptor for cancel_job + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + def pre_delete_job( + self, request: jobs.DeleteJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.DeleteJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def pre_get_job( + self, request: jobs.GetJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.GetJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_get_job(self, response: jobs.Job) -> jobs.Job: + """Post-rpc interceptor for get_job + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + def pre_list_jobs( + self, request: jobs.ListJobsRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.ListJobsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_jobs + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_list_jobs(self, response: jobs.ListJobsResponse) -> jobs.ListJobsResponse: + """Post-rpc interceptor for list_jobs + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + def pre_submit_job( + self, request: jobs.SubmitJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.SubmitJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for submit_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_submit_job(self, response: jobs.Job) -> jobs.Job: + """Post-rpc interceptor for submit_job + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + def pre_submit_job_as_operation( + self, request: jobs.SubmitJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.SubmitJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for submit_job_as_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_submit_job_as_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for submit_job_as_operation + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + def pre_update_job( + self, request: jobs.UpdateJobRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[jobs.UpdateJobRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for update_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the JobController server. + """ + return request, metadata + + def post_update_job(self, response: jobs.Job) -> jobs.Job: + """Post-rpc interceptor for update_job + + Override in a subclass to manipulate the response + after it is returned by the JobController server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class JobControllerRestStub: + _session: AuthorizedSession + _host: str + _interceptor: JobControllerRestInterceptor + + +class JobControllerRestTransport(JobControllerTransport): + """REST backend transport for JobController. + + The JobController provides methods to manage jobs. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[JobControllerRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or JobControllerRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.CancelOperation": [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/operations/*}:cancel", + }, + ], + "google.longrunning.Operations.DeleteOperation": [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations}", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CancelJob(JobControllerRestStub): + def __hash__(self): + return hash("CancelJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.CancelJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> jobs.Job: + r"""Call the cancel job method over HTTP. + + Args: + request (~.jobs.CancelJobRequest): + The request object. A request to cancel a job. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.jobs.Job: + A Dataproc job resource. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs/{job_id}:cancel", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_cancel_job(request, metadata) + pb_request = jobs.CancelJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = jobs.Job() + pb_resp = jobs.Job.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_cancel_job(resp) + return resp + + class _DeleteJob(JobControllerRestStub): + def __hash__(self): + return hash("DeleteJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.DeleteJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + r"""Call the delete job method over HTTP. + + Args: + request (~.jobs.DeleteJobRequest): + The request object. A request to delete a job. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs/{job_id}", + }, + ] + request, metadata = self._interceptor.pre_delete_job(request, metadata) + pb_request = jobs.DeleteJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetJob(JobControllerRestStub): + def __hash__(self): + return hash("GetJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.GetJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> jobs.Job: + r"""Call the get job method over HTTP. + + Args: + request (~.jobs.GetJobRequest): + The request object. A request to get the resource + representation for a job in a project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.jobs.Job: + A Dataproc job resource. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs/{job_id}", + }, + ] + request, metadata = self._interceptor.pre_get_job(request, metadata) + pb_request = jobs.GetJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = jobs.Job() + pb_resp = jobs.Job.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_job(resp) + return resp + + class _ListJobs(JobControllerRestStub): + def __hash__(self): + return hash("ListJobs") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.ListJobsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> jobs.ListJobsResponse: + r"""Call the list jobs method over HTTP. + + Args: + request (~.jobs.ListJobsRequest): + The request object. A request to list jobs in a project. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.jobs.ListJobsResponse: + A list of jobs in a project. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs", + }, + ] + request, metadata = self._interceptor.pre_list_jobs(request, metadata) + pb_request = jobs.ListJobsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = jobs.ListJobsResponse() + pb_resp = jobs.ListJobsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_jobs(resp) + return resp + + class _SubmitJob(JobControllerRestStub): + def __hash__(self): + return hash("SubmitJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.SubmitJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> jobs.Job: + r"""Call the submit job method over HTTP. + + Args: + request (~.jobs.SubmitJobRequest): + The request object. A request to submit a job. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.jobs.Job: + A Dataproc job resource. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs:submit", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_submit_job(request, metadata) + pb_request = jobs.SubmitJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = jobs.Job() + pb_resp = jobs.Job.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_submit_job(resp) + return resp + + class _SubmitJobAsOperation(JobControllerRestStub): + def __hash__(self): + return hash("SubmitJobAsOperation") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.SubmitJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the submit job as operation method over HTTP. + + Args: + request (~.jobs.SubmitJobRequest): + The request object. A request to submit a job. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs:submitAsOperation", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_submit_job_as_operation( + request, metadata + ) + pb_request = jobs.SubmitJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_submit_job_as_operation(resp) + return resp + + class _UpdateJob(JobControllerRestStub): + def __hash__(self): + return hash("UpdateJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: jobs.UpdateJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> jobs.Job: + r"""Call the update job method over HTTP. + + Args: + request (~.jobs.UpdateJobRequest): + The request object. A request to update a job. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.jobs.Job: + A Dataproc job resource. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1/projects/{project_id}/regions/{region}/jobs/{job_id}", + "body": "job", + }, + ] + request, metadata = self._interceptor.pre_update_job(request, metadata) + pb_request = jobs.UpdateJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = jobs.Job() + pb_resp = jobs.Job.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_job(resp) + return resp + + @property + def cancel_job(self) -> Callable[[jobs.CancelJobRequest], jobs.Job]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CancelJob(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_job(self) -> Callable[[jobs.DeleteJobRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteJob(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_job(self) -> Callable[[jobs.GetJobRequest], jobs.Job]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetJob(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_jobs(self) -> Callable[[jobs.ListJobsRequest], jobs.ListJobsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListJobs(self._session, self._host, self._interceptor) # type: ignore + + @property + def submit_job(self) -> Callable[[jobs.SubmitJobRequest], jobs.Job]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._SubmitJob(self._session, self._host, self._interceptor) # type: ignore + + @property + def submit_job_as_operation( + self, + ) -> Callable[[jobs.SubmitJobRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._SubmitJobAsOperation(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_job(self) -> Callable[[jobs.UpdateJobRequest], jobs.Job]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateJob(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("JobControllerRestTransport",) diff --git a/google/cloud/dataproc_v1/services/node_group_controller/client.py b/google/cloud/dataproc_v1/services/node_group_controller/client.py index 402b31e8..35d8d357 100644 --- a/google/cloud/dataproc_v1/services/node_group_controller/client.py +++ b/google/cloud/dataproc_v1/services/node_group_controller/client.py @@ -54,6 +54,7 @@ from .transports.base import NodeGroupControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import NodeGroupControllerGrpcTransport from .transports.grpc_asyncio import NodeGroupControllerGrpcAsyncIOTransport +from .transports.rest import NodeGroupControllerRestTransport class NodeGroupControllerClientMeta(type): @@ -69,6 +70,7 @@ class NodeGroupControllerClientMeta(type): ) # type: Dict[str, Type[NodeGroupControllerTransport]] _transport_registry["grpc"] = NodeGroupControllerGrpcTransport _transport_registry["grpc_asyncio"] = NodeGroupControllerGrpcAsyncIOTransport + _transport_registry["rest"] = NodeGroupControllerRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/node_group_controller/transports/__init__.py b/google/cloud/dataproc_v1/services/node_group_controller/transports/__init__.py index 0803a9a0..466aa0a7 100644 --- a/google/cloud/dataproc_v1/services/node_group_controller/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/node_group_controller/transports/__init__.py @@ -19,6 +19,8 @@ from .base import NodeGroupControllerTransport from .grpc import NodeGroupControllerGrpcTransport from .grpc_asyncio import NodeGroupControllerGrpcAsyncIOTransport +from .rest import NodeGroupControllerRestTransport +from .rest import NodeGroupControllerRestInterceptor # Compile a registry of transports. @@ -27,9 +29,12 @@ ) # type: Dict[str, Type[NodeGroupControllerTransport]] _transport_registry["grpc"] = NodeGroupControllerGrpcTransport _transport_registry["grpc_asyncio"] = NodeGroupControllerGrpcAsyncIOTransport +_transport_registry["rest"] = NodeGroupControllerRestTransport __all__ = ( "NodeGroupControllerTransport", "NodeGroupControllerGrpcTransport", "NodeGroupControllerGrpcAsyncIOTransport", + "NodeGroupControllerRestTransport", + "NodeGroupControllerRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/node_group_controller/transports/rest.py b/google/cloud/dataproc_v1/services/node_group_controller/transports/rest.py new file mode 100644 index 00000000..c2b5a0f5 --- /dev/null +++ b/google/cloud/dataproc_v1/services/node_group_controller/transports/rest.py @@ -0,0 +1,641 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import clusters +from google.cloud.dataproc_v1.types import node_groups +from google.longrunning import operations_pb2 # type: ignore + +from .base import ( + NodeGroupControllerTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class NodeGroupControllerRestInterceptor: + """Interceptor for NodeGroupController. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the NodeGroupControllerRestTransport. + + .. code-block:: python + class MyCustomNodeGroupControllerInterceptor(NodeGroupControllerRestInterceptor): + def pre_create_node_group(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_node_group(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_node_group(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_node_group(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_resize_node_group(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_resize_node_group(self, response): + logging.log(f"Received response: {response}") + return response + + transport = NodeGroupControllerRestTransport(interceptor=MyCustomNodeGroupControllerInterceptor()) + client = NodeGroupControllerClient(transport=transport) + + + """ + + def pre_create_node_group( + self, + request: node_groups.CreateNodeGroupRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[node_groups.CreateNodeGroupRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for create_node_group + + Override in a subclass to manipulate the request or metadata + before they are sent to the NodeGroupController server. + """ + return request, metadata + + def post_create_node_group( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for create_node_group + + Override in a subclass to manipulate the response + after it is returned by the NodeGroupController server but before + it is returned to user code. + """ + return response + + def pre_get_node_group( + self, + request: node_groups.GetNodeGroupRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[node_groups.GetNodeGroupRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_node_group + + Override in a subclass to manipulate the request or metadata + before they are sent to the NodeGroupController server. + """ + return request, metadata + + def post_get_node_group(self, response: clusters.NodeGroup) -> clusters.NodeGroup: + """Post-rpc interceptor for get_node_group + + Override in a subclass to manipulate the response + after it is returned by the NodeGroupController server but before + it is returned to user code. + """ + return response + + def pre_resize_node_group( + self, + request: node_groups.ResizeNodeGroupRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[node_groups.ResizeNodeGroupRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for resize_node_group + + Override in a subclass to manipulate the request or metadata + before they are sent to the NodeGroupController server. + """ + return request, metadata + + def post_resize_node_group( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for resize_node_group + + Override in a subclass to manipulate the response + after it is returned by the NodeGroupController server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class NodeGroupControllerRestStub: + _session: AuthorizedSession + _host: str + _interceptor: NodeGroupControllerRestInterceptor + + +class NodeGroupControllerRestTransport(NodeGroupControllerTransport): + """REST backend transport for NodeGroupController. + + The ``NodeGroupControllerService`` provides methods to manage node + groups of Compute Engine managed instances. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[NodeGroupControllerRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or NodeGroupControllerRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.CancelOperation": [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/operations/*}:cancel", + }, + ], + "google.longrunning.Operations.DeleteOperation": [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations}", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CreateNodeGroup(NodeGroupControllerRestStub): + def __hash__(self): + return hash("CreateNodeGroup") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: node_groups.CreateNodeGroupRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the create node group method over HTTP. + + Args: + request (~.node_groups.CreateNodeGroupRequest): + The request object. A request to create a node group. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/regions/*/clusters/*}/nodeGroups", + "body": "node_group", + }, + ] + request, metadata = self._interceptor.pre_create_node_group( + request, metadata + ) + pb_request = node_groups.CreateNodeGroupRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_node_group(resp) + return resp + + class _GetNodeGroup(NodeGroupControllerRestStub): + def __hash__(self): + return hash("GetNodeGroup") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: node_groups.GetNodeGroupRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> clusters.NodeGroup: + r"""Call the get node group method over HTTP. + + Args: + request (~.node_groups.GetNodeGroupRequest): + The request object. A request to get a node group . + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.clusters.NodeGroup: + Dataproc Node Group. **The Dataproc ``NodeGroup`` + resource is not related to the Dataproc + [NodeGroupAffinity][google.cloud.dataproc.v1.NodeGroupAffinity] + resource.** + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/clusters/*/nodeGroups/*}", + }, + ] + request, metadata = self._interceptor.pre_get_node_group(request, metadata) + pb_request = node_groups.GetNodeGroupRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = clusters.NodeGroup() + pb_resp = clusters.NodeGroup.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_node_group(resp) + return resp + + class _ResizeNodeGroup(NodeGroupControllerRestStub): + def __hash__(self): + return hash("ResizeNodeGroup") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: node_groups.ResizeNodeGroupRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the resize node group method over HTTP. + + Args: + request (~.node_groups.ResizeNodeGroupRequest): + The request object. A request to resize a node group. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/clusters/*/nodeGroups/*}:resize", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_resize_node_group( + request, metadata + ) + pb_request = node_groups.ResizeNodeGroupRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_resize_node_group(resp) + return resp + + @property + def create_node_group( + self, + ) -> Callable[[node_groups.CreateNodeGroupRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateNodeGroup(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_node_group( + self, + ) -> Callable[[node_groups.GetNodeGroupRequest], clusters.NodeGroup]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetNodeGroup(self._session, self._host, self._interceptor) # type: ignore + + @property + def resize_node_group( + self, + ) -> Callable[[node_groups.ResizeNodeGroupRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ResizeNodeGroup(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("NodeGroupControllerRestTransport",) diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/client.py b/google/cloud/dataproc_v1/services/workflow_template_service/client.py index c13dacf5..7612a621 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/client.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/client.py @@ -56,6 +56,7 @@ from .transports.base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import WorkflowTemplateServiceGrpcTransport from .transports.grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport +from .transports.rest import WorkflowTemplateServiceRestTransport class WorkflowTemplateServiceClientMeta(type): @@ -71,6 +72,7 @@ class WorkflowTemplateServiceClientMeta(type): ) # type: Dict[str, Type[WorkflowTemplateServiceTransport]] _transport_registry["grpc"] = WorkflowTemplateServiceGrpcTransport _transport_registry["grpc_asyncio"] = WorkflowTemplateServiceGrpcAsyncIOTransport + _transport_registry["rest"] = WorkflowTemplateServiceRestTransport def get_transport_class( cls, diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/transports/__init__.py b/google/cloud/dataproc_v1/services/workflow_template_service/transports/__init__.py index b170d8f5..e854d6fe 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/transports/__init__.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/transports/__init__.py @@ -19,6 +19,8 @@ from .base import WorkflowTemplateServiceTransport from .grpc import WorkflowTemplateServiceGrpcTransport from .grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport +from .rest import WorkflowTemplateServiceRestTransport +from .rest import WorkflowTemplateServiceRestInterceptor # Compile a registry of transports. @@ -27,9 +29,12 @@ ) # type: Dict[str, Type[WorkflowTemplateServiceTransport]] _transport_registry["grpc"] = WorkflowTemplateServiceGrpcTransport _transport_registry["grpc_asyncio"] = WorkflowTemplateServiceGrpcAsyncIOTransport +_transport_registry["rest"] = WorkflowTemplateServiceRestTransport __all__ = ( "WorkflowTemplateServiceTransport", "WorkflowTemplateServiceGrpcTransport", "WorkflowTemplateServiceGrpcAsyncIOTransport", + "WorkflowTemplateServiceRestTransport", + "WorkflowTemplateServiceRestInterceptor", ) diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/transports/rest.py b/google/cloud/dataproc_v1/services/workflow_template_service/transports/rest.py new file mode 100644 index 00000000..18d19590 --- /dev/null +++ b/google/cloud/dataproc_v1/services/workflow_template_service/transports/rest.py @@ -0,0 +1,1238 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.cloud.dataproc_v1.types import workflow_templates +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from .base import ( + WorkflowTemplateServiceTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class WorkflowTemplateServiceRestInterceptor: + """Interceptor for WorkflowTemplateService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the WorkflowTemplateServiceRestTransport. + + .. code-block:: python + class MyCustomWorkflowTemplateServiceInterceptor(WorkflowTemplateServiceRestInterceptor): + def pre_create_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_workflow_template(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_workflow_template(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_instantiate_inline_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_instantiate_inline_workflow_template(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_instantiate_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_instantiate_workflow_template(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_workflow_templates(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_workflow_templates(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_workflow_template(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_workflow_template(self, response): + logging.log(f"Received response: {response}") + return response + + transport = WorkflowTemplateServiceRestTransport(interceptor=MyCustomWorkflowTemplateServiceInterceptor()) + client = WorkflowTemplateServiceClient(transport=transport) + + + """ + + def pre_create_workflow_template( + self, + request: workflow_templates.CreateWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.CreateWorkflowTemplateRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for create_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_create_workflow_template( + self, response: workflow_templates.WorkflowTemplate + ) -> workflow_templates.WorkflowTemplate: + """Post-rpc interceptor for create_workflow_template + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + def pre_delete_workflow_template( + self, + request: workflow_templates.DeleteWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.DeleteWorkflowTemplateRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for delete_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def pre_get_workflow_template( + self, + request: workflow_templates.GetWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.GetWorkflowTemplateRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for get_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_get_workflow_template( + self, response: workflow_templates.WorkflowTemplate + ) -> workflow_templates.WorkflowTemplate: + """Post-rpc interceptor for get_workflow_template + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + def pre_instantiate_inline_workflow_template( + self, + request: workflow_templates.InstantiateInlineWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.InstantiateInlineWorkflowTemplateRequest, + Sequence[Tuple[str, str]], + ]: + """Pre-rpc interceptor for instantiate_inline_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_instantiate_inline_workflow_template( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for instantiate_inline_workflow_template + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + def pre_instantiate_workflow_template( + self, + request: workflow_templates.InstantiateWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.InstantiateWorkflowTemplateRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for instantiate_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_instantiate_workflow_template( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for instantiate_workflow_template + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + def pre_list_workflow_templates( + self, + request: workflow_templates.ListWorkflowTemplatesRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.ListWorkflowTemplatesRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for list_workflow_templates + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_list_workflow_templates( + self, response: workflow_templates.ListWorkflowTemplatesResponse + ) -> workflow_templates.ListWorkflowTemplatesResponse: + """Post-rpc interceptor for list_workflow_templates + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + def pre_update_workflow_template( + self, + request: workflow_templates.UpdateWorkflowTemplateRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + workflow_templates.UpdateWorkflowTemplateRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for update_workflow_template + + Override in a subclass to manipulate the request or metadata + before they are sent to the WorkflowTemplateService server. + """ + return request, metadata + + def post_update_workflow_template( + self, response: workflow_templates.WorkflowTemplate + ) -> workflow_templates.WorkflowTemplate: + """Post-rpc interceptor for update_workflow_template + + Override in a subclass to manipulate the response + after it is returned by the WorkflowTemplateService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class WorkflowTemplateServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: WorkflowTemplateServiceRestInterceptor + + +class WorkflowTemplateServiceRestTransport(WorkflowTemplateServiceTransport): + """REST backend transport for WorkflowTemplateService. + + The API interface for managing Workflow Templates in the + Dataproc API. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "dataproc.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[WorkflowTemplateServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or WorkflowTemplateServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.CancelOperation": [ + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/operations/*}:cancel", + }, + ], + "google.longrunning.Operations.DeleteOperation": [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/operations}", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CreateWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("CreateWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.CreateWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> workflow_templates.WorkflowTemplate: + r"""Call the create workflow template method over HTTP. + + Args: + request (~.workflow_templates.CreateWorkflowTemplateRequest): + The request object. A request to create a workflow + template. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.workflow_templates.WorkflowTemplate: + A Dataproc workflow template + resource. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/locations/*}/workflowTemplates", + "body": "template", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/regions/*}/workflowTemplates", + "body": "template", + }, + ] + request, metadata = self._interceptor.pre_create_workflow_template( + request, metadata + ) + pb_request = workflow_templates.CreateWorkflowTemplateRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = workflow_templates.WorkflowTemplate() + pb_resp = workflow_templates.WorkflowTemplate.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_workflow_template(resp) + return resp + + class _DeleteWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("DeleteWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.DeleteWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): + r"""Call the delete workflow template method over HTTP. + + Args: + request (~.workflow_templates.DeleteWorkflowTemplateRequest): + The request object. A request to delete a workflow + template. + Currently started workflows will remain + running. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/workflowTemplates/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/regions/*/workflowTemplates/*}", + }, + ] + request, metadata = self._interceptor.pre_delete_workflow_template( + request, metadata + ) + pb_request = workflow_templates.DeleteWorkflowTemplateRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("GetWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.GetWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> workflow_templates.WorkflowTemplate: + r"""Call the get workflow template method over HTTP. + + Args: + request (~.workflow_templates.GetWorkflowTemplateRequest): + The request object. A request to fetch a workflow + template. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.workflow_templates.WorkflowTemplate: + A Dataproc workflow template + resource. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/workflowTemplates/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/regions/*/workflowTemplates/*}", + }, + ] + request, metadata = self._interceptor.pre_get_workflow_template( + request, metadata + ) + pb_request = workflow_templates.GetWorkflowTemplateRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = workflow_templates.WorkflowTemplate() + pb_resp = workflow_templates.WorkflowTemplate.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_workflow_template(resp) + return resp + + class _InstantiateInlineWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("InstantiateInlineWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.InstantiateInlineWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the instantiate inline + workflow template method over HTTP. + + Args: + request (~.workflow_templates.InstantiateInlineWorkflowTemplateRequest): + The request object. A request to instantiate an inline + workflow template. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/locations/*}/workflowTemplates:instantiateInline", + "body": "template", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/regions/*}/workflowTemplates:instantiateInline", + "body": "template", + }, + ] + ( + request, + metadata, + ) = self._interceptor.pre_instantiate_inline_workflow_template( + request, metadata + ) + pb_request = workflow_templates.InstantiateInlineWorkflowTemplateRequest.pb( + request + ) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_instantiate_inline_workflow_template(resp) + return resp + + class _InstantiateWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("InstantiateWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.InstantiateWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the instantiate workflow + template method over HTTP. + + Args: + request (~.workflow_templates.InstantiateWorkflowTemplateRequest): + The request object. A request to instantiate a workflow + template. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/workflowTemplates/*}:instantiate", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/regions/*/workflowTemplates/*}:instantiate", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_instantiate_workflow_template( + request, metadata + ) + pb_request = workflow_templates.InstantiateWorkflowTemplateRequest.pb( + request + ) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_instantiate_workflow_template(resp) + return resp + + class _ListWorkflowTemplates(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("ListWorkflowTemplates") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.ListWorkflowTemplatesRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> workflow_templates.ListWorkflowTemplatesResponse: + r"""Call the list workflow templates method over HTTP. + + Args: + request (~.workflow_templates.ListWorkflowTemplatesRequest): + The request object. A request to list workflow templates + in a project. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.workflow_templates.ListWorkflowTemplatesResponse: + A response to a request to list + workflow templates in a project. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{parent=projects/*/locations/*}/workflowTemplates", + }, + { + "method": "get", + "uri": "/v1/{parent=projects/*/regions/*}/workflowTemplates", + }, + ] + request, metadata = self._interceptor.pre_list_workflow_templates( + request, metadata + ) + pb_request = workflow_templates.ListWorkflowTemplatesRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = workflow_templates.ListWorkflowTemplatesResponse() + pb_resp = workflow_templates.ListWorkflowTemplatesResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_workflow_templates(resp) + return resp + + class _UpdateWorkflowTemplate(WorkflowTemplateServiceRestStub): + def __hash__(self): + return hash("UpdateWorkflowTemplate") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: workflow_templates.UpdateWorkflowTemplateRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> workflow_templates.WorkflowTemplate: + r"""Call the update workflow template method over HTTP. + + Args: + request (~.workflow_templates.UpdateWorkflowTemplateRequest): + The request object. A request to update a workflow + template. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.workflow_templates.WorkflowTemplate: + A Dataproc workflow template + resource. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "put", + "uri": "/v1/{template.name=projects/*/locations/*/workflowTemplates/*}", + "body": "template", + }, + { + "method": "put", + "uri": "/v1/{template.name=projects/*/regions/*/workflowTemplates/*}", + "body": "template", + }, + ] + request, metadata = self._interceptor.pre_update_workflow_template( + request, metadata + ) + pb_request = workflow_templates.UpdateWorkflowTemplateRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = workflow_templates.WorkflowTemplate() + pb_resp = workflow_templates.WorkflowTemplate.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_workflow_template(resp) + return resp + + @property + def create_workflow_template( + self, + ) -> Callable[ + [workflow_templates.CreateWorkflowTemplateRequest], + workflow_templates.WorkflowTemplate, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_workflow_template( + self, + ) -> Callable[[workflow_templates.DeleteWorkflowTemplateRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_workflow_template( + self, + ) -> Callable[ + [workflow_templates.GetWorkflowTemplateRequest], + workflow_templates.WorkflowTemplate, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def instantiate_inline_workflow_template( + self, + ) -> Callable[ + [workflow_templates.InstantiateInlineWorkflowTemplateRequest], + operations_pb2.Operation, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._InstantiateInlineWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def instantiate_workflow_template( + self, + ) -> Callable[ + [workflow_templates.InstantiateWorkflowTemplateRequest], + operations_pb2.Operation, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._InstantiateWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_workflow_templates( + self, + ) -> Callable[ + [workflow_templates.ListWorkflowTemplatesRequest], + workflow_templates.ListWorkflowTemplatesResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListWorkflowTemplates(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_workflow_template( + self, + ) -> Callable[ + [workflow_templates.UpdateWorkflowTemplateRequest], + workflow_templates.WorkflowTemplate, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateWorkflowTemplate(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("WorkflowTemplateServiceRestTransport",) diff --git a/tests/system/gapic/v1/test_system_cluster_controller_v1.py b/tests/system/gapic/v1/test_system_cluster_controller_v1.py index 604a2f1f..7aa87662 100644 --- a/tests/system/gapic/v1/test_system_cluster_controller_v1.py +++ b/tests/system/gapic/v1/test_system_cluster_controller_v1.py @@ -15,16 +15,17 @@ # limitations under the License. import os -import time +import pytest from google.cloud import dataproc_v1 +@pytest.mark.parametrize("transport", ["grpc", "rest"]) class TestSystemClusterController(object): - def test_list_clusters(self): + def test_list_clusters(self, transport): project_id = os.environ["PROJECT_ID"] - client = dataproc_v1.ClusterControllerClient() + client = dataproc_v1.ClusterControllerClient(transport=transport) project_id_2 = project_id region = "global" response = client.list_clusters( diff --git a/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py b/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py index 6bcc0e9d..f58a98ff 100644 --- a/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py +++ b/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -101,6 +108,7 @@ def test__get_default_mtls_endpoint(): [ (AutoscalingPolicyServiceClient, "grpc"), (AutoscalingPolicyServiceAsyncClient, "grpc_asyncio"), + (AutoscalingPolicyServiceClient, "rest"), ], ) def test_autoscaling_policy_service_client_from_service_account_info( @@ -116,7 +124,11 @@ def test_autoscaling_policy_service_client_from_service_account_info( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -124,6 +136,7 @@ def test_autoscaling_policy_service_client_from_service_account_info( [ (transports.AutoscalingPolicyServiceGrpcTransport, "grpc"), (transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.AutoscalingPolicyServiceRestTransport, "rest"), ], ) def test_autoscaling_policy_service_client_service_account_always_use_jwt( @@ -149,6 +162,7 @@ def test_autoscaling_policy_service_client_service_account_always_use_jwt( [ (AutoscalingPolicyServiceClient, "grpc"), (AutoscalingPolicyServiceAsyncClient, "grpc_asyncio"), + (AutoscalingPolicyServiceClient, "rest"), ], ) def test_autoscaling_policy_service_client_from_service_account_file( @@ -171,13 +185,18 @@ def test_autoscaling_policy_service_client_from_service_account_file( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_autoscaling_policy_service_client_get_transport_class(): transport = AutoscalingPolicyServiceClient.get_transport_class() available_transports = [ transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceRestTransport, ] assert transport in available_transports @@ -198,6 +217,11 @@ def test_autoscaling_policy_service_client_get_transport_class(): transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceRestTransport, + "rest", + ), ], ) @mock.patch.object( @@ -357,6 +381,18 @@ def test_autoscaling_policy_service_client_client_options( "grpc_asyncio", "false", ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceRestTransport, + "rest", + "true", + ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceRestTransport, + "rest", + "false", + ), ], ) @mock.patch.object( @@ -563,6 +599,11 @@ def test_autoscaling_policy_service_client_get_mtls_endpoint_and_cert_source( transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceRestTransport, + "rest", + ), ], ) def test_autoscaling_policy_service_client_client_options_scopes( @@ -603,6 +644,12 @@ def test_autoscaling_policy_service_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceRestTransport, + "rest", + None, + ), ], ) def test_autoscaling_policy_service_client_client_options_credentials_file( @@ -2175,163 +2222,1706 @@ async def test_delete_autoscaling_policy_flattened_error_async(): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( +@pytest.mark.parametrize( + "request_type", + [ + autoscaling_policies.CreateAutoscalingPolicyRequest, + dict, + ], +) +def test_create_autoscaling_policy_rest(request_type): + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - with pytest.raises(ValueError): - client = AutoscalingPolicyServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["policy"] = { + "id": "id_value", + "name": "name_value", + "basic_algorithm": { + "yarn_config": { + "graceful_decommission_timeout": {"seconds": 751, "nanos": 543}, + "scale_up_factor": 0.1578, + "scale_down_factor": 0.1789, + "scale_up_min_worker_fraction": 0.2973, + "scale_down_min_worker_fraction": 0.3184, + }, + "cooldown_period": {}, + }, + "worker_config": {"min_instances": 1387, "max_instances": 1389, "weight": 648}, + "secondary_worker_config": {}, + "labels": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy( + id="id_value", + name="name_value", + basic_algorithm=autoscaling_policies.BasicAutoscalingAlgorithm( + yarn_config=autoscaling_policies.BasicYarnAutoscalingConfig( + graceful_decommission_timeout=duration_pb2.Duration(seconds=751) + ) + ), ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = AutoscalingPolicyServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_autoscaling_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) + assert response.id == "id_value" + assert response.name == "name_value" + + +def test_create_autoscaling_policy_rest_required_fields( + request_type=autoscaling_policies.CreateAutoscalingPolicyRequest, +): + transport_class = transports.AutoscalingPolicyServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, ) + ) - # It is an error to provide an api_key and a transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_autoscaling_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_autoscaling_policy_rest_unset_required_fields(): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = AutoscalingPolicyServiceClient( - client_options=options, - transport=transport, - ) - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = AutoscalingPolicyServiceClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + unset_fields = transport.create_autoscaling_policy._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "policy", + ) ) + ) - # It is an error to provide scopes and a transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_autoscaling_policy_rest_interceptors(null_interceptor): + transport = transports.AutoscalingPolicyServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.AutoscalingPolicyServiceRestInterceptor(), ) - with pytest.raises(ValueError): - client = AutoscalingPolicyServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client = AutoscalingPolicyServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "post_create_autoscaling_policy", + ) as post, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "pre_create_autoscaling_policy", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = autoscaling_policies.CreateAutoscalingPolicyRequest.pb( + autoscaling_policies.CreateAutoscalingPolicyRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = autoscaling_policies.AutoscalingPolicy.to_json( + autoscaling_policies.AutoscalingPolicy() ) + request = autoscaling_policies.CreateAutoscalingPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = autoscaling_policies.AutoscalingPolicy() -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( + client.create_autoscaling_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_autoscaling_policy_rest_bad_request( + transport: str = "rest", + request_type=autoscaling_policies.CreateAutoscalingPolicyRequest, +): + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - client = AutoscalingPolicyServiceClient(transport=transport) - assert client.transport is transport + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["policy"] = { + "id": "id_value", + "name": "name_value", + "basic_algorithm": { + "yarn_config": { + "graceful_decommission_timeout": {"seconds": 751, "nanos": 543}, + "scale_up_factor": 0.1578, + "scale_down_factor": 0.1789, + "scale_up_min_worker_fraction": 0.2973, + "scale_down_min_worker_fraction": 0.3184, + }, + "cooldown_period": {}, + }, + "worker_config": {"min_instances": 1387, "max_instances": 1389, "weight": 648}, + "secondary_worker_config": {}, + "labels": {}, + } + request = request_type(**request_init) -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.AutoscalingPolicyServiceGrpcTransport( + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_autoscaling_policy(request) + + +def test_create_autoscaling_policy_rest_flattened(): + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - channel = transport.grpc_channel - assert channel - transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + policy=autoscaling_policies.AutoscalingPolicy(id="id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_autoscaling_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/autoscalingPolicies" + % client.transport._host, + args[1], + ) + + +def test_create_autoscaling_policy_rest_flattened_error(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - channel = transport.grpc_channel - assert channel + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_autoscaling_policy( + autoscaling_policies.CreateAutoscalingPolicyRequest(), + parent="parent_value", + policy=autoscaling_policies.AutoscalingPolicy(id="id_value"), + ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.AutoscalingPolicyServiceGrpcTransport, - transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() + +def test_create_autoscaling_policy_rest_error(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) @pytest.mark.parametrize( - "transport_name", + "request_type", [ - "grpc", + autoscaling_policies.UpdateAutoscalingPolicyRequest, + dict, ], ) -def test_transport_kind(transport_name): - transport = AutoscalingPolicyServiceClient.get_transport_class(transport_name)( +def test_update_autoscaling_policy_rest(request_type): + client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert transport.kind == transport_name + # send a request that will satisfy transcoding + request_init = { + "policy": { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + } + request_init["policy"] = { + "id": "id_value", + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3", + "basic_algorithm": { + "yarn_config": { + "graceful_decommission_timeout": {"seconds": 751, "nanos": 543}, + "scale_up_factor": 0.1578, + "scale_down_factor": 0.1789, + "scale_up_min_worker_fraction": 0.2973, + "scale_down_min_worker_fraction": 0.3184, + }, + "cooldown_period": {}, + }, + "worker_config": {"min_instances": 1387, "max_instances": 1389, "weight": 648}, + "secondary_worker_config": {}, + "labels": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy( + id="id_value", + name="name_value", + basic_algorithm=autoscaling_policies.BasicAutoscalingAlgorithm( + yarn_config=autoscaling_policies.BasicYarnAutoscalingConfig( + graceful_decommission_timeout=duration_pb2.Duration(seconds=751) + ) + ), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_autoscaling_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) + assert response.id == "id_value" + assert response.name == "name_value" + + +def test_update_autoscaling_policy_rest_required_fields( + request_type=autoscaling_policies.UpdateAutoscalingPolicyRequest, +): + transport_class = transports.AutoscalingPolicyServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. client = AutoscalingPolicyServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert isinstance( - client.transport, - transports.AutoscalingPolicyServiceGrpcTransport, + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "put", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_autoscaling_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_autoscaling_policy_rest_unset_required_fields(): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials ) + unset_fields = transport.update_autoscaling_policy._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("policy",))) -def test_autoscaling_policy_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.AutoscalingPolicyServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json", + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_autoscaling_policy_rest_interceptors(null_interceptor): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.AutoscalingPolicyServiceRestInterceptor(), + ) + client = AutoscalingPolicyServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "post_update_autoscaling_policy", + ) as post, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "pre_update_autoscaling_policy", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = autoscaling_policies.UpdateAutoscalingPolicyRequest.pb( + autoscaling_policies.UpdateAutoscalingPolicyRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = autoscaling_policies.AutoscalingPolicy.to_json( + autoscaling_policies.AutoscalingPolicy() ) + request = autoscaling_policies.UpdateAutoscalingPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = autoscaling_policies.AutoscalingPolicy() -def test_autoscaling_policy_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.dataproc_v1.services.autoscaling_policy_service.transports.AutoscalingPolicyServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.AutoscalingPolicyServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), + client.update_autoscaling_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_autoscaling_policy", - "update_autoscaling_policy", - "get_autoscaling_policy", - "list_autoscaling_policies", - "delete_autoscaling_policy", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) + pre.assert_called_once() + post.assert_called_once() - with pytest.raises(NotImplementedError): - transport.close() - # Catch all for all remaining methods and properties - remainder = [ - "kind", - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() +def test_update_autoscaling_policy_rest_bad_request( + transport: str = "rest", + request_type=autoscaling_policies.UpdateAutoscalingPolicyRequest, +): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "policy": { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + } + request_init["policy"] = { + "id": "id_value", + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3", + "basic_algorithm": { + "yarn_config": { + "graceful_decommission_timeout": {"seconds": 751, "nanos": 543}, + "scale_up_factor": 0.1578, + "scale_down_factor": 0.1789, + "scale_up_min_worker_fraction": 0.2973, + "scale_down_min_worker_fraction": 0.3184, + }, + "cooldown_period": {}, + }, + "worker_config": {"min_instances": 1387, "max_instances": 1389, "weight": 648}, + "secondary_worker_config": {}, + "labels": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_autoscaling_policy(request) + + +def test_update_autoscaling_policy_rest_flattened(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + + # get arguments that satisfy an http rule for this method + sample_request = { + "policy": { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + } + + # get truthy value for each flattened field + mock_args = dict( + policy=autoscaling_policies.AutoscalingPolicy(id="id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_autoscaling_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{policy.name=projects/*/locations/*/autoscalingPolicies/*}" + % client.transport._host, + args[1], + ) + + +def test_update_autoscaling_policy_rest_flattened_error(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_autoscaling_policy( + autoscaling_policies.UpdateAutoscalingPolicyRequest(), + policy=autoscaling_policies.AutoscalingPolicy(id="id_value"), + ) + + +def test_update_autoscaling_policy_rest_error(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + autoscaling_policies.GetAutoscalingPolicyRequest, + dict, + ], +) +def test_get_autoscaling_policy_rest(request_type): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy( + id="id_value", + name="name_value", + basic_algorithm=autoscaling_policies.BasicAutoscalingAlgorithm( + yarn_config=autoscaling_policies.BasicYarnAutoscalingConfig( + graceful_decommission_timeout=duration_pb2.Duration(seconds=751) + ) + ), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_autoscaling_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) + assert response.id == "id_value" + assert response.name == "name_value" + + +def test_get_autoscaling_policy_rest_required_fields( + request_type=autoscaling_policies.GetAutoscalingPolicyRequest, +): + transport_class = transports.AutoscalingPolicyServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_autoscaling_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_autoscaling_policy_rest_unset_required_fields(): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_autoscaling_policy._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_autoscaling_policy_rest_interceptors(null_interceptor): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.AutoscalingPolicyServiceRestInterceptor(), + ) + client = AutoscalingPolicyServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "post_get_autoscaling_policy", + ) as post, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, "pre_get_autoscaling_policy" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = autoscaling_policies.GetAutoscalingPolicyRequest.pb( + autoscaling_policies.GetAutoscalingPolicyRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = autoscaling_policies.AutoscalingPolicy.to_json( + autoscaling_policies.AutoscalingPolicy() + ) + + request = autoscaling_policies.GetAutoscalingPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = autoscaling_policies.AutoscalingPolicy() + + client.get_autoscaling_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_autoscaling_policy_rest_bad_request( + transport: str = "rest", + request_type=autoscaling_policies.GetAutoscalingPolicyRequest, +): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_autoscaling_policy(request) + + +def test_get_autoscaling_policy_rest_flattened(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.AutoscalingPolicy() + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.AutoscalingPolicy.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_autoscaling_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/autoscalingPolicies/*}" + % client.transport._host, + args[1], + ) + + +def test_get_autoscaling_policy_rest_flattened_error(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_autoscaling_policy( + autoscaling_policies.GetAutoscalingPolicyRequest(), + name="name_value", + ) + + +def test_get_autoscaling_policy_rest_error(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + autoscaling_policies.ListAutoscalingPoliciesRequest, + dict, + ], +) +def test_list_autoscaling_policies_rest(request_type): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.ListAutoscalingPoliciesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.ListAutoscalingPoliciesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_autoscaling_policies(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAutoscalingPoliciesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_autoscaling_policies_rest_required_fields( + request_type=autoscaling_policies.ListAutoscalingPoliciesRequest, +): + transport_class = transports.AutoscalingPolicyServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_autoscaling_policies._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_autoscaling_policies._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = autoscaling_policies.ListAutoscalingPoliciesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_autoscaling_policies(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_autoscaling_policies_rest_unset_required_fields(): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_autoscaling_policies._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_autoscaling_policies_rest_interceptors(null_interceptor): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.AutoscalingPolicyServiceRestInterceptor(), + ) + client = AutoscalingPolicyServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "post_list_autoscaling_policies", + ) as post, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "pre_list_autoscaling_policies", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = autoscaling_policies.ListAutoscalingPoliciesRequest.pb( + autoscaling_policies.ListAutoscalingPoliciesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + autoscaling_policies.ListAutoscalingPoliciesResponse.to_json( + autoscaling_policies.ListAutoscalingPoliciesResponse() + ) + ) + + request = autoscaling_policies.ListAutoscalingPoliciesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() + + client.list_autoscaling_policies( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_autoscaling_policies_rest_bad_request( + transport: str = "rest", + request_type=autoscaling_policies.ListAutoscalingPoliciesRequest, +): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_autoscaling_policies(request) + + +def test_list_autoscaling_policies_rest_flattened(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = autoscaling_policies.ListAutoscalingPoliciesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_autoscaling_policies(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/autoscalingPolicies" + % client.transport._host, + args[1], + ) + + +def test_list_autoscaling_policies_rest_flattened_error(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_autoscaling_policies( + autoscaling_policies.ListAutoscalingPoliciesRequest(), + parent="parent_value", + ) + + +def test_list_autoscaling_policies_rest_pager(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + autoscaling_policies.ListAutoscalingPoliciesResponse( + policies=[ + autoscaling_policies.AutoscalingPolicy(), + autoscaling_policies.AutoscalingPolicy(), + autoscaling_policies.AutoscalingPolicy(), + ], + next_page_token="abc", + ), + autoscaling_policies.ListAutoscalingPoliciesResponse( + policies=[], + next_page_token="def", + ), + autoscaling_policies.ListAutoscalingPoliciesResponse( + policies=[ + autoscaling_policies.AutoscalingPolicy(), + ], + next_page_token="ghi", + ), + autoscaling_policies.ListAutoscalingPoliciesResponse( + policies=[ + autoscaling_policies.AutoscalingPolicy(), + autoscaling_policies.AutoscalingPolicy(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + autoscaling_policies.ListAutoscalingPoliciesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_autoscaling_policies(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all( + isinstance(i, autoscaling_policies.AutoscalingPolicy) for i in results + ) + + pages = list(client.list_autoscaling_policies(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + autoscaling_policies.DeleteAutoscalingPolicyRequest, + dict, + ], +) +def test_delete_autoscaling_policy_rest(request_type): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_autoscaling_policy(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_autoscaling_policy_rest_required_fields( + request_type=autoscaling_policies.DeleteAutoscalingPolicyRequest, +): + transport_class = transports.AutoscalingPolicyServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_autoscaling_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_autoscaling_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_autoscaling_policy_rest_unset_required_fields(): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_autoscaling_policy._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_autoscaling_policy_rest_interceptors(null_interceptor): + transport = transports.AutoscalingPolicyServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.AutoscalingPolicyServiceRestInterceptor(), + ) + client = AutoscalingPolicyServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.AutoscalingPolicyServiceRestInterceptor, + "pre_delete_autoscaling_policy", + ) as pre: + pre.assert_not_called() + pb_message = autoscaling_policies.DeleteAutoscalingPolicyRequest.pb( + autoscaling_policies.DeleteAutoscalingPolicyRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = autoscaling_policies.DeleteAutoscalingPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_autoscaling_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_delete_autoscaling_policy_rest_bad_request( + transport: str = "rest", + request_type=autoscaling_policies.DeleteAutoscalingPolicyRequest, +): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_autoscaling_policy(request) + + +def test_delete_autoscaling_policy_rest_flattened(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/autoscalingPolicies/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_autoscaling_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/autoscalingPolicies/*}" + % client.transport._host, + args[1], + ) + + +def test_delete_autoscaling_policy_rest_flattened_error(transport: str = "rest"): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_autoscaling_policy( + autoscaling_policies.DeleteAutoscalingPolicyRequest(), + name="name_value", + ) + + +def test_delete_autoscaling_policy_rest_error(): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = AutoscalingPolicyServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = AutoscalingPolicyServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = AutoscalingPolicyServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = AutoscalingPolicyServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = AutoscalingPolicyServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.AutoscalingPolicyServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + transports.AutoscalingPolicyServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) +def test_transport_kind(transport_name): + transport = AutoscalingPolicyServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = AutoscalingPolicyServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.AutoscalingPolicyServiceGrpcTransport, + ) + + +def test_autoscaling_policy_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.AutoscalingPolicyServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_autoscaling_policy_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.dataproc_v1.services.autoscaling_policy_service.transports.AutoscalingPolicyServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.AutoscalingPolicyServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_autoscaling_policy", + "update_autoscaling_policy", + "get_autoscaling_policy", + "list_autoscaling_policies", + "delete_autoscaling_policy", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() def test_autoscaling_policy_service_base_transport_with_credentials_file(): @@ -2403,6 +3993,7 @@ def test_autoscaling_policy_service_transport_auth_adc(transport_class): [ transports.AutoscalingPolicyServiceGrpcTransport, transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + transports.AutoscalingPolicyServiceRestTransport, ], ) def test_autoscaling_policy_service_transport_auth_gdch_credentials(transport_class): @@ -2504,11 +4095,23 @@ def test_autoscaling_policy_service_grpc_transport_client_cert_source_for_mtls( ) +def test_autoscaling_policy_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.AutoscalingPolicyServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_autoscaling_policy_service_host_no_port(transport_name): @@ -2519,7 +4122,11 @@ def test_autoscaling_policy_service_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -2527,6 +4134,7 @@ def test_autoscaling_policy_service_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_autoscaling_policy_service_host_with_port(transport_name): @@ -2537,7 +4145,45 @@ def test_autoscaling_policy_service_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_autoscaling_policy_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = AutoscalingPolicyServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = AutoscalingPolicyServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_autoscaling_policy._session + session2 = client2.transport.create_autoscaling_policy._session + assert session1 != session2 + session1 = client1.transport.update_autoscaling_policy._session + session2 = client2.transport.update_autoscaling_policy._session + assert session1 != session2 + session1 = client1.transport.get_autoscaling_policy._session + session2 = client2.transport.get_autoscaling_policy._session + assert session1 != session2 + session1 = client1.transport.list_autoscaling_policies._session + session2 = client2.transport.list_autoscaling_policies._session + assert session1 != session2 + session1 = client1.transport.delete_autoscaling_policy._session + session2 = client2.transport.delete_autoscaling_policy._session + assert session1 != session2 def test_autoscaling_policy_service_grpc_transport_channel(): @@ -2836,6 +4482,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -2853,6 +4500,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: diff --git a/tests/unit/gapic/dataproc_v1/test_batch_controller.py b/tests/unit/gapic/dataproc_v1/test_batch_controller.py index 26fa26d6..575fe534 100644 --- a/tests/unit/gapic/dataproc_v1/test_batch_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_batch_controller.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -105,6 +112,7 @@ def test__get_default_mtls_endpoint(): [ (BatchControllerClient, "grpc"), (BatchControllerAsyncClient, "grpc_asyncio"), + (BatchControllerClient, "rest"), ], ) def test_batch_controller_client_from_service_account_info( @@ -120,7 +128,11 @@ def test_batch_controller_client_from_service_account_info( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -128,6 +140,7 @@ def test_batch_controller_client_from_service_account_info( [ (transports.BatchControllerGrpcTransport, "grpc"), (transports.BatchControllerGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.BatchControllerRestTransport, "rest"), ], ) def test_batch_controller_client_service_account_always_use_jwt( @@ -153,6 +166,7 @@ def test_batch_controller_client_service_account_always_use_jwt( [ (BatchControllerClient, "grpc"), (BatchControllerAsyncClient, "grpc_asyncio"), + (BatchControllerClient, "rest"), ], ) def test_batch_controller_client_from_service_account_file( @@ -175,13 +189,18 @@ def test_batch_controller_client_from_service_account_file( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_batch_controller_client_get_transport_class(): transport = BatchControllerClient.get_transport_class() available_transports = [ transports.BatchControllerGrpcTransport, + transports.BatchControllerRestTransport, ] assert transport in available_transports @@ -198,6 +217,7 @@ def test_batch_controller_client_get_transport_class(): transports.BatchControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (BatchControllerClient, transports.BatchControllerRestTransport, "rest"), ], ) @mock.patch.object( @@ -353,6 +373,18 @@ def test_batch_controller_client_client_options( "grpc_asyncio", "false", ), + ( + BatchControllerClient, + transports.BatchControllerRestTransport, + "rest", + "true", + ), + ( + BatchControllerClient, + transports.BatchControllerRestTransport, + "rest", + "false", + ), ], ) @mock.patch.object( @@ -552,6 +584,7 @@ def test_batch_controller_client_get_mtls_endpoint_and_cert_source(client_class) transports.BatchControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (BatchControllerClient, transports.BatchControllerRestTransport, "rest"), ], ) def test_batch_controller_client_client_options_scopes( @@ -592,6 +625,7 @@ def test_batch_controller_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + (BatchControllerClient, transports.BatchControllerRestTransport, "rest", None), ], ) def test_batch_controller_client_client_options_credentials_file( @@ -1841,6 +1875,1276 @@ async def test_delete_batch_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + batches.CreateBatchRequest, + dict, + ], +) +def test_create_batch_rest(request_type): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["batch"] = { + "name": "name_value", + "uuid": "uuid_value", + "create_time": {"seconds": 751, "nanos": 543}, + "pyspark_batch": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": ["python_file_uris_value1", "python_file_uris_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_batch": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_r_batch": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_sql_batch": { + "query_file_uri": "query_file_uri_value", + "query_variables": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "runtime_info": { + "endpoints": {}, + "output_uri": "output_uri_value", + "diagnostic_output_uri": "diagnostic_output_uri_value", + }, + "state": 1, + "state_message": "state_message_value", + "state_time": {}, + "creator": "creator_value", + "labels": {}, + "runtime_config": { + "version": "version_value", + "container_image": "container_image_value", + "properties": {}, + }, + "environment_config": { + "execution_config": { + "service_account": "service_account_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "network_tags": ["network_tags_value1", "network_tags_value2"], + "kms_key": "kms_key_value", + }, + "peripherals_config": { + "metastore_service": "metastore_service_value", + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "operation": "operation_value", + "state_history": [ + {"state": 1, "state_message": "state_message_value", "state_start_time": {}} + ], + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_batch(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_batch_rest_required_fields(request_type=batches.CreateBatchRequest): + transport_class = transports.BatchControllerRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_batch._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_batch._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "batch_id", + "request_id", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_batch(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_batch_rest_unset_required_fields(): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_batch._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "batchId", + "requestId", + ) + ) + & set( + ( + "parent", + "batch", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_batch_rest_interceptors(null_interceptor): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.BatchControllerRestInterceptor(), + ) + client = BatchControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.BatchControllerRestInterceptor, "post_create_batch" + ) as post, mock.patch.object( + transports.BatchControllerRestInterceptor, "pre_create_batch" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = batches.CreateBatchRequest.pb(batches.CreateBatchRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = batches.CreateBatchRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_batch( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_batch_rest_bad_request( + transport: str = "rest", request_type=batches.CreateBatchRequest +): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["batch"] = { + "name": "name_value", + "uuid": "uuid_value", + "create_time": {"seconds": 751, "nanos": 543}, + "pyspark_batch": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": ["python_file_uris_value1", "python_file_uris_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_batch": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_r_batch": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + }, + "spark_sql_batch": { + "query_file_uri": "query_file_uri_value", + "query_variables": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "runtime_info": { + "endpoints": {}, + "output_uri": "output_uri_value", + "diagnostic_output_uri": "diagnostic_output_uri_value", + }, + "state": 1, + "state_message": "state_message_value", + "state_time": {}, + "creator": "creator_value", + "labels": {}, + "runtime_config": { + "version": "version_value", + "container_image": "container_image_value", + "properties": {}, + }, + "environment_config": { + "execution_config": { + "service_account": "service_account_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "network_tags": ["network_tags_value1", "network_tags_value2"], + "kms_key": "kms_key_value", + }, + "peripherals_config": { + "metastore_service": "metastore_service_value", + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "operation": "operation_value", + "state_history": [ + {"state": 1, "state_message": "state_message_value", "state_start_time": {}} + ], + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_batch(request) + + +def test_create_batch_rest_flattened(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + batch=batches.Batch(name="name_value"), + batch_id="batch_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_batch(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/batches" % client.transport._host, + args[1], + ) + + +def test_create_batch_rest_flattened_error(transport: str = "rest"): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_batch( + batches.CreateBatchRequest(), + parent="parent_value", + batch=batches.Batch(name="name_value"), + batch_id="batch_id_value", + ) + + +def test_create_batch_rest_error(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + batches.GetBatchRequest, + dict, + ], +) +def test_get_batch_rest(request_type): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/batches/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = batches.Batch( + name="name_value", + uuid="uuid_value", + state=batches.Batch.State.PENDING, + state_message="state_message_value", + creator="creator_value", + operation="operation_value", + pyspark_batch=batches.PySparkBatch( + main_python_file_uri="main_python_file_uri_value" + ), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = batches.Batch.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_batch(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, batches.Batch) + assert response.name == "name_value" + assert response.uuid == "uuid_value" + assert response.state == batches.Batch.State.PENDING + assert response.state_message == "state_message_value" + assert response.creator == "creator_value" + assert response.operation == "operation_value" + + +def test_get_batch_rest_required_fields(request_type=batches.GetBatchRequest): + transport_class = transports.BatchControllerRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_batch._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_batch._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = batches.Batch() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = batches.Batch.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_batch(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_batch_rest_unset_required_fields(): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_batch._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_batch_rest_interceptors(null_interceptor): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.BatchControllerRestInterceptor(), + ) + client = BatchControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.BatchControllerRestInterceptor, "post_get_batch" + ) as post, mock.patch.object( + transports.BatchControllerRestInterceptor, "pre_get_batch" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = batches.GetBatchRequest.pb(batches.GetBatchRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = batches.Batch.to_json(batches.Batch()) + + request = batches.GetBatchRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = batches.Batch() + + client.get_batch( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_batch_rest_bad_request( + transport: str = "rest", request_type=batches.GetBatchRequest +): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/batches/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_batch(request) + + +def test_get_batch_rest_flattened(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = batches.Batch() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "projects/sample1/locations/sample2/batches/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = batches.Batch.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_batch(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/batches/*}" % client.transport._host, + args[1], + ) + + +def test_get_batch_rest_flattened_error(transport: str = "rest"): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_batch( + batches.GetBatchRequest(), + name="name_value", + ) + + +def test_get_batch_rest_error(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + batches.ListBatchesRequest, + dict, + ], +) +def test_list_batches_rest(request_type): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = batches.ListBatchesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = batches.ListBatchesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_batches(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_batches_rest_required_fields(request_type=batches.ListBatchesRequest): + transport_class = transports.BatchControllerRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_batches._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_batches._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = batches.ListBatchesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = batches.ListBatchesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_batches(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_batches_rest_unset_required_fields(): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_batches._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_batches_rest_interceptors(null_interceptor): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.BatchControllerRestInterceptor(), + ) + client = BatchControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.BatchControllerRestInterceptor, "post_list_batches" + ) as post, mock.patch.object( + transports.BatchControllerRestInterceptor, "pre_list_batches" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = batches.ListBatchesRequest.pb(batches.ListBatchesRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = batches.ListBatchesResponse.to_json( + batches.ListBatchesResponse() + ) + + request = batches.ListBatchesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = batches.ListBatchesResponse() + + client.list_batches( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_batches_rest_bad_request( + transport: str = "rest", request_type=batches.ListBatchesRequest +): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_batches(request) + + +def test_list_batches_rest_flattened(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = batches.ListBatchesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = batches.ListBatchesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_batches(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/batches" % client.transport._host, + args[1], + ) + + +def test_list_batches_rest_flattened_error(transport: str = "rest"): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_batches( + batches.ListBatchesRequest(), + parent="parent_value", + ) + + +def test_list_batches_rest_pager(transport: str = "rest"): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + batches.ListBatchesResponse( + batches=[ + batches.Batch(), + batches.Batch(), + batches.Batch(), + ], + next_page_token="abc", + ), + batches.ListBatchesResponse( + batches=[], + next_page_token="def", + ), + batches.ListBatchesResponse( + batches=[ + batches.Batch(), + ], + next_page_token="ghi", + ), + batches.ListBatchesResponse( + batches=[ + batches.Batch(), + batches.Batch(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(batches.ListBatchesResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_batches(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, batches.Batch) for i in results) + + pages = list(client.list_batches(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + batches.DeleteBatchRequest, + dict, + ], +) +def test_delete_batch_rest(request_type): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/batches/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_batch(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_batch_rest_required_fields(request_type=batches.DeleteBatchRequest): + transport_class = transports.BatchControllerRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_batch._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_batch._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_batch(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_batch_rest_unset_required_fields(): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_batch._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_batch_rest_interceptors(null_interceptor): + transport = transports.BatchControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.BatchControllerRestInterceptor(), + ) + client = BatchControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.BatchControllerRestInterceptor, "pre_delete_batch" + ) as pre: + pre.assert_not_called() + pb_message = batches.DeleteBatchRequest.pb(batches.DeleteBatchRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = batches.DeleteBatchRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_batch( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_delete_batch_rest_bad_request( + transport: str = "rest", request_type=batches.DeleteBatchRequest +): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/batches/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_batch(request) + + +def test_delete_batch_rest_flattened(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "projects/sample1/locations/sample2/batches/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_batch(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/batches/*}" % client.transport._host, + args[1], + ) + + +def test_delete_batch_rest_flattened_error(transport: str = "rest"): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch( + batches.DeleteBatchRequest(), + name="name_value", + ) + + +def test_delete_batch_rest_error(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.BatchControllerGrpcTransport( @@ -1922,6 +3226,7 @@ def test_transport_get_channel(): [ transports.BatchControllerGrpcTransport, transports.BatchControllerGrpcAsyncIOTransport, + transports.BatchControllerRestTransport, ], ) def test_transport_adc(transport_class): @@ -1936,6 +3241,7 @@ def test_transport_adc(transport_class): "transport_name", [ "grpc", + "rest", ], ) def test_transport_kind(transport_name): @@ -2073,6 +3379,7 @@ def test_batch_controller_transport_auth_adc(transport_class): [ transports.BatchControllerGrpcTransport, transports.BatchControllerGrpcAsyncIOTransport, + transports.BatchControllerRestTransport, ], ) def test_batch_controller_transport_auth_gdch_credentials(transport_class): @@ -2170,11 +3477,40 @@ def test_batch_controller_grpc_transport_client_cert_source_for_mtls(transport_c ) +def test_batch_controller_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.BatchControllerRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_batch_controller_rest_lro_client(): + client = BatchControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_batch_controller_host_no_port(transport_name): @@ -2185,7 +3521,11 @@ def test_batch_controller_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -2193,6 +3533,7 @@ def test_batch_controller_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_batch_controller_host_with_port(transport_name): @@ -2203,7 +3544,42 @@ def test_batch_controller_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_batch_controller_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = BatchControllerClient( + credentials=creds1, + transport=transport_name, + ) + client2 = BatchControllerClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_batch._session + session2 = client2.transport.create_batch._session + assert session1 != session2 + session1 = client1.transport.get_batch._session + session2 = client2.transport.get_batch._session + assert session1 != session2 + session1 = client1.transport.list_batches._session + session2 = client2.transport.list_batches._session + assert session1 != session2 + session1 = client1.transport.delete_batch._session + session2 = client2.transport.delete_batch._session + assert session1 != session2 def test_batch_controller_grpc_transport_channel(): @@ -2534,6 +3910,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -2551,6 +3928,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: diff --git a/tests/unit/gapic/dataproc_v1/test_cluster_controller.py b/tests/unit/gapic/dataproc_v1/test_cluster_controller.py index dbb07acb..a73555db 100644 --- a/tests/unit/gapic/dataproc_v1/test_cluster_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_cluster_controller.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -109,6 +116,7 @@ def test__get_default_mtls_endpoint(): [ (ClusterControllerClient, "grpc"), (ClusterControllerAsyncClient, "grpc_asyncio"), + (ClusterControllerClient, "rest"), ], ) def test_cluster_controller_client_from_service_account_info( @@ -124,7 +132,11 @@ def test_cluster_controller_client_from_service_account_info( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -132,6 +144,7 @@ def test_cluster_controller_client_from_service_account_info( [ (transports.ClusterControllerGrpcTransport, "grpc"), (transports.ClusterControllerGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.ClusterControllerRestTransport, "rest"), ], ) def test_cluster_controller_client_service_account_always_use_jwt( @@ -157,6 +170,7 @@ def test_cluster_controller_client_service_account_always_use_jwt( [ (ClusterControllerClient, "grpc"), (ClusterControllerAsyncClient, "grpc_asyncio"), + (ClusterControllerClient, "rest"), ], ) def test_cluster_controller_client_from_service_account_file( @@ -179,13 +193,18 @@ def test_cluster_controller_client_from_service_account_file( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_cluster_controller_client_get_transport_class(): transport = ClusterControllerClient.get_transport_class() available_transports = [ transports.ClusterControllerGrpcTransport, + transports.ClusterControllerRestTransport, ] assert transport in available_transports @@ -202,6 +221,7 @@ def test_cluster_controller_client_get_transport_class(): transports.ClusterControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (ClusterControllerClient, transports.ClusterControllerRestTransport, "rest"), ], ) @mock.patch.object( @@ -357,6 +377,18 @@ def test_cluster_controller_client_client_options( "grpc_asyncio", "false", ), + ( + ClusterControllerClient, + transports.ClusterControllerRestTransport, + "rest", + "true", + ), + ( + ClusterControllerClient, + transports.ClusterControllerRestTransport, + "rest", + "false", + ), ], ) @mock.patch.object( @@ -556,6 +588,7 @@ def test_cluster_controller_client_get_mtls_endpoint_and_cert_source(client_clas transports.ClusterControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (ClusterControllerClient, transports.ClusterControllerRestTransport, "rest"), ], ) def test_cluster_controller_client_client_options_scopes( @@ -596,6 +629,12 @@ def test_cluster_controller_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + ( + ClusterControllerClient, + transports.ClusterControllerRestTransport, + "rest", + None, + ), ], ) def test_cluster_controller_client_client_options_credentials_file( @@ -2729,197 +2768,3236 @@ async def test_diagnose_cluster_flattened_error_async(): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.ClusterControllerGrpcTransport( +@pytest.mark.parametrize( + "request_type", + [ + clusters.CreateClusterRequest, + dict, + ], +) +def test_create_cluster_rest(request_type): + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - with pytest.raises(ValueError): - client = ClusterControllerClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.ClusterControllerGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request_init["cluster"] = { + "project_id": "project_id_value", + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": {"node_group_uri": "node_group_uri_value"}, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": {"enable_confidential_compute": True}, + }, + "master_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": {"gce_pd_kms_key_name": "gce_pd_kms_key_name_value"}, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {"seconds": 751, "nanos": 543}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": {"http_ports": {}, "enable_http_port_access": True}, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "virtual_cluster_config": { + "staging_bucket": "staging_bucket_value", + "kubernetes_cluster_config": { + "kubernetes_namespace": "kubernetes_namespace_value", + "gke_cluster_config": { + "gke_cluster_target": "gke_cluster_target_value", + "node_pool_target": [ + { + "node_pool": "node_pool_value", + "roles": [1], + "node_pool_config": { + "config": { + "machine_type": "machine_type_value", + "preemptible": True, + "local_ssd_count": 1596, + "accelerators": [ + { + "accelerator_count": 1805, + "accelerator_type": "accelerator_type_value", + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "locations": ["locations_value1", "locations_value2"], + "autoscaling": { + "min_node_count": 1489, + "max_node_count": 1491, + }, + }, + } + ], + }, + "kubernetes_software_config": { + "component_version": {}, + "properties": {}, + }, + }, + "auxiliary_services_config": { + "metastore_config": {}, + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "labels": {}, + "status": { + "state": 1, + "detail": "detail_value", + "state_start_time": {}, + "substate": 1, + }, + "status_history": {}, + "cluster_uuid": "cluster_uuid_value", + "metrics": {"hdfs_metrics": {}, "yarn_metrics": {}}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_cluster(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_cluster_rest_required_fields( + request_type=clusters.CreateClusterRequest, +): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) ) - with pytest.raises(ValueError): - client = ClusterControllerClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_cluster._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "action_on_failed_primary_workers", + "request_id", ) + ) + jsonified_request.update(unset_fields) - # It is an error to provide an api_key and a transport instance. - transport = transports.ClusterControllerGrpcTransport( + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = ClusterControllerClient( - client_options=options, - transport=transport, - ) - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = ClusterControllerClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + unset_fields = transport.create_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "actionOnFailedPrimaryWorkers", + "requestId", + ) + ) + & set( + ( + "projectId", + "region", + "cluster", + ) ) + ) - # It is an error to provide scopes and a transport instance. - transport = transports.ClusterControllerGrpcTransport( + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), ) - with pytest.raises(ValueError): - client = ClusterControllerClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_create_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_create_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.CreateClusterRequest.pb(clusters.CreateClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) + request = clusters.CreateClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.ClusterControllerGrpcTransport( + client.create_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.CreateClusterRequest +): + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - client = ClusterControllerClient(transport=transport) - assert client.transport is transport + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request_init["cluster"] = { + "project_id": "project_id_value", + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": {"node_group_uri": "node_group_uri_value"}, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": {"enable_confidential_compute": True}, + }, + "master_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": {"gce_pd_kms_key_name": "gce_pd_kms_key_name_value"}, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {"seconds": 751, "nanos": 543}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": {"http_ports": {}, "enable_http_port_access": True}, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "virtual_cluster_config": { + "staging_bucket": "staging_bucket_value", + "kubernetes_cluster_config": { + "kubernetes_namespace": "kubernetes_namespace_value", + "gke_cluster_config": { + "gke_cluster_target": "gke_cluster_target_value", + "node_pool_target": [ + { + "node_pool": "node_pool_value", + "roles": [1], + "node_pool_config": { + "config": { + "machine_type": "machine_type_value", + "preemptible": True, + "local_ssd_count": 1596, + "accelerators": [ + { + "accelerator_count": 1805, + "accelerator_type": "accelerator_type_value", + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "locations": ["locations_value1", "locations_value2"], + "autoscaling": { + "min_node_count": 1489, + "max_node_count": 1491, + }, + }, + } + ], + }, + "kubernetes_software_config": { + "component_version": {}, + "properties": {}, + }, + }, + "auxiliary_services_config": { + "metastore_config": {}, + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "labels": {}, + "status": { + "state": 1, + "detail": "detail_value", + "state_start_time": {}, + "substate": 1, + }, + "status_history": {}, + "cluster_uuid": "cluster_uuid_value", + "metrics": {"hdfs_metrics": {}, "yarn_metrics": {}}, + } + request = request_type(**request_init) -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.ClusterControllerGrpcTransport( + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_cluster(request) + + +def test_create_cluster_rest_flattened(): + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - channel = transport.grpc_channel - assert channel - transport = transports.ClusterControllerGrpcAsyncIOTransport( + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"project_id": "sample1", "region": "sample2"} + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + cluster=clusters.Cluster(project_id="project_id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_cluster(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters" + % client.transport._host, + args[1], + ) + + +def test_create_cluster_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - channel = transport.grpc_channel - assert channel + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_cluster( + clusters.CreateClusterRequest(), + project_id="project_id_value", + region="region_value", + cluster=clusters.Cluster(project_id="project_id_value"), + ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.ClusterControllerGrpcTransport, - transports.ClusterControllerGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() +def test_create_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) @pytest.mark.parametrize( - "transport_name", + "request_type", [ - "grpc", + clusters.UpdateClusterRequest, + dict, ], ) -def test_transport_kind(transport_name): - transport = ClusterControllerClient.get_transport_class(transport_name)( +def test_update_cluster_rest(request_type): + client = ClusterControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert transport.kind == transport_name + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request_init["cluster"] = { + "project_id": "project_id_value", + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": {"node_group_uri": "node_group_uri_value"}, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": {"enable_confidential_compute": True}, + }, + "master_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": {"gce_pd_kms_key_name": "gce_pd_kms_key_name_value"}, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {"seconds": 751, "nanos": 543}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": {"http_ports": {}, "enable_http_port_access": True}, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "virtual_cluster_config": { + "staging_bucket": "staging_bucket_value", + "kubernetes_cluster_config": { + "kubernetes_namespace": "kubernetes_namespace_value", + "gke_cluster_config": { + "gke_cluster_target": "gke_cluster_target_value", + "node_pool_target": [ + { + "node_pool": "node_pool_value", + "roles": [1], + "node_pool_config": { + "config": { + "machine_type": "machine_type_value", + "preemptible": True, + "local_ssd_count": 1596, + "accelerators": [ + { + "accelerator_count": 1805, + "accelerator_type": "accelerator_type_value", + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "locations": ["locations_value1", "locations_value2"], + "autoscaling": { + "min_node_count": 1489, + "max_node_count": 1491, + }, + }, + } + ], + }, + "kubernetes_software_config": { + "component_version": {}, + "properties": {}, + }, + }, + "auxiliary_services_config": { + "metastore_config": {}, + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "labels": {}, + "status": { + "state": 1, + "detail": "detail_value", + "state_start_time": {}, + "substate": 1, + }, + "status_history": {}, + "cluster_uuid": "cluster_uuid_value", + "metrics": {"hdfs_metrics": {}, "yarn_metrics": {}}, + } + request = request_type(**request_init) -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = ClusterControllerClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ClusterControllerGrpcTransport, - ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) -def test_cluster_controller_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.ClusterControllerTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_cluster(request) + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" -def test_cluster_controller_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.ClusterControllerTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_cluster", - "update_cluster", - "stop_cluster", - "start_cluster", - "delete_cluster", - "get_cluster", - "list_clusters", - "diagnose_cluster", +def test_update_cluster_rest_required_fields( + request_type=clusters.UpdateClusterRequest, +): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - with pytest.raises(NotImplementedError): - transport.close() + # verify fields with default values are dropped - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) - # Catch all for all remaining methods and properties - remainder = [ - "kind", - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() + # verify required fields with default values are now present + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" -def test_cluster_controller_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - google.auth, "load_credentials_from_file", autospec=True - ) as load_creds, mock.patch( - "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.ClusterControllerTransport( - credentials_file="credentials.json", - quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=None, - default_scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_cluster._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "graceful_decommission_timeout", + "request_id", + "update_mask", ) + ) + jsonified_request.update(unset_fields) + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" -def test_cluster_controller_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "gracefulDecommissionTimeout", + "requestId", + "updateMask", + ) + ) + & set( + ( + "projectId", + "region", + "clusterName", + "cluster", + "updateMask", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_update_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_update_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.UpdateClusterRequest.pb(clusters.UpdateClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = clusters.UpdateClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.update_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.UpdateClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request_init["cluster"] = { + "project_id": "project_id_value", + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": {"node_group_uri": "node_group_uri_value"}, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": {"enable_confidential_compute": True}, + }, + "master_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": {"gce_pd_kms_key_name": "gce_pd_kms_key_name_value"}, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {"seconds": 751, "nanos": 543}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": {"http_ports": {}, "enable_http_port_access": True}, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "virtual_cluster_config": { + "staging_bucket": "staging_bucket_value", + "kubernetes_cluster_config": { + "kubernetes_namespace": "kubernetes_namespace_value", + "gke_cluster_config": { + "gke_cluster_target": "gke_cluster_target_value", + "node_pool_target": [ + { + "node_pool": "node_pool_value", + "roles": [1], + "node_pool_config": { + "config": { + "machine_type": "machine_type_value", + "preemptible": True, + "local_ssd_count": 1596, + "accelerators": [ + { + "accelerator_count": 1805, + "accelerator_type": "accelerator_type_value", + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "locations": ["locations_value1", "locations_value2"], + "autoscaling": { + "min_node_count": 1489, + "max_node_count": 1491, + }, + }, + } + ], + }, + "kubernetes_software_config": { + "component_version": {}, + "properties": {}, + }, + }, + "auxiliary_services_config": { + "metastore_config": {}, + "spark_history_server_config": { + "dataproc_cluster": "dataproc_cluster_value" + }, + }, + }, + "labels": {}, + "status": { + "state": 1, + "detail": "detail_value", + "state_start_time": {}, + "substate": 1, + }, + "status_history": {}, + "cluster_uuid": "cluster_uuid_value", + "metrics": {"hdfs_metrics": {}, "yarn_metrics": {}}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_cluster(request) + + +def test_update_cluster_rest_flattened(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + cluster=clusters.Cluster(project_id="project_id_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_cluster(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}" + % client.transport._host, + args[1], + ) + + +def test_update_cluster_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_cluster( + clusters.UpdateClusterRequest(), + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + cluster=clusters.Cluster(project_id="project_id_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_update_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.StopClusterRequest, + dict, + ], +) +def test_stop_cluster_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.stop_cluster(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_stop_cluster_rest_required_fields(request_type=clusters.StopClusterRequest): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).stop_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).stop_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.stop_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_stop_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.stop_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "clusterName", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_stop_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_stop_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_stop_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.StopClusterRequest.pb(clusters.StopClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = clusters.StopClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.stop_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_stop_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.StopClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.stop_cluster(request) + + +def test_stop_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.StartClusterRequest, + dict, + ], +) +def test_start_cluster_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.start_cluster(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_start_cluster_rest_required_fields(request_type=clusters.StartClusterRequest): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).start_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).start_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.start_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_start_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.start_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "clusterName", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_start_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_start_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_start_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.StartClusterRequest.pb(clusters.StartClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = clusters.StartClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.start_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_start_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.StartClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.start_cluster(request) + + +def test_start_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.DeleteClusterRequest, + dict, + ], +) +def test_delete_cluster_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_cluster(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_delete_cluster_rest_required_fields( + request_type=clusters.DeleteClusterRequest, +): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_cluster._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "cluster_uuid", + "request_id", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "clusterUuid", + "requestId", + ) + ) + & set( + ( + "projectId", + "region", + "clusterName", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_delete_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_delete_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.DeleteClusterRequest.pb(clusters.DeleteClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = clusters.DeleteClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.delete_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.DeleteClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_cluster(request) + + +def test_delete_cluster_rest_flattened(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_cluster(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}" + % client.transport._host, + args[1], + ) + + +def test_delete_cluster_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_cluster( + clusters.DeleteClusterRequest(), + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + + +def test_delete_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.GetClusterRequest, + dict, + ], +) +def test_get_cluster_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.Cluster( + project_id="project_id_value", + cluster_name="cluster_name_value", + cluster_uuid="cluster_uuid_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.Cluster.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_cluster(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, clusters.Cluster) + assert response.project_id == "project_id_value" + assert response.cluster_name == "cluster_name_value" + assert response.cluster_uuid == "cluster_uuid_value" + + +def test_get_cluster_rest_required_fields(request_type=clusters.GetClusterRequest): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = clusters.Cluster() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = clusters.Cluster.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "clusterName", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_get_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_get_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.GetClusterRequest.pb(clusters.GetClusterRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = clusters.Cluster.to_json(clusters.Cluster()) + + request = clusters.GetClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = clusters.Cluster() + + client.get_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.GetClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_cluster(request) + + +def test_get_cluster_rest_flattened(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.Cluster() + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.Cluster.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_cluster(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}" + % client.transport._host, + args[1], + ) + + +def test_get_cluster_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_cluster( + clusters.GetClusterRequest(), + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + + +def test_get_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.ListClustersRequest, + dict, + ], +) +def test_list_clusters_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.ListClustersResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.ListClustersResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_clusters(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListClustersPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_clusters_rest_required_fields(request_type=clusters.ListClustersRequest): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_clusters._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_clusters._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = clusters.ListClustersResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = clusters.ListClustersResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_clusters(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_clusters_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_clusters._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "pageSize", + "pageToken", + ) + ) + & set( + ( + "projectId", + "region", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_clusters_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_list_clusters" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_list_clusters" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.ListClustersRequest.pb(clusters.ListClustersRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = clusters.ListClustersResponse.to_json( + clusters.ListClustersResponse() + ) + + request = clusters.ListClustersRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = clusters.ListClustersResponse() + + client.list_clusters( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_clusters_rest_bad_request( + transport: str = "rest", request_type=clusters.ListClustersRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_clusters(request) + + +def test_list_clusters_rest_flattened(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.ListClustersResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"project_id": "sample1", "region": "sample2"} + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + filter="filter_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.ListClustersResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_clusters(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters" + % client.transport._host, + args[1], + ) + + +def test_list_clusters_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_clusters( + clusters.ListClustersRequest(), + project_id="project_id_value", + region="region_value", + filter="filter_value", + ) + + +def test_list_clusters_rest_pager(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + clusters.ListClustersResponse( + clusters=[ + clusters.Cluster(), + clusters.Cluster(), + clusters.Cluster(), + ], + next_page_token="abc", + ), + clusters.ListClustersResponse( + clusters=[], + next_page_token="def", + ), + clusters.ListClustersResponse( + clusters=[ + clusters.Cluster(), + ], + next_page_token="ghi", + ), + clusters.ListClustersResponse( + clusters=[ + clusters.Cluster(), + clusters.Cluster(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(clusters.ListClustersResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"project_id": "sample1", "region": "sample2"} + + pager = client.list_clusters(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, clusters.Cluster) for i in results) + + pages = list(client.list_clusters(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + clusters.DiagnoseClusterRequest, + dict, + ], +) +def test_diagnose_cluster_rest(request_type): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.diagnose_cluster(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_diagnose_cluster_rest_required_fields( + request_type=clusters.DiagnoseClusterRequest, +): + transport_class = transports.ClusterControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["cluster_name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).diagnose_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["clusterName"] = "cluster_name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).diagnose_cluster._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "clusterName" in jsonified_request + assert jsonified_request["clusterName"] == "cluster_name_value" + + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.diagnose_cluster(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_diagnose_cluster_rest_unset_required_fields(): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.diagnose_cluster._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "clusterName", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_diagnose_cluster_rest_interceptors(null_interceptor): + transport = transports.ClusterControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ClusterControllerRestInterceptor(), + ) + client = ClusterControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ClusterControllerRestInterceptor, "post_diagnose_cluster" + ) as post, mock.patch.object( + transports.ClusterControllerRestInterceptor, "pre_diagnose_cluster" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = clusters.DiagnoseClusterRequest.pb( + clusters.DiagnoseClusterRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = clusters.DiagnoseClusterRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.diagnose_cluster( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_diagnose_cluster_rest_bad_request( + transport: str = "rest", request_type=clusters.DiagnoseClusterRequest +): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.diagnose_cluster(request) + + +def test_diagnose_cluster_rest_flattened(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "cluster_name": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.diagnose_cluster(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/clusters/{cluster_name}:diagnose" + % client.transport._host, + args[1], + ) + + +def test_diagnose_cluster_rest_flattened_error(transport: str = "rest"): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.diagnose_cluster( + clusters.DiagnoseClusterRequest(), + project_id="project_id_value", + region="region_value", + cluster_name="cluster_name_value", + ) + + +def test_diagnose_cluster_rest_error(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ClusterControllerClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ClusterControllerClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ClusterControllerClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ClusterControllerClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ClusterControllerClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ClusterControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ClusterControllerGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + transports.ClusterControllerRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) +def test_transport_kind(transport_name): + transport = ClusterControllerClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ClusterControllerGrpcTransport, + ) + + +def test_cluster_controller_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ClusterControllerTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_cluster_controller_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ClusterControllerTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_cluster", + "update_cluster", + "stop_cluster", + "start_cluster", + "delete_cluster", + "get_cluster", + "list_clusters", + "diagnose_cluster", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_cluster_controller_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ClusterControllerTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_cluster_controller_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None @@ -2965,6 +6043,7 @@ def test_cluster_controller_transport_auth_adc(transport_class): [ transports.ClusterControllerGrpcTransport, transports.ClusterControllerGrpcAsyncIOTransport, + transports.ClusterControllerRestTransport, ], ) def test_cluster_controller_transport_auth_gdch_credentials(transport_class): @@ -3062,11 +6141,40 @@ def test_cluster_controller_grpc_transport_client_cert_source_for_mtls(transport ) +def test_cluster_controller_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ClusterControllerRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_cluster_controller_rest_lro_client(): + client = ClusterControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_cluster_controller_host_no_port(transport_name): @@ -3077,7 +6185,11 @@ def test_cluster_controller_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -3085,6 +6197,7 @@ def test_cluster_controller_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_cluster_controller_host_with_port(transport_name): @@ -3095,7 +6208,54 @@ def test_cluster_controller_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_cluster_controller_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = ClusterControllerClient( + credentials=creds1, + transport=transport_name, + ) + client2 = ClusterControllerClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_cluster._session + session2 = client2.transport.create_cluster._session + assert session1 != session2 + session1 = client1.transport.update_cluster._session + session2 = client2.transport.update_cluster._session + assert session1 != session2 + session1 = client1.transport.stop_cluster._session + session2 = client2.transport.stop_cluster._session + assert session1 != session2 + session1 = client1.transport.start_cluster._session + session2 = client2.transport.start_cluster._session + assert session1 != session2 + session1 = client1.transport.delete_cluster._session + session2 = client2.transport.delete_cluster._session + assert session1 != session2 + session1 = client1.transport.get_cluster._session + session2 = client2.transport.get_cluster._session + assert session1 != session2 + session1 = client1.transport.list_clusters._session + session2 = client2.transport.list_clusters._session + assert session1 != session2 + session1 = client1.transport.diagnose_cluster._session + session2 = client2.transport.diagnose_cluster._session + assert session1 != session2 def test_cluster_controller_grpc_transport_channel(): @@ -3457,6 +6617,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -3474,6 +6635,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: diff --git a/tests/unit/gapic/dataproc_v1/test_job_controller.py b/tests/unit/gapic/dataproc_v1/test_job_controller.py index 54554155..48c08540 100644 --- a/tests/unit/gapic/dataproc_v1/test_job_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_job_controller.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -102,6 +109,7 @@ def test__get_default_mtls_endpoint(): [ (JobControllerClient, "grpc"), (JobControllerAsyncClient, "grpc_asyncio"), + (JobControllerClient, "rest"), ], ) def test_job_controller_client_from_service_account_info(client_class, transport_name): @@ -115,7 +123,11 @@ def test_job_controller_client_from_service_account_info(client_class, transport assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -123,6 +135,7 @@ def test_job_controller_client_from_service_account_info(client_class, transport [ (transports.JobControllerGrpcTransport, "grpc"), (transports.JobControllerGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.JobControllerRestTransport, "rest"), ], ) def test_job_controller_client_service_account_always_use_jwt( @@ -148,6 +161,7 @@ def test_job_controller_client_service_account_always_use_jwt( [ (JobControllerClient, "grpc"), (JobControllerAsyncClient, "grpc_asyncio"), + (JobControllerClient, "rest"), ], ) def test_job_controller_client_from_service_account_file(client_class, transport_name): @@ -168,13 +182,18 @@ def test_job_controller_client_from_service_account_file(client_class, transport assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_job_controller_client_get_transport_class(): transport = JobControllerClient.get_transport_class() available_transports = [ transports.JobControllerGrpcTransport, + transports.JobControllerRestTransport, ] assert transport in available_transports @@ -191,6 +210,7 @@ def test_job_controller_client_get_transport_class(): transports.JobControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (JobControllerClient, transports.JobControllerRestTransport, "rest"), ], ) @mock.patch.object( @@ -336,6 +356,8 @@ def test_job_controller_client_client_options( "grpc_asyncio", "false", ), + (JobControllerClient, transports.JobControllerRestTransport, "rest", "true"), + (JobControllerClient, transports.JobControllerRestTransport, "rest", "false"), ], ) @mock.patch.object( @@ -535,6 +557,7 @@ def test_job_controller_client_get_mtls_endpoint_and_cert_source(client_class): transports.JobControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + (JobControllerClient, transports.JobControllerRestTransport, "rest"), ], ) def test_job_controller_client_client_options_scopes( @@ -575,6 +598,7 @@ def test_job_controller_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + (JobControllerClient, transports.JobControllerRestTransport, "rest", None), ], ) def test_job_controller_client_client_options_credentials_file( @@ -2598,191 +2622,2414 @@ async def test_delete_job_flattened_error_async(): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.JobControllerGrpcTransport( +@pytest.mark.parametrize( + "request_type", + [ + jobs.SubmitJobRequest, + dict, + ], +) +def test_submit_job_rest(request_type): + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - with pytest.raises(ValueError): - client = JobControllerClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job( + driver_output_resource_uri="driver_output_resource_uri_value", + driver_control_files_uri="driver_control_files_uri_value", + job_uuid="job_uuid_value", + done=True, + hadoop_job=jobs.HadoopJob(main_jar_file_uri="main_jar_file_uri_value"), ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.JobControllerGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobControllerClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.submit_job(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) + assert response.driver_output_resource_uri == "driver_output_resource_uri_value" + assert response.driver_control_files_uri == "driver_control_files_uri_value" + assert response.job_uuid == "job_uuid_value" + assert response.done is True + + +def test_submit_job_rest_required_fields(request_type=jobs.SubmitJobRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, ) + ) - # It is an error to provide an api_key and a transport instance. - transport = transports.JobControllerGrpcTransport( + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).submit_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).submit_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.submit_job(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_submit_job_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = JobControllerClient( - client_options=options, - transport=transport, - ) - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = JobControllerClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + unset_fields = transport.submit_job._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "job", + ) ) + ) - # It is an error to provide scopes and a transport instance. - transport = transports.JobControllerGrpcTransport( + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_submit_job_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), ) - with pytest.raises(ValueError): - client = JobControllerClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "post_submit_job" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_submit_job" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.SubmitJobRequest.pb(jobs.SubmitJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = jobs.Job.to_json(jobs.Job()) + + request = jobs.SubmitJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = jobs.Job() + + client.submit_job( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) + pre.assert_called_once() + post.assert_called_once() -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobControllerGrpcTransport( + +def test_submit_job_rest_bad_request( + transport: str = "rest", request_type=jobs.SubmitJobRequest +): + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - client = JobControllerClient(transport=transport) - assert client.transport is transport + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobControllerGrpcTransport( + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.submit_job(request) + + +def test_submit_job_rest_flattened(): + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - channel = transport.grpc_channel - assert channel - transport = transports.JobControllerGrpcAsyncIOTransport( + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + + # get arguments that satisfy an http rule for this method + sample_request = {"project_id": "sample1", "region": "sample2"} + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + job=jobs.Job(reference=jobs.JobReference(project_id="project_id_value")), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.submit_job(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs:submit" + % client.transport._host, + args[1], + ) + + +def test_submit_job_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - channel = transport.grpc_channel - assert channel + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.submit_job( + jobs.SubmitJobRequest(), + project_id="project_id_value", + region="region_value", + job=jobs.Job(reference=jobs.JobReference(project_id="project_id_value")), + ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.JobControllerGrpcTransport, - transports.JobControllerGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() +def test_submit_job_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) @pytest.mark.parametrize( - "transport_name", + "request_type", [ - "grpc", + jobs.SubmitJobRequest, + dict, ], ) -def test_transport_kind(transport_name): - transport = JobControllerClient.get_transport_class(transport_name)( +def test_submit_job_as_operation_rest(request_type): + client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert transport.kind == transport_name + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.submit_job_as_operation(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_submit_job_as_operation_rest_required_fields( + request_type=jobs.SubmitJobRequest, +): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).submit_job_as_operation._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).submit_job_as_operation._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. client = JobControllerClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert isinstance( - client.transport, - transports.JobControllerGrpcTransport, + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.submit_job_as_operation(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_submit_job_as_operation_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials ) + unset_fields = transport.submit_job_as_operation._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "job", + ) + ) + ) -def test_job_controller_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.JobControllerTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json", + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_submit_job_as_operation_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.JobControllerRestInterceptor, "post_submit_job_as_operation" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_submit_job_as_operation" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.SubmitJobRequest.pb(jobs.SubmitJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) + request = jobs.SubmitJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() -def test_job_controller_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.dataproc_v1.services.job_controller.transports.JobControllerTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.JobControllerTransport( - credentials=ga_credentials.AnonymousCredentials(), + client.submit_job_as_operation( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "submit_job", - "submit_job_as_operation", - "get_job", - "list_jobs", - "update_job", - "cancel_job", - "delete_job", + pre.assert_called_once() + post.assert_called_once() + + +def test_submit_job_as_operation_rest_bad_request( + transport: str = "rest", request_type=jobs.SubmitJobRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - with pytest.raises(NotImplementedError): - transport.close() + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.submit_job_as_operation(request) - # Catch all for all remaining methods and properties - remainder = [ - "kind", - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() +def test_submit_job_as_operation_rest_flattened(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) -def test_job_controller_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - google.auth, "load_credentials_from_file", autospec=True - ) as load_creds, mock.patch( - "google.cloud.dataproc_v1.services.job_controller.transports.JobControllerTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.JobControllerTransport( - credentials_file="credentials.json", - quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=None, - default_scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"project_id": "sample1", "region": "sample2"} + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + job=jobs.Job(reference=jobs.JobReference(project_id="project_id_value")), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.submit_job_as_operation(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs:submitAsOperation" + % client.transport._host, + args[1], + ) + + +def test_submit_job_as_operation_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.submit_job_as_operation( + jobs.SubmitJobRequest(), + project_id="project_id_value", + region="region_value", + job=jobs.Job(reference=jobs.JobReference(project_id="project_id_value")), + ) + + +def test_submit_job_as_operation_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + jobs.GetJobRequest, + dict, + ], +) +def test_get_job_rest(request_type): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job( + driver_output_resource_uri="driver_output_resource_uri_value", + driver_control_files_uri="driver_control_files_uri_value", + job_uuid="job_uuid_value", + done=True, + hadoop_job=jobs.HadoopJob(main_jar_file_uri="main_jar_file_uri_value"), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_job(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) + assert response.driver_output_resource_uri == "driver_output_resource_uri_value" + assert response.driver_control_files_uri == "driver_control_files_uri_value" + assert response.job_uuid == "job_uuid_value" + assert response.done is True + + +def test_get_job_rest_required_fields(request_type=jobs.GetJobRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["job_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["jobId"] = "job_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "jobId" in jsonified_request + assert jsonified_request["jobId"] == "job_id_value" + + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_job(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_job_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_job._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "jobId", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_job_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "post_get_job" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_get_job" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.GetJobRequest.pb(jobs.GetJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = jobs.Job.to_json(jobs.Job()) + + request = jobs.GetJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = jobs.Job() + + client.get_job( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_job_rest_bad_request( + transport: str = "rest", request_type=jobs.GetJobRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_job(request) + + +def test_get_job_rest_flattened(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "job_id": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_job(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs/{job_id}" + % client.transport._host, + args[1], + ) + + +def test_get_job_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_job( + jobs.GetJobRequest(), + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + + +def test_get_job_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + jobs.ListJobsRequest, + dict, + ], +) +def test_list_jobs_rest(request_type): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.ListJobsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.ListJobsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_jobs(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_jobs_rest_required_fields(request_type=jobs.ListJobsRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_jobs._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_jobs._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "cluster_name", + "filter", + "job_state_matcher", + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = jobs.ListJobsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = jobs.ListJobsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_jobs(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_jobs_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_jobs._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "clusterName", + "filter", + "jobStateMatcher", + "pageSize", + "pageToken", + ) + ) + & set( + ( + "projectId", + "region", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_jobs_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "post_list_jobs" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_list_jobs" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.ListJobsRequest.pb(jobs.ListJobsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = jobs.ListJobsResponse.to_json( + jobs.ListJobsResponse() + ) + + request = jobs.ListJobsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = jobs.ListJobsResponse() + + client.list_jobs( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_jobs_rest_bad_request( + transport: str = "rest", request_type=jobs.ListJobsRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_jobs(request) + + +def test_list_jobs_rest_flattened(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.ListJobsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"project_id": "sample1", "region": "sample2"} + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + filter="filter_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.ListJobsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_jobs(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs" + % client.transport._host, + args[1], + ) + + +def test_list_jobs_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_jobs( + jobs.ListJobsRequest(), + project_id="project_id_value", + region="region_value", + filter="filter_value", + ) + + +def test_list_jobs_rest_pager(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + jobs.ListJobsResponse( + jobs=[ + jobs.Job(), + jobs.Job(), + jobs.Job(), + ], + next_page_token="abc", + ), + jobs.ListJobsResponse( + jobs=[], + next_page_token="def", + ), + jobs.ListJobsResponse( + jobs=[ + jobs.Job(), + ], + next_page_token="ghi", + ), + jobs.ListJobsResponse( + jobs=[ + jobs.Job(), + jobs.Job(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(jobs.ListJobsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"project_id": "sample1", "region": "sample2"} + + pager = client.list_jobs(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, jobs.Job) for i in results) + + pages = list(client.list_jobs(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + jobs.UpdateJobRequest, + dict, + ], +) +def test_update_job_rest(request_type): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request_init["job"] = { + "reference": {"project_id": "project_id_value", "job_id": "job_id_value"}, + "placement": { + "cluster_name": "cluster_name_value", + "cluster_uuid": "cluster_uuid_value", + "cluster_labels": {}, + }, + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": ["python_file_uris_value1", "python_file_uris_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "status": { + "state": 1, + "details": "details_value", + "state_start_time": {"seconds": 751, "nanos": 543}, + "substate": 1, + }, + "status_history": {}, + "yarn_applications": [ + { + "name": "name_value", + "state": 1, + "progress": 0.885, + "tracking_url": "tracking_url_value", + } + ], + "driver_output_resource_uri": "driver_output_resource_uri_value", + "driver_control_files_uri": "driver_control_files_uri_value", + "labels": {}, + "scheduling": {"max_failures_per_hour": 2243, "max_failures_total": 1923}, + "job_uuid": "job_uuid_value", + "done": True, + "driver_scheduling_config": {"memory_mb": 967, "vcores": 658}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job( + driver_output_resource_uri="driver_output_resource_uri_value", + driver_control_files_uri="driver_control_files_uri_value", + job_uuid="job_uuid_value", + done=True, + hadoop_job=jobs.HadoopJob(main_jar_file_uri="main_jar_file_uri_value"), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_job(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) + assert response.driver_output_resource_uri == "driver_output_resource_uri_value" + assert response.driver_control_files_uri == "driver_control_files_uri_value" + assert response.job_uuid == "job_uuid_value" + assert response.done is True + + +def test_update_job_rest_required_fields(request_type=jobs.UpdateJobRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["job_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["jobId"] = "job_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_job._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "jobId" in jsonified_request + assert jsonified_request["jobId"] == "job_id_value" + + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_job(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_job_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_job._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "projectId", + "region", + "jobId", + "job", + "updateMask", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_job_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "post_update_job" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_update_job" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.UpdateJobRequest.pb(jobs.UpdateJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = jobs.Job.to_json(jobs.Job()) + + request = jobs.UpdateJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = jobs.Job() + + client.update_job( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_job_rest_bad_request( + transport: str = "rest", request_type=jobs.UpdateJobRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request_init["job"] = { + "reference": {"project_id": "project_id_value", "job_id": "job_id_value"}, + "placement": { + "cluster_name": "cluster_name_value", + "cluster_uuid": "cluster_uuid_value", + "cluster_labels": {}, + }, + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": ["python_file_uris_value1", "python_file_uris_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "status": { + "state": 1, + "details": "details_value", + "state_start_time": {"seconds": 751, "nanos": 543}, + "substate": 1, + }, + "status_history": {}, + "yarn_applications": [ + { + "name": "name_value", + "state": 1, + "progress": 0.885, + "tracking_url": "tracking_url_value", + } + ], + "driver_output_resource_uri": "driver_output_resource_uri_value", + "driver_control_files_uri": "driver_control_files_uri_value", + "labels": {}, + "scheduling": {"max_failures_per_hour": 2243, "max_failures_total": 1923}, + "job_uuid": "job_uuid_value", + "done": True, + "driver_scheduling_config": {"memory_mb": 967, "vcores": 658}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_job(request) + + +def test_update_job_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + jobs.CancelJobRequest, + dict, + ], +) +def test_cancel_job_rest(request_type): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job( + driver_output_resource_uri="driver_output_resource_uri_value", + driver_control_files_uri="driver_control_files_uri_value", + job_uuid="job_uuid_value", + done=True, + hadoop_job=jobs.HadoopJob(main_jar_file_uri="main_jar_file_uri_value"), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.cancel_job(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) + assert response.driver_output_resource_uri == "driver_output_resource_uri_value" + assert response.driver_control_files_uri == "driver_control_files_uri_value" + assert response.job_uuid == "job_uuid_value" + assert response.done is True + + +def test_cancel_job_rest_required_fields(request_type=jobs.CancelJobRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["job_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).cancel_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["jobId"] = "job_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).cancel_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "jobId" in jsonified_request + assert jsonified_request["jobId"] == "job_id_value" + + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.cancel_job(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_cancel_job_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.cancel_job._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "jobId", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_cancel_job_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "post_cancel_job" + ) as post, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_cancel_job" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = jobs.CancelJobRequest.pb(jobs.CancelJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = jobs.Job.to_json(jobs.Job()) + + request = jobs.CancelJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = jobs.Job() + + client.cancel_job( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_cancel_job_rest_bad_request( + transport: str = "rest", request_type=jobs.CancelJobRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.cancel_job(request) + + +def test_cancel_job_rest_flattened(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = jobs.Job() + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "job_id": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = jobs.Job.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.cancel_job(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs/{job_id}:cancel" + % client.transport._host, + args[1], + ) + + +def test_cancel_job_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_job( + jobs.CancelJobRequest(), + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + + +def test_cancel_job_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + jobs.DeleteJobRequest, + dict, + ], +) +def test_delete_job_rest(request_type): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_job(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_job_rest_required_fields(request_type=jobs.DeleteJobRequest): + transport_class = transports.JobControllerRestTransport + + request_init = {} + request_init["project_id"] = "" + request_init["region"] = "" + request_init["job_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["projectId"] = "project_id_value" + jsonified_request["region"] = "region_value" + jsonified_request["jobId"] = "job_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_job._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "projectId" in jsonified_request + assert jsonified_request["projectId"] == "project_id_value" + assert "region" in jsonified_request + assert jsonified_request["region"] == "region_value" + assert "jobId" in jsonified_request + assert jsonified_request["jobId"] == "job_id_value" + + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_job(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_job_rest_unset_required_fields(): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_job._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "projectId", + "region", + "jobId", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_job_rest_interceptors(null_interceptor): + transport = transports.JobControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.JobControllerRestInterceptor(), + ) + client = JobControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.JobControllerRestInterceptor, "pre_delete_job" + ) as pre: + pre.assert_not_called() + pb_message = jobs.DeleteJobRequest.pb(jobs.DeleteJobRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = jobs.DeleteJobRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_job( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_delete_job_rest_bad_request( + transport: str = "rest", request_type=jobs.DeleteJobRequest +): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"project_id": "sample1", "region": "sample2", "job_id": "sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_job(request) + + +def test_delete_job_rest_flattened(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = { + "project_id": "sample1", + "region": "sample2", + "job_id": "sample3", + } + + # get truthy value for each flattened field + mock_args = dict( + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_job(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/projects/{project_id}/regions/{region}/jobs/{job_id}" + % client.transport._host, + args[1], + ) + + +def test_delete_job_rest_flattened_error(transport: str = "rest"): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_job( + jobs.DeleteJobRequest(), + project_id="project_id_value", + region="region_value", + job_id="job_id_value", + ) + + +def test_delete_job_rest_error(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobControllerClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobControllerClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobControllerClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobControllerClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = JobControllerClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobControllerGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.JobControllerGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + transports.JobControllerRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) +def test_transport_kind(transport_name): + transport = JobControllerClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.JobControllerGrpcTransport, + ) + + +def test_job_controller_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.JobControllerTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_job_controller_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.dataproc_v1.services.job_controller.transports.JobControllerTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.JobControllerTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "submit_job", + "submit_job_as_operation", + "get_job", + "list_jobs", + "update_job", + "cancel_job", + "delete_job", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_job_controller_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.dataproc_v1.services.job_controller.transports.JobControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.JobControllerTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) def test_job_controller_base_transport_with_adc(): @@ -2833,6 +5080,7 @@ def test_job_controller_transport_auth_adc(transport_class): [ transports.JobControllerGrpcTransport, transports.JobControllerGrpcAsyncIOTransport, + transports.JobControllerRestTransport, ], ) def test_job_controller_transport_auth_gdch_credentials(transport_class): @@ -2930,11 +5178,40 @@ def test_job_controller_grpc_transport_client_cert_source_for_mtls(transport_cla ) +def test_job_controller_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.JobControllerRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_job_controller_rest_lro_client(): + client = JobControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_job_controller_host_no_port(transport_name): @@ -2945,7 +5222,11 @@ def test_job_controller_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -2953,6 +5234,7 @@ def test_job_controller_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_job_controller_host_with_port(transport_name): @@ -2963,7 +5245,51 @@ def test_job_controller_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_job_controller_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = JobControllerClient( + credentials=creds1, + transport=transport_name, + ) + client2 = JobControllerClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.submit_job._session + session2 = client2.transport.submit_job._session + assert session1 != session2 + session1 = client1.transport.submit_job_as_operation._session + session2 = client2.transport.submit_job_as_operation._session + assert session1 != session2 + session1 = client1.transport.get_job._session + session2 = client2.transport.get_job._session + assert session1 != session2 + session1 = client1.transport.list_jobs._session + session2 = client2.transport.list_jobs._session + assert session1 != session2 + session1 = client1.transport.update_job._session + session2 = client2.transport.update_job._session + assert session1 != session2 + session1 = client1.transport.cancel_job._session + session2 = client2.transport.cancel_job._session + assert session1 != session2 + session1 = client1.transport.delete_job._session + session2 = client2.transport.delete_job._session + assert session1 != session2 def test_job_controller_grpc_transport_channel(): @@ -3266,6 +5592,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -3283,6 +5610,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: diff --git a/tests/unit/gapic/dataproc_v1/test_node_group_controller.py b/tests/unit/gapic/dataproc_v1/test_node_group_controller.py index 3046da76..7cef549c 100644 --- a/tests/unit/gapic/dataproc_v1/test_node_group_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_node_group_controller.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -107,6 +114,7 @@ def test__get_default_mtls_endpoint(): [ (NodeGroupControllerClient, "grpc"), (NodeGroupControllerAsyncClient, "grpc_asyncio"), + (NodeGroupControllerClient, "rest"), ], ) def test_node_group_controller_client_from_service_account_info( @@ -122,7 +130,11 @@ def test_node_group_controller_client_from_service_account_info( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -130,6 +142,7 @@ def test_node_group_controller_client_from_service_account_info( [ (transports.NodeGroupControllerGrpcTransport, "grpc"), (transports.NodeGroupControllerGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.NodeGroupControllerRestTransport, "rest"), ], ) def test_node_group_controller_client_service_account_always_use_jwt( @@ -155,6 +168,7 @@ def test_node_group_controller_client_service_account_always_use_jwt( [ (NodeGroupControllerClient, "grpc"), (NodeGroupControllerAsyncClient, "grpc_asyncio"), + (NodeGroupControllerClient, "rest"), ], ) def test_node_group_controller_client_from_service_account_file( @@ -177,13 +191,18 @@ def test_node_group_controller_client_from_service_account_file( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_node_group_controller_client_get_transport_class(): transport = NodeGroupControllerClient.get_transport_class() available_transports = [ transports.NodeGroupControllerGrpcTransport, + transports.NodeGroupControllerRestTransport, ] assert transport in available_transports @@ -204,6 +223,11 @@ def test_node_group_controller_client_get_transport_class(): transports.NodeGroupControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + NodeGroupControllerClient, + transports.NodeGroupControllerRestTransport, + "rest", + ), ], ) @mock.patch.object( @@ -359,6 +383,18 @@ def test_node_group_controller_client_client_options( "grpc_asyncio", "false", ), + ( + NodeGroupControllerClient, + transports.NodeGroupControllerRestTransport, + "rest", + "true", + ), + ( + NodeGroupControllerClient, + transports.NodeGroupControllerRestTransport, + "rest", + "false", + ), ], ) @mock.patch.object( @@ -562,6 +598,11 @@ def test_node_group_controller_client_get_mtls_endpoint_and_cert_source(client_c transports.NodeGroupControllerGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + NodeGroupControllerClient, + transports.NodeGroupControllerRestTransport, + "rest", + ), ], ) def test_node_group_controller_client_client_options_scopes( @@ -602,6 +643,12 @@ def test_node_group_controller_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + ( + NodeGroupControllerClient, + transports.NodeGroupControllerRestTransport, + "rest", + None, + ), ], ) def test_node_group_controller_client_client_options_credentials_file( @@ -1456,6 +1503,912 @@ async def test_get_node_group_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + node_groups.CreateNodeGroupRequest, + dict, + ], +) +def test_create_node_group_rest(request_type): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/regions/sample2/clusters/sample3"} + request_init["node_group"] = { + "name": "name_value", + "roles": [1], + "node_group_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "labels": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_node_group(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_node_group_rest_required_fields( + request_type=node_groups.CreateNodeGroupRequest, +): + transport_class = transports.NodeGroupControllerRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_node_group._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_node_group._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "node_group_id", + "request_id", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_node_group(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_node_group_rest_unset_required_fields(): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_node_group._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "nodeGroupId", + "requestId", + ) + ) + & set( + ( + "parent", + "nodeGroup", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_node_group_rest_interceptors(null_interceptor): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NodeGroupControllerRestInterceptor(), + ) + client = NodeGroupControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "post_create_node_group" + ) as post, mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "pre_create_node_group" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = node_groups.CreateNodeGroupRequest.pb( + node_groups.CreateNodeGroupRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = node_groups.CreateNodeGroupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_node_group( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_node_group_rest_bad_request( + transport: str = "rest", request_type=node_groups.CreateNodeGroupRequest +): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/regions/sample2/clusters/sample3"} + request_init["node_group"] = { + "name": "name_value", + "roles": [1], + "node_group_config": { + "num_instances": 1399, + "instance_names": ["instance_names_value1", "instance_names_value2"], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "labels": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_node_group(request) + + +def test_create_node_group_rest_flattened(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/regions/sample2/clusters/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + node_group=clusters.NodeGroup(name="name_value"), + node_group_id="node_group_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_node_group(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/regions/*/clusters/*}/nodeGroups" + % client.transport._host, + args[1], + ) + + +def test_create_node_group_rest_flattened_error(transport: str = "rest"): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_node_group( + node_groups.CreateNodeGroupRequest(), + parent="parent_value", + node_group=clusters.NodeGroup(name="name_value"), + node_group_id="node_group_id_value", + ) + + +def test_create_node_group_rest_error(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + node_groups.ResizeNodeGroupRequest, + dict, + ], +) +def test_resize_node_group_rest(request_type): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.resize_node_group(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_resize_node_group_rest_required_fields( + request_type=node_groups.ResizeNodeGroupRequest, +): + transport_class = transports.NodeGroupControllerRestTransport + + request_init = {} + request_init["name"] = "" + request_init["size"] = 0 + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).resize_node_group._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + jsonified_request["size"] = 443 + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).resize_node_group._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + assert "size" in jsonified_request + assert jsonified_request["size"] == 443 + + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.resize_node_group(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_resize_node_group_rest_unset_required_fields(): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.resize_node_group._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "size", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_resize_node_group_rest_interceptors(null_interceptor): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NodeGroupControllerRestInterceptor(), + ) + client = NodeGroupControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "post_resize_node_group" + ) as post, mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "pre_resize_node_group" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = node_groups.ResizeNodeGroupRequest.pb( + node_groups.ResizeNodeGroupRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = node_groups.ResizeNodeGroupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.resize_node_group( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_resize_node_group_rest_bad_request( + transport: str = "rest", request_type=node_groups.ResizeNodeGroupRequest +): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.resize_node_group(request) + + +def test_resize_node_group_rest_flattened(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + size=443, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.resize_node_group(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/regions/*/clusters/*/nodeGroups/*}:resize" + % client.transport._host, + args[1], + ) + + +def test_resize_node_group_rest_flattened_error(transport: str = "rest"): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.resize_node_group( + node_groups.ResizeNodeGroupRequest(), + name="name_value", + size=443, + ) + + +def test_resize_node_group_rest_error(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + node_groups.GetNodeGroupRequest, + dict, + ], +) +def test_get_node_group_rest(request_type): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.NodeGroup( + name="name_value", + roles=[clusters.NodeGroup.Role.DRIVER], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.NodeGroup.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_node_group(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, clusters.NodeGroup) + assert response.name == "name_value" + assert response.roles == [clusters.NodeGroup.Role.DRIVER] + + +def test_get_node_group_rest_required_fields( + request_type=node_groups.GetNodeGroupRequest, +): + transport_class = transports.NodeGroupControllerRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_node_group._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_node_group._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = clusters.NodeGroup() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = clusters.NodeGroup.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_node_group(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_node_group_rest_unset_required_fields(): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_node_group._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_node_group_rest_interceptors(null_interceptor): + transport = transports.NodeGroupControllerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NodeGroupControllerRestInterceptor(), + ) + client = NodeGroupControllerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "post_get_node_group" + ) as post, mock.patch.object( + transports.NodeGroupControllerRestInterceptor, "pre_get_node_group" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = node_groups.GetNodeGroupRequest.pb( + node_groups.GetNodeGroupRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = clusters.NodeGroup.to_json(clusters.NodeGroup()) + + request = node_groups.GetNodeGroupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = clusters.NodeGroup() + + client.get_node_group( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_node_group_rest_bad_request( + transport: str = "rest", request_type=node_groups.GetNodeGroupRequest +): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_node_group(request) + + +def test_get_node_group_rest_flattened(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = clusters.NodeGroup() + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/regions/sample2/clusters/sample3/nodeGroups/sample4" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = clusters.NodeGroup.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_node_group(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/regions/*/clusters/*/nodeGroups/*}" + % client.transport._host, + args[1], + ) + + +def test_get_node_group_rest_flattened_error(transport: str = "rest"): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_node_group( + node_groups.GetNodeGroupRequest(), + name="name_value", + ) + + +def test_get_node_group_rest_error(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.NodeGroupControllerGrpcTransport( @@ -1537,6 +2490,7 @@ def test_transport_get_channel(): [ transports.NodeGroupControllerGrpcTransport, transports.NodeGroupControllerGrpcAsyncIOTransport, + transports.NodeGroupControllerRestTransport, ], ) def test_transport_adc(transport_class): @@ -1551,6 +2505,7 @@ def test_transport_adc(transport_class): "transport_name", [ "grpc", + "rest", ], ) def test_transport_kind(transport_name): @@ -1687,6 +2642,7 @@ def test_node_group_controller_transport_auth_adc(transport_class): [ transports.NodeGroupControllerGrpcTransport, transports.NodeGroupControllerGrpcAsyncIOTransport, + transports.NodeGroupControllerRestTransport, ], ) def test_node_group_controller_transport_auth_gdch_credentials(transport_class): @@ -1786,11 +2742,40 @@ def test_node_group_controller_grpc_transport_client_cert_source_for_mtls( ) +def test_node_group_controller_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NodeGroupControllerRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_node_group_controller_rest_lro_client(): + client = NodeGroupControllerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_node_group_controller_host_no_port(transport_name): @@ -1801,7 +2786,11 @@ def test_node_group_controller_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -1809,6 +2798,7 @@ def test_node_group_controller_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_node_group_controller_host_with_port(transport_name): @@ -1819,7 +2809,39 @@ def test_node_group_controller_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_node_group_controller_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = NodeGroupControllerClient( + credentials=creds1, + transport=transport_name, + ) + client2 = NodeGroupControllerClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_node_group._session + session2 = client2.transport.create_node_group._session + assert session1 != session2 + session1 = client1.transport.resize_node_group._session + session2 = client2.transport.resize_node_group._session + assert session1 != session2 + session1 = client1.transport.get_node_group._session + session2 = client2.transport.get_node_group._session + assert session1 != session2 def test_node_group_controller_grpc_transport_channel(): @@ -2155,6 +3177,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -2172,6 +3195,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: diff --git a/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py b/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py index a746971d..b8d9b273 100644 --- a/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py +++ b/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -111,6 +118,7 @@ def test__get_default_mtls_endpoint(): [ (WorkflowTemplateServiceClient, "grpc"), (WorkflowTemplateServiceAsyncClient, "grpc_asyncio"), + (WorkflowTemplateServiceClient, "rest"), ], ) def test_workflow_template_service_client_from_service_account_info( @@ -126,7 +134,11 @@ def test_workflow_template_service_client_from_service_account_info( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -134,6 +146,7 @@ def test_workflow_template_service_client_from_service_account_info( [ (transports.WorkflowTemplateServiceGrpcTransport, "grpc"), (transports.WorkflowTemplateServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.WorkflowTemplateServiceRestTransport, "rest"), ], ) def test_workflow_template_service_client_service_account_always_use_jwt( @@ -159,6 +172,7 @@ def test_workflow_template_service_client_service_account_always_use_jwt( [ (WorkflowTemplateServiceClient, "grpc"), (WorkflowTemplateServiceAsyncClient, "grpc_asyncio"), + (WorkflowTemplateServiceClient, "rest"), ], ) def test_workflow_template_service_client_from_service_account_file( @@ -181,13 +195,18 @@ def test_workflow_template_service_client_from_service_account_file( assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) def test_workflow_template_service_client_get_transport_class(): transport = WorkflowTemplateServiceClient.get_transport_class() available_transports = [ transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceRestTransport, ] assert transport in available_transports @@ -208,6 +227,11 @@ def test_workflow_template_service_client_get_transport_class(): transports.WorkflowTemplateServiceGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceRestTransport, + "rest", + ), ], ) @mock.patch.object( @@ -363,6 +387,18 @@ def test_workflow_template_service_client_client_options( "grpc_asyncio", "false", ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceRestTransport, + "rest", + "true", + ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceRestTransport, + "rest", + "false", + ), ], ) @mock.patch.object( @@ -568,6 +604,11 @@ def test_workflow_template_service_client_get_mtls_endpoint_and_cert_source( transports.WorkflowTemplateServiceGrpcAsyncIOTransport, "grpc_asyncio", ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceRestTransport, + "rest", + ), ], ) def test_workflow_template_service_client_client_options_scopes( @@ -608,6 +649,12 @@ def test_workflow_template_service_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceRestTransport, + "rest", + None, + ), ], ) def test_workflow_template_service_client_client_options_credentials_file( @@ -2677,162 +2724,3687 @@ async def test_delete_workflow_template_flattened_error_async(): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.CreateWorkflowTemplateRequest, + dict, + ], +) +def test_create_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - with pytest.raises(ValueError): - client = WorkflowTemplateServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["template"] = { + "id": "id_value", + "name": "name_value", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate( + id="id_value", + name="name_value", + version=774, ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = WorkflowTemplateServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_workflow_template(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) + assert response.id == "id_value" + assert response.name == "name_value" + assert response.version == 774 + + +def test_create_workflow_template_rest_required_fields( + request_type=workflow_templates.CreateWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, ) + ) - # It is an error to provide an api_key and a transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = WorkflowTemplateServiceClient( - client_options=options, - transport=transport, - ) - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = WorkflowTemplateServiceClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + unset_fields = transport.create_workflow_template._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "template", + ) ) + ) - # It is an error to provide scopes and a transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), ) - with pytest.raises(ValueError): - client = WorkflowTemplateServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "post_create_workflow_template", + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "pre_create_workflow_template", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.CreateWorkflowTemplateRequest.pb( + workflow_templates.CreateWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = workflow_templates.WorkflowTemplate.to_json( + workflow_templates.WorkflowTemplate() ) + request = workflow_templates.CreateWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = workflow_templates.WorkflowTemplate() -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( + client.create_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_workflow_template_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.CreateWorkflowTemplateRequest, +): + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - client = WorkflowTemplateServiceClient(transport=transport) - assert client.transport is transport + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["template"] = { + "id": "id_value", + "name": "name_value", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.WorkflowTemplateServiceGrpcTransport( + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_workflow_template(request) + + +def test_create_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - channel = transport.grpc_channel - assert channel - transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/workflowTemplates" + % client.transport._host, + args[1], + ) + + +def test_create_workflow_template_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - channel = transport.grpc_channel - assert channel + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_workflow_template( + workflow_templates.CreateWorkflowTemplateRequest(), + parent="parent_value", + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.WorkflowTemplateServiceGrpcTransport, - transports.WorkflowTemplateServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() + +def test_create_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) @pytest.mark.parametrize( - "transport_name", + "request_type", [ - "grpc", + workflow_templates.GetWorkflowTemplateRequest, + dict, ], ) -def test_transport_kind(transport_name): - transport = WorkflowTemplateServiceClient.get_transport_class(transport_name)( +def test_get_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert transport.kind == transport_name + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate( + id="id_value", + name="name_value", + version=774, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_workflow_template(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) + assert response.id == "id_value" + assert response.name == "name_value" + assert response.version == 774 + + +def test_get_workflow_template_rest_required_fields( + request_type=workflow_templates.GetWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_workflow_template._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("version",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. client = WorkflowTemplateServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - assert isinstance( - client.transport, - transports.WorkflowTemplateServiceGrpcTransport, + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials ) + unset_fields = transport.get_workflow_template._get_unset_required_fields({}) + assert set(unset_fields) == (set(("version",)) & set(("name",))) -def test_workflow_template_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.WorkflowTemplateServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json", + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, "post_get_workflow_template" + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, "pre_get_workflow_template" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.GetWorkflowTemplateRequest.pb( + workflow_templates.GetWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = workflow_templates.WorkflowTemplate.to_json( + workflow_templates.WorkflowTemplate() ) + request = workflow_templates.GetWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = workflow_templates.WorkflowTemplate() -def test_workflow_template_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.dataproc_v1.services.workflow_template_service.transports.WorkflowTemplateServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.WorkflowTemplateServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), + client.get_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_workflow_template", - "get_workflow_template", - "instantiate_workflow_template", - "instantiate_inline_workflow_template", - "update_workflow_template", - "list_workflow_templates", - "delete_workflow_template", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) + pre.assert_called_once() + post.assert_called_once() - with pytest.raises(NotImplementedError): - transport.close() - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client +def test_get_workflow_template_rest_bad_request( + transport: str = "rest", request_type=workflow_templates.GetWorkflowTemplateRequest +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_workflow_template(request) + + +def test_get_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/workflowTemplates/*}" + % client.transport._host, + args[1], + ) + + +def test_get_workflow_template_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_workflow_template( + workflow_templates.GetWorkflowTemplateRequest(), + name="name_value", + ) + + +def test_get_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.InstantiateWorkflowTemplateRequest, + dict, + ], +) +def test_instantiate_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.instantiate_workflow_template(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_instantiate_workflow_template_rest_required_fields( + request_type=workflow_templates.InstantiateWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).instantiate_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).instantiate_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.instantiate_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_instantiate_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.instantiate_workflow_template._get_unset_required_fields( + {} + ) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_instantiate_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "post_instantiate_workflow_template", + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "pre_instantiate_workflow_template", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.InstantiateWorkflowTemplateRequest.pb( + workflow_templates.InstantiateWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = workflow_templates.InstantiateWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.instantiate_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_instantiate_workflow_template_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.InstantiateWorkflowTemplateRequest, +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.instantiate_workflow_template(request) + + +def test_instantiate_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + parameters={"key_value": "value_value"}, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.instantiate_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/workflowTemplates/*}:instantiate" + % client.transport._host, + args[1], + ) + + +def test_instantiate_workflow_template_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.instantiate_workflow_template( + workflow_templates.InstantiateWorkflowTemplateRequest(), + name="name_value", + parameters={"key_value": "value_value"}, + ) + + +def test_instantiate_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.InstantiateInlineWorkflowTemplateRequest, + dict, + ], +) +def test_instantiate_inline_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["template"] = { + "id": "id_value", + "name": "name_value", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.instantiate_inline_workflow_template(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_instantiate_inline_workflow_template_rest_required_fields( + request_type=workflow_templates.InstantiateInlineWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).instantiate_inline_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).instantiate_inline_workflow_template._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("request_id",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.instantiate_inline_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_instantiate_inline_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = ( + transport.instantiate_inline_workflow_template._get_unset_required_fields({}) + ) + assert set(unset_fields) == ( + set(("requestId",)) + & set( + ( + "parent", + "template", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_instantiate_inline_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "post_instantiate_inline_workflow_template", + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "pre_instantiate_inline_workflow_template", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.InstantiateInlineWorkflowTemplateRequest.pb( + workflow_templates.InstantiateInlineWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = workflow_templates.InstantiateInlineWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.instantiate_inline_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_instantiate_inline_workflow_template_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.InstantiateInlineWorkflowTemplateRequest, +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["template"] = { + "id": "id_value", + "name": "name_value", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.instantiate_inline_workflow_template(request) + + +def test_instantiate_inline_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.instantiate_inline_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/workflowTemplates:instantiateInline" + % client.transport._host, + args[1], + ) + + +def test_instantiate_inline_workflow_template_rest_flattened_error( + transport: str = "rest", +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.instantiate_inline_workflow_template( + workflow_templates.InstantiateInlineWorkflowTemplateRequest(), + parent="parent_value", + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) + + +def test_instantiate_inline_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.UpdateWorkflowTemplateRequest, + dict, + ], +) +def test_update_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "template": { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + } + request_init["template"] = { + "id": "id_value", + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate( + id="id_value", + name="name_value", + version=774, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_workflow_template(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) + assert response.id == "id_value" + assert response.name == "name_value" + assert response.version == 774 + + +def test_update_workflow_template_rest_required_fields( + request_type=workflow_templates.UpdateWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "put", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_workflow_template._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("template",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "post_update_workflow_template", + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "pre_update_workflow_template", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.UpdateWorkflowTemplateRequest.pb( + workflow_templates.UpdateWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = workflow_templates.WorkflowTemplate.to_json( + workflow_templates.WorkflowTemplate() + ) + + request = workflow_templates.UpdateWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = workflow_templates.WorkflowTemplate() + + client.update_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_workflow_template_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.UpdateWorkflowTemplateRequest, +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "template": { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + } + request_init["template"] = { + "id": "id_value", + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3", + "version": 774, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "labels": {}, + "placement": { + "managed_cluster": { + "cluster_name": "cluster_name_value", + "config": { + "config_bucket": "config_bucket_value", + "temp_bucket": "temp_bucket_value", + "gce_cluster_config": { + "zone_uri": "zone_uri_value", + "network_uri": "network_uri_value", + "subnetwork_uri": "subnetwork_uri_value", + "internal_ip_only": True, + "private_ipv6_google_access": 1, + "service_account": "service_account_value", + "service_account_scopes": [ + "service_account_scopes_value1", + "service_account_scopes_value2", + ], + "tags": ["tags_value1", "tags_value2"], + "metadata": {}, + "reservation_affinity": { + "consume_reservation_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, + "node_group_affinity": { + "node_group_uri": "node_group_uri_value" + }, + "shielded_instance_config": { + "enable_secure_boot": True, + "enable_vtpm": True, + "enable_integrity_monitoring": True, + }, + "confidential_instance_config": { + "enable_confidential_compute": True + }, + }, + "master_config": { + "num_instances": 1399, + "instance_names": [ + "instance_names_value1", + "instance_names_value2", + ], + "image_uri": "image_uri_value", + "machine_type_uri": "machine_type_uri_value", + "disk_config": { + "boot_disk_type": "boot_disk_type_value", + "boot_disk_size_gb": 1792, + "num_local_ssds": 1494, + "local_ssd_interface": "local_ssd_interface_value", + }, + "is_preemptible": True, + "preemptibility": 1, + "managed_group_config": { + "instance_template_name": "instance_template_name_value", + "instance_group_manager_name": "instance_group_manager_name_value", + }, + "accelerators": [ + { + "accelerator_type_uri": "accelerator_type_uri_value", + "accelerator_count": 1805, + } + ], + "min_cpu_platform": "min_cpu_platform_value", + }, + "worker_config": {}, + "secondary_worker_config": {}, + "software_config": { + "image_version": "image_version_value", + "properties": {}, + "optional_components": [5], + }, + "initialization_actions": [ + { + "executable_file": "executable_file_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + } + ], + "encryption_config": { + "gce_pd_kms_key_name": "gce_pd_kms_key_name_value" + }, + "autoscaling_config": {"policy_uri": "policy_uri_value"}, + "security_config": { + "kerberos_config": { + "enable_kerberos": True, + "root_principal_password_uri": "root_principal_password_uri_value", + "kms_key_uri": "kms_key_uri_value", + "keystore_uri": "keystore_uri_value", + "truststore_uri": "truststore_uri_value", + "keystore_password_uri": "keystore_password_uri_value", + "key_password_uri": "key_password_uri_value", + "truststore_password_uri": "truststore_password_uri_value", + "cross_realm_trust_realm": "cross_realm_trust_realm_value", + "cross_realm_trust_kdc": "cross_realm_trust_kdc_value", + "cross_realm_trust_admin_server": "cross_realm_trust_admin_server_value", + "cross_realm_trust_shared_password_uri": "cross_realm_trust_shared_password_uri_value", + "kdc_db_key_uri": "kdc_db_key_uri_value", + "tgt_lifetime_hours": 1933, + "realm": "realm_value", + }, + "identity_config": {"user_service_account_mapping": {}}, + }, + "lifecycle_config": { + "idle_delete_ttl": {}, + "auto_delete_time": {}, + "auto_delete_ttl": {}, + "idle_start_time": {}, + }, + "endpoint_config": { + "http_ports": {}, + "enable_http_port_access": True, + }, + "metastore_config": { + "dataproc_metastore_service": "dataproc_metastore_service_value" + }, + "dataproc_metric_config": { + "metrics": [ + { + "metric_source": 1, + "metric_overrides": [ + "metric_overrides_value1", + "metric_overrides_value2", + ], + } + ] + }, + "auxiliary_node_groups": [ + { + "node_group": { + "name": "name_value", + "roles": [1], + "node_group_config": {}, + "labels": {}, + }, + "node_group_id": "node_group_id_value", + } + ], + }, + "labels": {}, + }, + "cluster_selector": {"zone": "zone_value", "cluster_labels": {}}, + }, + "jobs": [ + { + "step_id": "step_id_value", + "hadoop_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {"driver_log_levels": {}}, + }, + "spark_job": { + "main_jar_file_uri": "main_jar_file_uri_value", + "main_class": "main_class_value", + "args": ["args_value1", "args_value2"], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "pyspark_job": { + "main_python_file_uri": "main_python_file_uri_value", + "args": ["args_value1", "args_value2"], + "python_file_uris": [ + "python_file_uris_value1", + "python_file_uris_value2", + ], + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "hive_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {"queries": ["queries_value1", "queries_value2"]}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + }, + "pig_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "spark_r_job": { + "main_r_file_uri": "main_r_file_uri_value", + "args": ["args_value1", "args_value2"], + "file_uris": ["file_uris_value1", "file_uris_value2"], + "archive_uris": ["archive_uris_value1", "archive_uris_value2"], + "properties": {}, + "logging_config": {}, + }, + "spark_sql_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "script_variables": {}, + "properties": {}, + "jar_file_uris": ["jar_file_uris_value1", "jar_file_uris_value2"], + "logging_config": {}, + }, + "presto_job": { + "query_file_uri": "query_file_uri_value", + "query_list": {}, + "continue_on_failure": True, + "output_format": "output_format_value", + "client_tags": ["client_tags_value1", "client_tags_value2"], + "properties": {}, + "logging_config": {}, + }, + "labels": {}, + "scheduling": { + "max_failures_per_hour": 2243, + "max_failures_total": 1923, + }, + "prerequisite_step_ids": [ + "prerequisite_step_ids_value1", + "prerequisite_step_ids_value2", + ], + } + ], + "parameters": [ + { + "name": "name_value", + "fields": ["fields_value1", "fields_value2"], + "description": "description_value", + "validation": { + "regex": {"regexes": ["regexes_value1", "regexes_value2"]}, + "values": {"values": ["values_value1", "values_value2"]}, + }, + } + ], + "dag_timeout": {}, + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_workflow_template(request) + + +def test_update_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.WorkflowTemplate() + + # get arguments that satisfy an http rule for this method + sample_request = { + "template": { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + } + + # get truthy value for each flattened field + mock_args = dict( + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.WorkflowTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{template.name=projects/*/locations/*/workflowTemplates/*}" + % client.transport._host, + args[1], + ) + + +def test_update_workflow_template_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_workflow_template( + workflow_templates.UpdateWorkflowTemplateRequest(), + template=workflow_templates.WorkflowTemplate(id="id_value"), + ) + + +def test_update_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.ListWorkflowTemplatesRequest, + dict, + ], +) +def test_list_workflow_templates_rest(request_type): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.ListWorkflowTemplatesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.ListWorkflowTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_workflow_templates(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListWorkflowTemplatesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_workflow_templates_rest_required_fields( + request_type=workflow_templates.ListWorkflowTemplatesRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_workflow_templates._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_workflow_templates._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = workflow_templates.ListWorkflowTemplatesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = workflow_templates.ListWorkflowTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_workflow_templates(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_workflow_templates_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_workflow_templates._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_workflow_templates_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "post_list_workflow_templates", + ) as post, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, "pre_list_workflow_templates" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = workflow_templates.ListWorkflowTemplatesRequest.pb( + workflow_templates.ListWorkflowTemplatesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + workflow_templates.ListWorkflowTemplatesResponse.to_json( + workflow_templates.ListWorkflowTemplatesResponse() + ) + ) + + request = workflow_templates.ListWorkflowTemplatesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = workflow_templates.ListWorkflowTemplatesResponse() + + client.list_workflow_templates( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_workflow_templates_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.ListWorkflowTemplatesRequest, +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_workflow_templates(request) + + +def test_list_workflow_templates_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = workflow_templates.ListWorkflowTemplatesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = workflow_templates.ListWorkflowTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_workflow_templates(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/workflowTemplates" + % client.transport._host, + args[1], + ) + + +def test_list_workflow_templates_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_workflow_templates( + workflow_templates.ListWorkflowTemplatesRequest(), + parent="parent_value", + ) + + +def test_list_workflow_templates_rest_pager(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + workflow_templates.ListWorkflowTemplatesResponse( + templates=[ + workflow_templates.WorkflowTemplate(), + workflow_templates.WorkflowTemplate(), + workflow_templates.WorkflowTemplate(), + ], + next_page_token="abc", + ), + workflow_templates.ListWorkflowTemplatesResponse( + templates=[], + next_page_token="def", + ), + workflow_templates.ListWorkflowTemplatesResponse( + templates=[ + workflow_templates.WorkflowTemplate(), + ], + next_page_token="ghi", + ), + workflow_templates.ListWorkflowTemplatesResponse( + templates=[ + workflow_templates.WorkflowTemplate(), + workflow_templates.WorkflowTemplate(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + workflow_templates.ListWorkflowTemplatesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_workflow_templates(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, workflow_templates.WorkflowTemplate) for i in results) + + pages = list(client.list_workflow_templates(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + workflow_templates.DeleteWorkflowTemplateRequest, + dict, + ], +) +def test_delete_workflow_template_rest(request_type): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_workflow_template(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_workflow_template_rest_required_fields( + request_type=workflow_templates.DeleteWorkflowTemplateRequest, +): + transport_class = transports.WorkflowTemplateServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_workflow_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_workflow_template._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("version",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_workflow_template(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_workflow_template_rest_unset_required_fields(): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_workflow_template._get_unset_required_fields({}) + assert set(unset_fields) == (set(("version",)) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_workflow_template_rest_interceptors(null_interceptor): + transport = transports.WorkflowTemplateServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.WorkflowTemplateServiceRestInterceptor(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.WorkflowTemplateServiceRestInterceptor, + "pre_delete_workflow_template", + ) as pre: + pre.assert_not_called() + pb_message = workflow_templates.DeleteWorkflowTemplateRequest.pb( + workflow_templates.DeleteWorkflowTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = workflow_templates.DeleteWorkflowTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_workflow_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_delete_workflow_template_rest_bad_request( + transport: str = "rest", + request_type=workflow_templates.DeleteWorkflowTemplateRequest, +): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_workflow_template(request) + + +def test_delete_workflow_template_rest_flattened(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/workflowTemplates/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_workflow_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/locations/*/workflowTemplates/*}" + % client.transport._host, + args[1], + ) + + +def test_delete_workflow_template_rest_flattened_error(transport: str = "rest"): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_workflow_template( + workflow_templates.DeleteWorkflowTemplateRequest(), + name="name_value", + ) + + +def test_delete_workflow_template_rest_error(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = WorkflowTemplateServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = WorkflowTemplateServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = WorkflowTemplateServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = WorkflowTemplateServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = WorkflowTemplateServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.WorkflowTemplateServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + transports.WorkflowTemplateServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) +def test_transport_kind(transport_name): + transport = WorkflowTemplateServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.WorkflowTemplateServiceGrpcTransport, + ) + + +def test_workflow_template_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.WorkflowTemplateServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_workflow_template_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.dataproc_v1.services.workflow_template_service.transports.WorkflowTemplateServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.WorkflowTemplateServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_workflow_template", + "get_workflow_template", + "instantiate_workflow_template", + "instantiate_inline_workflow_template", + "update_workflow_template", + "list_workflow_templates", + "delete_workflow_template", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client # Catch all for all remaining methods and properties remainder = [ @@ -2912,6 +6484,7 @@ def test_workflow_template_service_transport_auth_adc(transport_class): [ transports.WorkflowTemplateServiceGrpcTransport, transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + transports.WorkflowTemplateServiceRestTransport, ], ) def test_workflow_template_service_transport_auth_gdch_credentials(transport_class): @@ -3013,11 +6586,40 @@ def test_workflow_template_service_grpc_transport_client_cert_source_for_mtls( ) +def test_workflow_template_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.WorkflowTemplateServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_workflow_template_service_rest_lro_client(): + client = WorkflowTemplateServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_workflow_template_service_host_no_port(transport_name): @@ -3028,7 +6630,11 @@ def test_workflow_template_service_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:443") + assert client.transport._host == ( + "dataproc.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com" + ) @pytest.mark.parametrize( @@ -3036,6 +6642,7 @@ def test_workflow_template_service_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_workflow_template_service_host_with_port(transport_name): @@ -3046,7 +6653,51 @@ def test_workflow_template_service_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("dataproc.googleapis.com:8000") + assert client.transport._host == ( + "dataproc.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://dataproc.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_workflow_template_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = WorkflowTemplateServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = WorkflowTemplateServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_workflow_template._session + session2 = client2.transport.create_workflow_template._session + assert session1 != session2 + session1 = client1.transport.get_workflow_template._session + session2 = client2.transport.get_workflow_template._session + assert session1 != session2 + session1 = client1.transport.instantiate_workflow_template._session + session2 = client2.transport.instantiate_workflow_template._session + assert session1 != session2 + session1 = client1.transport.instantiate_inline_workflow_template._session + session2 = client2.transport.instantiate_inline_workflow_template._session + assert session1 != session2 + session1 = client1.transport.update_workflow_template._session + session2 = client2.transport.update_workflow_template._session + assert session1 != session2 + session1 = client1.transport.list_workflow_templates._session + session2 = client2.transport.list_workflow_templates._session + assert session1 != session2 + session1 = client1.transport.delete_workflow_template._session + session2 = client2.transport.delete_workflow_template._session + assert session1 != session2 def test_workflow_template_service_grpc_transport_channel(): @@ -3436,6 +7087,7 @@ async def test_transport_close_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -3453,6 +7105,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: