Skip to content

[8.19] [Inference API] Add "rerank" task type to "elastic" provider (#126022) #129196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
Expand Down Expand Up @@ -166,7 +167,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addAwsNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
addElasticNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
Expand Down Expand Up @@ -742,20 +743,32 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
// Sparse Text Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);

// Completion
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceCompletionServiceSettings.NAME,
ElasticInferenceServiceCompletionServiceSettings::new
)
);

// Rerank
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceRerankServiceSettings.NAME,
ElasticInferenceServiceRerankServiceSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {

private final String query;
private final List<String> documents;
private final Integer topN;
private final TraceContextHandler traceContextHandler;
private final ElasticInferenceServiceRerankModel model;

public ElasticInferenceServiceRerankRequest(
String query,
List<String> documents,
Integer topN,
ElasticInferenceServiceRerankModel model,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata
) {
super(metadata);
this.query = query;
this.documents = documents;
this.topN = topN;
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(getURI());
var requestEntity = Strings.toString(
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return httpPost;
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// no truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// no truncation
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record ElasticInferenceServiceRerankRequestEntity(
String query,
List<String> documents,
String modelId,
@Nullable Integer topNDocumentsOnly
) implements ToXContentObject {

private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
private static final String DOCUMENTS_FIELD = "documents";

public ElasticInferenceServiceRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(QUERY_FIELD, query);

builder.field(MODEL_FIELD, modelId);

if (Objects.nonNull(topNDocumentsOnly)) {
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
}

builder.startArray(DOCUMENTS_FIELD);
for (String document : documents) {
builder.value(document);
}

builder.endArray();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceRerankResponseEntity {

record RerankResult(List<RerankResultEntry> entries) {

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
}

record RerankResultEntry(Integer index, Float relevanceScore) {

public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
}

public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
}
}
}

public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);

return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
}
}

private ElasticInferenceServiceRerankResponseEntity() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

Expand All @@ -77,7 +78,11 @@ public class ElasticInferenceService extends SenderService {
public static final String NAME = "elastic";
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";

private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.RERANK
);
private static final String SERVICE_NAME = "Elastic";

// rainbow-sprinkles
Expand All @@ -91,7 +96,7 @@ public class ElasticInferenceService extends SenderService {
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK);

public static String defaultEndpointId(String modelId) {
return Strings.format(".%s-elastic", modelId);
Expand Down Expand Up @@ -161,6 +166,18 @@ public void onNodeStarted() {
authorizationHandler.init();
}

@Override
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
if (returnDocuments != null) {
validationException.addValidationError(
org.elasticsearch.core.Strings.format(
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
returnDocuments
)
);
}
}

/**
* Only use this in tests.
*
Expand Down Expand Up @@ -333,7 +350,7 @@ private static ElasticInferenceServiceModel createModel(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings,
ElasticInferenceServiceComponents eisServiceComponents,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
String failureMessage,
ConfigurationParseContext context
) {
Expand All @@ -345,7 +362,7 @@ private static ElasticInferenceServiceModel createModel(
serviceSettings,
taskSettings,
secretSettings,
eisServiceComponents,
elasticInferenceServiceComponents,
context
);
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
Expand All @@ -355,7 +372,17 @@ private static ElasticInferenceServiceModel createModel(
serviceSettings,
taskSettings,
secretSettings,
eisServiceComponents,
elasticInferenceServiceComponents,
context
);
case RERANK -> new ElasticInferenceServiceRerankModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
elasticInferenceServiceComponents,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
Expand Down Expand Up @@ -451,9 +478,8 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC

configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"The name of the model to use for the inference task."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK))
.setDescription("The name of the model to use for the inference task.")
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
Expand All @@ -476,7 +502,9 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC
);

configurationMap.putAll(
RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))
RateLimitSettings.toSettingsConfiguration(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)
)
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
Expand Down
Loading