@@ -4,7 +4,7 @@ use crate::http::types::{
44 EmbedSparseResponse , Input , OpenAICompatEmbedding , OpenAICompatErrorResponse ,
55 OpenAICompatRequest , OpenAICompatResponse , OpenAICompatUsage , PredictInput , PredictRequest ,
66 PredictResponse , Prediction , Rank , RerankRequest , RerankResponse , Sequence , SimpleToken ,
7- SparseValue , TokenizeRequest , TokenizeResponse , VertexInstance , VertexRequest , VertexResponse ,
7+ SparseValue , TokenizeRequest , TokenizeResponse , VertexRequest , VertexResponse ,
88 VertexResponseInstance ,
99} ;
1010use crate :: {
@@ -1181,11 +1181,6 @@ async fn vertex_compatibility(
11811181 let result = embed ( infer, info, Json ( req) ) . await ?;
11821182 Ok ( VertexResponseInstance :: Embed ( result. 1 . 0 ) )
11831183 } ;
1184- let embed_all_future =
1185- move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedAllRequest | async move {
1186- let result = embed_all ( infer, info, Json ( req) ) . await ?;
1187- Ok ( VertexResponseInstance :: EmbedAll ( result. 1 . 0 ) )
1188- } ;
11891184 let embed_sparse_future =
11901185 move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedSparseRequest | async move {
11911186 let result = embed_sparse ( infer, info, Json ( req) ) . await ?;
@@ -1201,45 +1196,44 @@ async fn vertex_compatibility(
12011196 let result = rerank ( infer, info, Json ( req) ) . await ?;
12021197 Ok ( VertexResponseInstance :: Rerank ( result. 1 . 0 ) )
12031198 } ;
1204- let tokenize_future =
1205- move |infer : Extension < Infer > , info : Extension < Info > , req : TokenizeRequest | async move {
1206- let result = tokenize ( infer, info, Json ( req) ) . await ?;
1207- Ok ( VertexResponseInstance :: Tokenize ( result. 0 ) )
1208- } ;
12091199
12101200 let mut futures = Vec :: with_capacity ( req. instances . len ( ) ) ;
12111201 for instance in req. instances {
12121202 let local_infer = infer. clone ( ) ;
12131203 let local_info = info. clone ( ) ;
12141204
1215- match instance {
1216- VertexInstance :: Embed ( req) => {
1217- futures. push ( embed_future ( local_infer, local_info, req) . boxed ( ) ) ;
1218- }
1219- VertexInstance :: EmbedAll ( req) => {
1220- futures. push ( embed_all_future ( local_infer, local_info, req) . boxed ( ) ) ;
1221- }
1222- VertexInstance :: EmbedSparse ( req) => {
1223- futures. push ( embed_sparse_future ( local_infer, local_info, req) . boxed ( ) ) ;
1224- }
1225- VertexInstance :: Predict ( req) => {
1226- futures. push ( predict_future ( local_infer, local_info, req) . boxed ( ) ) ;
1227- }
1228- VertexInstance :: Rerank ( req) => {
1229- futures. push ( rerank_future ( local_infer, local_info, req) . boxed ( ) ) ;
1205+ // Rerank is the only payload that can me matched safely
1206+ if let Ok ( instance) = serde_json:: from_value :: < RerankRequest > ( instance. clone ( ) ) {
1207+ futures. push ( rerank_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1208+ continue ;
1209+ }
1210+
1211+ match info. model_type {
1212+ ModelType :: Classifier ( _) | ModelType :: Reranker ( _) => {
1213+ let instance = serde_json:: from_value :: < PredictRequest > ( instance)
1214+ . map_err ( ErrorResponse :: from) ?;
1215+ futures. push ( predict_future ( local_infer, local_info, instance) . boxed ( ) ) ;
12301216 }
1231- VertexInstance :: Tokenize ( req) => {
1232- futures. push ( tokenize_future ( local_infer, local_info, req) . boxed ( ) ) ;
1217+ ModelType :: Embedding ( _) => {
1218+ if infer. is_splade ( ) {
1219+ let instance = serde_json:: from_value :: < EmbedSparseRequest > ( instance)
1220+ . map_err ( ErrorResponse :: from) ?;
1221+ futures. push ( embed_sparse_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1222+ } else {
1223+ let instance = serde_json:: from_value :: < EmbedRequest > ( instance)
1224+ . map_err ( ErrorResponse :: from) ?;
1225+ futures. push ( embed_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1226+ }
12331227 }
12341228 }
12351229 }
12361230
1237- let results = join_all ( futures)
1231+ let predictions = join_all ( futures)
12381232 . await
12391233 . into_iter ( )
12401234 . collect :: < Result < Vec < VertexResponseInstance > , ( StatusCode , Json < ErrorResponse > ) > > ( ) ?;
12411235
1242- Ok ( Json ( VertexResponse ( results ) ) )
1236+ Ok ( Json ( VertexResponse { predictions } ) )
12431237}
12441238
12451239/// Prometheus metrics scrape endpoint
@@ -1353,12 +1347,7 @@ pub async fn run(
13531347 #[ derive( OpenApi ) ]
13541348 #[ openapi(
13551349 paths( vertex_compatibility) ,
1356- components( schemas(
1357- VertexInstance ,
1358- VertexRequest ,
1359- VertexResponse ,
1360- VertexResponseInstance
1361- ) )
1350+ components( schemas( VertexRequest , VertexResponse , VertexResponseInstance ) )
13621351 ) ]
13631352 struct VertextApiDoc ;
13641353
@@ -1396,31 +1385,6 @@ pub async fn run(
13961385 // Prometheus metrics route
13971386 . route ( "/metrics" , get ( metrics) ) ;
13981387
1399- // Set default routes
1400- app = match & info. model_type {
1401- ModelType :: Classifier ( _) => {
1402- app. route ( "/" , post ( predict) )
1403- // AWS Sagemaker route
1404- . route ( "/invocations" , post ( predict) )
1405- }
1406- ModelType :: Reranker ( _) => {
1407- app. route ( "/" , post ( rerank) )
1408- // AWS Sagemaker route
1409- . route ( "/invocations" , post ( rerank) )
1410- }
1411- ModelType :: Embedding ( model) => {
1412- if model. pooling == "splade" {
1413- app. route ( "/" , post ( embed_sparse) )
1414- // AWS Sagemaker route
1415- . route ( "/invocations" , post ( embed_sparse) )
1416- } else {
1417- app. route ( "/" , post ( embed) )
1418- // AWS Sagemaker route
1419- . route ( "/invocations" , post ( embed) )
1420- }
1421- }
1422- } ;
1423-
14241388 #[ cfg( feature = "google" ) ]
14251389 {
14261390 tracing:: info!( "Built with `google` feature" ) ;
@@ -1433,6 +1397,44 @@ pub async fn run(
14331397 if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
14341398 app = app. route ( & env_health_route, get ( health) ) ;
14351399 }
1400+ let mut app = Router :: new ( ) . merge ( base_routes) ;
1401+
1402+ #[ cfg( feature = "google" ) ]
1403+ {
1404+ tracing:: info!( "Built with `google` feature" ) ;
1405+ let env_predict_route = std:: env:: var ( "AIP_PREDICT_ROUTE" )
1406+ . context ( "`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1407+ app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1408+ let env_health_route = std:: env:: var ( "AIP_HEALTH_ROUTE" )
1409+ . context ( "`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1410+ app = app. route ( & env_health_route, get ( health) ) ;
1411+ }
1412+ #[ cfg( not( feature = "google" ) ) ]
1413+ {
1414+ // Set default routes
1415+ app = match & info. model_type {
1416+ ModelType :: Classifier ( _) => {
1417+ app. route ( "/" , post ( predict) )
1418+ // AWS Sagemaker route
1419+ . route ( "/invocations" , post ( predict) )
1420+ }
1421+ ModelType :: Reranker ( _) => {
1422+ app. route ( "/" , post ( rerank) )
1423+ // AWS Sagemaker route
1424+ . route ( "/invocations" , post ( rerank) )
1425+ }
1426+ ModelType :: Embedding ( model) => {
1427+ if model. pooling == "splade" {
1428+ app. route ( "/" , post ( embed_sparse) )
1429+ // AWS Sagemaker route
1430+ . route ( "/invocations" , post ( embed_sparse) )
1431+ } else {
1432+ app. route ( "/" , post ( embed) )
1433+ // AWS Sagemaker route
1434+ . route ( "/invocations" , post ( embed) )
1435+ }
1436+ }
1437+ } ;
14361438 }
14371439
14381440 app = app
@@ -1509,3 +1511,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
15091511 ( StatusCode :: from ( & err. error_type ) , Json ( err. into ( ) ) )
15101512 }
15111513}
1514+
1515+ impl From < serde_json:: Error > for ErrorResponse {
1516+ fn from ( err : serde_json:: Error ) -> Self {
1517+ ErrorResponse {
1518+ error : err. to_string ( ) ,
1519+ error_type : ErrorType :: Validation ,
1520+ }
1521+ }
1522+ }
0 commit comments