Skip to content

feat: allow functions decorated with @bpd.remote_function to execute locally #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 26, 2024
10 changes: 6 additions & 4 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,11 +856,12 @@ def to_timestamp_op_impl(x: ibis_types.Value, op: ops.ToTimestampOp):

@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
if not hasattr(op.func, "bigframes_remote_function"):
ibis_node = getattr(op.func, "ibis_node", None)
if ibis_node is None:
raise TypeError(
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
)
x_transformed = op.func(x)
x_transformed = ibis_node(x)
if not op.apply_on_null:
x_transformed = ibis.case().when(x.isnull(), x).else_(x_transformed).end()
return x_transformed
Expand Down Expand Up @@ -1342,11 +1343,12 @@ def minimum_impl(
def binary_remote_function_op_impl(
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
):
if not hasattr(op.func, "bigframes_remote_function"):
ibis_node = getattr(op.func, "ibis_node", None)
if ibis_node is None:
raise TypeError(
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
)
x_transformed = op.func(x, y)
x_transformed = ibis_node(x, y)
return x_transformed


Expand Down
52 changes: 36 additions & 16 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,11 @@ def remote_function(

bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
def wrapper(func):
nonlocal input_types, output_type

if not callable(f):
raise TypeError("f must be callable, got {}".format(f))
if not callable(func):
raise TypeError("f must be callable, got {}".format(func))

if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
Expand All @@ -1028,7 +1028,7 @@ def wrapper(f):
signature_kwargs = {}

signature = inspect.signature(
f,
func,
**signature_kwargs,
)

Expand Down Expand Up @@ -1089,8 +1089,23 @@ def wrapper(f):
session=session, # type: ignore
)

# In the unlikely case where the user is trying to re-deploy the same
# function, cleanup the attributes we add below, first. This prevents
# the pickle from having dependencies that might not otherwise be
# present such as ibis or pandas.
def try_delattr(attr):
try:
delattr(func, attr)
except AttributeError:
pass

try_delattr("bigframes_cloud_function")
try_delattr("bigframes_remote_function")
try_delattr("output_dtype")
try_delattr("ibis_node")

rf_name, cf_name = remote_function_client.provision_bq_remote_function(
f,
func,
ibis_signature.input_types,
ibis_signature.output_type,
reuse,
Expand All @@ -1105,19 +1120,20 @@ def wrapper(f):

# TODO: Move ibis logic to compiler step
node = ibis.udf.scalar.builtin(
f,
func,
name=rf_name,
schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}",
signature=(ibis_signature.input_types, ibis_signature.output_type),
)
node.bigframes_cloud_function = (
func.bigframes_cloud_function = (
remote_function_client.get_cloud_function_fully_qualified_name(cf_name)
)
node.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
func.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
ibis_signature.output_type
)
return node
func.ibis_node = node
return func

return wrapper

Expand Down Expand Up @@ -1168,19 +1184,23 @@ def read_gbq_function(

# The name "args" conflicts with the Ibis operator, so we use
# non-standard names for the arguments here.
def node(*ignored_args, **ignored_kwargs):
def func(*ignored_args, **ignored_kwargs):
f"""Remote function {str(routine_ref)}."""
# TODO(swast): Construct an ibis client from bigquery_client and
# execute node via a query.

# TODO: Move ibis logic to compiler step
node.__name__ = routine_ref.routine_id
func.__name__ = routine_ref.routine_id

node = ibis.udf.scalar.builtin(
node,
func,
name=routine_ref.routine_id,
schema=f"{routine_ref.project}.{routine_ref.dataset_id}",
signature=(ibis_signature.input_types, ibis_signature.output_type),
)
node.bigframes_remote_function = str(routine_ref) # type: ignore
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
func.bigframes_remote_function = str(routine_ref) # type: ignore
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
ibis_signature.output_type
)
return node
func.ibis_node = node # type: ignore
return func
5 changes: 4 additions & 1 deletion tests/system/large/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def test_remote_function_stringify_with_ibis(
def stringify(x):
return f"I got {x}"

# Function should work locally.
assert stringify(42) == "I got 42"

_, dataset_name, table_name = scalars_table_id.split(".")
if not ibis_client.dataset:
ibis_client.dataset = dataset_name
Expand All @@ -205,7 +208,7 @@ def stringify(x):
pandas_df_orig = bigquery_client.query(sql).to_dataframe()

col = table[col_name]
col_2x = stringify(col).name("int64_str_col")
col_2x = stringify.ibis_node(col).name("int64_str_col")
table = table.mutate([col_2x])
sql = table.compile()
pandas_df_new = bigquery_client.query(sql).to_dataframe()
Expand Down
70 changes: 48 additions & 22 deletions tests/system/small/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bq_cf_connection_location_project(bigquery_client) -> str:

@pytest.fixture(scope="module")
def bq_cf_connection_location_project_mismatched() -> str:
"""Pre-created BQ connection in the migframes-metrics project in US location,
"""Pre-created BQ connection in the bigframes-metrics project in US location,
in format PROJECT_ID.LOCATION.CONNECTION_NAME, used to invoke cloud function.

$ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn
Expand Down Expand Up @@ -108,11 +108,15 @@ def test_remote_function_direct_no_session_param(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

assert square.bigframes_remote_function
assert square.bigframes_cloud_function
# Function should still work normally.
assert square(2) == 4

# Function should have extra metadata attached for remote execution.
assert hasattr(square, "bigframes_remote_function")
assert hasattr(square, "bigframes_cloud_function")
assert hasattr(square, "ibis_node")

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -161,8 +165,10 @@ def test_remote_function_direct_no_session_param_location_specified(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -197,7 +203,10 @@ def test_remote_function_direct_no_session_param_location_mismatched(
dataset_id_permanent,
bq_cf_connection_location_mismatched,
):
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("The location does not match BigQuery connection location:"),
):

@rf.remote_function(
[int],
Expand All @@ -212,7 +221,8 @@ def test_remote_function_direct_no_session_param_location_mismatched(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the location of the
# connection doesn't match the location of the dataset.
return x * x # pragma: NO COVER


Expand All @@ -239,8 +249,10 @@ def test_remote_function_direct_no_session_param_location_project_specified(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -275,7 +287,12 @@ def test_remote_function_direct_no_session_param_project_mismatched(
dataset_id_permanent,
bq_cf_connection_location_project_mismatched,
):
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape(
"The project_id does not match BigQuery connection gcp_project_id:"
),
):

@rf.remote_function(
[int],
Expand All @@ -290,7 +307,8 @@ def test_remote_function_direct_no_session_param_project_mismatched(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the project of the
# connection doesn't match the project of the dataset.
return x * x # pragma: NO COVER


Expand All @@ -302,8 +320,10 @@ def test_remote_function_direct_session_param(session_with_bq_connection, scalar
session=session_with_bq_connection,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -340,8 +360,10 @@ def test_remote_function_via_session_default(session_with_bq_connection, scalars
# cloud function would be common and quickly reused.
@session_with_bq_connection.remote_function([int], int)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -380,8 +402,10 @@ def test_remote_function_via_session_with_overrides(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -508,7 +532,7 @@ def test_skip_bq_connection_check(dataset_id_permanent):

@session.remote_function([int], int, dataset=dataset_id_permanent)
def add_one(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the connection doesn't exist.
return x + 1 # pragma: NO COVER


Expand Down Expand Up @@ -546,8 +570,10 @@ def test_read_gbq_function_like_original(
reuse=True,
)
def square1(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square1(2) == 4

square2 = rf.read_gbq_function(
function_name=square1.bigframes_remote_function,
Expand Down