@@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
163163 return NULL ;
164164}
165165
166- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
166+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , long long batch_size , RAI_Error * error ) {
167167 OrtStatus * status = NULL ;
168168 const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
169169
@@ -215,7 +215,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
215215 shape [i ] = dims [i ];
216216 strides [i ] = 1 ;
217217 }
218- shape [0 ] = batch_size ;
218+ if (batch_size != -1 ) {
219+ shape [0 ] = batch_size ;
220+ }
221+ else {
222+ batch_size = total_batch_size ;
223+ }
219224 for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
220225 {
221226 strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
@@ -412,9 +417,11 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
412417
413418 size_t batch_sizes [nbatches ];
414419 size_t batch_offsets [nbatches ];
420+ size_t total_batch_size = 0 ;
415421 if (array_len (mctxs [0 ]-> inputs ) > 0 ) {
416422 for (size_t b = 0 ; b < nbatches ; ++ b ) {
417423 batch_sizes [b ] = RAI_TensorDim (mctxs [b ]-> inputs [0 ].tensor , 0 );
424+ total_batch_size += batch_sizes [b ];
418425 }
419426 batch_offsets [0 ] = 0 ;
420427 for (size_t b = 1 ; b < nbatches ; ++ b ) {
@@ -530,14 +537,48 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
530537 }
531538
532539 for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
533- for (size_t b = 0 ; b < nbatches ; b ++ ) {
534- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
540+ if (nbatches > 1 ) {
541+ OrtTensorTypeAndShapeInfo * info ;
542+ status = ort -> GetTensorTypeAndShape (outputs [i ], & info );
543+ if (status != NULL ) goto error ;
544+
545+ size_t ndims ;
546+ status = ort -> GetDimensionsCount (info , & ndims );
547+ if (status != NULL ) goto error ;
548+
549+ int64_t dims [ndims ];
550+ status = ort -> GetDimensions (info , dims , ndims );
551+ if (status != NULL ) goto error ;
552+
553+ if (dims [0 ] != total_batch_size ) {
554+ RAI_SetError (error , RAI_EMODELRUN , "ERR Model did not generate the expected batch size" );
555+ ort -> ReleaseStatus (status );
556+ return 1 ;
557+ }
558+
559+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
560+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
561+ if (error -> code != RAI_OK ) {
562+ ort -> ReleaseStatus (status );
563+ return 1 ;
564+ }
565+ if (output_tensor ) {
566+ mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
567+ RAI_TensorFree (output_tensor );
568+ }
569+ else {
570+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)" );
571+ }
572+ }
573+ }
574+ else {
575+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], 0 , -1 , error );
535576 if (error -> code != RAI_OK ) {
536577 ort -> ReleaseStatus (status );
537578 return 1 ;
538579 }
539580 if (output_tensor ) {
540- mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
581+ mctxs [0 ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
541582 RAI_TensorFree (output_tensor );
542583 }
543584 else {
0 commit comments