@@ -83,23 +83,52 @@ void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) {
8383
8484int RAI_ModelRunTFLite (RAI_ModelRunCtx * mctx , RAI_Error * error ) {
8585
86- size_t ninputs = array_len (mctx -> inputs );
87- size_t noutputs = array_len (mctx -> outputs );
86+ const size_t nbatches = array_len (mctx -> batches );
87+ if (nbatches == 0 ) {
88+ RAI_SetError (error , RAI_EMODELRUN , "No batches to run\n" );
89+ return 1 ;
90+ }
91+
92+ size_t total_batch_size = 0 ;
93+ size_t batch_sizes [nbatches ];
94+ size_t batch_offsets [nbatches ];
95+ if (array_len (mctx -> batches [0 ].inputs ) > 0 ) {
96+ for (size_t b = 0 ; b < nbatches ; ++ b ) {
97+ batch_sizes [b ] = RAI_TensorDim (mctx -> batches [b ].inputs [0 ].tensor , 0 );
98+ total_batch_size += batch_sizes [b ];
99+ }
100+ batch_offsets [0 ] = 0 ;
101+ for (size_t b = 1 ; b < nbatches ; ++ b ) {
102+ batch_offsets [b ] = batch_sizes [b - 1 ];
103+ }
104+ }
105+
106+ size_t ninputs = array_len (mctx -> batches [0 ].inputs );
107+ size_t noutputs = array_len (mctx -> batches [0 ].outputs );
108+
109+ RAI_Tensor * inputs [ninputs ];
88110
89- DLManagedTensor * inputs [ninputs ];
90- DLManagedTensor * outputs [noutputs ];
111+ DLManagedTensor * inputs_dl [ninputs ];
112+ DLManagedTensor * outputs_dl [noutputs ];
91113
92114 for (size_t i = 0 ; i < ninputs ; ++ i ) {
93- inputs [i ] = & mctx -> inputs [i ].tensor -> tensor ;
115+ RAI_Tensor * batch [nbatches ];
116+
117+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
118+ batch [b ] = mctx -> batches [b ].inputs [i ].tensor ;
119+ }
120+
121+ inputs [i ] = RAI_TensorCreateByConcatenatingTensors (batch , nbatches );
122+ inputs_dl [i ] = & inputs [i ]-> tensor ;
94123 }
95124
96125 for (size_t i = 0 ; i < noutputs ; ++ i ) {
97- outputs [i ] = mctx -> outputs [ i ]. tensor ? & mctx -> outputs [ i ]. tensor -> tensor : NULL ;
126+ outputs_dl [i ] = NULL ;
98127 }
99128
100129 char * error_descr = NULL ;
101130 tfliteRunModel (mctx -> model -> model ,
102- ninputs , inputs , noutputs , outputs ,
131+ ninputs , inputs_dl , noutputs , outputs_dl ,
103132 & error_descr , RedisModule_Alloc );
104133
105134 if (error_descr != NULL ) {
@@ -108,16 +137,22 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) {
108137 return 1 ;
109138 }
110139
111- for (size_t i = 0 ; i < array_len ( mctx -> outputs ) ; ++ i ) {
112- if (outputs [i ] == NULL ) {
140+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
141+ if (outputs_dl [i ] == NULL ) {
113142 RAI_SetError (error , RAI_EMODELRUN , "Model did not generate the expected number of outputs." );
114143 return 1 ;
115144 }
116- RAI_Tensor * output_tensor = RAI_TensorCreateFromDLTensor (outputs [i ]);
117- mctx -> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
145+ RAI_Tensor * output_tensor = RAI_TensorCreateFromDLTensor (outputs_dl [i ]);
146+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
147+ mctx -> batches [b ].outputs [i ].tensor = RAI_TensorCreateBySlicingTensor (output_tensor , batch_offsets [b ], batch_sizes [b ]);
148+ }
118149 RAI_TensorFree (output_tensor );
119150 }
120151
152+ for (size_t i = 0 ; i < ninputs ; ++ i ) {
153+ RAI_TensorFree (inputs [i ]);
154+ }
155+
121156 return 0 ;
122157}
123158
0 commit comments