Skip to content

fix(functions): Refresh credentials before enqueueing first task #907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions firebase_admin/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from base64 import b64encode
from typing import Any, Optional, Dict
from dataclasses import dataclass

from google.auth.compute_engine import Credentials as ComputeEngineCredentials
from google.auth.exceptions import RefreshError
from google.auth.transport import requests as google_auth_requests

import requests
import firebase_admin
Expand Down Expand Up @@ -100,6 +103,12 @@ def __init__(self, app: App):
'GOOGLE_CLOUD_PROJECT environment variable.')

self._credential = app.credential.get_credential()
try:
# Refresh the credential to ensure all attributes (e.g. service_account_email)
# are populated, preventing cold start errors.
self._credential.refresh(google_auth_requests.Request())
except RefreshError as err:
raise ValueError(f'Initial credential refresh failed: {err}') from err
self._http_client = _http_client.JsonHttpClient(credential=self._credential)

def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue:
Expand Down Expand Up @@ -289,10 +298,9 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
# Meaning that it's credential should be a Compute Engine Credential.
if _Validators.is_non_empty_string(extension_id) and \
isinstance(self._credential, ComputeEngineCredentials):

id_token = self._credential.token
task.http_request['headers'] = \
{**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'}
{**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'}
# Delete oidc token
del task.http_request['oidc_token']
else:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def test_task_enqueue(self):
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension(self):
resource_name = (
'projects/test-project/locations/us-central1/queues/'
Expand All @@ -142,6 +146,59 @@ def test_task_enqueue_with_extension(self):
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_compute_engine(self):
app = firebase_admin.initialize_app(
testutils.MockComputeEngineCredential(),
options={'projectId': 'test-project'},
name='test-project-gce')
_, recorder = self._instrument_functions_service(app)
queue = functions.task_queue('test-function-name', app=app)
task_id = queue.enqueue(_DEFAULT_DATA)
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == _DEFAULT_REQUEST_URL
assert recorder[0].headers['Content-Type'] == 'application/json'
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension_compute_engine(self):
resource_name = (
'projects/test-project/locations/us-central1/queues/'
'ext-test-extension-id-test-function-name/tasks'
)
extension_response = json.dumps({'name': resource_name + '/test-task-id'})
app = firebase_admin.initialize_app(
testutils.MockComputeEngineCredential(),
options={'projectId': 'test-project'},
name='test-project-gce-extensions')
_, recorder = self._instrument_functions_service(app, payload=extension_response)
queue = functions.task_queue('test-function-name', 'test-extension-id', app)
task_id = queue.enqueue(_DEFAULT_DATA)
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == _CLOUD_TASKS_URL + resource_name
assert recorder[0].headers['Content-Type'] == 'application/json'
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert 'oidc_token' not in task['http_request']
assert task['http_request']['headers'] == {
'Content-Type': 'application/json',
'Authorization': 'Bearer mock-compute-engine-token'}

def test_task_delete(self):
_, recorder = self._instrument_functions_service()
queue = functions.task_queue('test-function-name')
Expand Down
7 changes: 6 additions & 1 deletion tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ class MockGoogleCredential(credentials.Credentials):
"""A mock Google authentication credential."""
def refresh(self, request):
self.token = 'mock-token'
self._service_account_email = "mock-email"

@property
def service_account_email(self):
return 'mock-email'
return self._service_account_email

# Simulate x-goog-api-client modification in credential refresh
def _metric_header_for_usage(self):
Expand All @@ -141,6 +142,10 @@ class MockGoogleComputeEngineCredential(compute_engine.Credentials):
"""A mock Compute Engine credential"""
def refresh(self, request):
self.token = 'mock-compute-engine-token'
self._service_account_email = 'mock-gce-email'

def _metric_header_for_usage(self):
return 'mock-gce-cred-metric-tag'

class MockComputeEngineCredential(firebase_admin.credentials.Base):
"""A mock Firebase credential implementation."""
Expand Down