Skip to content

Commit dad66fd

Browse files
ashleyxuutswast
andauthored
feat: support bigquery.vector_search() (#736)
* feat: supoort bigquery.vector_search() * minor fix * address comments * docstring fix * address comments * small fix * add docstring clarification * Update bigframes/bigquery/__init__.py --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 60f13e7 commit dad66fd

File tree

4 files changed

+413
-1
lines changed

4 files changed

+413
-1
lines changed

bigframes/bigquery/__init__.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
from __future__ import annotations
2222

2323
import typing
24+
from typing import Literal, Optional, Union
2425

2526
import bigframes.constants as constants
2627
import bigframes.core.groupby as groupby
28+
import bigframes.core.sql
29+
import bigframes.ml.utils as utils
2730
import bigframes.operations as ops
2831
import bigframes.operations.aggregations as agg_ops
32+
import bigframes.series
2933

3034
if typing.TYPE_CHECKING:
3135
import bigframes.dataframe as dataframe
@@ -148,3 +152,153 @@ def array_to_string(series: series.Series, delimiter: str) -> series.Series:
148152
149153
"""
150154
return series._apply_unary_op(ops.ArrayToStringOp(delimiter=delimiter))
155+
156+
157+
def vector_search(
158+
base_table: str,
159+
column_to_search: str,
160+
query: Union[dataframe.DataFrame, series.Series],
161+
*,
162+
query_column_to_search: Optional[str] = None,
163+
top_k: Optional[int] = 10,
164+
distance_type: Literal["euclidean", "cosine"] = "euclidean",
165+
fraction_lists_to_search: Optional[float] = None,
166+
use_brute_force: bool = False,
167+
) -> dataframe.DataFrame:
168+
"""
169+
Conduct vector search which searches embeddings to find semantically similar entities.
170+
171+
**Examples:**
172+
173+
174+
>>> import bigframes.pandas as bpd
175+
>>> import bigframes.bigquery as bbq
176+
>>> bpd.options.display.progress_bar = None
177+
178+
DataFrame embeddings for which to find nearest neighbors. The ``ARRAY<FLOAT64>`` column
179+
is used as the search query:
180+
181+
>>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"],
182+
... "embedding": [[1.0, 2.0], [3.0, 5.2]]})
183+
>>> bbq.vector_search(
184+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
185+
... column_to_search="my_embedding",
186+
... query=search_query,
187+
... top_k=2)
188+
query_id embedding id my_embedding distance
189+
1 cat [3. 5.2] 5 [5. 5.4] 2.009975
190+
0 dog [1. 2.] 1 [1. 2.] 0.0
191+
0 dog [1. 2.] 4 [1. 3.2] 1.2
192+
1 cat [3. 5.2] 2 [2. 4.] 1.56205
193+
<BLANKLINE>
194+
[4 rows x 5 columns]
195+
196+
Series embeddings for which to find nearest neighbors:
197+
198+
>>> search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]],
199+
... index=["dog", "cat"],
200+
... name="embedding")
201+
>>> bbq.vector_search(
202+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
203+
... column_to_search="my_embedding",
204+
... query=search_query,
205+
... top_k=2)
206+
embedding id my_embedding distance
207+
dog [1. 2.] 1 [1. 2.] 0.0
208+
cat [3. 5.2] 5 [5. 5.4] 2.009975
209+
dog [1. 2.] 4 [1. 3.2] 1.2
210+
cat [3. 5.2] 2 [2. 4.] 1.56205
211+
<BLANKLINE>
212+
[4 rows x 4 columns]
213+
214+
You can specify the name of the column in the query DataFrame embeddings and distance type.
215+
If you specify query_column_to_search_value, it will use the provided column which contains
216+
the embeddings for which to find nearest neighbors. Otherwiese, it uses the column_to_search value.
217+
218+
>>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"],
219+
... "embedding": [[1.0, 2.0], [3.0, 5.2]],
220+
... "another_embedding": [[0.7, 2.2], [3.3, 5.2]]})
221+
>>> bbq.vector_search(
222+
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
223+
... column_to_search="my_embedding",
224+
... query=search_query,
225+
... distance_type="cosine",
226+
... query_column_to_search="another_embedding",
227+
... top_k=2)
228+
query_id embedding another_embedding id my_embedding distance
229+
1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181
230+
0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013
231+
1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181
232+
0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697
233+
<BLANKLINE>
234+
[4 rows x 6 columns]
235+
236+
Args:
237+
base_table (str):
238+
The table to search for nearest neighbor embeddings.
239+
column_to_search (str):
240+
The name of the base table column to search for nearest neighbor embeddings.
241+
The column must have a type of ``ARRAY<FLOAT64>``. All elements in the array must be non-NULL.
242+
query (bigframes.dataframe.DataFrame | bigframes.dataframe.Series):
243+
A Series or DataFrame that provides the embeddings for which to find nearest neighbors.
244+
query_column_to_search (str):
245+
Specifies the name of the column in the query that contains the embeddings for which to
246+
find nearest neighbors. The column must have a type of ``ARRAY<FLOAT64>``. All elements in
247+
the array must be non-NULL and all values in the column must have the same array dimensions
248+
as the values in the ``column_to_search`` column. Can only be set when query is a DataFrame.
249+
top_k (int, default 10):
250+
Sepecifies the number of nearest neighbors to return. Default to 10.
251+
distance_type (str, defalt "euclidean"):
252+
Specifies the type of metric to use to compute the distance between two vectors.
253+
Possible values are "euclidean" and "cosine". Default to "euclidean".
254+
fraction_lists_to_search (float, range in [0.0, 1.0]):
255+
Specifies the percentage of lists to search. Specifying a higher percentage leads to
256+
higher recall and slower performance, and the converse is true when specifying a lower
257+
percentage. It is only used when a vector index is also used. You can only specify
258+
``fraction_lists_to_search`` when ``use_brute_force`` is set to False.
259+
use_brute_force (bool, default False):
260+
Determines whether to use brute force search by skipping the vector index if one is available.
261+
Default to False.
262+
263+
Returns:
264+
bigframes.dataframe.DataFrame: A DataFrame containing vector search result.
265+
"""
266+
if not fraction_lists_to_search and use_brute_force is True:
267+
raise ValueError(
268+
"You can't specify fraction_lists_to_search when use_brute_force is set to True."
269+
)
270+
if (
271+
isinstance(query, bigframes.series.Series)
272+
and query_column_to_search is not None
273+
):
274+
raise ValueError(
275+
"You can't specify query_column_to_search when query is a Series."
276+
)
277+
# TODO(ashleyxu): Support options in vector search. b/344019989
278+
if fraction_lists_to_search is not None or use_brute_force is True:
279+
raise NotImplementedError(
280+
f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}"
281+
)
282+
options = {
283+
"base_table": base_table,
284+
"column_to_search": column_to_search,
285+
"query_column_to_search": query_column_to_search,
286+
"distance_type": distance_type,
287+
"top_k": top_k,
288+
"fraction_lists_to_search": fraction_lists_to_search,
289+
"use_brute_force": use_brute_force,
290+
}
291+
292+
(query,) = utils.convert_to_dataframe(query)
293+
sql_string, index_col_ids, index_labels = query._to_sql_query(include_index=True)
294+
295+
sql = bigframes.core.sql.create_vector_search_sql(
296+
sql_string=sql_string, options=options # type: ignore
297+
)
298+
if index_col_ids is not None:
299+
df = query._session.read_gbq(sql, index_col=index_col_ids)
300+
else:
301+
df = query._session.read_gbq(sql)
302+
df.index.names = index_labels
303+
304+
return df

bigframes/core/sql.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import datetime
2121
import math
2222
import textwrap
23-
from typing import Iterable, TYPE_CHECKING
23+
from typing import Iterable, Mapping, TYPE_CHECKING, Union
2424

2525
# Literals and identifiers matching this pattern can be unquoted
2626
unquoted = r"^[A-Za-z_][A-Za-z_0-9]*$"
@@ -169,3 +169,47 @@ def ordering_clause(
169169
part = f"`{ordering_expr.id}` {asc_desc} {null_clause}"
170170
parts.append(part)
171171
return f"ORDER BY {' ,'.join(parts)}"
172+
173+
174+
def create_vector_search_sql(
175+
sql_string: str,
176+
options: Mapping[str, Union[str | int | bool | float]] = {},
177+
) -> str:
178+
"""Encode the VECTOR SEARCH statement for BigQuery Vector Search."""
179+
180+
base_table = options["base_table"]
181+
column_to_search = options["column_to_search"]
182+
distance_type = options["distance_type"]
183+
top_k = options["top_k"]
184+
query_column_to_search = options.get("query_column_to_search", None)
185+
186+
if query_column_to_search is not None:
187+
query_str = f"""
188+
SELECT
189+
query.*,
190+
base.*,
191+
distance,
192+
FROM VECTOR_SEARCH(
193+
TABLE `{base_table}`,
194+
{simple_literal(column_to_search)},
195+
({sql_string}),
196+
{simple_literal(query_column_to_search)},
197+
distance_type => {simple_literal(distance_type)},
198+
top_k => {simple_literal(top_k)}
199+
)
200+
"""
201+
else:
202+
query_str = f"""
203+
SELECT
204+
query.*,
205+
base.*,
206+
distance,
207+
FROM VECTOR_SEARCH(
208+
TABLE `{base_table}`,
209+
{simple_literal(column_to_search)},
210+
({sql_string}),
211+
distance_type => {simple_literal(distance_type)},
212+
top_k => {simple_literal(top_k)}
213+
)
214+
"""
215+
return query_str
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pandas as pd
17+
18+
import bigframes.bigquery as bbq
19+
import bigframes.pandas as bpd
20+
21+
22+
def test_vector_search_basic_params_with_df():
23+
search_query = bpd.DataFrame(
24+
{
25+
"query_id": ["dog", "cat"],
26+
"embedding": [[1.0, 2.0], [3.0, 5.2]],
27+
}
28+
)
29+
vector_search_result = bbq.vector_search(
30+
base_table="bigframes-dev.bigframes_tests_sys.base_table",
31+
column_to_search="my_embedding",
32+
query=search_query,
33+
top_k=2,
34+
).to_pandas() # type:ignore
35+
expected = pd.DataFrame(
36+
{
37+
"query_id": ["cat", "dog", "dog", "cat"],
38+
"embedding": [
39+
np.array([3.0, 5.2]),
40+
np.array([1.0, 2.0]),
41+
np.array([1.0, 2.0]),
42+
np.array([3.0, 5.2]),
43+
],
44+
"id": [5, 1, 4, 2],
45+
"my_embedding": [
46+
np.array([5.0, 5.4]),
47+
np.array([1.0, 2.0]),
48+
np.array([1.0, 3.2]),
49+
np.array([2.0, 4.0]),
50+
],
51+
"distance": [2.009975, 0.0, 1.2, 1.56205],
52+
},
53+
index=pd.Index([1, 0, 0, 1], dtype="Int64"),
54+
)
55+
pd.testing.assert_frame_equal(
56+
vector_search_result, expected, check_dtype=False, rtol=0.1
57+
)
58+
59+
60+
def test_vector_search_different_params_with_query():
61+
search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]])
62+
vector_search_result = bbq.vector_search(
63+
base_table="bigframes-dev.bigframes_tests_sys.base_table",
64+
column_to_search="my_embedding",
65+
query=search_query,
66+
distance_type="cosine",
67+
top_k=2,
68+
).to_pandas() # type:ignore
69+
expected = pd.DataFrame(
70+
{
71+
"0": [
72+
np.array([1.0, 2.0]),
73+
np.array([1.0, 2.0]),
74+
np.array([3.0, 5.2]),
75+
np.array([3.0, 5.2]),
76+
],
77+
"id": [2, 1, 1, 2],
78+
"my_embedding": [
79+
np.array([2.0, 4.0]),
80+
np.array([1.0, 2.0]),
81+
np.array([1.0, 2.0]),
82+
np.array([2.0, 4.0]),
83+
],
84+
"distance": [0.0, 0.0, 0.001777, 0.001777],
85+
},
86+
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
87+
)
88+
pd.testing.assert_frame_equal(
89+
vector_search_result, expected, check_dtype=False, rtol=0.1
90+
)
91+
92+
93+
def test_vector_search_df_with_query_column_to_search():
94+
search_query = bpd.DataFrame(
95+
{
96+
"query_id": ["dog", "cat"],
97+
"embedding": [[1.0, 2.0], [3.0, 5.2]],
98+
"another_embedding": [[1.0, 2.5], [3.3, 5.2]],
99+
}
100+
)
101+
vector_search_result = bbq.vector_search(
102+
base_table="bigframes-dev.bigframes_tests_sys.base_table",
103+
column_to_search="my_embedding",
104+
query=search_query,
105+
query_column_to_search="another_embedding",
106+
top_k=2,
107+
).to_pandas() # type:ignore
108+
expected = pd.DataFrame(
109+
{
110+
"query_id": ["dog", "dog", "cat", "cat"],
111+
"embedding": [
112+
np.array([1.0, 2.0]),
113+
np.array([1.0, 2.0]),
114+
np.array([3.0, 5.2]),
115+
np.array([3.0, 5.2]),
116+
],
117+
"another_embedding": [
118+
np.array([1.0, 2.5]),
119+
np.array([1.0, 2.5]),
120+
np.array([3.3, 5.2]),
121+
np.array([3.3, 5.2]),
122+
],
123+
"id": [1, 4, 2, 5],
124+
"my_embedding": [
125+
np.array([1.0, 2.0]),
126+
np.array([1.0, 3.2]),
127+
np.array([2.0, 4.0]),
128+
np.array([5.0, 5.4]),
129+
],
130+
"distance": [0.5, 0.7, 1.769181, 1.711724],
131+
},
132+
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
133+
)
134+
pd.testing.assert_frame_equal(
135+
vector_search_result, expected, check_dtype=False, rtol=0.1
136+
)

0 commit comments

Comments
 (0)