diff --git a/google/cloud/bigquery/job/base.py b/google/cloud/bigquery/job/base.py index 9b7ddb82d..7576fc9aa 100644 --- a/google/cloud/bigquery/job/base.py +++ b/google/cloud/bigquery/job/base.py @@ -1044,8 +1044,7 @@ def result( # type: ignore # (incompatible with supertype) if self.state is None: self._begin(retry=retry, timeout=timeout) - kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry} - return super(_AsyncJob, self).result(timeout=timeout, **kwargs) + return super(_AsyncJob, self).result(timeout=timeout, retry=retry) def cancelled(self): """Check if the job has been cancelled. diff --git a/tests/unit/test_job_retry.py b/tests/unit/test_job_retry.py index 7343fed3d..fa55e8f6a 100644 --- a/tests/unit/test_job_retry.py +++ b/tests/unit/test_job_retry.py @@ -615,3 +615,80 @@ def test_query_and_wait_retries_job_for_DDL_queries(global_time_lock): _, kwargs = calls[3] assert kwargs["method"] == "POST" assert kwargs["path"] == query_request_path + + +@pytest.mark.parametrize( + "result_retry_param", + [ + pytest.param( + {}, + id="default retry {}", + ), + pytest.param( + { + "retry": google.cloud.bigquery.retry.DEFAULT_RETRY.with_timeout( + timeout=10.0 + ) + }, + id="custom retry object with timeout 10.0", + ), + ], +) +def test_retry_load_job_result(result_retry_param, PROJECT, DS_ID): + from google.cloud.bigquery.dataset import DatasetReference + from google.cloud.bigquery.job.load import LoadJob + import google.cloud.bigquery.retry + + client = make_client() + conn = client._connection = make_connection( + dict( + status=dict(state="RUNNING"), + jobReference={"jobId": "id_1"}, + ), + google.api_core.exceptions.ServiceUnavailable("retry me"), + dict( + status=dict(state="DONE"), + jobReference={"jobId": "id_1"}, + statistics={"load": {"outputRows": 1}}, + ), + ) + + table_ref = DatasetReference(project=PROJECT, dataset_id=DS_ID).table("new_table") + job = LoadJob("id_1", source_uris=None, destination=table_ref, client=client) + + with mock.patch.object( + client, "_call_api", wraps=client._call_api + ) as wrapped_call_api: + result = job.result(**result_retry_param) + + assert job.state == "DONE" + assert result.output_rows == 1 + + # Check that _call_api was called multiple times due to retry + assert wrapped_call_api.call_count > 1 + + # Verify the retry object used in the calls to _call_api + expected_retry = result_retry_param.get( + "retry", google.cloud.bigquery.retry.DEFAULT_RETRY + ) + + for call in wrapped_call_api.mock_calls: + name, args, kwargs = call + # The retry object is the first positional argument to _call_api + called_retry = args[0] + + # We only care about the calls made during the job.result() polling + if kwargs.get("method") == "GET" and "jobs/id_1" in kwargs.get("path", ""): + assert called_retry._predicate == expected_retry._predicate + assert called_retry._initial == expected_retry._initial + assert called_retry._maximum == expected_retry._maximum + assert called_retry._multiplier == expected_retry._multiplier + assert called_retry._deadline == expected_retry._deadline + if "retry" in result_retry_param: + # Specifically check the timeout for the custom retry case + assert called_retry._timeout == 10.0 + else: + assert called_retry._timeout == expected_retry._timeout + + # The number of api_request calls should still be 3 + assert conn.api_request.call_count == 3