diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 6db0fbb4..5f490fb8 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -22,7 +22,10 @@ from base64 import b64encode from typing import Any, Optional, Dict from dataclasses import dataclass + from google.auth.compute_engine import Credentials as ComputeEngineCredentials +from google.auth.exceptions import RefreshError +from google.auth.transport import requests as google_auth_requests import requests import firebase_admin @@ -100,6 +103,12 @@ def __init__(self, app: App): 'GOOGLE_CLOUD_PROJECT environment variable.') self._credential = app.credential.get_credential() + try: + # Refresh the credential to ensure all attributes (e.g. service_account_email) + # are populated, preventing cold start errors. + self._credential.refresh(google_auth_requests.Request()) + except RefreshError as err: + raise ValueError(f'Initial credential refresh failed: {err}') from err self._http_client = _http_client.JsonHttpClient(credential=self._credential) def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: @@ -289,10 +298,9 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str # Meaning that it's credential should be a Compute Engine Credential. if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): - id_token = self._credential.token task.http_request['headers'] = \ - {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} + {**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'} # Delete oidc token del task.http_request['oidc_token'] else: diff --git a/tests/test_functions.py b/tests/test_functions.py index 52e92c1b..95356344 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -124,6 +124,10 @@ def test_task_enqueue(self): assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + def test_task_enqueue_with_extension(self): resource_name = ( 'projects/test-project/locations/us-central1/queues/' @@ -142,6 +146,59 @@ def test_task_enqueue_with_extension(self): assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + + def test_task_enqueue_compute_engine(self): + app = firebase_admin.initialize_app( + testutils.MockComputeEngineCredential(), + options={'projectId': 'test-project'}, + name='test-project-gce') + _, recorder = self._instrument_functions_service(app) + queue = functions.task_queue('test-function-name', app=app) + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _DEFAULT_REQUEST_URL + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token' + expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + assert task_id == 'test-task-id' + + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + + def test_task_enqueue_with_extension_compute_engine(self): + resource_name = ( + 'projects/test-project/locations/us-central1/queues/' + 'ext-test-extension-id-test-function-name/tasks' + ) + extension_response = json.dumps({'name': resource_name + '/test-task-id'}) + app = firebase_admin.initialize_app( + testutils.MockComputeEngineCredential(), + options={'projectId': 'test-project'}, + name='test-project-gce-extensions') + _, recorder = self._instrument_functions_service(app, payload=extension_response) + queue = functions.task_queue('test-function-name', 'test-extension-id', app) + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _CLOUD_TASKS_URL + resource_name + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token' + expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + assert task_id == 'test-task-id' + + task = json.loads(recorder[0].body.decode())['task'] + assert 'oidc_token' not in task['http_request'] + assert task['http_request']['headers'] == { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer mock-compute-engine-token'} + def test_task_delete(self): _, recorder = self._instrument_functions_service() queue = functions.task_queue('test-function-name') diff --git a/tests/testutils.py b/tests/testutils.py index 598a929b..d331d231 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -118,10 +118,11 @@ class MockGoogleCredential(credentials.Credentials): """A mock Google authentication credential.""" def refresh(self, request): self.token = 'mock-token' + self._service_account_email = "mock-email" @property def service_account_email(self): - return 'mock-email' + return self._service_account_email # Simulate x-goog-api-client modification in credential refresh def _metric_header_for_usage(self): @@ -141,6 +142,10 @@ class MockGoogleComputeEngineCredential(compute_engine.Credentials): """A mock Compute Engine credential""" def refresh(self, request): self.token = 'mock-compute-engine-token' + self._service_account_email = 'mock-gce-email' + + def _metric_header_for_usage(self): + return 'mock-gce-cred-metric-tag' class MockComputeEngineCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation."""