Skip to content

Commit 73e3cc0

Browse files
committed
Add a direct IO option to rescore_vector for bbq_hnsw
1 parent 3e47504 commit 73e3cc0

File tree

12 files changed

+252
-161
lines changed

12 files changed

+252
-161
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ static Codec createCodec(CmdLineArgs args) {
9595
if (args.indexType() == IndexType.FLAT) {
9696
format = new ES818BinaryQuantizedVectorsFormat();
9797
} else {
98-
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null);
98+
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, false, null);
9999
}
100100
} else if (args.quantizeBits() < 32) {
101101
if (args.indexType() == IndexType.FLAT) {

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio
6868
}
6969

7070
static boolean shouldUseDirectIO(SegmentReadState state) {
71-
assert ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO;
7271
return FsDirectoryFactory.isHybridFs(state.directory);
7372
}
7473

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@
8787
*/
8888
public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
8989

90-
public static final boolean USE_DIRECT_IO = Boolean.parseBoolean(System.getProperty("vector.rescoring.directio", "false"));
91-
9290
public static final String BINARIZED_VECTOR_COMPONENT = "BVEC";
9391
public static final String NAME = "ES818BinaryQuantizedVectorsFormat";
9492

@@ -100,17 +98,24 @@ public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
10098
static final String VECTOR_DATA_EXTENSION = "veb";
10199
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
102100

103-
private static final FlatVectorsFormat rawVectorFormat = USE_DIRECT_IO
104-
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
105-
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
106-
107101
private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
108102
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
109103
);
110104

105+
private final FlatVectorsFormat rawVectorFormat;
106+
111107
/** Creates a new instance with the default number of vectors per cluster. */
112108
public ES818BinaryQuantizedVectorsFormat() {
109+
this(false);
110+
}
111+
112+
/** Creates a new instance with the default number of vectors per cluster,
113+
* and whether direct IO should be used to access raw vectors. */
114+
public ES818BinaryQuantizedVectorsFormat(boolean useDirectIO) {
113115
super(NAME);
116+
rawVectorFormat = useDirectIO
117+
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
118+
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
114119
}
115120

116121
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
6262
private final int beamWidth;
6363

6464
/** The format for storing, reading, merging vectors on disk */
65-
private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat();
65+
private final FlatVectorsFormat flatVectorsFormat;
6666

6767
private final int numMergeWorkers;
6868
private final TaskExecutor mergeExec;
6969

7070
/** Constructs a format using default graph construction parameters */
7171
public ES818HnswBinaryQuantizedVectorsFormat() {
72-
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
72+
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, false, null);
7373
}
7474

7575
/**
@@ -79,7 +79,18 @@ public ES818HnswBinaryQuantizedVectorsFormat() {
7979
* @param beamWidth the size of the queue maintained during graph construction.
8080
*/
8181
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
82-
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
82+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, false, null);
83+
}
84+
85+
/**
86+
* Constructs a format using the given graph construction parameters.
87+
*
88+
* @param maxConn the maximum number of connections to a node in the HNSW graph
89+
* @param beamWidth the size of the queue maintained during graph construction.
90+
* @param useDirectIO whether direct IO should be used to access raw vectors
91+
*/
92+
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO) {
93+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, useDirectIO, null);
8394
}
8495

8596
/**
@@ -92,7 +103,13 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
92103
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
93104
* generated by this format to do the merge
94105
*/
95-
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
106+
public ES818HnswBinaryQuantizedVectorsFormat(
107+
int maxConn,
108+
int beamWidth,
109+
int numMergeWorkers,
110+
boolean useDirectIO,
111+
ExecutorService mergeExec
112+
) {
96113
super(NAME);
97114
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
98115
throw new IllegalArgumentException(
@@ -110,6 +127,9 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num
110127
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
111128
}
112129
this.numMergeWorkers = numMergeWorkers;
130+
131+
flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat(useDirectIO);
132+
113133
if (mergeExec != null) {
114134
this.mergeExec = new TaskExecutor(mergeExec);
115135
} else {

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo
387387
return new BBQHnswIndexOptions(
388388
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
389389
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
390-
new RescoreVector(DEFAULT_OVERSAMPLE)
390+
null
391391
);
392392
} else if (defaultInt8Hnsw) {
393393
return new Int8HnswIndexOptions(
@@ -1632,9 +1632,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16321632
RescoreVector rescoreVector = null;
16331633
if (hasRescoreIndexVersion(indexVersion)) {
16341634
rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1635-
if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) {
1636-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1637-
}
16381635
}
16391636
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
16401637
return new BBQHnswIndexOptions(m, efConstruction, rescoreVector);
@@ -1656,9 +1653,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16561653
RescoreVector rescoreVector = null;
16571654
if (hasRescoreIndexVersion(indexVersion)) {
16581655
rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1659-
if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) {
1660-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1661-
}
16621656
}
16631657
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
16641658
return new BBQFlatIndexOptions(rescoreVector);
@@ -1693,9 +1687,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16931687
}
16941688
}
16951689
RescoreVector rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1696-
if (rescoreVector == null) {
1697-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1698-
}
16991690
Object nProbeNode = indexOptionsMap.remove("default_n_probe");
17001691
int nProbe = -1;
17011692
if (nProbeNode != null) {
@@ -2183,7 +2174,8 @@ public BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVecto
21832174
@Override
21842175
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
21852176
assert elementType == ElementType.FLOAT;
2186-
return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction);
2177+
boolean directIO = rescoreVector != null && rescoreVector.useDirectIO != null && rescoreVector.useDirectIO;
2178+
return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction, directIO);
21872179
}
21882180

21892181
@Override
@@ -2342,36 +2334,46 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
23422334
}
23432335
}
23442336

2345-
public record RescoreVector(float oversample) implements ToXContentObject {
2337+
public record RescoreVector(Float oversample, Boolean useDirectIO) implements ToXContentObject {
23462338
static final String NAME = "rescore_vector";
23472339
static final String OVERSAMPLE = "oversample";
2340+
static final String DIRECT_IO = "direct_io";
23482341

23492342
static RescoreVector fromIndexOptions(Map<String, ?> indexOptionsMap, IndexVersion indexVersion) {
23502343
Object rescoreVectorNode = indexOptionsMap.remove(NAME);
23512344
if (rescoreVectorNode == null) {
23522345
return null;
23532346
}
23542347
Map<String, Object> mappedNode = XContentMapValues.nodeMapValue(rescoreVectorNode, NAME);
2348+
2349+
Float oversampleValue = null;
23552350
Object oversampleNode = mappedNode.get(OVERSAMPLE);
2356-
if (oversampleNode == null) {
2357-
throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE);
2358-
}
2359-
float oversampleValue = (float) XContentMapValues.nodeDoubleValue(oversampleNode);
2360-
if (oversampleValue == 0 && allowsZeroRescore(indexVersion) == false) {
2361-
throw new IllegalArgumentException("oversample must be greater than 1");
2362-
}
2363-
if (oversampleValue < 1 && oversampleValue != 0) {
2364-
throw new IllegalArgumentException("oversample must be greater than 1 or exactly 0");
2365-
} else if (oversampleValue > 10) {
2366-
throw new IllegalArgumentException("oversample must be less than or equal to 10");
2351+
if (oversampleNode != null) {
2352+
oversampleValue = (float) XContentMapValues.nodeDoubleValue(oversampleNode);
2353+
if (oversampleValue == 0 && allowsZeroRescore(indexVersion) == false) {
2354+
throw new IllegalArgumentException("oversample must be greater than 1");
2355+
}
2356+
if (oversampleValue < 1 && oversampleValue != 0) {
2357+
throw new IllegalArgumentException("oversample must be greater than 1 or exactly 0");
2358+
} else if (oversampleValue > 10) {
2359+
throw new IllegalArgumentException("oversample must be less than or equal to 10");
2360+
}
23672361
}
2368-
return new RescoreVector(oversampleValue);
2362+
2363+
Boolean directIO = (Boolean) mappedNode.get(DIRECT_IO);
2364+
2365+
return new RescoreVector(oversampleValue, directIO);
23692366
}
23702367

23712368
@Override
23722369
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
23732370
builder.startObject(NAME);
2374-
builder.field(OVERSAMPLE, oversample);
2371+
if (oversample != null) {
2372+
builder.field(OVERSAMPLE, oversample);
2373+
}
2374+
if (useDirectIO != null) {
2375+
builder.field(DIRECT_IO, useDirectIO);
2376+
}
23752377
builder.endObject();
23762378
return builder;
23772379
}
@@ -2710,6 +2712,10 @@ && isNotUnitVector(squaredMagnitude)) {
27102712
&& quantizedIndexOptions.rescoreVector != null) {
27112713
oversample = quantizedIndexOptions.rescoreVector.oversample;
27122714
}
2715+
if (oversample == null) {
2716+
oversample = DEFAULT_OVERSAMPLE;
2717+
}
2718+
27132719
boolean rescore = needsRescore(oversample);
27142720
if (rescore) {
27152721
// Will get k * oversample for rescoring, and get the top k

server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
3939
import org.apache.lucene.index.Term;
4040
import org.apache.lucene.index.VectorSimilarityFunction;
41-
import org.apache.lucene.misc.store.DirectIODirectory;
4241
import org.apache.lucene.search.FieldExistsQuery;
4342
import org.apache.lucene.search.IndexSearcher;
4443
import org.apache.lucene.search.KnnFloatVectorQuery;
@@ -52,32 +51,19 @@
5251
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
5352
import org.apache.lucene.search.join.QueryBitSetProducer;
5453
import org.apache.lucene.store.Directory;
55-
import org.apache.lucene.store.FSDirectory;
56-
import org.apache.lucene.store.IOContext;
57-
import org.apache.lucene.store.IndexOutput;
5854
import org.apache.lucene.store.MMapDirectory;
5955
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
6056
import org.apache.lucene.tests.store.MockDirectoryWrapper;
6157
import org.apache.lucene.tests.util.TestUtil;
6258
import org.elasticsearch.common.logging.LogConfigurator;
63-
import org.elasticsearch.common.settings.Settings;
64-
import org.elasticsearch.index.IndexModule;
65-
import org.elasticsearch.index.IndexSettings;
6659
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
6760
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
68-
import org.elasticsearch.index.shard.ShardId;
69-
import org.elasticsearch.index.shard.ShardPath;
70-
import org.elasticsearch.index.store.FsDirectoryFactory;
71-
import org.elasticsearch.test.IndexSettingsModule;
7261

7362
import java.io.IOException;
74-
import java.nio.file.Files;
75-
import java.nio.file.Path;
7663
import java.util.ArrayList;
7764
import java.util.Arrays;
7865
import java.util.List;
7966
import java.util.Locale;
80-
import java.util.OptionalLong;
8167

8268
import static java.lang.String.format;
8369
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
@@ -268,14 +254,6 @@ public void testSimpleOffHeapSize() throws IOException {
268254
}
269255
}
270256

271-
public void testSimpleOffHeapSizeFSDir() throws IOException {
272-
checkDirectIOSupported();
273-
var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO
274-
try (Directory dir = newFSDirectory()) {
275-
testSimpleOffHeapSizeImpl(dir, config, false);
276-
}
277-
}
278-
279257
public void testSimpleOffHeapSizeMMapDir() throws IOException {
280258
try (Directory dir = newMMapDirectory()) {
281259
testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true);
@@ -315,39 +293,4 @@ static Directory newMMapDirectory() throws IOException {
315293
}
316294
return dir;
317295
}
318-
319-
private Directory newFSDirectory() throws IOException {
320-
Settings settings = Settings.builder()
321-
.put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT))
322-
.build();
323-
IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings);
324-
Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0");
325-
Files.createDirectories(tempDir);
326-
ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0));
327-
Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path);
328-
if (random().nextBoolean()) {
329-
dir = new MockDirectoryWrapper(random(), dir);
330-
}
331-
return dir;
332-
}
333-
334-
static void checkDirectIOSupported() {
335-
assumeTrue("Direct IO is not enabled", ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO);
336-
337-
Path path = createTempDir("directIOProbe");
338-
try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) {
339-
out.writeString("test");
340-
} catch (IOException e) {
341-
assumeNoException("test requires a filesystem that supports Direct IO", e);
342-
}
343-
}
344-
345-
static DirectIODirectory open(Path path) throws IOException {
346-
return new DirectIODirectory(FSDirectory.open(path)) {
347-
@Override
348-
protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) {
349-
return true;
350-
}
351-
};
352-
}
353296
}

0 commit comments

Comments
 (0)