Skip to content

[ML] Remove Voyageai request manager classes #124512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
Expand All @@ -26,6 +33,15 @@
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
*/
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
public static final ResponseHandler EMBEDDINGS_HANDLER = new VoyageAIResponseHandler(
"voyageai text embedding",
VoyageAIEmbeddingsResponseEntity::fromResponse
);
static final ResponseHandler RERANK_HANDLER = new VoyageAIResponseHandler(
"voyageai rerank",
(request, response) -> VoyageAIRerankResponseEntity.fromResponse(response)
);

private final Sender sender;
private final ServiceComponents serviceComponents;

Expand All @@ -37,16 +53,30 @@ public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents)
@Override
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
overriddenModel,
EMBEDDINGS_HANDLER,
(documentsOnlyInput) -> new VoyageAIEmbeddingsRequest(documentsOnlyInput.getInputs(), overriddenModel),
DocumentsOnlyInput.class
);

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
overriddenModel,
RERANK_HANDLER,
(rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
QueryAndDocsInputs.class
);

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI rerank");
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
}
}

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,9 @@ public static URI buildUri(URI accountUri, String service, CheckedSupplier<URI,
}
}

public static URI buildUri(String service, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
return buildUri(null, service, uriBuilder);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a helper that converts a URI exception to an ElasticsearchStatusException.

}

private RequestUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;

import java.net.URI;
import java.nio.charset.StandardCharsets;
Expand All @@ -24,47 +22,43 @@

public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {

private final VoyageAIAccount account;
private final List<String> input;
private final VoyageAIEmbeddingsServiceSettings serviceSettings;
private final VoyageAIEmbeddingsTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
private final VoyageAIEmbeddingsModel embeddingsModel;

public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
Objects.requireNonNull(embeddingsModel);

account = VoyageAIAccount.of(embeddingsModel);
this.embeddingsModel = Objects.requireNonNull(embeddingsModel);
this.input = Objects.requireNonNull(input);
serviceSettings = embeddingsModel.getServiceSettings();
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());
HttpPost httpPost = new HttpPost(embeddingsModel.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
.getBytes(StandardCharsets.UTF_8)
Strings.toString(
new VoyageAIEmbeddingsRequestEntity(
input,
embeddingsModel.getServiceSettings(),
embeddingsModel.getTaskSettings(),
embeddingsModel.getServiceSettings().modelId()
)
).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

decorateWithHeaders(httpPost, account);
decorateWithHeaders(httpPost, embeddingsModel);

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return inferenceEntityId;
return embeddingsModel.getInferenceEntityId();
}

@Override
public URI getURI() {
return account.uri();
return embeddingsModel.uri();
}

@Override
Expand All @@ -77,11 +71,7 @@ public boolean[] getTruncationInfo() {
return null;
}

public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

return taskSettings;
}

public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
return serviceSettings;
return embeddingsModel.getServiceSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public abstract class VoyageAIRequest implements Request {

public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) {
public static void decorateWithHeaders(HttpPost request, VoyageAIModel model) {
request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
request.setHeader(createAuthBearerHeader(account.apiKey()));
request.setHeader(createAuthBearerHeader(model.apiKey()));
request.setHeader(VoyageAIUtils.createRequestSourceHeader());
}

Expand Down
Loading