@@ -122,20 +122,36 @@ where
122122{
123123 /// Vector search stage of aggregation pipeline of mongoDB collection.
124124 /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
125- fn pipeline_search_stage ( & self , prompt_embedding : & Embedding , n : usize ) -> bson:: Document {
125+ fn pipeline_search_stage (
126+ & self ,
127+ prompt_embedding : & Embedding ,
128+ req : & VectorSearchRequest < MongoDbSearchFilter > ,
129+ ) -> bson:: Document {
126130 let SearchParams {
127- filter,
128131 exact,
129132 num_candidates,
130133 } = & self . search_params ;
131134
135+ let samples = req. samples ( ) as usize ;
136+
137+ let thresh = req
138+ . threshold ( )
139+ . map ( |thresh| MongoDbSearchFilter :: gte ( "score" . into ( ) , thresh. into ( ) ) ) ;
140+
141+ let filter = match ( thresh, req. filter ( ) ) {
142+ ( Some ( thresh) , Some ( filt) ) => thresh. and ( filt. clone ( ) ) . into_inner ( ) ,
143+ ( Some ( thresh) , _) => thresh. into_inner ( ) ,
144+ ( _, Some ( filt) ) => filt. clone ( ) . into_inner ( ) ,
145+ _ => Default :: default ( ) ,
146+ } ;
147+
132148 doc ! {
133149 "$vectorSearch" : {
134150 "index" : & self . index_name,
135151 "path" : self . embedded_field. clone( ) ,
136152 "queryVector" : & prompt_embedding. vec,
137- "numCandidates" : num_candidates. unwrap_or( ( n * 10 ) as u32 ) ,
138- "limit" : n as u32 ,
153+ "numCandidates" : num_candidates. unwrap_or( ( samples * 10 ) as u32 ) ,
154+ "limit" : samples as u32 ,
139155 "filter" : filter,
140156 "exact" : exact. unwrap_or( false )
141157 }
@@ -201,7 +217,6 @@ where
201217/// on each of the fields
202218#[ derive( Default ) ]
203219pub struct SearchParams {
204- filter : mongodb:: bson:: Document ,
205220 exact : Option < bool > ,
206221 num_candidates : Option < u32 > ,
207222}
@@ -210,19 +225,11 @@ impl SearchParams {
210225 /// Initializes a new `SearchParams` with default values.
211226 pub fn new ( ) -> Self {
212227 Self {
213- filter : doc ! { } ,
214228 exact : None ,
215229 num_candidates : None ,
216230 }
217231 }
218232
219- /// Sets the pre-filter field of the search params.
220- /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
221- pub fn filter ( mut self , filter : mongodb:: bson:: Document ) -> Self {
222- self . filter = filter;
223- self
224- }
225-
226233 /// Sets the exact field of the search params.
227234 /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
228235 /// By default, exact is false.
@@ -270,9 +277,8 @@ impl SearchFilter for MongoDbSearchFilter {
270277}
271278
272279impl MongoDbSearchFilter {
273- /// Render the filter as a MonadDB `$match` expression
274- pub fn into_document ( self ) -> Document {
275- doc ! { "$match" : self . 0 }
280+ fn into_inner ( self ) -> Document {
281+ self . 0
276282 }
277283
278284 pub fn gte ( key : String , value : <Self as SearchFilter >:: Value ) -> Self {
@@ -285,7 +291,25 @@ impl MongoDbSearchFilter {
285291
286292 #[ allow( clippy:: should_implement_trait) ]
287293 pub fn not ( self ) -> Self {
288- Self ( doc ! { "$not" : self . 0 } )
294+ Self ( doc ! { "$nor" : [ self . 0 ] } )
295+ }
296+
297+ /// Tests whether the value at `key` is the BSON type `typ`
298+ pub fn is_type ( key : String , typ : & ' static str ) -> Self {
299+ Self ( doc ! { key: { "$type" : typ } } )
300+ }
301+
302+ pub fn size ( key : String , size : i32 ) -> Self {
303+ Self ( doc ! { key: { "$size" : size } } )
304+ }
305+
306+ // Array ops
307+ pub fn all ( key : String , values : Vec < Bson > ) -> Self {
308+ Self ( doc ! { key: { "$all" : values } } )
309+ }
310+
311+ pub fn any ( key : String , condition : Document ) -> Self {
312+ Self ( doc ! { key: { "$elemMatch" : condition } } )
289313 }
290314}
291315
@@ -305,28 +329,16 @@ where
305329 ) -> Result < Vec < ( f64 , String , T ) > , VectorStoreError > {
306330 let prompt_embedding = self . model . embed_text ( req. query ( ) ) . await ?;
307331
308- let mut pipeline = vec ! [
309- self . pipeline_search_stage( & prompt_embedding, req. samples ( ) as usize ) ,
332+ let pipeline = vec ! [
333+ self . pipeline_search_stage( & prompt_embedding, & req) ,
310334 self . pipeline_score_stage( ) ,
335+ doc! {
336+ "$project" : {
337+ self . embedded_field. clone( ) : 0
338+ }
339+ } ,
311340 ] ;
312341
313- if let Some ( filter) = req. filter ( ) {
314- let filter = req
315- . threshold ( )
316- . map ( |thresh| {
317- MongoDbSearchFilter :: gte ( "score" . into ( ) , thresh. into ( ) ) . and ( filter. clone ( ) )
318- } )
319- . unwrap_or ( filter. clone ( ) ) ;
320-
321- pipeline. push ( filter. into_document ( ) )
322- }
323-
324- pipeline. push ( doc ! {
325- "$project" : {
326- self . embedded_field. clone( ) : 0
327- }
328- } ) ;
329-
330342 let mut cursor = self
331343 . collection
332344 . aggregate ( pipeline)
@@ -361,28 +373,16 @@ where
361373 ) -> Result < Vec < ( f64 , String ) > , VectorStoreError > {
362374 let prompt_embedding = self . model . embed_text ( req. query ( ) ) . await ?;
363375
364- let mut pipeline = vec ! [
365- self . pipeline_search_stage( & prompt_embedding, req. samples ( ) as usize ) ,
376+ let pipeline = vec ! [
377+ self . pipeline_search_stage( & prompt_embedding, & req) ,
366378 self . pipeline_score_stage( ) ,
367- ] ;
368-
369- if let Some ( filter) = req. filter ( ) {
370- let filter = req
371- . threshold ( )
372- . map ( |thresh| {
373- MongoDbSearchFilter :: gte ( "score" . into ( ) , thresh. into ( ) ) . and ( filter. clone ( ) )
374- } )
375- . unwrap_or ( filter. clone ( ) ) ;
376-
377- pipeline. push ( filter. into_document ( ) )
378- }
379-
380- pipeline. push ( doc ! {
381- "$project" : {
382- "_id" : 1 ,
383- "score" : 1
379+ doc! {
380+ "$project" : {
381+ "_id" : 1 ,
382+ "score" : 1
383+ } ,
384384 } ,
385- } ) ;
385+ ] ;
386386
387387 let mut cursor = self
388388 . collection
0 commit comments