20
20
21
21
from __future__ import annotations
22
22
23
+ import typing
23
24
from typing import Literal , Optional , Union
24
25
25
- import bigframes .bigquery .utils as utils
26
26
import bigframes .constants as constants
27
27
import bigframes .core .groupby as groupby
28
+ import bigframes .core .sql
29
+ import bigframes .ml .utils as utils
28
30
import bigframes .operations as ops
29
31
import bigframes .operations .aggregations as agg_ops
30
- import bigframes .pandas as bpd
32
+ import bigframes .series
31
33
34
+ if typing .TYPE_CHECKING :
35
+ import bigframes .dataframe as dataframe
36
+ import bigframes .series as series
32
37
33
- def array_length (series : bpd .Series ) -> bpd .Series :
38
+
39
+ def array_length (series : series .Series ) -> series .Series :
34
40
"""Compute the length of each array element in the Series.
35
41
36
42
**Examples:**
@@ -67,7 +73,7 @@ def array_length(series: bpd.Series) -> bpd.Series:
67
73
68
74
def array_agg (
69
75
obj : groupby .SeriesGroupBy | groupby .DataFrameGroupBy ,
70
- ) -> bpd .Series | bpd .DataFrame :
76
+ ) -> series .Series | dataframe .DataFrame :
71
77
"""Group data and create arrays from selected columns, omitting NULLs to avoid
72
78
BigQuery errors (NULLs not allowed in arrays).
73
79
@@ -118,7 +124,7 @@ def array_agg(
118
124
)
119
125
120
126
121
- def array_to_string (series : bpd .Series , delimiter : str ) -> bpd .Series :
127
+ def array_to_string (series : series .Series , delimiter : str ) -> series .Series :
122
128
"""Converts array elements within a Series into delimited strings.
123
129
124
130
**Examples:**
@@ -151,14 +157,14 @@ def array_to_string(series: bpd.Series, delimiter: str) -> bpd.Series:
151
157
def vector_search (
152
158
base_table : str ,
153
159
column_to_search : str ,
154
- query : Union [bpd .DataFrame , bpd .Series ],
160
+ query : Union [dataframe .DataFrame , series .Series ],
155
161
* ,
156
162
query_column_to_search : Optional [str ] = None ,
157
163
top_k : Optional [int ] = 10 ,
158
164
distance_type : Literal ["euclidean" , "cosine" ] = "euclidean" ,
159
165
fraction_lists_to_search : Optional [float ] = None ,
160
166
use_brute_force : bool = False ,
161
- ) -> bpd .DataFrame :
167
+ ) -> dataframe .DataFrame :
162
168
"""
163
169
Conduct vector search to earch embeddings to find semantically similar entities.
164
170
@@ -258,11 +264,14 @@ def vector_search(
258
264
raise ValueError (
259
265
"You can't specify fraction_lists_to_search when use_brute_force is set to True."
260
266
)
261
- if isinstance (query , bpd .Series ) and query_column_to_search is not None :
267
+ if (
268
+ isinstance (query , bigframes .series .Series )
269
+ and query_column_to_search is not None
270
+ ):
262
271
raise ValueError (
263
272
"You can't specify query_column_to_search when query is a Series."
264
273
)
265
- ## ( TODO: ashleyxu. Support options in vector search.)
274
+ # TODO(ashleyxu) : ashleyxu. Support options in vector search. b/344019989
266
275
if fraction_lists_to_search is not None or use_brute_force is True :
267
276
raise NotImplementedError (
268
277
f"fraction_lists_to_search and use_brute_force is not supported. { constants .FEEDBACK_LINK } "
@@ -277,8 +286,16 @@ def vector_search(
277
286
"use_brute_force" : use_brute_force ,
278
287
}
279
288
280
- df = utils .apply_sql (
281
- query ,
282
- options , # type:ignore
289
+ (query ,) = utils .convert_to_dataframe (query )
290
+ sql_string , index_col_ids , index_labels = query ._to_sql_query (include_index = True )
291
+
292
+ sql = bigframes .core .sql .create_vector_search_sql (
293
+ sql_string = sql_string , options = options # type: ignore
283
294
)
295
+ if index_col_ids is not None :
296
+ df = query ._session .read_gbq (sql , index_col = index_col_ids )
297
+ else :
298
+ df = query ._session .read_gbq (sql )
299
+ df .index .names = index_labels
300
+
284
301
return df
0 commit comments