Skip to content

Commit c36909e

Browse files
committed
feat: supoort bigquery.vector_search()
1 parent 56cbd3b commit c36909e

File tree

5 files changed

+525
-8
lines changed

5 files changed

+525
-8
lines changed

bigframes/bigquery/__init__.py

Lines changed: 142 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@
2020

2121
from __future__ import annotations
2222

23-
import typing
23+
from typing import Literal, Optional, Union
2424

25+
import bigframes.bigquery.utils as utils
2526
import bigframes.constants as constants
2627
import bigframes.core.groupby as groupby
2728
import bigframes.operations as ops
2829
import bigframes.operations.aggregations as agg_ops
30+
import bigframes.pandas as bpd
2931

30-
if typing.TYPE_CHECKING:
31-
import bigframes.dataframe as dataframe
32-
import bigframes.series as series
3332

34-
35-
def array_length(series: series.Series) -> series.Series:
33+
def array_length(series: bpd.Series) -> bpd.Series:
3634
"""Compute the length of each array element in the Series.
3735
3836
**Examples:**
@@ -69,7 +67,7 @@ def array_length(series: series.Series) -> series.Series:
6967

7068
def array_agg(
7169
obj: groupby.SeriesGroupBy | groupby.DataFrameGroupBy,
72-
) -> series.Series | dataframe.DataFrame:
70+
) -> bpd.Series | bpd.DataFrame:
7371
"""Group data and create arrays from selected columns, omitting NULLs to avoid
7472
BigQuery errors (NULLs not allowed in arrays).
7573
@@ -120,7 +118,7 @@ def array_agg(
120118
)
121119

122120

123-
def array_to_string(series: series.Series, delimiter: str) -> series.Series:
121+
def array_to_string(series: bpd.Series, delimiter: str) -> bpd.Series:
124122
"""Converts array elements within a Series into delimited strings.
125123
126124
**Examples:**
@@ -148,3 +146,139 @@ def array_to_string(series: series.Series, delimiter: str) -> series.Series:
148146
149147
"""
150148
return series._apply_unary_op(ops.ArrayToStringOp(delimiter=delimiter))
149+
150+
151+
def vector_search(
152+
base_table: str,
153+
column_to_search: str,
154+
query: Union[bpd.DataFrame, bpd.Series],
155+
*,
156+
query_column_to_search: Optional[str] = None,
157+
top_k: Optional[int] = 10,
158+
distance_type: Literal["euclidean", "cosine"] = "euclidean",
159+
fraction_lists_to_search: Optional[float] = None,
160+
use_brute_force: bool = False,
161+
) -> bpd.DataFrame:
162+
"""
163+
Conduct vector search to earch embeddings to find semantically similar entities.
164+
165+
**Examples:**
166+
167+
168+
>>> import bigframes.pandas as bpd
169+
>>> import bigframes.bigquery as bbq
170+
>>> bpd.options.display.progress_bar = None
171+
172+
DataFrame embeddings for which to find nearest neighbors:
173+
174+
>>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"],
175+
... "embedding": [[1.0, 2.0], [3.0, 5.2]]})
176+
>>> bbq.vector_search(
177+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
178+
... column_to_search="my_embedding",
179+
... query=search_query,
180+
... top_k=2)
181+
query_id embedding id my_embedding distance
182+
1 cat [3. 5.2] 5 [5. 5.4] 2.009975
183+
0 dog [1. 2.] 1 [1. 2.] 0.0
184+
0 dog [1. 2.] 4 [1. 3.2] 1.2
185+
1 cat [3. 5.2] 2 [2. 4.] 1.56205
186+
<BLANKLINE>
187+
[4 rows x 5 columns]
188+
189+
Series embeddings for which to find nearest neighbors:
190+
191+
>>> search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]],
192+
... index=["dog", "cat"],
193+
... name="embedding")
194+
>>> bbq.vector_search(
195+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
196+
... column_to_search="my_embedding",
197+
... query=search_query,
198+
... top_k=2)
199+
embedding id my_embedding distance
200+
dog [1. 2.] 1 [1. 2.] 0.0
201+
cat [3. 5.2] 5 [5. 5.4] 2.009975
202+
dog [1. 2.] 4 [1. 3.2] 1.2
203+
cat [3. 5.2] 2 [2. 4.] 1.56205
204+
<BLANKLINE>
205+
[4 rows x 4 columns]
206+
207+
You can specify the name of the column in the query DataFrame embeddings and distance type:
208+
209+
>>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"],
210+
... "embedding": [[1.0, 2.0], [3.0, 5.2]],
211+
... "another_embedding": [[0.7, 2.2], [3.3, 5.2]]})
212+
>>> bbq.vector_search(
213+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
214+
... column_to_search="my_embedding",
215+
... query=search_query,
216+
... distance_type="cosine",
217+
... query_column_to_search="another_embedding",
218+
... top_k=2)
219+
query_id embedding another_embedding id my_embedding distance
220+
1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181
221+
0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013
222+
1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181
223+
0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697
224+
<BLANKLINE>
225+
[4 rows x 6 columns]
226+
227+
Args:
228+
base_table (str):
229+
The table to search for nearest neighbor embeddings.
230+
column_to_search (groupby.SeriesGroupBy | groupby.DataFrameGroupBy):
231+
The name of the base table column to search for nearest neighbor embeddings.
232+
The column must have a type of ``ARRAY<FLOAT64>``. All elements in the array must be non-NULL.
233+
query (bigframes.dataframe.DataFrame | bigframes.dataframe.Series):
234+
A Series or DataFrame that provides the embeddings for which to find nearest neighbors.
235+
query_column_to_search (str):
236+
Specifies the name of the column in the query that contains the embeddings for which to
237+
find nearest neighbors. The column must have a type of ``ARRAY<FLOAT64>``. All elements in
238+
the array must be non-NULL and all values in the column must have the same array dimensions
239+
as the values in the ``column_to_search`` column. Can only be set when query is a DataFrame.
240+
top_k (int, default 10):
241+
Sepecifies the number of nearest neighbors to return. Default to 10.
242+
distance_type (str, defalt "euclidean"):
243+
Specifies the type of metric to use to compute the distance between two vectors.
244+
Possible values are "euclidean" and "cosine". Default to "euclidean".
245+
fraction_lists_to_search (float, range in [0.0, 1.0]):
246+
Specifies the percentage of lists to search. Specifying a higher percentage leads to
247+
higher recall and slower performance, and the converse is true when specifying a lower
248+
percentage. It is only used when a vector index is also used. You can only specify
249+
``fraction_lists_to_search`` when ``use_brute_force`` is set to False.
250+
use_brute_force (bool, default False):
251+
Determines whether to use brute force search by skipping the vector index if one is available.
252+
Default to False.
253+
254+
Returns:
255+
bigframes.dataframe.DataFrame: A DataFrame containing vector search result.
256+
"""
257+
if not fraction_lists_to_search and use_brute_force is True:
258+
raise ValueError(
259+
"You can't specify fraction_lists_to_search when use_brute_force is set to True."
260+
)
261+
if isinstance(query, bpd.Series) and query_column_to_search is not None:
262+
raise ValueError(
263+
"You can't specify query_column_to_search when query is a Series."
264+
)
265+
## (TODO: ashleyxu. Support options in vector search.)
266+
if fraction_lists_to_search is not None or use_brute_force is True:
267+
raise NotImplementedError(
268+
f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}"
269+
)
270+
options = {
271+
"base_table": base_table,
272+
"column_to_search": column_to_search,
273+
"query_column_to_search": query_column_to_search,
274+
"distance_type": distance_type,
275+
"top_k": top_k,
276+
"fraction_lists_to_search": fraction_lists_to_search,
277+
"use_brute_force": use_brute_force,
278+
}
279+
280+
df = utils.apply_sql(
281+
query,
282+
options, # type:ignore
283+
)
284+
return df

bigframes/bigquery/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Iterable, Mapping, Union
16+
17+
import bigframes.ml.utils as utils
18+
import bigframes.pandas as bpd
19+
20+
21+
def create_vector_search_sql(
22+
sql_string: str,
23+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
24+
) -> str:
25+
"""Encode the VECTOR SEARCH statement for BigQuery Vector Search."""
26+
27+
base_table = options["base_table"]
28+
column_to_search = options["column_to_search"]
29+
distance_type = options["distance_type"]
30+
top_k = options["top_k"]
31+
query_column_to_search = options.get("query_column_to_search", None)
32+
33+
if query_column_to_search is not None:
34+
query_str = f"""
35+
SELECT
36+
query.*,
37+
base.*,
38+
distance,
39+
FROM VECTOR_SEARCH(
40+
TABLE `{base_table}`,
41+
"{column_to_search}",
42+
({sql_string}),
43+
"{query_column_to_search}",
44+
distance_type => "{distance_type}",
45+
top_k => {top_k}
46+
)
47+
"""
48+
else:
49+
query_str = f"""
50+
SELECT
51+
query.*,
52+
base.*,
53+
distance,
54+
FROM VECTOR_SEARCH(
55+
TABLE `{base_table}`,
56+
"{column_to_search}",
57+
({sql_string}),
58+
distance_type => "{distance_type}",
59+
top_k => {top_k}
60+
)
61+
"""
62+
return query_str
63+
64+
65+
def apply_sql(
66+
query: Union[bpd.DataFrame, bpd.Series],
67+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
68+
) -> bpd.DataFrame:
69+
"""Helper to wrap a dataframe in a SQL query, keeping the index intact.
70+
71+
Args:
72+
query (bigframes.dataframe.DataFrame):
73+
The dataframe to be wrapped.
74+
"""
75+
(query,) = utils.convert_to_dataframe(query)
76+
sql_string, index_col_ids, index_labels = query._to_sql_query(include_index=True)
77+
78+
sql = create_vector_search_sql(sql_string=sql_string, options=options)
79+
if index_col_ids is not None:
80+
df = query._session.read_gbq(sql, index_col=index_col_ids)
81+
else:
82+
df = query._session.read_gbq(sql)
83+
df.index.names = index_labels
84+
85+
return df

0 commit comments

Comments
 (0)