Skip to content

Commit 133ac6b

Browse files
authored
feat!: add structured output for ai map, ai filter and ai join (#1746)
* add structured output for ai map, ai filter and ai join * fix mypy * fix test * update notebook
1 parent 80aad9a commit 133ac6b

File tree

4 files changed

+291
-92
lines changed

4 files changed

+291
-92
lines changed

bigframes/operations/ai.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import re
1618
import typing
17-
from typing import List, Optional
19+
from typing import Dict, List, Optional
1820
import warnings
1921

2022
import numpy as np
@@ -34,7 +36,13 @@ def __init__(self, df) -> None:
3436

3537
self._df: bigframes.dataframe.DataFrame = df
3638

37-
def filter(self, instruction: str, model, ground_with_google_search: bool = False):
39+
def filter(
40+
self,
41+
instruction: str,
42+
model,
43+
ground_with_google_search: bool = False,
44+
attach_logprobs: bool = False,
45+
):
3846
"""
3947
Filters the DataFrame with the semantics of the user instruction.
4048
@@ -74,6 +82,10 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
7482
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
7583
The default is `False`.
7684
85+
attach_logprobs (bool, default False):
86+
Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
87+
of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
88+
7789
Returns:
7890
bigframes.pandas.DataFrame: DataFrame filtered by the instruction.
7991
@@ -82,72 +94,27 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
8294
ValueError: when the instruction refers to a non-existing column, or when no
8395
columns are referred to.
8496
"""
85-
import bigframes.dataframe
86-
import bigframes.series
8797

88-
self._validate_model(model)
89-
columns = self._parse_columns(instruction)
90-
for column in columns:
91-
if column not in self._df.columns:
92-
raise ValueError(f"Column {column} not found.")
98+
answer_col = "answer"
9399

94-
if ground_with_google_search:
95-
msg = exceptions.format_message(
96-
"Enables Grounding with Google Search may impact billing cost. See pricing "
97-
"details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
98-
)
99-
warnings.warn(msg, category=UserWarning)
100-
101-
self._confirm_operation(len(self._df))
102-
103-
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
104-
has_blob_column = False
105-
for column in columns:
106-
if df[column].dtype == dtypes.OBJ_REF_DTYPE:
107-
# Don't cast blob columns to string
108-
has_blob_column = True
109-
continue
110-
111-
if df[column].dtype != dtypes.STRING_DTYPE:
112-
df[column] = df[column].astype(dtypes.STRING_DTYPE)
113-
114-
user_instruction = self._format_instruction(instruction, columns)
115-
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
116-
117-
if has_blob_column:
118-
results = typing.cast(
119-
bigframes.dataframe.DataFrame,
120-
model.predict(
121-
df,
122-
prompt=self._make_multimodel_prompt(
123-
df, columns, user_instruction, output_instruction
124-
),
125-
temperature=0.0,
126-
ground_with_google_search=ground_with_google_search,
127-
),
128-
)
129-
else:
130-
results = typing.cast(
131-
bigframes.dataframe.DataFrame,
132-
model.predict(
133-
self._make_text_prompt(
134-
df, columns, user_instruction, output_instruction
135-
),
136-
temperature=0.0,
137-
ground_with_google_search=ground_with_google_search,
138-
),
139-
)
100+
output_schema = {answer_col: "bool"}
101+
result = self.map(
102+
instruction,
103+
model,
104+
output_schema,
105+
ground_with_google_search,
106+
attach_logprobs,
107+
)
140108

141-
return self._df[
142-
results["ml_generate_text_llm_result"].str.lower().str.contains("true")
143-
]
109+
return result[result[answer_col]].drop(answer_col, axis=1)
144110

145111
def map(
146112
self,
147113
instruction: str,
148-
output_column: str,
149114
model,
115+
output_schema: Dict[str, str] | None = None,
150116
ground_with_google_search: bool = False,
117+
attach_logprobs=False,
151118
):
152119
"""
153120
Maps the DataFrame with the semantics of the user instruction.
@@ -163,7 +130,7 @@ def map(
163130
>>> model = llm.GeminiTextGenerator(model_name="gemini-2.0-flash-001")
164131
165132
>>> df = bpd.DataFrame({"ingredient_1": ["Burger Bun", "Soy Bean"], "ingredient_2": ["Beef Patty", "Bittern"]})
166-
>>> df.ai.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", output_column="food", model=model)
133+
>>> df.ai.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", model=model, output_schema={"food": "string"})
167134
ingredient_1 ingredient_2 food
168135
0 Burger Bun Beef Patty Burger
169136
<BLANKLINE>
@@ -180,12 +147,14 @@ def map(
180147
in the instructions like:
181148
"Get the ingredients of {food}."
182149
183-
output_column (str):
184-
The column name of the mapping result.
185-
186150
model (bigframes.ml.llm.GeminiTextGenerator):
187151
A GeminiTextGenerator provided by Bigframes ML package.
188152
153+
output_schema (Dict[str, str] or None, default None):
154+
The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>.
155+
Supported types are int64, float64, bool, string, array<type> and struct<column type>. If None, generate string result under the column
156+
"ml_generate_text_llm_result".
157+
189158
ground_with_google_search (bool, default False):
190159
Enables Grounding with Google Search for the GeminiTextGenerator model.
191160
When set to True, the model incorporates relevant information from Google
@@ -194,6 +163,11 @@ def map(
194163
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
195164
The default is `False`.
196165
166+
attach_logprobs (bool, default False):
167+
Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
168+
of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
169+
170+
197171
Returns:
198172
bigframes.pandas.DataFrame: DataFrame with attached mapping results.
199173
@@ -236,6 +210,9 @@ def map(
236210
"Based on the provided contenxt, answer the following instruction:"
237211
)
238212

213+
if output_schema is None:
214+
output_schema = {"ml_generate_text_llm_result": "string"}
215+
239216
if has_blob_column:
240217
results = typing.cast(
241218
bigframes.series.Series,
@@ -246,7 +223,8 @@ def map(
246223
),
247224
temperature=0.0,
248225
ground_with_google_search=ground_with_google_search,
249-
)["ml_generate_text_llm_result"],
226+
output_schema=output_schema,
227+
),
250228
)
251229
else:
252230
results = typing.cast(
@@ -257,19 +235,36 @@ def map(
257235
),
258236
temperature=0.0,
259237
ground_with_google_search=ground_with_google_search,
260-
)["ml_generate_text_llm_result"],
238+
output_schema=output_schema,
239+
),
240+
)
241+
242+
attach_columns = [results[col] for col, _ in output_schema.items()]
243+
244+
def extract_logprob(s: bigframes.series.Series) -> bigframes.series.Series:
245+
from bigframes import bigquery as bbq
246+
247+
logprob_jsons = bbq.json_extract_array(s, "$.candidates").list[0]
248+
logprobs = bbq.json_extract(logprob_jsons, "$.avg_logprobs").astype(
249+
"Float64"
261250
)
251+
logprobs.name = "logprob"
252+
return logprobs
253+
254+
if attach_logprobs:
255+
attach_columns.append(extract_logprob(results["full_response"]))
262256

263257
from bigframes.core.reshape.api import concat
264258

265-
return concat([self._df, results.rename(output_column)], axis=1)
259+
return concat([self._df, *attach_columns], axis=1)
266260

267261
def join(
268262
self,
269263
other,
270264
instruction: str,
271265
model,
272266
ground_with_google_search: bool = False,
267+
attach_logprobs=False,
273268
):
274269
"""
275270
Joines two dataframes by applying the instruction over each pair of rows from
@@ -313,10 +308,6 @@ def join(
313308
model (bigframes.ml.llm.GeminiTextGenerator):
314309
A GeminiTextGenerator provided by Bigframes ML package.
315310
316-
max_rows (int, default 1000):
317-
The maximum number of rows allowed to be sent to the model per call. If the result is too large, the method
318-
call will end early with an error.
319-
320311
ground_with_google_search (bool, default False):
321312
Enables Grounding with Google Search for the GeminiTextGenerator model.
322313
When set to True, the model incorporates relevant information from Google
@@ -325,6 +316,10 @@ def join(
325316
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
326317
The default is `False`.
327318
319+
attach_logprobs (bool, default False):
320+
Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
321+
of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
322+
328323
Returns:
329324
bigframes.pandas.DataFrame: The joined dataframe.
330325
@@ -400,7 +395,10 @@ def join(
400395
joined_df = self._df.merge(other, how="cross", suffixes=("_left", "_right"))
401396

402397
return joined_df.ai.filter(
403-
instruction, model, ground_with_google_search=ground_with_google_search
398+
instruction,
399+
model,
400+
ground_with_google_search=ground_with_google_search,
401+
attach_logprobs=attach_logprobs,
404402
).reset_index(drop=True)
405403

406404
def search(

0 commit comments

Comments
 (0)