diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5820843c64be8..4f4570ddcbb8c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -238,6 +238,8 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45); public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46); public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47); + public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48); + /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 53b527b6a1c17..57b3a777e8085 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -71,6 +71,7 @@ import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings; @@ -166,7 +167,7 @@ public static List getNamedWriteables() { addAnthropicNamedWritables(namedWriteables); addAmazonBedrockNamedWriteables(namedWriteables); addAwsNamedWriteables(namedWriteables); - addEisNamedWriteables(namedWriteables); + addElasticNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); @@ -742,7 +743,8 @@ private static void addVoyageAINamedWriteables(List namedWriteables) { + private static void addElasticNamedWriteables(List namedWriteables) { + // Sparse Text Embeddings namedWriteables.add( new NamedWriteableRegistry.Entry( ServiceSettings.class, @@ -750,6 +752,8 @@ private static void addEisNamedWriteables(List nam ElasticInferenceServiceSparseEmbeddingsServiceSettings::new ) ); + + // Completion namedWriteables.add( new NamedWriteableRegistry.Entry( ServiceSettings.class, @@ -757,5 +761,14 @@ private static void addEisNamedWriteables(List nam ElasticInferenceServiceCompletionServiceSettings::new ) ); + + // Rerank + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + ElasticInferenceServiceRerankServiceSettings.NAME, + ElasticInferenceServiceRerankServiceSettings::new + ) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java new file mode 100644 index 0000000000000..08b3fd2384642 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest { + + private final String query; + private final List documents; + private final Integer topN; + private final TraceContextHandler traceContextHandler; + private final ElasticInferenceServiceRerankModel model; + + public ElasticInferenceServiceRerankRequest( + String query, + List documents, + Integer topN, + ElasticInferenceServiceRerankModel model, + TraceContext traceContext, + ElasticInferenceServiceRequestMetadata metadata + ) { + super(metadata); + this.query = query; + this.documents = documents; + this.topN = topN; + this.model = Objects.requireNonNull(model); + this.traceContextHandler = new TraceContextHandler(traceContext); + } + + @Override + public HttpRequestBase createHttpRequestBase() { + var httpPost = new HttpPost(getURI()); + var requestEntity = Strings.toString( + new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + traceContextHandler.propagateTraceContext(httpPost); + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return httpPost; + } + + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // no truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // no truncation + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java new file mode 100644 index 0000000000000..b542af93047fa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record ElasticInferenceServiceRerankRequestEntity( + String query, + List documents, + String modelId, + @Nullable Integer topNDocumentsOnly +) implements ToXContentObject { + + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; + private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n"; + private static final String DOCUMENTS_FIELD = "documents"; + + public ElasticInferenceServiceRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(QUERY_FIELD, query); + + builder.field(MODEL_FIELD, modelId); + + if (Objects.nonNull(topNDocumentsOnly)) { + builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly); + } + + builder.startArray(DOCUMENTS_FIELD); + for (String document : documents) { + builder.value(document); + } + + builder.endArray(); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntity.java new file mode 100644 index 0000000000000..b226e82ae7d91 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class ElasticInferenceServiceRerankResponseEntity { + + record RerankResult(List entries) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResult.class.getSimpleName(), + true, + args -> new RerankResult((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results")); + } + + record RerankResultEntry(Integer index, Float relevanceScore) { + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResultEntry.class.getSimpleName(), + args -> new RerankResultEntry((Integer) args[0], (Float) args[1]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); + } + + public RankedDocsResults.RankedDoc toRankedDoc() { + return new RankedDocsResults.RankedDoc(index, relevanceScore, null); + } + } + } + + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + var rerankResult = RerankResult.PARSER.apply(jsonParser, null); + + return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList()); + } + } + + private ElasticInferenceServiceRerankResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 2910048cbd0a4..280ae2756d62d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -51,6 +51,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -77,7 +78,11 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); + private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.CHAT_COMPLETION, + TaskType.RERANK + ); private static final String SERVICE_NAME = "Elastic"; // rainbow-sprinkles @@ -91,7 +96,7 @@ public class ElasticInferenceService extends SenderService { /** * The task types that the {@link InferenceAction.Request} can accept. */ - private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK); public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); @@ -161,6 +166,18 @@ public void onNodeStarted() { authorizationHandler.init(); } + @Override + protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { + if (returnDocuments != null) { + validationException.addValidationError( + org.elasticsearch.core.Strings.format( + "Invalid return_documents [%s]. The return_documents option is not supported by this service", + returnDocuments + ) + ); + } + } + /** * Only use this in tests. * @@ -333,7 +350,7 @@ private static ElasticInferenceServiceModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - ElasticInferenceServiceComponents eisServiceComponents, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, String failureMessage, ConfigurationParseContext context ) { @@ -345,7 +362,7 @@ private static ElasticInferenceServiceModel createModel( serviceSettings, taskSettings, secretSettings, - eisServiceComponents, + elasticInferenceServiceComponents, context ); case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel( @@ -355,7 +372,17 @@ private static ElasticInferenceServiceModel createModel( serviceSettings, taskSettings, secretSettings, - eisServiceComponents, + elasticInferenceServiceComponents, + context + ); + case RERANK -> new ElasticInferenceServiceRerankModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + elasticInferenceServiceComponents, context ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); @@ -451,9 +478,8 @@ private LazyInitializable initC configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( - "The name of the model to use for the inference task." - ) + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)) + .setDescription("The name of the model to use for the inference task.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -476,7 +502,9 @@ private LazyInitializable initC ); configurationMap.putAll( - RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)) + RateLimitSettings.toSettingsConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK) + ) ); return new InferenceServiceConfiguration.Builder().setService(NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index e03cc36e62417..34a8086119150 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -7,14 +7,15 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class ElasticInferenceServiceModel extends Model { +public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel { private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; @@ -35,12 +36,18 @@ public ElasticInferenceServiceModel( public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); - this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.rateLimitServiceSettings = model.rateLimitServiceSettings; this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents(); } - public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() { - return rateLimitServiceSettings; + @Override + public int rateLimitGroupingHash() { + // We only have one model for rerank + return Objects.hash(this.getServiceSettings().modelId()); + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); } public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java index 1c4b6cb340ecc..8d5556e64f3b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java @@ -20,7 +20,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM private final ElasticInferenceServiceRequestMetadata requestMetadata; protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { - super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings()); this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); } @@ -32,7 +32,7 @@ record RateLimitGrouping(int modelIdHash) { public static RateLimitGrouping of(ElasticInferenceServiceModel model) { Objects.requireNonNull(model); - return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + return new RateLimitGrouping(model.rateLimitGroupingHash()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index ecb452d5f4d78..11d3f44f0d4a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -30,6 +30,7 @@ import static org.elasticsearch.xpack.inference.common.Truncator.truncate; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; +// TODO: remove and use GenericRequestManager in ElasticInferenceServiceActionCreator public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends ElasticInferenceServiceRequestManager { private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceSparseEmbeddingsRequestManager.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 7fdcea2d987e6..dca17ef6926cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -7,19 +7,27 @@ package org.elasticsearch.xpack.inference.services.elastic.action; +import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; -import java.util.Locale; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { @@ -29,6 +37,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer private final TraceContext traceContext; + static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler( + "elastic rerank", + (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response) + ); + public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); @@ -39,8 +52,29 @@ public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents ser public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) { var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext); var errorMessage = constructFailedToSendRequestMessage( - String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER) + Strings.format("%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER) + ); + return new SenderExecutableAction(sender, requestManager, errorMessage); + } + + @Override + public ExecutableAction create(ElasticInferenceServiceRerankModel model) { + var threadPool = serviceComponents.threadPool(); + var requestManager = new GenericRequestManager<>( + threadPool, + model, + RERANK_HANDLER, + (rerankInput) -> new ElasticInferenceServiceRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getTopN(), + model, + traceContext, + extractRequestMetadataFromThreadContext(threadPool.getThreadContext()) + ), + QueryAndDocsInputs.class ); + var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)); return new SenderExecutableAction(sender, requestManager, errorMessage); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index a639bfdcbad71..3919c6e2461bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -9,9 +9,12 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); + ExecutableAction create(ElasticInferenceServiceRerankModel model); + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java new file mode 100644 index 0000000000000..7e592406a718a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.rerank; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; +import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class ElasticInferenceServiceRerankModel extends ElasticInferenceServiceExecutableActionModel { + + private final URI uri; + + public ElasticInferenceServiceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceRerankServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); + } + + public ElasticInferenceServiceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceRerankServiceSettings serviceSettings, + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + this.uri = createUri(); + } + + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + @Override + public ElasticInferenceServiceRerankServiceSettings getServiceSettings() { + return (ElasticInferenceServiceRerankServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java new file mode 100644 index 0000000000000..aefce277ee18d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class ElasticInferenceServiceRerankServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_rerank_service_settings"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); + + public static ElasticInferenceServiceRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + return new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public ElasticInferenceServiceRerankServiceSettings(String modelId, RateLimitSettings rateLimitSettings) { + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public ElasticInferenceServiceRerankServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceRerankServiceSettings that = (ElasticInferenceServiceRerankServiceSettings) o; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java new file mode 100644 index 0000000000000..407d3e38b4da1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequestEntity; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class ElasticInferenceServiceRerankRequestEntityTests extends ESTestCase { + + public void testToXContent_SingleDocument_NoTopN() throws IOException { + var entity = new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), "rerank-model-id", null); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "query": "query", + "model": "rerank-model-id", + "documents": ["document 1"] + }""")); + } + + public void testToXContent_MultipleDocuments_NoTopN() throws IOException { + var entity = new ElasticInferenceServiceRerankRequestEntity( + "query", + List.of("document 1", "document 2", "document 3"), + "rerank-model-id", + null + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "query": "query", + "model": "rerank-model-id", + "documents": [ + "document 1", + "document 2", + "document 3" + ] + } + """)); + } + + public void testToXContent_SingleDocument_WithTopN() throws IOException { + var entity = new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), "rerank-model-id", 3); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "query": "query", + "model": "rerank-model-id", + "top_n": 3, + "documents": ["document 1"] + } + """)); + } + + public void testToXContent_MultipleDocuments_WithTopN() throws IOException { + var entity = new ElasticInferenceServiceRerankRequestEntity( + "query", + List.of("document 1", "document 2", "document 3", "document 4", "document 5"), + "rerank-model-id", + 3 + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "query": "query", + "model": "rerank-model-id", + "top_n": 3, + "documents": [ + "document 1", + "document 2", + "document 3", + "document 4", + "document 5" + ] + } + """)); + } + + public void testNullQueryThrowsException() { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new ElasticInferenceServiceRerankRequestEntity(null, List.of("document 1"), "model-id", null) + ); + assertNotNull(e); + } + + public void testNullDocumentsThrowsException() { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new ElasticInferenceServiceRerankRequestEntity("query", null, "model-id", null) + ); + assertNotNull(e); + } + + public void testNullModelIdThrowsException() { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), null, null) + ); + assertNotNull(e); + } + + private String xContentEntityToString(ElasticInferenceServiceRerankRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + return Strings.toString(builder); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java new file mode 100644 index 0000000000000..4e6efed6faa59 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceRerankRequestTests extends ESTestCase { + + public void testTraceContextPropagatedThroughHTTPHeaders() { + var url = "http://eis-gateway.com"; + var query = "query"; + var documents = List.of("document 1", "document 2", "document 3"); + var modelId = "my-model-id"; + var topN = 3; + + var request = createRequest(url, modelId, query, documents, topN); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var traceParent = request.getTraceContext().traceParent(); + var traceState = request.getTraceContext().traceState(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + } + + public void testTruncate_DoesNotTruncate() throws IOException { + var url = "http://eis-gateway.com"; + var query = "query"; + var documents = List.of("document 1", "document 2", "document 3"); + var modelId = "my-model-id"; + var topN = 3; + + var request = createRequest(url, modelId, query, documents, topN); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("documents"), is(documents)); + assertThat(requestMap.get("top_n"), is(topN)); + } + + private ElasticInferenceServiceRerankRequest createRequest( + String url, + String modelId, + String query, + List documents, + Integer topN + ) { + var rerankModel = ElasticInferenceServiceRerankModelTests.createModel(url, modelId); + + return new ElasticInferenceServiceRerankRequest( + query, + documents, + topN, + rerankModel, + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata() + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntityTests.java new file mode 100644 index 0000000000000..2d3b9fb309dbb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntityTests.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class ElasticInferenceServiceRerankResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "results": [ + { + "index": 0, + "relevance_score": 0.94 + } + ] + } + """; + + RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.94F, null)))); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "results": [ + { + "index": 0, + "relevance_score": 0.94 + }, + { + "index": 1, + "relevance_score": 0.78 + }, + { + "index": 2, + "relevance_score": 0.65 + } + ] + } + """; + + RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.getRankedDocs(), + is( + List.of( + new RankedDocsResults.RankedDoc(0, 0.94F, null), + new RankedDocsResults.RankedDoc(1, 0.78F, null), + new RankedDocsResults.RankedDoc(2, 0.65F, null) + ) + ) + ); + } + + public void testFromResponse_HandlesFloatingPointPrecision() throws IOException { + String responseJson = """ + { + "results": [ + { + "index": 0, + "relevance_score": 0.9432156 + } + ] + } + """; + + RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.9432156F, null)))); + } + + public void testFromResponse_OrderIsPreserved() throws IOException { + String responseJson = """ + { + "results": [ + { + "index": 2, + "relevance_score": 0.94 + }, + { + "index": 0, + "relevance_score": 0.78 + }, + { + "index": 1, + "relevance_score": 0.65 + } + ] + } + """; + + RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + // Verify the order is maintained from the response + assertThat( + parsedResults.getRankedDocs(), + is( + List.of( + new RankedDocsResults.RankedDoc(2, 0.94F, null), + new RankedDocsResults.RankedDoc(0, 0.78F, null), + new RankedDocsResults.RankedDoc(1, 0.65F, null) + ) + ) + ); + } + + public void testFromResponse_HandlesEmptyResultsList() throws IOException { + String responseJson = """ + { + "results": [] + } + """; + + RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.getRankedDocs(), is(List.of())); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index c73df7fc90386..d11eb1cbf66e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -58,6 +59,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -149,6 +152,23 @@ public void testParseRequestConfig_CreatesASparseEmbeddingsModel() throws IOExce } } + public void testParseRequestConfig_CreatesARerankModel() throws IOException { + try (var service = createServiceWithMockSender()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(ElasticInferenceServiceRerankModel.class)); + ElasticInferenceServiceRerankModel rerankModel = (ElasticInferenceServiceRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is("my-rerank-model-id")); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.RERANK, + getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, "my-rerank-model-id"), Map.of(), Map.of()), + modelListener + ); + } + } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createServiceWithMockSender()) { var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()); @@ -367,6 +387,39 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException verifyNoMoreInteractions(sender); } + public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + try (var service = createServiceWithMockSender()) { + var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), "my-rerank-model-id"); + PlainActionFuture listener = new PlainActionFuture<>(); + + var thrownException = expectThrows( + ValidationException.class, + () -> service.infer( + model, + "search query", + Boolean.TRUE, + 10, + List.of("doc1", "doc2", "doc3"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + + assertThat( + thrownException.getMessage(), + is("Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this service;") + ); + } + } + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var sender = mock(Sender.class); @@ -395,7 +448,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [text_embedding] " - + "for inference, the task type must be one of [sparse_embedding]." + + "for inference, the task type must be one of [sparse_embedding, rerank]." ) ); @@ -436,7 +489,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [sparse_embedding]. " + + "for inference, the task type must be one of [sparse_embedding, rerank]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -504,6 +557,76 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { } } + @SuppressWarnings("unchecked") + public void testRerank_SendsRerankRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var elasticInferenceServiceURL = getUrl(webServer); + + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + var modelId = "my-model-id"; + var topN = 2; + String responseJson = """ + { + "results": [ + {"index": 0, "relevance_score": 0.95}, + {"index": 1, "relevance_score": 0.85}, + {"index": 2, "relevance_score": 0.75} + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceRerankModelTests.createModel(elasticInferenceServiceURL, modelId); + PlainActionFuture listener = new PlainActionFuture<>(); + + service.infer( + model, + "search query", + null, + topN, + List.of("doc1", "doc2", "doc3"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + var resultMap = result.asMap(); + var rerankResults = (List>) resultMap.get("rerank"); + assertThat(rerankResults.size(), Matchers.is(3)); + + Map rankedDocOne = (Map) rerankResults.get(0).get("ranked_doc"); + Map rankedDocTwo = (Map) rerankResults.get(1).get("ranked_doc"); + Map rankedDocThree = (Map) rerankResults.get(2).get("ranked_doc"); + + assertThat(rankedDocOne.get("index"), equalTo(0)); + assertThat(rankedDocTwo.get("index"), equalTo(1)); + assertThat(rankedDocThree.get("index"), equalTo(2)); + + // Verify the outgoing HTTP request + var request = webServer.requests().get(0); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + + // Verify the outgoing request body + Map requestMap = entityAsMap(request.getBody()); + Map expectedRequestMap = Map.of( + "query", + "search query", + "model", + modelId, + "top_n", + topN, + "documents", + List.of("doc1", "doc2", "doc3") + ); + assertThat(requestMap, is(expectedRequestMap)); + } + } + public void testInfer_PropagatesProductUseCaseHeader() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var elasticInferenceServiceURL = getUrl(webServer); @@ -850,7 +973,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -859,7 +982,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -905,7 +1028,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -914,7 +1037,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -974,7 +1097,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -983,7 +1106,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index fadf4a899e45d..49957800f3a83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -20,12 +20,15 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; import org.junit.Before; @@ -181,6 +184,78 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx } } + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "results": [ + { + "index": 0, + "relevance_score": 0.94 + }, + { + "index": 1, + "relevance_score": 0.21 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var modelId = "my-model-id"; + var topN = 3; + var query = "query"; + var documents = List.of("document 1", "document 2", "document 3"); + + var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(query, documents, null, topN, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + equalTo( + RankedDocsResultsTests.buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.94f)), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.21f)) + ) + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + + assertThat(requestMap.size(), is(4)); + + assertThat(requestMap.get("documents"), instanceOf(List.class)); + List requestDocuments = (List) requestMap.get("documents"); + assertThat(requestDocuments.get(0), equalTo(documents.get(0))); + assertThat(requestDocuments.get(1), equalTo(documents.get(1))); + assertThat(requestDocuments.get(2), equalTo(documents.get(2))); + + assertThat(requestMap.get("top_n"), equalTo(topN)); + + assertThat(requestMap.get("query"), equalTo(query)); + + assertThat(requestMap.get("model"), equalTo(modelId)); + } + } + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java new file mode 100644 index 0000000000000..f5da46915e13c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.rerank; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; + +public class ElasticInferenceServiceRerankModelTests extends ESTestCase { + + public static ElasticInferenceServiceRerankModel createModel(String url, String modelId) { + return new ElasticInferenceServiceRerankModel( + "id", + TaskType.RERANK, + "service", + new ElasticInferenceServiceRerankServiceSettings(modelId, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..8066da9c43683 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.rerank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceRerankServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceRerankServiceSettings> { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceRerankServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceRerankServiceSettings mutateInstance(ElasticInferenceServiceRerankServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, ElasticInferenceServiceRerankServiceSettingsTests::createRandom); + } + + public void testFromMap() { + var modelId = "my-model-id"; + + var serviceSettings = ElasticInferenceServiceRerankServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId, null))); + } + + public void testToXContent_WritesAllFields() throws IOException { + var modelId = ".rerank-v1"; + var rateLimitSettings = new RateLimitSettings(100L); + var serviceSettings = new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + } + + public static ElasticInferenceServiceRerankServiceSettings createRandom() { + return new ElasticInferenceServiceRerankServiceSettings(randomRerankModel(), null); + } + + private static String randomRerankModel() { + return randomFrom(".rerank-v1", ".rerank-v2"); + } +}