Skip to content

Commit bac3e16

Browse files
authored
[8.19] [Inference API] Add "rerank" task type to "elastic" provider (#126022) #129196
1 parent 18adc3f commit bac3e16

20 files changed

+1232
-28
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ static TransportVersion def(int id) {
238238
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
239239
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
240240
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
241+
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
242+
241243
/*
242244
* STOP! READ THIS FIRST! No, really,
243245
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
7272
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
7373
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
74+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
7475
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
7576
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
7677
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
@@ -166,7 +167,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
166167
addAnthropicNamedWritables(namedWriteables);
167168
addAmazonBedrockNamedWriteables(namedWriteables);
168169
addAwsNamedWriteables(namedWriteables);
169-
addEisNamedWriteables(namedWriteables);
170+
addElasticNamedWriteables(namedWriteables);
170171
addAlibabaCloudSearchNamedWriteables(namedWriteables);
171172
addJinaAINamedWriteables(namedWriteables);
172173
addVoyageAINamedWriteables(namedWriteables);
@@ -742,20 +743,32 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
742743
);
743744
}
744745

745-
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
746+
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
747+
// Sparse Text Embeddings
746748
namedWriteables.add(
747749
new NamedWriteableRegistry.Entry(
748750
ServiceSettings.class,
749751
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
750752
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
751753
)
752754
);
755+
756+
// Completion
753757
namedWriteables.add(
754758
new NamedWriteableRegistry.Entry(
755759
ServiceSettings.class,
756760
ElasticInferenceServiceCompletionServiceSettings.NAME,
757761
ElasticInferenceServiceCompletionServiceSettings::new
758762
)
759763
);
764+
765+
// Rerank
766+
namedWriteables.add(
767+
new NamedWriteableRegistry.Entry(
768+
ServiceSettings.class,
769+
ElasticInferenceServiceRerankServiceSettings.NAME,
770+
ElasticInferenceServiceRerankServiceSettings::new
771+
)
772+
);
760773
}
761774
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.apache.http.message.BasicHeader;
15+
import org.elasticsearch.common.Strings;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
19+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
20+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
22+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
23+
24+
import java.net.URI;
25+
import java.nio.charset.StandardCharsets;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {
30+
31+
private final String query;
32+
private final List<String> documents;
33+
private final Integer topN;
34+
private final TraceContextHandler traceContextHandler;
35+
private final ElasticInferenceServiceRerankModel model;
36+
37+
public ElasticInferenceServiceRerankRequest(
38+
String query,
39+
List<String> documents,
40+
Integer topN,
41+
ElasticInferenceServiceRerankModel model,
42+
TraceContext traceContext,
43+
ElasticInferenceServiceRequestMetadata metadata
44+
) {
45+
super(metadata);
46+
this.query = query;
47+
this.documents = documents;
48+
this.topN = topN;
49+
this.model = Objects.requireNonNull(model);
50+
this.traceContextHandler = new TraceContextHandler(traceContext);
51+
}
52+
53+
@Override
54+
public HttpRequestBase createHttpRequestBase() {
55+
var httpPost = new HttpPost(getURI());
56+
var requestEntity = Strings.toString(
57+
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
58+
);
59+
60+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
61+
httpPost.setEntity(byteEntity);
62+
63+
traceContextHandler.propagateTraceContext(httpPost);
64+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
65+
66+
return httpPost;
67+
}
68+
69+
public TraceContext getTraceContext() {
70+
return traceContextHandler.traceContext();
71+
}
72+
73+
@Override
74+
public String getInferenceEntityId() {
75+
return model.getInferenceEntityId();
76+
}
77+
78+
@Override
79+
public URI getURI() {
80+
return model.uri();
81+
}
82+
83+
@Override
84+
public Request truncate() {
85+
// no truncation
86+
return this;
87+
}
88+
89+
@Override
90+
public boolean[] getTruncationInfo() {
91+
// no truncation
92+
return null;
93+
}
94+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record ElasticInferenceServiceRerankRequestEntity(
19+
String query,
20+
List<String> documents,
21+
String modelId,
22+
@Nullable Integer topNDocumentsOnly
23+
) implements ToXContentObject {
24+
25+
private static final String QUERY_FIELD = "query";
26+
private static final String MODEL_FIELD = "model";
27+
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
28+
private static final String DOCUMENTS_FIELD = "documents";
29+
30+
public ElasticInferenceServiceRerankRequestEntity {
31+
Objects.requireNonNull(query);
32+
Objects.requireNonNull(documents);
33+
Objects.requireNonNull(modelId);
34+
}
35+
36+
@Override
37+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
38+
builder.startObject();
39+
40+
builder.field(QUERY_FIELD, query);
41+
42+
builder.field(MODEL_FIELD, modelId);
43+
44+
if (Objects.nonNull(topNDocumentsOnly)) {
45+
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
46+
}
47+
48+
builder.startArray(DOCUMENTS_FIELD);
49+
for (String document : documents) {
50+
builder.value(document);
51+
}
52+
53+
builder.endArray();
54+
55+
builder.endObject();
56+
57+
return builder;
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.elastic;
9+
10+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.ConstructingObjectParser;
13+
import org.elasticsearch.xcontent.ParseField;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
19+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
20+
21+
import java.io.IOException;
22+
import java.util.List;
23+
24+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
25+
26+
public class ElasticInferenceServiceRerankResponseEntity {
27+
28+
record RerankResult(List<RerankResultEntry> entries) {
29+
30+
@SuppressWarnings("unchecked")
31+
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
32+
RerankResult.class.getSimpleName(),
33+
true,
34+
args -> new RerankResult((List<RerankResultEntry>) args[0])
35+
);
36+
37+
static {
38+
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
39+
}
40+
41+
record RerankResultEntry(Integer index, Float relevanceScore) {
42+
43+
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
44+
RerankResultEntry.class.getSimpleName(),
45+
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
46+
);
47+
48+
static {
49+
PARSER.declareInt(constructorArg(), new ParseField("index"));
50+
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
51+
}
52+
53+
public RankedDocsResults.RankedDoc toRankedDoc() {
54+
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
55+
}
56+
}
57+
}
58+
59+
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
60+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
61+
62+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
63+
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
64+
65+
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
66+
}
67+
}
68+
69+
private ElasticInferenceServiceRerankResponseEntity() {}
70+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
5252
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5353
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
54+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
5455
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
5556
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
5657

@@ -77,7 +78,11 @@ public class ElasticInferenceService extends SenderService {
7778
public static final String NAME = "elastic";
7879
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
7980

80-
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
81+
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
82+
TaskType.SPARSE_EMBEDDING,
83+
TaskType.CHAT_COMPLETION,
84+
TaskType.RERANK
85+
);
8186
private static final String SERVICE_NAME = "Elastic";
8287

8388
// rainbow-sprinkles
@@ -91,7 +96,7 @@ public class ElasticInferenceService extends SenderService {
9196
/**
9297
* The task types that the {@link InferenceAction.Request} can accept.
9398
*/
94-
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
99+
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK);
95100

96101
public static String defaultEndpointId(String modelId) {
97102
return Strings.format(".%s-elastic", modelId);
@@ -161,6 +166,18 @@ public void onNodeStarted() {
161166
authorizationHandler.init();
162167
}
163168

169+
@Override
170+
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
171+
if (returnDocuments != null) {
172+
validationException.addValidationError(
173+
org.elasticsearch.core.Strings.format(
174+
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
175+
returnDocuments
176+
)
177+
);
178+
}
179+
}
180+
164181
/**
165182
* Only use this in tests.
166183
*
@@ -333,7 +350,7 @@ private static ElasticInferenceServiceModel createModel(
333350
Map<String, Object> serviceSettings,
334351
Map<String, Object> taskSettings,
335352
@Nullable Map<String, Object> secretSettings,
336-
ElasticInferenceServiceComponents eisServiceComponents,
353+
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
337354
String failureMessage,
338355
ConfigurationParseContext context
339356
) {
@@ -345,7 +362,7 @@ private static ElasticInferenceServiceModel createModel(
345362
serviceSettings,
346363
taskSettings,
347364
secretSettings,
348-
eisServiceComponents,
365+
elasticInferenceServiceComponents,
349366
context
350367
);
351368
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
@@ -355,7 +372,17 @@ private static ElasticInferenceServiceModel createModel(
355372
serviceSettings,
356373
taskSettings,
357374
secretSettings,
358-
eisServiceComponents,
375+
elasticInferenceServiceComponents,
376+
context
377+
);
378+
case RERANK -> new ElasticInferenceServiceRerankModel(
379+
inferenceEntityId,
380+
taskType,
381+
NAME,
382+
serviceSettings,
383+
taskSettings,
384+
secretSettings,
385+
elasticInferenceServiceComponents,
359386
context
360387
);
361388
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
@@ -451,9 +478,8 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC
451478

452479
configurationMap.put(
453480
MODEL_ID,
454-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
455-
"The name of the model to use for the inference task."
456-
)
481+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK))
482+
.setDescription("The name of the model to use for the inference task.")
457483
.setLabel("Model ID")
458484
.setRequired(true)
459485
.setSensitive(false)
@@ -476,7 +502,9 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC
476502
);
477503

478504
configurationMap.putAll(
479-
RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))
505+
RateLimitSettings.toSettingsConfiguration(
506+
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)
507+
)
480508
);
481509

482510
return new InferenceServiceConfiguration.Builder().setService(NAME)

0 commit comments

Comments
 (0)