Skip to content

Commit 9106ee4

Browse files
authored
[Inference API] Add "rerank" task type to "elastic" provider (#126022)
1 parent 3d085b0 commit 9106ee4

20 files changed

+1232
-28
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ static TransportVersion def(int id) {
193193
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
194194
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
195195
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
196+
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
196197
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
197198
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
198199
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -290,6 +291,7 @@ static TransportVersion def(int id) {
290291
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST = def(9_091_0_00);
291292
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM = def(9_092_0_00);
292293
public static final TransportVersion SNAPSHOT_INDEX_SHARD_STATUS_MISSING_STATS = def(9_093_0_00);
294+
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_094_0_00);
293295

294296
/*
295297
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
7171
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
7272
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
73+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
7374
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
7475
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
7576
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
@@ -166,7 +167,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
166167
addAnthropicNamedWritables(namedWriteables);
167168
addAmazonBedrockNamedWriteables(namedWriteables);
168169
addAwsNamedWriteables(namedWriteables);
169-
addEisNamedWriteables(namedWriteables);
170+
addElasticNamedWriteables(namedWriteables);
170171
addAlibabaCloudSearchNamedWriteables(namedWriteables);
171172
addJinaAINamedWriteables(namedWriteables);
172173
addVoyageAINamedWriteables(namedWriteables);
@@ -742,20 +743,32 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
742743
);
743744
}
744745

745-
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
746+
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
747+
// Sparse Text Embeddings
746748
namedWriteables.add(
747749
new NamedWriteableRegistry.Entry(
748750
ServiceSettings.class,
749751
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
750752
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
751753
)
752754
);
755+
756+
// Completion
753757
namedWriteables.add(
754758
new NamedWriteableRegistry.Entry(
755759
ServiceSettings.class,
756760
ElasticInferenceServiceCompletionServiceSettings.NAME,
757761
ElasticInferenceServiceCompletionServiceSettings::new
758762
)
759763
);
764+
765+
// Rerank
766+
namedWriteables.add(
767+
new NamedWriteableRegistry.Entry(
768+
ServiceSettings.class,
769+
ElasticInferenceServiceRerankServiceSettings.NAME,
770+
ElasticInferenceServiceRerankServiceSettings::new
771+
)
772+
);
760773
}
761774
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.apache.http.message.BasicHeader;
15+
import org.elasticsearch.common.Strings;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
19+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
20+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
22+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
23+
24+
import java.net.URI;
25+
import java.nio.charset.StandardCharsets;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {
30+
31+
private final String query;
32+
private final List<String> documents;
33+
private final Integer topN;
34+
private final TraceContextHandler traceContextHandler;
35+
private final ElasticInferenceServiceRerankModel model;
36+
37+
public ElasticInferenceServiceRerankRequest(
38+
String query,
39+
List<String> documents,
40+
Integer topN,
41+
ElasticInferenceServiceRerankModel model,
42+
TraceContext traceContext,
43+
ElasticInferenceServiceRequestMetadata metadata
44+
) {
45+
super(metadata);
46+
this.query = query;
47+
this.documents = documents;
48+
this.topN = topN;
49+
this.model = Objects.requireNonNull(model);
50+
this.traceContextHandler = new TraceContextHandler(traceContext);
51+
}
52+
53+
@Override
54+
public HttpRequestBase createHttpRequestBase() {
55+
var httpPost = new HttpPost(getURI());
56+
var requestEntity = Strings.toString(
57+
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
58+
);
59+
60+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
61+
httpPost.setEntity(byteEntity);
62+
63+
traceContextHandler.propagateTraceContext(httpPost);
64+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
65+
66+
return httpPost;
67+
}
68+
69+
public TraceContext getTraceContext() {
70+
return traceContextHandler.traceContext();
71+
}
72+
73+
@Override
74+
public String getInferenceEntityId() {
75+
return model.getInferenceEntityId();
76+
}
77+
78+
@Override
79+
public URI getURI() {
80+
return model.uri();
81+
}
82+
83+
@Override
84+
public Request truncate() {
85+
// no truncation
86+
return this;
87+
}
88+
89+
@Override
90+
public boolean[] getTruncationInfo() {
91+
// no truncation
92+
return null;
93+
}
94+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record ElasticInferenceServiceRerankRequestEntity(
19+
String query,
20+
List<String> documents,
21+
String modelId,
22+
@Nullable Integer topNDocumentsOnly
23+
) implements ToXContentObject {
24+
25+
private static final String QUERY_FIELD = "query";
26+
private static final String MODEL_FIELD = "model";
27+
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
28+
private static final String DOCUMENTS_FIELD = "documents";
29+
30+
public ElasticInferenceServiceRerankRequestEntity {
31+
Objects.requireNonNull(query);
32+
Objects.requireNonNull(documents);
33+
Objects.requireNonNull(modelId);
34+
}
35+
36+
@Override
37+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
38+
builder.startObject();
39+
40+
builder.field(QUERY_FIELD, query);
41+
42+
builder.field(MODEL_FIELD, modelId);
43+
44+
if (Objects.nonNull(topNDocumentsOnly)) {
45+
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
46+
}
47+
48+
builder.startArray(DOCUMENTS_FIELD);
49+
for (String document : documents) {
50+
builder.value(document);
51+
}
52+
53+
builder.endArray();
54+
55+
builder.endObject();
56+
57+
return builder;
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.elastic;
9+
10+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.ConstructingObjectParser;
13+
import org.elasticsearch.xcontent.ParseField;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
19+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
20+
21+
import java.io.IOException;
22+
import java.util.List;
23+
24+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
25+
26+
public class ElasticInferenceServiceRerankResponseEntity {
27+
28+
record RerankResult(List<RerankResultEntry> entries) {
29+
30+
@SuppressWarnings("unchecked")
31+
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
32+
RerankResult.class.getSimpleName(),
33+
true,
34+
args -> new RerankResult((List<RerankResultEntry>) args[0])
35+
);
36+
37+
static {
38+
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
39+
}
40+
41+
record RerankResultEntry(Integer index, Float relevanceScore) {
42+
43+
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
44+
RerankResultEntry.class.getSimpleName(),
45+
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
46+
);
47+
48+
static {
49+
PARSER.declareInt(constructorArg(), new ParseField("index"));
50+
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
51+
}
52+
53+
public RankedDocsResults.RankedDoc toRankedDoc() {
54+
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
55+
}
56+
}
57+
}
58+
59+
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
60+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
61+
62+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
63+
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
64+
65+
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
66+
}
67+
}
68+
69+
private ElasticInferenceServiceRerankResponseEntity() {}
70+
}

0 commit comments

Comments
 (0)