Skip to content

Commit a7a430f

Browse files
authored
Adds unused lower level ivf knn query (#127852)
this is a low level query for some basic IVF querying logic. Right now its the simple, just hit ever segment and search. But we needed to fork away from the typical kNN query due to all the logic there around dropping to exact search for filtered search, etc.
1 parent 24f0772 commit a7a430f

File tree

7 files changed

+1469
-3
lines changed

7 files changed

+1469
-3
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
*/
4747
public class IVFVectorsFormat extends KnnVectorsFormat {
4848

49-
static final FeatureFlag IVF_FORMAT_FEATURE_FLAG = new FeatureFlag("ivf_format");
49+
public static final FeatureFlag IVF_FORMAT_FEATURE_FLAG = new FeatureFlag("ivf_format");
5050
public static final String IVF_VECTOR_COMPONENT = "IVF";
5151
public static final String NAME = "IVFVectorsFormat";
5252
// centroid ordinals -> centroid values, offsets

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.util.FixedBitSet;
3333
import org.apache.lucene.util.hnsw.NeighborQueue;
3434
import org.elasticsearch.core.IOUtils;
35+
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3536

3637
import java.io.IOException;
3738
import java.util.function.IntPredicate;
@@ -243,8 +244,11 @@ public final void search(String field, float[] target, KnnCollector knnCollector
243244
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
244245
return;
245246
}
246-
// TODO add new ivf search strategy
247-
int nProbe = 10;
247+
if (fieldInfo.getVectorDimension() != target.length) {
248+
throw new IllegalArgumentException(
249+
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
250+
);
251+
}
248252
float percentFiltered = 1f;
249253
if (acceptDocs instanceof BitSet bitSet) {
250254
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
@@ -257,6 +261,13 @@ public final void search(String field, float[] target, KnnCollector knnCollector
257261
}
258262
return visitedDocs.getAndSet(docId) == false;
259263
};
264+
final int nProbe;
265+
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
266+
nProbe = ivfSearchStrategy.getNProbe();
267+
} else {
268+
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
269+
nProbe = 10;
270+
}
260271

261272
FieldEntry entry = fields.get(fieldInfo.number);
262273
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.IndexReader;
13+
import org.apache.lucene.index.LeafReader;
14+
import org.apache.lucene.index.LeafReaderContext;
15+
import org.apache.lucene.search.BooleanClause;
16+
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.DocIdSetIterator;
18+
import org.apache.lucene.search.FieldExistsQuery;
19+
import org.apache.lucene.search.FilteredDocIdSetIterator;
20+
import org.apache.lucene.search.IndexSearcher;
21+
import org.apache.lucene.search.KnnCollector;
22+
import org.apache.lucene.search.MatchNoDocsQuery;
23+
import org.apache.lucene.search.Query;
24+
import org.apache.lucene.search.QueryVisitor;
25+
import org.apache.lucene.search.ScoreDoc;
26+
import org.apache.lucene.search.ScoreMode;
27+
import org.apache.lucene.search.Scorer;
28+
import org.apache.lucene.search.TaskExecutor;
29+
import org.apache.lucene.search.TopDocs;
30+
import org.apache.lucene.search.TopDocsCollector;
31+
import org.apache.lucene.search.TopKnnCollector;
32+
import org.apache.lucene.search.Weight;
33+
import org.apache.lucene.search.knn.KnnCollectorManager;
34+
import org.apache.lucene.search.knn.KnnSearchStrategy;
35+
import org.apache.lucene.util.BitSet;
36+
import org.apache.lucene.util.BitSetIterator;
37+
import org.apache.lucene.util.Bits;
38+
import org.elasticsearch.search.profile.query.QueryProfiler;
39+
40+
import java.io.IOException;
41+
import java.util.ArrayList;
42+
import java.util.List;
43+
import java.util.Objects;
44+
import java.util.concurrent.Callable;
45+
46+
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
47+
48+
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
49+
50+
protected final String field;
51+
protected final int nProbe;
52+
protected final int k;
53+
protected final Query filter;
54+
protected final KnnSearchStrategy searchStrategy;
55+
protected int vectorOpsCount;
56+
57+
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
58+
this.field = field;
59+
this.nProbe = nProbe;
60+
this.k = k;
61+
this.filter = filter;
62+
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
63+
}
64+
65+
@Override
66+
public void visit(QueryVisitor visitor) {
67+
if (visitor.acceptField(field)) {
68+
visitor.visitLeaf(this);
69+
}
70+
}
71+
72+
@Override
73+
public boolean equals(Object o) {
74+
if (this == o) return true;
75+
if (o == null || getClass() != o.getClass()) return false;
76+
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
77+
return k == that.k
78+
&& Objects.equals(field, that.field)
79+
&& Objects.equals(filter, that.filter)
80+
&& Objects.equals(nProbe, that.nProbe);
81+
}
82+
83+
@Override
84+
public int hashCode() {
85+
return Objects.hash(field, k, filter, nProbe);
86+
}
87+
88+
@Override
89+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
90+
vectorOpsCount = 0;
91+
IndexReader reader = indexSearcher.getIndexReader();
92+
93+
final Weight filterWeight;
94+
if (filter != null) {
95+
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
96+
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
97+
.build();
98+
Query rewritten = indexSearcher.rewrite(booleanQuery);
99+
if (rewritten.getClass() == MatchNoDocsQuery.class) {
100+
return rewritten;
101+
}
102+
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
103+
} else {
104+
filterWeight = null;
105+
}
106+
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
107+
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
108+
List<LeafReaderContext> leafReaderContexts = reader.leaves();
109+
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
110+
for (LeafReaderContext context : leafReaderContexts) {
111+
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
112+
}
113+
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
114+
115+
// Merge sort the results
116+
TopDocs topK = TopDocs.merge(k, perLeafResults);
117+
vectorOpsCount = (int) topK.totalHits.value();
118+
if (topK.scoreDocs.length == 0) {
119+
return new MatchNoDocsQuery();
120+
}
121+
return new KnnScoreDocQuery(topK.scoreDocs, reader);
122+
}
123+
124+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
125+
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
126+
if (ctx.docBase > 0) {
127+
for (ScoreDoc scoreDoc : results.scoreDocs) {
128+
scoreDoc.doc += ctx.docBase;
129+
}
130+
}
131+
return results;
132+
}
133+
134+
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
135+
final LeafReader reader = ctx.reader();
136+
final Bits liveDocs = reader.getLiveDocs();
137+
138+
if (filterWeight == null) {
139+
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
140+
}
141+
142+
Scorer scorer = filterWeight.scorer(ctx);
143+
if (scorer == null) {
144+
return TopDocsCollector.EMPTY_TOPDOCS;
145+
}
146+
147+
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
148+
final int cost = acceptDocs.cardinality();
149+
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
150+
}
151+
152+
abstract TopDocs approximateSearch(
153+
LeafReaderContext context,
154+
Bits acceptDocs,
155+
int visitedLimit,
156+
KnnCollectorManager knnCollectorManager
157+
) throws IOException;
158+
159+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
160+
return new IVFCollectorManager(k, nProbe);
161+
}
162+
163+
@Override
164+
public final void profile(QueryProfiler queryProfiler) {
165+
queryProfiler.addVectorOpsCount(vectorOpsCount);
166+
}
167+
168+
BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
169+
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
170+
// If we already have a BitSet and no deletions, reuse the BitSet
171+
return bitSetIterator.getBitSet();
172+
} else {
173+
// Create a new BitSet from matching and live docs
174+
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
175+
@Override
176+
protected boolean match(int doc) {
177+
return liveDocs == null || liveDocs.get(doc);
178+
}
179+
};
180+
return BitSet.of(filterIterator, maxDoc);
181+
}
182+
}
183+
184+
static class IVFCollectorManager implements KnnCollectorManager {
185+
private final int k;
186+
private final int nprobe;
187+
188+
IVFCollectorManager(int k, int nprobe) {
189+
this.k = k;
190+
this.nprobe = nprobe;
191+
}
192+
193+
@Override
194+
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
195+
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
196+
}
197+
}
198+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.search.vectors;
10+
11+
import org.apache.lucene.index.FloatVectorValues;
12+
import org.apache.lucene.index.LeafReader;
13+
import org.apache.lucene.index.LeafReaderContext;
14+
import org.apache.lucene.search.KnnCollector;
15+
import org.apache.lucene.search.Query;
16+
import org.apache.lucene.search.TopDocs;
17+
import org.apache.lucene.search.knn.KnnCollectorManager;
18+
import org.apache.lucene.util.Bits;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
23+
/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
24+
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
25+
26+
private final float[] query;
27+
28+
/**
29+
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
30+
* @param field the field to search
31+
* @param query the query vector
32+
* @param k the number of nearest neighbors to return
33+
* @param filter the filter to apply to the results
34+
* @param nProbe the number of probes to use for the IVF search strategy
35+
*/
36+
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
37+
super(field, nProbe, k, filter);
38+
if (k < 1) {
39+
throw new IllegalArgumentException("k must be at least 1, got: " + k);
40+
}
41+
if (nProbe < 1) {
42+
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
43+
}
44+
this.query = query;
45+
}
46+
47+
@Override
48+
public String toString(String field) {
49+
StringBuilder buffer = new StringBuilder();
50+
buffer.append(getClass().getSimpleName())
51+
.append(":")
52+
.append(this.field)
53+
.append("[")
54+
.append(query[0])
55+
.append(",...]")
56+
.append("[")
57+
.append(k)
58+
.append("]");
59+
if (this.filter != null) {
60+
buffer.append("[").append(this.filter).append("]");
61+
}
62+
return buffer.toString();
63+
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o) return true;
68+
if (super.equals(o) == false) return false;
69+
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
70+
return Arrays.equals(query, that.query);
71+
}
72+
73+
@Override
74+
public int hashCode() {
75+
int result = super.hashCode();
76+
result = 31 * result + Arrays.hashCode(query);
77+
return result;
78+
}
79+
80+
@Override
81+
protected TopDocs approximateSearch(
82+
LeafReaderContext context,
83+
Bits acceptDocs,
84+
int visitedLimit,
85+
KnnCollectorManager knnCollectorManager
86+
) throws IOException {
87+
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
88+
LeafReader reader = context.reader();
89+
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
90+
if (floatVectorValues == null) {
91+
FloatVectorValues.checkField(reader, field);
92+
return NO_RESULTS;
93+
}
94+
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
95+
return NO_RESULTS;
96+
}
97+
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
98+
TopDocs results = knnCollector.topDocs();
99+
return results != null ? results : NO_RESULTS;
100+
}
101+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.search.vectors;
10+
11+
import org.apache.lucene.search.knn.KnnSearchStrategy;
12+
13+
import java.util.Objects;
14+
15+
public class IVFKnnSearchStrategy extends KnnSearchStrategy {
16+
private final int nProbe;
17+
18+
IVFKnnSearchStrategy(int nProbe) {
19+
this.nProbe = nProbe;
20+
}
21+
22+
public int getNProbe() {
23+
return nProbe;
24+
}
25+
26+
@Override
27+
public boolean equals(Object o) {
28+
if (this == o) return true;
29+
if (o == null || getClass() != o.getClass()) return false;
30+
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
31+
return nProbe == that.nProbe;
32+
}
33+
34+
@Override
35+
public int hashCode() {
36+
return Objects.hashCode(nProbe);
37+
}
38+
39+
@Override
40+
public void nextVectorsBlock() {
41+
// do nothing
42+
}
43+
}

0 commit comments

Comments
 (0)