Skip to content

Commit 3bce8a9

Browse files
dongxianzheXianzhe Dong
andauthored
feat: add vlm embedding interface. (#398)
* feat: support mm embedding service and vlm embedding model factory. * refactor: share build messages logic between mm chat and mm embedding services. * refactor: rename embedding model factory name. * fix: fix vlm worker input. --------- Co-authored-by: Xianzhe Dong <[email protected]>
1 parent ea8b42e commit 3bce8a9

18 files changed

+380
-55
lines changed

xllm/api_service/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
stream_call.h
1919
models_service_impl.h
2020
stream_output_parser.h
21+
mm_service_utils.h
2122
SRCS
2223
api_service.cpp
2324
call.cpp

xllm/api_service/api_service.cpp

100755100644
Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ APIService::APIService(Master* master,
6464
auto vlm_master = dynamic_cast<VLMMaster*>(master);
6565
mm_chat_service_impl_ =
6666
std::make_unique<MMChatServiceImpl>(vlm_master, model_names);
67+
mm_embedding_service_impl_ =
68+
std::make_unique<MMEmbeddingServiceImpl>(vlm_master, model_names);
6769
} else if (FLAGS_backend == "dit") {
6870
image_generation_service_impl_ =
6971
std::make_unique<ImageGenerationServiceImpl>(
@@ -190,10 +192,13 @@ void APIService::Embeddings(::google::protobuf::RpcController* controller,
190192
// TODO with xllm-service
191193
}
192194

193-
void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
194-
const proto::HttpRequest* request,
195-
proto::HttpResponse* response,
196-
::google::protobuf::Closure* done) {
195+
namespace {
196+
template <typename EmbeddingCall, typename Service>
197+
void handle_embedding_request(std::unique_ptr<Service>& embedding_service_impl_,
198+
::google::protobuf::RpcController* controller,
199+
const proto::HttpRequest* request,
200+
proto::HttpResponse* response,
201+
::google::protobuf::Closure* done) {
197202
xllm::ClosureGuard done_guard(
198203
done,
199204
std::bind(request_in_metric, nullptr),
@@ -202,12 +207,13 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
202207
LOG(ERROR) << "brpc request | respose | controller is null";
203208
return;
204209
}
205-
206210
auto arena = response->GetArena();
207211
auto req_pb =
208-
google::protobuf::Arena::CreateMessage<proto::EmbeddingRequest>(arena);
212+
google::protobuf::Arena::CreateMessage<typename EmbeddingCall::ReqType>(
213+
arena);
209214
auto resp_pb =
210-
google::protobuf::Arena::CreateMessage<proto::EmbeddingResponse>(arena);
215+
google::protobuf::Arena::CreateMessage<typename EmbeddingCall::ResType>(
216+
arena);
211217

212218
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
213219
std::string error;
@@ -230,6 +236,22 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
230236
ctrl, done_guard.release(), req_pb, resp_pb);
231237
embedding_service_impl_->process_async(call);
232238
}
239+
} // namespace
240+
241+
void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
242+
const proto::HttpRequest* request,
243+
proto::HttpResponse* response,
244+
::google::protobuf::Closure* done) {
245+
if (FLAGS_backend == "llm") {
246+
CHECK(embedding_service_impl_) << " embedding service is invalid.";
247+
handle_embedding_request<EmbeddingCall, EmbeddingServiceImpl>(
248+
embedding_service_impl_, controller, request, response, done);
249+
} else if (FLAGS_backend == "vlm") {
250+
CHECK(mm_embedding_service_impl_) << " mm embedding service is invalid.";
251+
handle_embedding_request<MMEmbeddingCall, MMEmbeddingServiceImpl>(
252+
mm_embedding_service_impl_, controller, request, response, done);
253+
}
254+
}
233255

234256
void APIService::ImageGeneration(::google::protobuf::RpcController* controller,
235257
const proto::ImageGenerationRequest* request,

xllm/api_service/api_service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class APIService : public proto::XllmAPIService {
120120
std::unique_ptr<ChatServiceImpl> chat_service_impl_;
121121
std::unique_ptr<MMChatServiceImpl> mm_chat_service_impl_;
122122
std::unique_ptr<EmbeddingServiceImpl> embedding_service_impl_;
123+
std::unique_ptr<MMEmbeddingServiceImpl> mm_embedding_service_impl_;
123124
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
124125
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
125126
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;

xllm/api_service/chat_service_impl.cpp

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "core/runtime/vlm_master.h"
3737
#include "core/util/utils.h"
3838
#include "core/util/uuid.h"
39+
#include "mm_service_utils.h"
3940

4041
namespace xllm {
4142
namespace {
@@ -737,43 +738,9 @@ void MMChatServiceImpl::process_async_impl(std::shared_ptr<MMChatCall> call) {
737738
rpc_request, call->get_x_request_id(), call->get_x_request_time());
738739

739740
std::vector<Message> messages;
740-
messages.reserve(rpc_request.messages_size());
741-
742-
for (const auto& req_message : req_messages) {
743-
MMContentVec contents;
744-
for (const auto& input : req_message.content()) {
745-
auto& item = const_cast<::xllm::proto::MMInputData&>(input);
746-
if (item.type() == "text") {
747-
contents.emplace_back(item.type(), *item.release_text());
748-
} else if (item.type() == "image_url") {
749-
ImageURL image_url;
750-
image_url.url = std::move(*item.mutable_image_url()->release_url());
751-
contents.emplace_back(item.type(), image_url);
752-
} else if (item.type() == "video_url") {
753-
VideoURL video_url;
754-
video_url.url = std::move(*item.mutable_video_url()->release_url());
755-
contents.emplace_back(item.type(), video_url);
756-
} else if (item.type() == "audio_url") {
757-
AudioURL audio_url;
758-
audio_url.url = std::move(*item.mutable_audio_url()->release_url());
759-
contents.emplace_back(item.type(), audio_url);
760-
} else {
761-
call->finish_with_error(StatusCode::INVALID_ARGUMENT,
762-
"message content type is invalid.");
763-
return;
764-
}
765-
}
766-
messages.emplace_back(req_message.role(), std::move(contents));
767-
}
768-
769-
// check if the request image number exceeds the allowed image limit.
770-
for (auto& msg : messages) {
771-
if (msg.calc_count("image_url") > master_->get_image_limit()) {
772-
call->finish_with_error(StatusCode::INVALID_ARGUMENT,
773-
"Number of images in a single message exceeds "
774-
"the allowed image limit.");
775-
return;
776-
}
741+
if (!build_messages<MMChatCall>(
742+
req_messages, messages, call, master_->get_image_limit())) {
743+
return;
777744
}
778745

779746
bool include_usage = false;

xllm/api_service/embedding_service_impl.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ limitations under the License.
2121

2222
#include "common/instance_name.h"
2323
#include "framework/request/request_params.h"
24+
#include "mm_service_utils.h"
2425
#include "runtime/llm_master.h"
2526
#include "util/utils.h"
2627
#include "util/uuid.h"
2728

2829
namespace xllm {
2930
namespace {
3031

32+
template <typename EmbeddingCall>
3133
bool send_result_to_client_brpc(std::shared_ptr<EmbeddingCall> call,
3234
const std::string& request_id,
3335
int64_t created_time,
@@ -113,9 +115,59 @@ void EmbeddingServiceImpl::process_async_impl(
113115
}
114116
}
115117

116-
return send_result_to_client_brpc(
118+
return send_result_to_client_brpc<EmbeddingCall>(
117119
call, request_id, created_time, model, req_output);
118120
});
119121
}
120122

123+
MMEmbeddingServiceImpl::MMEmbeddingServiceImpl(
124+
VLMMaster* master,
125+
const std::vector<std::string>& models)
126+
: APIServiceImpl(models), master_(master) {
127+
CHECK(master_ != nullptr);
128+
}
129+
130+
void MMEmbeddingServiceImpl::process_async_impl(
131+
std::shared_ptr<MMEmbeddingCall> call) {
132+
const auto& rpc_request = call->request();
133+
// check if model is supported
134+
const auto& model = rpc_request.model();
135+
if (!models_.contains(model)) {
136+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
137+
return;
138+
}
139+
140+
// create RequestParams for embeddings request
141+
// set is_embeddings and max_tokens = 1 to control engine step once.
142+
RequestParams request_params(
143+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
144+
145+
auto& req_messages = rpc_request.messages();
146+
147+
std::vector<Message> messages;
148+
if (!build_messages<MMEmbeddingCall>(
149+
req_messages, messages, call, master_->get_image_limit())) {
150+
return;
151+
}
152+
auto request_id = request_params.request_id;
153+
// schedule the request
154+
master_->handle_request(
155+
std::move(messages),
156+
std::move(request_params),
157+
[call,
158+
model,
159+
request_id = request_id,
160+
created_time = absl::ToUnixSeconds(absl::Now())](
161+
const RequestOutput& req_output) -> bool {
162+
if (req_output.status.has_value()) {
163+
const auto& status = req_output.status.value();
164+
if (!status.ok()) {
165+
return call->finish_with_error(status.code(), status.message());
166+
}
167+
}
168+
169+
return send_result_to_client_brpc<MMEmbeddingCall>(
170+
call, request_id, created_time, model, req_output);
171+
});
172+
}
121173
} // namespace xllm

xllm/api_service/embedding_service_impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include "api_service/api_service_impl.h"
2020
#include "api_service/call.h"
2121
#include "api_service/non_stream_call.h"
22+
#include "core/runtime/vlm_master.h"
2223
#include "embedding.pb.h"
2324

2425
namespace xllm {
@@ -40,4 +41,18 @@ class EmbeddingServiceImpl final : public APIServiceImpl<EmbeddingCall> {
4041
LLMMaster* master_ = nullptr;
4142
};
4243

44+
using MMEmbeddingCall =
45+
NonStreamCall<proto::MMEmbeddingRequest, proto::EmbeddingResponse>;
46+
class MMEmbeddingServiceImpl : public APIServiceImpl<MMEmbeddingCall> {
47+
public:
48+
MMEmbeddingServiceImpl(VLMMaster* master,
49+
const std::vector<std::string>& models);
50+
// brpc call_data needs to use shared_ptr
51+
void process_async_impl(std::shared_ptr<MMEmbeddingCall> call);
52+
53+
private:
54+
DISALLOW_COPY_AND_ASSIGN(MMEmbeddingServiceImpl);
55+
VLMMaster* master_ = nullptr;
56+
};
57+
4358
} // namespace xllm
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
19+
#include "core/common/message.h"
20+
#include "core/common/types.h"
21+
#include "multimodal.pb.h"
22+
23+
namespace xllm {
24+
25+
template <typename Call>
26+
bool build_messages(const google::protobuf::RepeatedPtrField<
27+
xllm::proto::MMChatMessage>& req_messages,
28+
std::vector<Message>& out_messages,
29+
std::shared_ptr<Call> call,
30+
int image_limit) {
31+
out_messages.clear();
32+
out_messages.reserve(req_messages.size());
33+
34+
for (const auto& req_message : req_messages) {
35+
MMContentVec contents;
36+
37+
for (const auto& input : req_message.content()) {
38+
auto& item = const_cast<::xllm::proto::MMInputData&>(input);
39+
40+
if (item.type() == "text") {
41+
contents.emplace_back(item.type(), *item.release_text());
42+
43+
} else if (item.type() == "image_url") {
44+
ImageURL image_url;
45+
image_url.url = std::move(*item.mutable_image_url()->release_url());
46+
contents.emplace_back(item.type(), image_url);
47+
48+
} else if (item.type() == "video_url") {
49+
VideoURL video_url;
50+
video_url.url = std::move(*item.mutable_video_url()->release_url());
51+
contents.emplace_back(item.type(), video_url);
52+
53+
} else if (item.type() == "audio_url") {
54+
AudioURL audio_url;
55+
audio_url.url = std::move(*item.mutable_audio_url()->release_url());
56+
contents.emplace_back(item.type(), audio_url);
57+
58+
} else {
59+
call->finish_with_error(StatusCode::INVALID_ARGUMENT,
60+
"message content type is invalid.");
61+
return false;
62+
}
63+
}
64+
65+
out_messages.emplace_back(req_message.role(), std::move(contents));
66+
}
67+
68+
for (auto& msg : out_messages) {
69+
if (msg.calc_count("image_url") > image_limit) {
70+
call->finish_with_error(StatusCode::INVALID_ARGUMENT,
71+
"Number of images in a single message exceeds "
72+
"the allowed image limit.");
73+
return false;
74+
}
75+
}
76+
77+
return true;
78+
};
79+
80+
} // namespace xllm

xllm/api_service/non_stream_call.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ namespace xllm {
3333
template <typename Request, typename Response>
3434
class NonStreamCall : public Call {
3535
public:
36+
using ReqType = Request;
37+
using ResType = Response;
3638
NonStreamCall(brpc::Controller* controller,
3739
::google::protobuf::Closure* done,
3840
Request* request,

xllm/core/framework/model/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
causal_vlm.h
2929
dit_model.h
3030
embedding_lm.h
31+
embedding_vlm.h
3132
model_args.h
3233
npu_dp_ep_padding.h
3334
model_input_params.h

0 commit comments

Comments
 (0)