Skip to content

Commit ccf7ff5

Browse files
committed
address comments
1 parent 3e584f4 commit ccf7ff5

File tree

6 files changed

+78
-201
lines changed

6 files changed

+78
-201
lines changed

bigframes/bigquery/__init__.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,23 @@
2020

2121
from __future__ import annotations
2222

23+
import typing
2324
from typing import Literal, Optional, Union
2425

25-
import bigframes.bigquery.utils as utils
2626
import bigframes.constants as constants
2727
import bigframes.core.groupby as groupby
28+
import bigframes.core.sql
29+
import bigframes.ml.utils as utils
2830
import bigframes.operations as ops
2931
import bigframes.operations.aggregations as agg_ops
30-
import bigframes.pandas as bpd
32+
import bigframes.series
3133

34+
if typing.TYPE_CHECKING:
35+
import bigframes.dataframe as dataframe
36+
import bigframes.series as series
3237

33-
def array_length(series: bpd.Series) -> bpd.Series:
38+
39+
def array_length(series: series.Series) -> series.Series:
3440
"""Compute the length of each array element in the Series.
3541
3642
**Examples:**
@@ -67,7 +73,7 @@ def array_length(series: bpd.Series) -> bpd.Series:
6773

6874
def array_agg(
6975
obj: groupby.SeriesGroupBy | groupby.DataFrameGroupBy,
70-
) -> bpd.Series | bpd.DataFrame:
76+
) -> series.Series | dataframe.DataFrame:
7177
"""Group data and create arrays from selected columns, omitting NULLs to avoid
7278
BigQuery errors (NULLs not allowed in arrays).
7379
@@ -118,7 +124,7 @@ def array_agg(
118124
)
119125

120126

121-
def array_to_string(series: bpd.Series, delimiter: str) -> bpd.Series:
127+
def array_to_string(series: series.Series, delimiter: str) -> series.Series:
122128
"""Converts array elements within a Series into delimited strings.
123129
124130
**Examples:**
@@ -151,14 +157,14 @@ def array_to_string(series: bpd.Series, delimiter: str) -> bpd.Series:
151157
def vector_search(
152158
base_table: str,
153159
column_to_search: str,
154-
query: Union[bpd.DataFrame, bpd.Series],
160+
query: Union[dataframe.DataFrame, series.Series],
155161
*,
156162
query_column_to_search: Optional[str] = None,
157163
top_k: Optional[int] = 10,
158164
distance_type: Literal["euclidean", "cosine"] = "euclidean",
159165
fraction_lists_to_search: Optional[float] = None,
160166
use_brute_force: bool = False,
161-
) -> bpd.DataFrame:
167+
) -> dataframe.DataFrame:
162168
"""
163169
Conduct vector search to earch embeddings to find semantically similar entities.
164170
@@ -258,11 +264,14 @@ def vector_search(
258264
raise ValueError(
259265
"You can't specify fraction_lists_to_search when use_brute_force is set to True."
260266
)
261-
if isinstance(query, bpd.Series) and query_column_to_search is not None:
267+
if (
268+
isinstance(query, bigframes.series.Series)
269+
and query_column_to_search is not None
270+
):
262271
raise ValueError(
263272
"You can't specify query_column_to_search when query is a Series."
264273
)
265-
## (TODO: ashleyxu. Support options in vector search.)
274+
# TODO(ashleyxu): ashleyxu. Support options in vector search. b/344019989
266275
if fraction_lists_to_search is not None or use_brute_force is True:
267276
raise NotImplementedError(
268277
f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}"
@@ -277,8 +286,16 @@ def vector_search(
277286
"use_brute_force": use_brute_force,
278287
}
279288

280-
df = utils.apply_sql(
281-
query,
282-
options, # type:ignore
289+
(query,) = utils.convert_to_dataframe(query)
290+
sql_string, index_col_ids, index_labels = query._to_sql_query(include_index=True)
291+
292+
sql = bigframes.core.sql.create_vector_search_sql(
293+
sql_string=sql_string, options=options # type: ignore
283294
)
295+
if index_col_ids is not None:
296+
df = query._session.read_gbq(sql, index_col=index_col_ids)
297+
else:
298+
df = query._session.read_gbq(sql)
299+
df.index.names = index_labels
300+
284301
return df

bigframes/bigquery/utils.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

bigframes/core/sql.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import datetime
2121
import math
2222
import textwrap
23-
from typing import Iterable, TYPE_CHECKING
23+
from typing import Iterable, Mapping, TYPE_CHECKING, Union
2424

2525
# Literals and identifiers matching this pattern can be unquoted
2626
unquoted = r"^[A-Za-z_][A-Za-z_0-9]*$"
@@ -169,3 +169,47 @@ def ordering_clause(
169169
part = f"`{ordering_expr.id}` {asc_desc} {null_clause}"
170170
parts.append(part)
171171
return f"ORDER BY {' ,'.join(parts)}"
172+
173+
174+
def create_vector_search_sql(
175+
sql_string: str,
176+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
177+
) -> str:
178+
"""Encode the VECTOR SEARCH statement for BigQuery Vector Search."""
179+
180+
base_table = options["base_table"]
181+
column_to_search = options["column_to_search"]
182+
distance_type = options["distance_type"]
183+
top_k = options["top_k"]
184+
query_column_to_search = options.get("query_column_to_search", None)
185+
186+
if query_column_to_search is not None:
187+
query_str = f"""
188+
SELECT
189+
query.*,
190+
base.*,
191+
distance,
192+
FROM VECTOR_SEARCH(
193+
TABLE `{base_table}`,
194+
"{column_to_search}",
195+
({sql_string}),
196+
"{query_column_to_search}",
197+
distance_type => "{distance_type}",
198+
top_k => {top_k}
199+
)
200+
"""
201+
else:
202+
query_str = f"""
203+
SELECT
204+
query.*,
205+
base.*,
206+
distance,
207+
FROM VECTOR_SEARCH(
208+
TABLE `{base_table}`,
209+
"{column_to_search}",
210+
({sql_string}),
211+
distance_type => "{distance_type}",
212+
top_k => {top_k}
213+
)
214+
"""
215+
return query_str

tests/system/small/bigquery/test_utils.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

tests/unit/bigquery/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/unit/bigquery/test_utils.py renamed to tests/unit/core/test_sql.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import bigframes.bigquery as bbq
15+
16+
from bigframes.core import sql
1617

1718

1819
def test_create_vector_search_sql_simple():
@@ -39,7 +40,7 @@ def test_create_vector_search_sql_simple():
3940
)
4041
"""
4142

42-
result_query = bbq.utils.create_vector_search_sql(
43+
result_query = sql.create_vector_search_sql(
4344
sql_string, options # type:ignore
4445
)
4546
assert result_query == expected_query
@@ -71,7 +72,7 @@ def test_create_vector_search_sql_query_column_to_search():
7172
)
7273
"""
7374

74-
result_query = bbq.utils.create_vector_search_sql(
75+
result_query = sql.create_vector_search_sql(
7576
sql_string, options # type:ignore
7677
)
7778
assert result_query == expected_query

0 commit comments

Comments
 (0)