22#include "backends/util.h"
33#include "backends/tensorflow.h"
44#include "util/arr.h"
5+ #include "execution/execution_contexts/modelRun_ctx.h"
56#include "redis_ai_objects/model.h"
67#include "redis_ai_objects/tensor.h"
78
@@ -461,17 +462,17 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
461462 TF_DeleteStatus (status );
462463}
463464
464- int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
465+ int RAI_ModelRunTF (RAI_Model * model , RAI_ExecutionCtx * * ectxs , RAI_Error * error ) {
465466 TF_Status * status = TF_NewStatus ();
466467
467- const size_t nbatches = array_len (mctxs );
468+ const size_t nbatches = array_len (ectxs );
468469 if (nbatches == 0 ) {
469470 RAI_SetError (error , RAI_EMODELRUN , "ERR No batches to run" );
470471 return 1 ;
471472 }
472473
473- const size_t ninputs = array_len ( mctxs [0 ]-> inputs );
474- const size_t noutputs = array_len ( mctxs [0 ]-> outputs );
474+ const size_t ninputs = RAI_ExecutionCtx_NumInputs ( ectxs [0 ]);
475+ const size_t noutputs = RAI_ExecutionCtx_NumOutputs ( ectxs [0 ]);
475476 TF_Tensor * inputTensorsValues [ninputs ];
476477 TF_Output inputs [ninputs ];
477478 TF_Tensor * outputTensorsValues [noutputs ];
@@ -482,7 +483,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
482483 size_t total_batch_size = 0 ;
483484 if (ninputs > 0 ) {
484485 for (size_t b = 0 ; b < nbatches ; ++ b ) {
485- batch_sizes [b ] = RAI_TensorDim (mctxs [b ]-> inputs [ 0 ]. tensor , 0 );
486+ batch_sizes [b ] = RAI_TensorDim (RAI_ExecutionCtx_GetInput ( ectxs [b ], 0 ) , 0 );
486487 total_batch_size += batch_sizes [b ];
487488 }
488489 batch_offsets [0 ] = 0 ;
@@ -491,15 +492,18 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
491492 }
492493 }
493494
495+ void * tfGraph = RAI_ModelGetModel (model );
496+ void * tfSession = RAI_ModelGetSession (model );
497+
494498 for (size_t i = 0 ; i < ninputs ; ++ i ) {
495499 RAI_Tensor * batched_input_tensors [nbatches ];
496500
497501 for (size_t b = 0 ; b < nbatches ; ++ b ) {
498- batched_input_tensors [b ] = mctxs [b ]-> inputs [ i ]. tensor ;
502+ batched_input_tensors [b ] = RAI_ExecutionCtx_GetInput ( ectxs [b ], i ) ;
499503 }
500504 inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
501505 TF_Output port ;
502- port .oper = TF_GraphOperationByName (mctxs [ 0 ] -> model -> model , mctxs [ 0 ] -> inputs [ i ]. name );
506+ port .oper = TF_GraphOperationByName (tfGraph , RAI_ModelGetInputName ( model , i ) );
503507 port .index = 0 ;
504508 if (port .oper == NULL ) {
505509 return 1 ;
@@ -509,17 +513,17 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
509513
510514 for (size_t i = 0 ; i < noutputs ; ++ i ) {
511515 TF_Output port ;
512- port .oper = TF_GraphOperationByName (mctxs [ 0 ] -> model -> model , mctxs [ 0 ] -> outputs [ i ]. name );
516+ port .oper = TF_GraphOperationByName (tfGraph , RAI_ModelGetOutputName ( model , i ) );
513517 port .index = 0 ;
514518 if (port .oper == NULL ) {
515519 return 1 ;
516520 }
517521 outputs [i ] = port ;
518522 }
519523
520- TF_SessionRun (mctxs [ 0 ] -> model -> session , NULL /* run_options */ , inputs , inputTensorsValues ,
521- ninputs , outputs , outputTensorsValues , noutputs , NULL /* target_opers */ ,
522- 0 /* ntargets */ , NULL /* run_Metadata */ , status );
524+ TF_SessionRun (tfSession , NULL /* run_options */ , inputs , inputTensorsValues , ninputs , outputs ,
525+ outputTensorsValues , noutputs , NULL /* target_opers */ , 0 /* ntargets */ ,
526+ NULL /* run_Metadata */ , status );
523527
524528 for (size_t i = 0 ; i < ninputs ; ++ i ) {
525529 TF_DeleteTensor (inputTensorsValues [i ]);
@@ -547,12 +551,15 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
547551 }
548552
549553 for (size_t b = 0 ; b < nbatches ; b ++ ) {
550- mctxs [b ]-> outputs [i ].tensor = RAI_TensorCreateFromTFTensor (
551- outputTensorsValues [i ], batch_offsets [b ], batch_sizes [b ]);
554+ RAI_ExecutionCtx_SetOutput (ectxs [b ],
555+ RAI_TensorCreateFromTFTensor (outputTensorsValues [i ],
556+ batch_offsets [b ],
557+ batch_sizes [b ]),
558+ i );
552559 }
553560 } else {
554- mctxs [ 0 ] -> outputs [ i ]. tensor =
555- RAI_TensorCreateFromTFTensor (outputTensorsValues [i ], 0 , -1 );
561+ RAI_ExecutionCtx_SetOutput (
562+ ectxs [ 0 ], RAI_TensorCreateFromTFTensor (outputTensorsValues [i ], 0 , -1 ), i );
556563 }
557564 TF_DeleteTensor (outputTensorsValues [i ]);
558565 }
0 commit comments