Skip to content

Commit 461e9e0

Browse files
authored
feat: support list output for managed function (#1457)
* feat: support list output for managed function * add test decorator * resolve comments
1 parent ff46f5a commit 461e9e0

File tree

6 files changed

+345
-20
lines changed

6 files changed

+345
-20
lines changed

bigframes/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4199,11 +4199,13 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
41994199
udf_input_dtypes = getattr(func, "input_dtypes")
42004200
if len(udf_input_dtypes) != len(self.columns):
42014201
raise ValueError(
4202-
f"Remote function takes {len(udf_input_dtypes)} arguments but DataFrame has {len(self.columns)} columns."
4202+
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
4203+
f" arguments but DataFrame has {len(self.columns)} columns."
42034204
)
42044205
if udf_input_dtypes != tuple(self.dtypes.to_list()):
42054206
raise ValueError(
4206-
f"Remote function takes arguments of types {udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
4207+
f"BigFrames BigQuery function takes arguments of types "
4208+
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
42074209
)
42084210

42094211
series_list = [self[col] for col in self.columns]

bigframes/functions/_function_session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ def wrapper(func):
892892
func = cloudpickle.loads(cloudpickle.dumps(func))
893893

894894
self._try_delattr(func, "bigframes_bigquery_function")
895+
self._try_delattr(func, "bigframes_bigquery_function_output_dtype")
895896
self._try_delattr(func, "input_dtypes")
896897
self._try_delattr(func, "output_dtype")
897898
self._try_delattr(func, "is_row_processor")
@@ -951,6 +952,10 @@ def wrapper(func):
951952
ibis_signature.output_type
952953
)
953954
)
955+
# Managed function directly supports certain output types which are
956+
# not supported in remote function (e.g. list output). Thus no more
957+
# processing for 'bigframes_bigquery_function_output_dtype'.
958+
func.bigframes_bigquery_function_output_dtype = func.output_dtype
954959
func.is_row_processor = is_row_processor
955960
func.ibis_node = node
956961

bigframes/operations/remote_function_ops.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ def expensive(self) -> bool:
2929
return True
3030

3131
def output_type(self, *input_types):
32-
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
32+
# The output dtype should be set to a valid Dtype by @udf decorator,
33+
# @remote_function decorator, or read_gbq_function method.
3334
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
3435
return self.func.bigframes_bigquery_function_output_dtype
35-
else:
36-
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")
36+
37+
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")
3738

3839

3940
@dataclasses.dataclass(frozen=True)
@@ -46,11 +47,12 @@ def expensive(self) -> bool:
4647
return True
4748

4849
def output_type(self, *input_types):
49-
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
50+
# The output dtype should be set to a valid Dtype by @udf decorator,
51+
# @remote_function decorator, or read_gbq_function method.
5052
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
5153
return self.func.bigframes_bigquery_function_output_dtype
52-
else:
53-
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")
54+
55+
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")
5456

5557

5658
@dataclasses.dataclass(frozen=True)
@@ -63,8 +65,9 @@ def expensive(self) -> bool:
6365
return True
6466

6567
def output_type(self, *input_types):
66-
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
68+
# The output dtype should be set to a valid Dtype by @udf decorator,
69+
# @remote_function decorator, or read_gbq_function method.
6770
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
6871
return self.func.bigframes_bigquery_function_output_dtype
69-
else:
70-
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")
72+
73+
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")

tests/system/large/functions/test_managed_function.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import pandas
16+
import pyarrow
1617
import pytest
1718

19+
import bigframes
1820
from bigframes.functions import _function_session as bff_session
1921
from bigframes.functions._utils import get_python_version
2022
import bigframes.pandas as bpd
@@ -164,3 +166,161 @@ def func(x, y):
164166
cleanup_function_assets(
165167
session.bqclient, session.cloudfunctionsclient, managed_func
166168
)
169+
170+
171+
@pytest.mark.parametrize(
172+
"array_dtype",
173+
[
174+
bool,
175+
int,
176+
float,
177+
str,
178+
],
179+
)
180+
@pytest.mark.skipif(
181+
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
182+
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
183+
)
184+
def test_managed_function_array_output(session, scalars_dfs, dataset_id, array_dtype):
185+
try:
186+
187+
@session.udf(dataset=dataset_id)
188+
def featurize(x: int) -> list[array_dtype]: # type: ignore
189+
return [array_dtype(i) for i in [x, x + 1, x + 2]]
190+
191+
scalars_df, scalars_pandas_df = scalars_dfs
192+
193+
bf_int64_col = scalars_df["int64_too"]
194+
bf_result = bf_int64_col.apply(featurize).to_pandas()
195+
196+
pd_int64_col = scalars_pandas_df["int64_too"]
197+
pd_result = pd_int64_col.apply(featurize)
198+
199+
# Ignore any dtype disparity.
200+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
201+
202+
finally:
203+
# Clean up the gcp assets created for the managed function.
204+
cleanup_function_assets(
205+
featurize, session.bqclient, session.cloudfunctionsclient
206+
)
207+
208+
209+
@pytest.mark.skipif(
210+
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
211+
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
212+
)
213+
def test_managed_function_binop_array_output(session, scalars_dfs, dataset_id):
214+
try:
215+
216+
def func(x, y):
217+
return [len(x), abs(y % 4)]
218+
219+
managed_func = session.udf(
220+
input_types=[str, int],
221+
output_type=list[int],
222+
dataset=dataset_id,
223+
)(func)
224+
225+
scalars_df, scalars_pandas_df = scalars_dfs
226+
227+
scalars_df = scalars_df.dropna()
228+
scalars_pandas_df = scalars_pandas_df.dropna()
229+
bf_result = (
230+
scalars_df["string_col"]
231+
.combine(scalars_df["int64_col"], managed_func)
232+
.to_pandas()
233+
)
234+
pd_result = scalars_pandas_df["string_col"].combine(
235+
scalars_pandas_df["int64_col"], func
236+
)
237+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
238+
finally:
239+
# Clean up the gcp assets created for the managed function.
240+
cleanup_function_assets(
241+
managed_func, session.bqclient, session.cloudfunctionsclient
242+
)
243+
244+
245+
@pytest.mark.skipif(
246+
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
247+
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
248+
)
249+
def test_manage_function_df_apply_axis_1_array_output(session):
250+
bf_df = bigframes.dataframe.DataFrame(
251+
{
252+
"Id": [1, 2, 3],
253+
"Age": [22.5, 23, 23.5],
254+
"Name": ["alpha", "beta", "gamma"],
255+
}
256+
)
257+
258+
expected_dtypes = (
259+
bigframes.dtypes.INT_DTYPE,
260+
bigframes.dtypes.FLOAT_DTYPE,
261+
bigframes.dtypes.STRING_DTYPE,
262+
)
263+
264+
# Assert the dataframe dtypes.
265+
assert tuple(bf_df.dtypes) == expected_dtypes
266+
267+
try:
268+
269+
@session.udf(input_types=[int, float, str], output_type=list[str])
270+
def foo(x, y, z):
271+
return [str(x), str(y), z]
272+
273+
assert getattr(foo, "is_row_processor") is False
274+
assert getattr(foo, "input_dtypes") == expected_dtypes
275+
assert getattr(foo, "output_dtype") == pandas.ArrowDtype(
276+
pyarrow.list_(
277+
bigframes.dtypes.bigframes_dtype_to_arrow_dtype(
278+
bigframes.dtypes.STRING_DTYPE
279+
)
280+
)
281+
)
282+
assert getattr(foo, "output_dtype") == getattr(
283+
foo, "bigframes_bigquery_function_output_dtype"
284+
)
285+
286+
# Fails to apply on dataframe with incompatible number of columns.
287+
with pytest.raises(
288+
ValueError,
289+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
290+
):
291+
bf_df[["Id", "Age"]].apply(foo, axis=1)
292+
293+
with pytest.raises(
294+
ValueError,
295+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
296+
):
297+
bf_df.assign(Country="lalaland").apply(foo, axis=1)
298+
299+
# Fails to apply on dataframe with incompatible column datatypes.
300+
with pytest.raises(
301+
ValueError,
302+
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
303+
):
304+
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
305+
306+
# Successfully applies to dataframe with matching number of columns.
307+
# and their datatypes.
308+
bf_result = bf_df.apply(foo, axis=1).to_pandas()
309+
310+
# Since this scenario is not pandas-like, let's handcraft the
311+
# expected result.
312+
expected_result = pandas.Series(
313+
[
314+
["1", "22.5", "alpha"],
315+
["2", "23.0", "beta"],
316+
["3", "23.5", "gamma"],
317+
]
318+
)
319+
320+
pandas.testing.assert_series_equal(
321+
expected_result, bf_result, check_dtype=False, check_index_type=False
322+
)
323+
324+
finally:
325+
# Clean up the gcp assets created for the managed function.
326+
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)

tests/system/large/functions/test_remote_function.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,19 +2085,19 @@ def foo(x, y, z):
20852085
# Fails to apply on dataframe with incompatible number of columns
20862086
with pytest.raises(
20872087
ValueError,
2088-
match="^Remote function takes 3 arguments but DataFrame has 2 columns\\.$",
2088+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
20892089
):
20902090
bf_df[["Id", "Age"]].apply(foo, axis=1)
20912091
with pytest.raises(
20922092
ValueError,
2093-
match="^Remote function takes 3 arguments but DataFrame has 4 columns\\.$",
2093+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
20942094
):
20952095
bf_df.assign(Country="lalaland").apply(foo, axis=1)
20962096

20972097
# Fails to apply on dataframe with incompatible column datatypes
20982098
with pytest.raises(
20992099
ValueError,
2100-
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
2100+
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
21012101
):
21022102
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
21032103

@@ -2171,19 +2171,19 @@ def foo(x, y, z):
21712171
# Fails to apply on dataframe with incompatible number of columns
21722172
with pytest.raises(
21732173
ValueError,
2174-
match="^Remote function takes 3 arguments but DataFrame has 2 columns\\.$",
2174+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
21752175
):
21762176
bf_df[["Id", "Age"]].apply(foo, axis=1)
21772177
with pytest.raises(
21782178
ValueError,
2179-
match="^Remote function takes 3 arguments but DataFrame has 4 columns\\.$",
2179+
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
21802180
):
21812181
bf_df.assign(Country="lalaland").apply(foo, axis=1)
21822182

21832183
# Fails to apply on dataframe with incompatible column datatypes
21842184
with pytest.raises(
21852185
ValueError,
2186-
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
2186+
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
21872187
):
21882188
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
21892189

@@ -2240,19 +2240,19 @@ def foo(x):
22402240
# Fails to apply on dataframe with incompatible number of columns
22412241
with pytest.raises(
22422242
ValueError,
2243-
match="^Remote function takes 1 arguments but DataFrame has 0 columns\\.$",
2243+
match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 0 columns\\.$",
22442244
):
22452245
bf_df[[]].apply(foo, axis=1)
22462246
with pytest.raises(
22472247
ValueError,
2248-
match="^Remote function takes 1 arguments but DataFrame has 2 columns\\.$",
2248+
match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 2 columns\\.$",
22492249
):
22502250
bf_df.assign(Country="lalaland").apply(foo, axis=1)
22512251

22522252
# Fails to apply on dataframe with incompatible column datatypes
22532253
with pytest.raises(
22542254
ValueError,
2255-
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
2255+
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
22562256
):
22572257
bf_df.assign(Id=bf_df["Id"].astype("Float64")).apply(foo, axis=1)
22582258

0 commit comments

Comments
 (0)