|
20 | 20 |
|
21 | 21 | from __future__ import annotations
|
22 | 22 |
|
23 |
| -import typing |
| 23 | +from typing import Literal, Optional, Union |
24 | 24 |
|
| 25 | +import bigframes.bigquery.utils as utils |
25 | 26 | import bigframes.constants as constants
|
26 | 27 | import bigframes.core.groupby as groupby
|
27 | 28 | import bigframes.operations as ops
|
28 | 29 | import bigframes.operations.aggregations as agg_ops
|
| 30 | +import bigframes.pandas as bpd |
29 | 31 |
|
30 |
| -if typing.TYPE_CHECKING: |
31 |
| - import bigframes.dataframe as dataframe |
32 |
| - import bigframes.series as series |
33 | 32 |
|
34 |
| - |
35 |
| -def array_length(series: series.Series) -> series.Series: |
| 33 | +def array_length(series: bpd.Series) -> bpd.Series: |
36 | 34 | """Compute the length of each array element in the Series.
|
37 | 35 |
|
38 | 36 | **Examples:**
|
@@ -69,7 +67,7 @@ def array_length(series: series.Series) -> series.Series:
|
69 | 67 |
|
70 | 68 | def array_agg(
|
71 | 69 | obj: groupby.SeriesGroupBy | groupby.DataFrameGroupBy,
|
72 |
| -) -> series.Series | dataframe.DataFrame: |
| 70 | +) -> bpd.Series | bpd.DataFrame: |
73 | 71 | """Group data and create arrays from selected columns, omitting NULLs to avoid
|
74 | 72 | BigQuery errors (NULLs not allowed in arrays).
|
75 | 73 |
|
@@ -120,7 +118,7 @@ def array_agg(
|
120 | 118 | )
|
121 | 119 |
|
122 | 120 |
|
123 |
| -def array_to_string(series: series.Series, delimiter: str) -> series.Series: |
| 121 | +def array_to_string(series: bpd.Series, delimiter: str) -> bpd.Series: |
124 | 122 | """Converts array elements within a Series into delimited strings.
|
125 | 123 |
|
126 | 124 | **Examples:**
|
@@ -148,3 +146,139 @@ def array_to_string(series: series.Series, delimiter: str) -> series.Series:
|
148 | 146 |
|
149 | 147 | """
|
150 | 148 | return series._apply_unary_op(ops.ArrayToStringOp(delimiter=delimiter))
|
| 149 | + |
| 150 | + |
| 151 | +def vector_search( |
| 152 | + base_table: str, |
| 153 | + column_to_search: str, |
| 154 | + query: Union[bpd.DataFrame, bpd.Series], |
| 155 | + *, |
| 156 | + query_column_to_search: Optional[str] = None, |
| 157 | + top_k: Optional[int] = 10, |
| 158 | + distance_type: Literal["euclidean", "cosine"] = "euclidean", |
| 159 | + fraction_lists_to_search: Optional[float] = None, |
| 160 | + use_brute_force: bool = False, |
| 161 | +) -> bpd.DataFrame: |
| 162 | + """ |
| 163 | + Conduct vector search to earch embeddings to find semantically similar entities. |
| 164 | +
|
| 165 | + **Examples:** |
| 166 | +
|
| 167 | +
|
| 168 | + >>> import bigframes.pandas as bpd |
| 169 | + >>> import bigframes.bigquery as bbq |
| 170 | + >>> bpd.options.display.progress_bar = None |
| 171 | +
|
| 172 | + DataFrame embeddings for which to find nearest neighbors: |
| 173 | +
|
| 174 | + >>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"], |
| 175 | + ... "embedding": [[1.0, 2.0], [3.0, 5.2]]}) |
| 176 | + >>> bbq.vector_search( |
| 177 | + ... base_table="bigframes-dev.bigframes_tests_sys.base_table", |
| 178 | + ... column_to_search="my_embedding", |
| 179 | + ... query=search_query, |
| 180 | + ... top_k=2) |
| 181 | + query_id embedding id my_embedding distance |
| 182 | + 1 cat [3. 5.2] 5 [5. 5.4] 2.009975 |
| 183 | + 0 dog [1. 2.] 1 [1. 2.] 0.0 |
| 184 | + 0 dog [1. 2.] 4 [1. 3.2] 1.2 |
| 185 | + 1 cat [3. 5.2] 2 [2. 4.] 1.56205 |
| 186 | + <BLANKLINE> |
| 187 | + [4 rows x 5 columns] |
| 188 | +
|
| 189 | + Series embeddings for which to find nearest neighbors: |
| 190 | +
|
| 191 | + >>> search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]], |
| 192 | + ... index=["dog", "cat"], |
| 193 | + ... name="embedding") |
| 194 | + >>> bbq.vector_search( |
| 195 | + ... base_table="bigframes-dev.bigframes_tests_sys.base_table", |
| 196 | + ... column_to_search="my_embedding", |
| 197 | + ... query=search_query, |
| 198 | + ... top_k=2) |
| 199 | + embedding id my_embedding distance |
| 200 | + dog [1. 2.] 1 [1. 2.] 0.0 |
| 201 | + cat [3. 5.2] 5 [5. 5.4] 2.009975 |
| 202 | + dog [1. 2.] 4 [1. 3.2] 1.2 |
| 203 | + cat [3. 5.2] 2 [2. 4.] 1.56205 |
| 204 | + <BLANKLINE> |
| 205 | + [4 rows x 4 columns] |
| 206 | +
|
| 207 | + You can specify the name of the column in the query DataFrame embeddings and distance type: |
| 208 | +
|
| 209 | + >>> search_query = bpd.DataFrame({"query_id": ["dog", "cat"], |
| 210 | + ... "embedding": [[1.0, 2.0], [3.0, 5.2]], |
| 211 | + ... "another_embedding": [[0.7, 2.2], [3.3, 5.2]]}) |
| 212 | + >>> bbq.vector_search( |
| 213 | + ... base_table="bigframes-dev.bigframes_tests_sys.base_table", |
| 214 | + ... column_to_search="my_embedding", |
| 215 | + ... query=search_query, |
| 216 | + ... distance_type="cosine", |
| 217 | + ... query_column_to_search="another_embedding", |
| 218 | + ... top_k=2) |
| 219 | + query_id embedding another_embedding id my_embedding distance |
| 220 | + 1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181 |
| 221 | + 0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013 |
| 222 | + 1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181 |
| 223 | + 0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697 |
| 224 | + <BLANKLINE> |
| 225 | + [4 rows x 6 columns] |
| 226 | +
|
| 227 | + Args: |
| 228 | + base_table (str): |
| 229 | + The table to search for nearest neighbor embeddings. |
| 230 | + column_to_search (groupby.SeriesGroupBy | groupby.DataFrameGroupBy): |
| 231 | + The name of the base table column to search for nearest neighbor embeddings. |
| 232 | + The column must have a type of ``ARRAY<FLOAT64>``. All elements in the array must be non-NULL. |
| 233 | + query (bigframes.dataframe.DataFrame | bigframes.dataframe.Series): |
| 234 | + A Series or DataFrame that provides the embeddings for which to find nearest neighbors. |
| 235 | + query_column_to_search (str): |
| 236 | + Specifies the name of the column in the query that contains the embeddings for which to |
| 237 | + find nearest neighbors. The column must have a type of ``ARRAY<FLOAT64>``. All elements in |
| 238 | + the array must be non-NULL and all values in the column must have the same array dimensions |
| 239 | + as the values in the ``column_to_search`` column. Can only be set when query is a DataFrame. |
| 240 | + top_k (int, default 10): |
| 241 | + Sepecifies the number of nearest neighbors to return. Default to 10. |
| 242 | + distance_type (str, defalt "euclidean"): |
| 243 | + Specifies the type of metric to use to compute the distance between two vectors. |
| 244 | + Possible values are "euclidean" and "cosine". Default to "euclidean". |
| 245 | + fraction_lists_to_search (float, range in [0.0, 1.0]): |
| 246 | + Specifies the percentage of lists to search. Specifying a higher percentage leads to |
| 247 | + higher recall and slower performance, and the converse is true when specifying a lower |
| 248 | + percentage. It is only used when a vector index is also used. You can only specify |
| 249 | + ``fraction_lists_to_search`` when ``use_brute_force`` is set to False. |
| 250 | + use_brute_force (bool, default False): |
| 251 | + Determines whether to use brute force search by skipping the vector index if one is available. |
| 252 | + Default to False. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + bigframes.dataframe.DataFrame: A DataFrame containing vector search result. |
| 256 | + """ |
| 257 | + if not fraction_lists_to_search and use_brute_force is True: |
| 258 | + raise ValueError( |
| 259 | + "You can't specify fraction_lists_to_search when use_brute_force is set to True." |
| 260 | + ) |
| 261 | + if isinstance(query, bpd.Series) and query_column_to_search is not None: |
| 262 | + raise ValueError( |
| 263 | + "You can't specify query_column_to_search when query is a Series." |
| 264 | + ) |
| 265 | + ## (TODO: ashleyxu. Support options in vector search.) |
| 266 | + if fraction_lists_to_search is not None or use_brute_force is True: |
| 267 | + raise NotImplementedError( |
| 268 | + f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}" |
| 269 | + ) |
| 270 | + options = { |
| 271 | + "base_table": base_table, |
| 272 | + "column_to_search": column_to_search, |
| 273 | + "query_column_to_search": query_column_to_search, |
| 274 | + "distance_type": distance_type, |
| 275 | + "top_k": top_k, |
| 276 | + "fraction_lists_to_search": fraction_lists_to_search, |
| 277 | + "use_brute_force": use_brute_force, |
| 278 | + } |
| 279 | + |
| 280 | + df = utils.apply_sql( |
| 281 | + query, |
| 282 | + options, # type:ignore |
| 283 | + ) |
| 284 | + return df |
0 commit comments