Skip to content

Commit 72f1474

Browse files
authored
feat(rig-1014): add backend specific vector search filters (#1032)
1 parent 9263016 commit 72f1474

File tree

9 files changed

+491
-59
lines changed

9 files changed

+491
-59
lines changed

rig-lancedb/src/lib.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::Range;
2+
13
use lancedb::{
24
DistanceType,
35
query::{QueryBase, VectorQuery},
@@ -205,6 +207,87 @@ impl LanceDBFilter {
205207
pub fn not(self) -> Self {
206208
Self(self.0.map(|s| format!("NOT ({s})")))
207209
}
210+
211+
/// IN operator
212+
pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
213+
Self(
214+
values
215+
.into_iter()
216+
.map(escape_value)
217+
.collect::<Result<Vec<_>, FilterError>>()
218+
.map(|xs| xs.join(","))
219+
.map(|xs| format!("{key} IN ({xs})")),
220+
)
221+
}
222+
223+
/// LIKE operator (string pattern matching)
224+
pub fn like<S>(key: String, pattern: S) -> Self
225+
where
226+
S: AsRef<str>,
227+
{
228+
Self(
229+
escape_value(serde_json::Value::String(pattern.as_ref().into()))
230+
.map(|pat| format!("{key} LIKE {pat}")),
231+
)
232+
}
233+
234+
/// ILIKE operator (case-insensitive pattern matching)
235+
pub fn ilike<S>(key: String, pattern: S) -> Self
236+
where
237+
S: AsRef<str>,
238+
{
239+
Self(
240+
escape_value(serde_json::Value::String(pattern.as_ref().into()))
241+
.map(|pat| format!("{key} ILIKE {pat}")),
242+
)
243+
}
244+
245+
/// IS NULL check
246+
pub fn is_null(key: String) -> Self {
247+
Self(Ok(format!("{key} IS NULL")))
248+
}
249+
250+
/// IS NOT NULL check
251+
pub fn is_not_null(key: String) -> Self {
252+
Self(Ok(format!("{key} IS NOT NULL")))
253+
}
254+
255+
/// Array has any (for LIST columns with scalar index)
256+
pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
257+
Self(
258+
values
259+
.into_iter()
260+
.map(escape_value)
261+
.collect::<Result<Vec<_>, FilterError>>()
262+
.map(|xs| xs.join(","))
263+
.map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
264+
)
265+
}
266+
267+
/// Array has all (for LIST columns with scalar index)
268+
pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
269+
Self(
270+
values
271+
.into_iter()
272+
.map(escape_value)
273+
.collect::<Result<Vec<_>, FilterError>>()
274+
.map(|xs| xs.join(","))
275+
.map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
276+
)
277+
}
278+
279+
/// Array length comparison
280+
pub fn array_length(key: String, length: i32) -> Self {
281+
Self(Ok(format!("array_length({key}) = {length}")))
282+
}
283+
284+
/// BETWEEN operator
285+
pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
286+
where
287+
T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
288+
{
289+
Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
290+
}
208291
}
209292

210293
/// Parameters used to perform a vector search on a LanceDb table.

rig-milvus/src/filter.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,61 @@ impl Filter {
125125
Self(format!("{key} <= {}", value.escaped()))
126126
}
127127

128+
/// IN operator
129+
pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
130+
let values_str = values
131+
.into_iter()
132+
.map(|v| v.escaped())
133+
.collect::<Vec<_>>()
134+
.join(", ");
135+
Self(format!("{} in [{}]", key, values_str))
136+
}
137+
138+
/// NOT IN operator
139+
pub fn not_in(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
140+
let values_str = values
141+
.into_iter()
142+
.map(|v| v.escaped())
143+
.collect::<Vec<_>>()
144+
.join(", ");
145+
Self(format!("{} not in [{}]", key, values_str))
146+
}
147+
148+
/// LIKE operator (string pattern matching)
149+
pub fn like(key: String, pattern: String) -> Self {
150+
Self(format!("{} like '{}'", key, pattern))
151+
}
152+
153+
/// Array contains
154+
pub fn array_contains(key: String, value: <Self as SearchFilter>::Value) -> Self {
155+
Self(format!("array_contains({}, {})", key, value.escaped()))
156+
}
157+
158+
/// Array contains all
159+
pub fn array_contains_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
160+
let values_str = values
161+
.into_iter()
162+
.map(|v| v.escaped())
163+
.collect::<Vec<_>>()
164+
.join(", ");
165+
Self(format!("array_contains_all({}, [{}])", key, values_str))
166+
}
167+
168+
/// Array contains any
169+
pub fn array_contains_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
170+
let values_str = values
171+
.into_iter()
172+
.map(|v| v.escaped())
173+
.collect::<Vec<_>>()
174+
.join(", ");
175+
Self(format!("array_contains_any({}, [{}])", key, values_str))
176+
}
177+
178+
/// Array length comparison
179+
pub fn array_length_eq(key: String, length: i32) -> Self {
180+
Self(format!("array_length({}) == {}", key, length))
181+
}
182+
128183
pub fn into_inner(self) -> String {
129184
self.0
130185
}

rig-mongodb/src/lib.rs

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,36 @@ where
122122
{
123123
/// Vector search stage of aggregation pipeline of mongoDB collection.
124124
/// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
125-
fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
125+
fn pipeline_search_stage(
126+
&self,
127+
prompt_embedding: &Embedding,
128+
req: &VectorSearchRequest<MongoDbSearchFilter>,
129+
) -> bson::Document {
126130
let SearchParams {
127-
filter,
128131
exact,
129132
num_candidates,
130133
} = &self.search_params;
131134

135+
let samples = req.samples() as usize;
136+
137+
let thresh = req
138+
.threshold()
139+
.map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
140+
141+
let filter = match (thresh, req.filter()) {
142+
(Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
143+
(Some(thresh), _) => thresh.into_inner(),
144+
(_, Some(filt)) => filt.clone().into_inner(),
145+
_ => Default::default(),
146+
};
147+
132148
doc! {
133149
"$vectorSearch": {
134150
"index": &self.index_name,
135151
"path": self.embedded_field.clone(),
136152
"queryVector": &prompt_embedding.vec,
137-
"numCandidates": num_candidates.unwrap_or((n * 10) as u32),
138-
"limit": n as u32,
153+
"numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
154+
"limit": samples as u32,
139155
"filter": filter,
140156
"exact": exact.unwrap_or(false)
141157
}
@@ -201,7 +217,6 @@ where
201217
/// on each of the fields
202218
#[derive(Default)]
203219
pub struct SearchParams {
204-
filter: mongodb::bson::Document,
205220
exact: Option<bool>,
206221
num_candidates: Option<u32>,
207222
}
@@ -210,19 +225,11 @@ impl SearchParams {
210225
/// Initializes a new `SearchParams` with default values.
211226
pub fn new() -> Self {
212227
Self {
213-
filter: doc! {},
214228
exact: None,
215229
num_candidates: None,
216230
}
217231
}
218232

219-
/// Sets the pre-filter field of the search params.
220-
/// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
221-
pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
222-
self.filter = filter;
223-
self
224-
}
225-
226233
/// Sets the exact field of the search params.
227234
/// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
228235
/// By default, exact is false.
@@ -270,9 +277,8 @@ impl SearchFilter for MongoDbSearchFilter {
270277
}
271278

272279
impl MongoDbSearchFilter {
273-
/// Render the filter as a MonadDB `$match` expression
274-
pub fn into_document(self) -> Document {
275-
doc! { "$match": self.0 }
280+
fn into_inner(self) -> Document {
281+
self.0
276282
}
277283

278284
pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
@@ -285,7 +291,25 @@ impl MongoDbSearchFilter {
285291

286292
#[allow(clippy::should_implement_trait)]
287293
pub fn not(self) -> Self {
288-
Self(doc! { "$not": self.0 })
294+
Self(doc! { "$nor": [self.0] })
295+
}
296+
297+
/// Tests whether the value at `key` is the BSON type `typ`
298+
pub fn is_type(key: String, typ: &'static str) -> Self {
299+
Self(doc! { key: { "$type": typ } })
300+
}
301+
302+
pub fn size(key: String, size: i32) -> Self {
303+
Self(doc! { key: { "$size": size } })
304+
}
305+
306+
// Array ops
307+
pub fn all(key: String, values: Vec<Bson>) -> Self {
308+
Self(doc! { key: { "$all": values } })
309+
}
310+
311+
pub fn any(key: String, condition: Document) -> Self {
312+
Self(doc! { key: { "$elemMatch": condition } })
289313
}
290314
}
291315

@@ -305,28 +329,16 @@ where
305329
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
306330
let prompt_embedding = self.model.embed_text(req.query()).await?;
307331

308-
let mut pipeline = vec![
309-
self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
332+
let pipeline = vec![
333+
self.pipeline_search_stage(&prompt_embedding, &req),
310334
self.pipeline_score_stage(),
335+
doc! {
336+
"$project": {
337+
self.embedded_field.clone(): 0
338+
}
339+
},
311340
];
312341

313-
if let Some(filter) = req.filter() {
314-
let filter = req
315-
.threshold()
316-
.map(|thresh| {
317-
MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
318-
})
319-
.unwrap_or(filter.clone());
320-
321-
pipeline.push(filter.into_document())
322-
}
323-
324-
pipeline.push(doc! {
325-
"$project": {
326-
self.embedded_field.clone(): 0
327-
}
328-
});
329-
330342
let mut cursor = self
331343
.collection
332344
.aggregate(pipeline)
@@ -361,28 +373,16 @@ where
361373
) -> Result<Vec<(f64, String)>, VectorStoreError> {
362374
let prompt_embedding = self.model.embed_text(req.query()).await?;
363375

364-
let mut pipeline = vec![
365-
self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
376+
let pipeline = vec![
377+
self.pipeline_search_stage(&prompt_embedding, &req),
366378
self.pipeline_score_stage(),
367-
];
368-
369-
if let Some(filter) = req.filter() {
370-
let filter = req
371-
.threshold()
372-
.map(|thresh| {
373-
MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
374-
})
375-
.unwrap_or(filter.clone());
376-
377-
pipeline.push(filter.into_document())
378-
}
379-
380-
pipeline.push(doc! {
381-
"$project": {
382-
"_id": 1,
383-
"score": 1
379+
doc! {
380+
"$project": {
381+
"_id": 1,
382+
"score": 1
383+
},
384384
},
385-
});
385+
];
386386

387387
let mut cursor = self
388388
.collection

0 commit comments

Comments
 (0)