Skip to content

feat: support type annotations to supply input and output types to @remote_function decorator #717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 25, 2024
94 changes: 69 additions & 25 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
import sys
import tempfile
import textwrap
from typing import cast, List, NamedTuple, Optional, Sequence, TYPE_CHECKING, Union
from typing import (
Any,
cast,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
import warnings

import ibis
Expand Down Expand Up @@ -736,8 +746,8 @@ def get_routine_reference(
# which has moved as @js to the ibis package
# https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/udf/__init__.py
def remote_function(
input_types: Union[type, Sequence[type]],
output_type: type,
input_types: Union[None, type, Sequence[type]] = None,
output_type: Optional[type] = None,
session: Optional[Session] = None,
bigquery_client: Optional[bigquery.Client] = None,
bigquery_connection_client: Optional[
Expand Down Expand Up @@ -801,11 +811,11 @@ def remote_function(
`$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`.

Args:
input_types (type or sequence(type)):
input_types (None, type, or sequence(type)):
For scalar user defined function it should be the input type or
sequence of input types. For row processing user defined function,
type `Series` should be specified.
output_type (type):
output_type (Optional[type]):
Data type of the output in the user defined function.
session (bigframes.Session, Optional):
BigQuery DataFrames session to use for getting default project,
Expand Down Expand Up @@ -908,27 +918,10 @@ def remote_function(
service(s) that are on a VPC network. See for more details
https://cloud.google.com/functions/docs/networking/connecting-vpc.
"""
is_row_processor = False

import bigframes.series
import bigframes.session

if input_types == bigframes.series.Series:
warnings.warn(
"input_types=Series scenario is in preview.",
stacklevel=1,
category=bigframes.exceptions.PreviewWarning,
)

# we will model the row as a json serialized string containing the data
# and the metadata representing the row
input_types = [str]
is_row_processor = True
elif isinstance(input_types, type):
input_types = [input_types]

# Some defaults may be used from the session if not provided otherwise
import bigframes.pandas as bpd
import bigframes.series
import bigframes.session

session = cast(bigframes.session.Session, session or bpd.get_global_session())

Expand Down Expand Up @@ -1021,10 +1014,61 @@ def remote_function(
bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
nonlocal input_types, output_type

if not callable(f):
raise TypeError("f must be callable, got {}".format(f))

signature = inspect.signature(f)
if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
# corresponding type objects. Need Python 3.10 for eval_str parameter.
# https://docs.python.org/3/library/inspect.html#inspect.signature
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
else:
signature_kwargs = {}

signature = inspect.signature(
f,
**signature_kwargs,
)

# Try to get input types via type annotations.
if input_types is None:
input_types = []
for parameter in signature.parameters.values():
if (param_type := parameter.annotation) is inspect.Signature.empty:
raise ValueError(
"'input_types' was not set and parameter "
f"'{parameter.name}' is missing a type annotation. "
"Types are required to use @remote_function."
)
input_types.append(param_type)

if output_type is None:
if (output_type := signature.return_annotation) is inspect.Signature.empty:
raise ValueError(
"'output_type' was not set and function is missing a "
"return type annotation. Types are required to use "
"@remote_function."
)

# The function will actually be receiving a pandas Series, but allow both
# BigQuery DataFrames and pandas object types for compatibility.
is_row_processor = False
if input_types == bigframes.series.Series or input_types == pandas.Series:
warnings.warn(
"input_types=Series scenario is in preview.",
stacklevel=1,
category=bigframes.exceptions.PreviewWarning,
)

# we will model the row as a json serialized string containing the data
# and the metadata representing the row
input_types = [str]
is_row_processor = True
elif isinstance(input_types, type):
input_types = [input_types]

# TODO(b/340898611): fix type error
ibis_signature = ibis_signature_from_python_signature(
signature, input_types, output_type # type: ignore
Expand Down
102 changes: 51 additions & 51 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import math
import pathlib
import textwrap
import traceback
import typing
from typing import Dict, Generator, Optional
import warnings

import google.api_core.exceptions
import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -1097,54 +1097,54 @@ def cleanup_cloud_functions(session, cloudfunctions_client, dataset_id_permanent
session.bqclient, dataset_id_permanent
)
delete_count = 0
for cloud_function in tests.system.utils.get_cloud_functions(
cloudfunctions_client,
session.bqclient.project,
session.bqclient.location,
name_prefix="bigframes-",
):
# Ignore bigframes cloud functions referred by the remote functions in
# the permanent dataset
if cloud_function.service_config.uri in permanent_endpoints:
continue

# Ignore the functions less than one day old
age = datetime.now() - datetime.fromtimestamp(
cloud_function.update_time.timestamp()
)
if age.days <= 0:
continue

# Go ahead and delete
try:
tests.system.utils.delete_cloud_function(
cloudfunctions_client, cloud_function.name
try:
for cloud_function in tests.system.utils.get_cloud_functions(
cloudfunctions_client,
session.bqclient.project,
session.bqclient.location,
name_prefix="bigframes-",
):
# Ignore bigframes cloud functions referred by the remote functions in
# the permanent dataset
if cloud_function.service_config.uri in permanent_endpoints:
continue

# Ignore the functions less than one day old
age = datetime.now() - datetime.fromtimestamp(
cloud_function.update_time.timestamp()
)
delete_count += 1
if delete_count >= MAX_NUM_FUNCTIONS_TO_DELETE_PER_SESSION:
break
except google.api_core.exceptions.NotFound:
# This can happen when multiple pytest sessions are running in
# parallel. Two or more sessions may discover the same cloud
# function, but only one of them would be able to delete it
# successfully, while the other instance will run into this
# exception. Ignore this exception.
pass
except Exception as exc:
# Don't fail the tests for unknown exceptions.
#
# This can happen if we are hitting GCP limits, e.g.
# google.api_core.exceptions.ResourceExhausted: 429 Quota exceeded
# for quota metric 'Per project mutation requests' and limit
# 'Per project mutation requests per minute per region' of service
# 'cloudfunctions.googleapis.com' for consumer
# 'project_number:1084210331973'.
# [reason: "RATE_LIMIT_EXCEEDED" domain: "googleapis.com" ...
#
# It can also happen occasionally with
# google.api_core.exceptions.ServiceUnavailable when there is some
# backend flakiness.
#
# Let's stop further clean up and leave it to later.
warnings.warn(f"Cloud functions cleanup failed: {str(exc)}")
break
if age.days <= 0:
continue

# Go ahead and delete
try:
tests.system.utils.delete_cloud_function(
cloudfunctions_client, cloud_function.name
)
delete_count += 1
if delete_count >= MAX_NUM_FUNCTIONS_TO_DELETE_PER_SESSION:
break
except google.api_core.exceptions.NotFound:
# This can happen when multiple pytest sessions are running in
# parallel. Two or more sessions may discover the same cloud
# function, but only one of them would be able to delete it
# successfully, while the other instance will run into this
# exception. Ignore this exception.
pass
except Exception as exc:
# Don't fail the tests for unknown exceptions.
#
# This can happen if we are hitting GCP limits, e.g.
# google.api_core.exceptions.ResourceExhausted: 429 Quota exceeded
# for quota metric 'Per project mutation requests' and limit
# 'Per project mutation requests per minute per region' of service
# 'cloudfunctions.googleapis.com' for consumer
# 'project_number:1084210331973'.
# [reason: "RATE_LIMIT_EXCEEDED" domain: "googleapis.com" ...
#
# It can also happen occasionally with
# google.api_core.exceptions.ServiceUnavailable when there is some
# backend flakiness.
#
# Let's stop further clean up and leave it to later.
traceback.print_exception(exc)
8 changes: 4 additions & 4 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_bigquery_session(
session_id: str = "abcxyz",
table_schema: Sequence[google.cloud.bigquery.SchemaField] = TEST_SCHEMA,
anonymous_dataset: Optional[google.cloud.bigquery.DatasetReference] = None,
location: str = "test-region",
) -> bigframes.Session:
credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
Expand All @@ -53,11 +54,12 @@ def create_bigquery_session(
if bqclient is None:
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.location = location

# Mock the location.
table = mock.create_autospec(google.cloud.bigquery.Table, instance=True)
table._properties = {}
type(table).location = mock.PropertyMock(return_value="test-region")
type(table).location = mock.PropertyMock(return_value=location)
type(table).schema = mock.PropertyMock(return_value=table_schema)
type(table).reference = mock.PropertyMock(
return_value=anonymous_dataset.table("test_table")
Expand Down Expand Up @@ -93,9 +95,7 @@ def query_mock(query, *args, **kwargs):
type(clients_provider).bqclient = mock.PropertyMock(return_value=bqclient)
clients_provider._credentials = credentials

bqoptions = bigframes.BigQueryOptions(
credentials=credentials, location="test-region"
)
bqoptions = bigframes.BigQueryOptions(credentials=credentials, location=location)
session = bigframes.Session(context=bqoptions, clients_provider=clients_provider)
return session

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def all_session_methods():
[(method_name,) for method_name in all_session_methods()],
)
def test_method_matches_session(method_name: str):
if sys.version_info <= (3, 10):
if sys.version_info < (3, 10):
pytest.skip(
"Need Python 3.10 to reconcile deferred annotations."
) # pragma: no cover
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
from ibis.expr import datatypes as ibis_types
import pytest

import bigframes.dtypes
import bigframes.functions.remote_function
from tests.unit import resources


def test_supported_types_correspond():
Expand All @@ -29,3 +32,39 @@ def test_supported_types_correspond():
}

assert ibis_types_from_python == ibis_types_from_bigquery


def test_missing_input_types():
session = resources.create_bigquery_session()
remote_function_decorator = bigframes.functions.remote_function.remote_function(
session=session
)

def function_without_parameter_annotations(myparam) -> str:
return str(myparam)

assert function_without_parameter_annotations(42) == "42"

with pytest.raises(
ValueError,
match="'input_types' was not set .* 'myparam' is missing a type annotation",
):
remote_function_decorator(function_without_parameter_annotations)


def test_missing_output_type():
session = resources.create_bigquery_session()
remote_function_decorator = bigframes.functions.remote_function.remote_function(
session=session
)

def function_without_return_annotation(myparam: int):
return str(myparam)

assert function_without_return_annotation(42) == "42"

with pytest.raises(
ValueError,
match="'output_type' was not set .* missing a return type annotation",
):
remote_function_decorator(function_without_return_annotation)
16 changes: 10 additions & 6 deletions third_party/bigframes_vendored/pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3916,8 +3916,8 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
to potentially reuse a previously deployed ``remote_function`` from
the same user defined function.

>>> @bpd.remote_function(int, float, reuse=False)
... def minutes_to_hours(x):
>>> @bpd.remote_function(reuse=False)
... def minutes_to_hours(x: int) -> float:
... return x/60

>>> df_minutes = bpd.DataFrame(
Expand Down Expand Up @@ -4238,6 +4238,7 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
**Examples:**

>>> import bigframes.pandas as bpd
>>> import pandas as pd
>>> bpd.options.display.progress_bar = None

>>> df = bpd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
Expand All @@ -4259,16 +4260,19 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
[2 rows x 2 columns]

You could apply a user defined function to every row of the DataFrame by
creating a remote function out of it, and using it with `axis=1`.
creating a remote function out of it, and using it with `axis=1`. Within
the function, each row is passed as a ``pandas.Series``. It is recommended
to select only the necessary columns before calling `apply()`. Note: This
feature is currently in **preview**.

>>> @bpd.remote_function(bpd.Series, int, reuse=False)
... def foo(row):
>>> @bpd.remote_function(reuse=False)
... def foo(row: pd.Series) -> int:
... result = 1
... result += row["col1"]
... result += row["col2"]*row["col2"]
... return result

>>> df.apply(foo, axis=1)
>>> df[["col1", "col2"]].apply(foo, axis=1)
0 11
1 19
dtype: Int64
Expand Down
Loading