Skip to content

Commit 094f191

Browse files
use the model type to match safely
1 parent b1efb4b commit 094f191

File tree

3 files changed

+81
-78
lines changed

3 files changed

+81
-78
lines changed

router/src/http/server.rs

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::http::types::{
44
EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse,
55
OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest,
66
PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken,
7-
SparseValue, TokenizeRequest, TokenizeResponse, VertexInstance, VertexRequest, VertexResponse,
7+
SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, VertexResponse,
88
VertexResponseInstance,
99
};
1010
use crate::{
@@ -1181,11 +1181,6 @@ async fn vertex_compatibility(
11811181
let result = embed(infer, info, Json(req)).await?;
11821182
Ok(VertexResponseInstance::Embed(result.1 .0))
11831183
};
1184-
let embed_all_future =
1185-
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedAllRequest| async move {
1186-
let result = embed_all(infer, info, Json(req)).await?;
1187-
Ok(VertexResponseInstance::EmbedAll(result.1 .0))
1188-
};
11891184
let embed_sparse_future =
11901185
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedSparseRequest| async move {
11911186
let result = embed_sparse(infer, info, Json(req)).await?;
@@ -1201,45 +1196,44 @@ async fn vertex_compatibility(
12011196
let result = rerank(infer, info, Json(req)).await?;
12021197
Ok(VertexResponseInstance::Rerank(result.1 .0))
12031198
};
1204-
let tokenize_future =
1205-
move |infer: Extension<Infer>, info: Extension<Info>, req: TokenizeRequest| async move {
1206-
let result = tokenize(infer, info, Json(req)).await?;
1207-
Ok(VertexResponseInstance::Tokenize(result.0))
1208-
};
12091199

12101200
let mut futures = Vec::with_capacity(req.instances.len());
12111201
for instance in req.instances {
12121202
let local_infer = infer.clone();
12131203
let local_info = info.clone();
12141204

1215-
match instance {
1216-
VertexInstance::Embed(req) => {
1217-
futures.push(embed_future(local_infer, local_info, req).boxed());
1218-
}
1219-
VertexInstance::EmbedAll(req) => {
1220-
futures.push(embed_all_future(local_infer, local_info, req).boxed());
1221-
}
1222-
VertexInstance::EmbedSparse(req) => {
1223-
futures.push(embed_sparse_future(local_infer, local_info, req).boxed());
1224-
}
1225-
VertexInstance::Predict(req) => {
1226-
futures.push(predict_future(local_infer, local_info, req).boxed());
1227-
}
1228-
VertexInstance::Rerank(req) => {
1229-
futures.push(rerank_future(local_infer, local_info, req).boxed());
1205+
// Rerank is the only payload that can me matched safely
1206+
if let Ok(instance) = serde_json::from_value::<RerankRequest>(instance.clone()) {
1207+
futures.push(rerank_future(local_infer, local_info, instance).boxed());
1208+
continue;
1209+
}
1210+
1211+
match info.model_type {
1212+
ModelType::Classifier(_) | ModelType::Reranker(_) => {
1213+
let instance = serde_json::from_value::<PredictRequest>(instance)
1214+
.map_err(ErrorResponse::from)?;
1215+
futures.push(predict_future(local_infer, local_info, instance).boxed());
12301216
}
1231-
VertexInstance::Tokenize(req) => {
1232-
futures.push(tokenize_future(local_infer, local_info, req).boxed());
1217+
ModelType::Embedding(_) => {
1218+
if infer.is_splade() {
1219+
let instance = serde_json::from_value::<EmbedSparseRequest>(instance)
1220+
.map_err(ErrorResponse::from)?;
1221+
futures.push(embed_sparse_future(local_infer, local_info, instance).boxed());
1222+
} else {
1223+
let instance = serde_json::from_value::<EmbedRequest>(instance)
1224+
.map_err(ErrorResponse::from)?;
1225+
futures.push(embed_future(local_infer, local_info, instance).boxed());
1226+
}
12331227
}
12341228
}
12351229
}
12361230

1237-
let results = join_all(futures)
1231+
let predictions = join_all(futures)
12381232
.await
12391233
.into_iter()
12401234
.collect::<Result<Vec<VertexResponseInstance>, (StatusCode, Json<ErrorResponse>)>>()?;
12411235

1242-
Ok(Json(VertexResponse(results)))
1236+
Ok(Json(VertexResponse { predictions }))
12431237
}
12441238

12451239
/// Prometheus metrics scrape endpoint
@@ -1353,12 +1347,7 @@ pub async fn run(
13531347
#[derive(OpenApi)]
13541348
#[openapi(
13551349
paths(vertex_compatibility),
1356-
components(schemas(
1357-
VertexInstance,
1358-
VertexRequest,
1359-
VertexResponse,
1360-
VertexResponseInstance
1361-
))
1350+
components(schemas(VertexRequest, VertexResponse, VertexResponseInstance))
13621351
)]
13631352
struct VertextApiDoc;
13641353

@@ -1396,31 +1385,6 @@ pub async fn run(
13961385
// Prometheus metrics route
13971386
.route("/metrics", get(metrics));
13981387

1399-
// Set default routes
1400-
app = match &info.model_type {
1401-
ModelType::Classifier(_) => {
1402-
app.route("/", post(predict))
1403-
// AWS Sagemaker route
1404-
.route("/invocations", post(predict))
1405-
}
1406-
ModelType::Reranker(_) => {
1407-
app.route("/", post(rerank))
1408-
// AWS Sagemaker route
1409-
.route("/invocations", post(rerank))
1410-
}
1411-
ModelType::Embedding(model) => {
1412-
if model.pooling == "splade" {
1413-
app.route("/", post(embed_sparse))
1414-
// AWS Sagemaker route
1415-
.route("/invocations", post(embed_sparse))
1416-
} else {
1417-
app.route("/", post(embed))
1418-
// AWS Sagemaker route
1419-
.route("/invocations", post(embed))
1420-
}
1421-
}
1422-
};
1423-
14241388
#[cfg(feature = "google")]
14251389
{
14261390
tracing::info!("Built with `google` feature");
@@ -1433,6 +1397,44 @@ pub async fn run(
14331397
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
14341398
app = app.route(&env_health_route, get(health));
14351399
}
1400+
let mut app = Router::new().merge(base_routes);
1401+
1402+
#[cfg(feature = "google")]
1403+
{
1404+
tracing::info!("Built with `google` feature");
1405+
let env_predict_route = std::env::var("AIP_PREDICT_ROUTE")
1406+
.context("`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments")?;
1407+
app = app.route(&env_predict_route, post(vertex_compatibility));
1408+
let env_health_route = std::env::var("AIP_HEALTH_ROUTE")
1409+
.context("`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments")?;
1410+
app = app.route(&env_health_route, get(health));
1411+
}
1412+
#[cfg(not(feature = "google"))]
1413+
{
1414+
// Set default routes
1415+
app = match &info.model_type {
1416+
ModelType::Classifier(_) => {
1417+
app.route("/", post(predict))
1418+
// AWS Sagemaker route
1419+
.route("/invocations", post(predict))
1420+
}
1421+
ModelType::Reranker(_) => {
1422+
app.route("/", post(rerank))
1423+
// AWS Sagemaker route
1424+
.route("/invocations", post(rerank))
1425+
}
1426+
ModelType::Embedding(model) => {
1427+
if model.pooling == "splade" {
1428+
app.route("/", post(embed_sparse))
1429+
// AWS Sagemaker route
1430+
.route("/invocations", post(embed_sparse))
1431+
} else {
1432+
app.route("/", post(embed))
1433+
// AWS Sagemaker route
1434+
.route("/invocations", post(embed))
1435+
}
1436+
}
1437+
};
14361438
}
14371439

14381440
app = app
@@ -1509,3 +1511,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
15091511
(StatusCode::from(&err.error_type), Json(err.into()))
15101512
}
15111513
}
1514+
1515+
impl From<serde_json::Error> for ErrorResponse {
1516+
fn from(err: serde_json::Error) -> Self {
1517+
ErrorResponse {
1518+
error: err.to_string(),
1519+
error_type: ErrorType::Validation,
1520+
}
1521+
}
1522+
}

router/src/http/types.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -382,32 +382,21 @@ pub(crate) struct SimpleToken {
382382
#[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))]
383383
pub(crate) struct TokenizeResponse(pub Vec<Vec<SimpleToken>>);
384384

385-
#[derive(Deserialize, ToSchema)]
386-
#[serde(tag = "type", rename_all = "snake_case")]
387-
pub(crate) enum VertexInstance {
388-
Embed(EmbedRequest),
389-
EmbedAll(EmbedAllRequest),
390-
EmbedSparse(EmbedSparseRequest),
391-
Predict(PredictRequest),
392-
Rerank(RerankRequest),
393-
Tokenize(TokenizeRequest),
394-
}
395-
396385
#[derive(Deserialize, ToSchema)]
397386
pub(crate) struct VertexRequest {
398-
pub instances: Vec<VertexInstance>,
387+
pub instances: Vec<serde_json::Value>,
399388
}
400389

401390
#[derive(Serialize, ToSchema)]
402-
#[serde(tag = "type", content = "result", rename_all = "snake_case")]
391+
#[serde(untagged)]
403392
pub(crate) enum VertexResponseInstance {
404393
Embed(EmbedResponse),
405-
EmbedAll(EmbedAllResponse),
406394
EmbedSparse(EmbedSparseResponse),
407395
Predict(PredictResponse),
408396
Rerank(RerankResponse),
409-
Tokenize(TokenizeResponse),
410397
}
411398

412399
#[derive(Serialize, ToSchema)]
413-
pub(crate) struct VertexResponse(pub Vec<VertexResponseInstance>);
400+
pub(crate) struct VertexResponse {
401+
pub predictions: Vec<VertexResponseInstance>,
402+
}

router/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ pub async fn run(
246246
std::env::var("AIP_HTTP_PORT")
247247
.ok()
248248
.and_then(|p| p.parse().ok())
249-
.context("Invalid or unset AIP_HTTP_PORT")?
249+
.context("`AIP_HTTP_PORT` env var must be set for Google Vertex deployments")?
250250
} else {
251251
port
252252
};
@@ -264,6 +264,9 @@ pub async fn run(
264264
#[cfg(all(feature = "grpc", feature = "http"))]
265265
compile_error!("Features `http` and `grpc` cannot be enabled at the same time.");
266266

267+
#[cfg(all(feature = "grpc", feature = "google"))]
268+
compile_error!("Features `http` and `google` cannot be enabled at the same time.");
269+
267270
#[cfg(not(any(feature = "http", feature = "grpc")))]
268271
compile_error!("Either feature `http` or `grpc` must be enabled.");
269272

0 commit comments

Comments
 (0)