@@ -77,31 +77,98 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
7777 return (DLDataType ){ .bits = 0 };
7878}
7979
80- OrtValue * RAI_OrtValueFromTensor (RAI_Tensor * t , RAI_Error * error ) {
81- // TODO: create outside and pass?
82- OrtAllocatorInfo * allocator_info ;
83- OrtStatus * status ;
84- status = OrtCreateCpuAllocatorInfo (OrtArenaAllocator , OrtMemTypeDefault , & allocator_info );
85- if (status != NULL ) {
86- goto error ;
80+ // OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81+ // // TODO: create outside and pass?
82+ // OrtAllocatorInfo* allocator_info;
83+ // OrtStatus* status;
84+ // status = OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info);
85+ // if (status != NULL) {
86+ // goto error;
87+ // }
88+ //
89+ // OrtValue* out;
90+ // status = OrtCreateTensorWithDataAsOrtValue(
91+ // allocator_info,
92+ // t->tensor.dl_tensor.data,
93+ // RAI_TensorByteSize(t),
94+ // t->tensor.dl_tensor.shape,
95+ // t->tensor.dl_tensor.ndim,
96+ // RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
97+ // &out);
98+ //
99+ // if (status != NULL) {
100+ // OrtReleaseAllocatorInfo(allocator_info);
101+ // goto error;
102+ // }
103+ //
104+ // OrtReleaseAllocatorInfo(allocator_info);
105+ //
106+ // return out;
107+ //
108+ // error:
109+ // RAI_SetError(error, RAI_EMODELCREATE, OrtGetErrorMessage(status));
110+ // OrtReleaseStatus(status);
111+ // return NULL;
112+ // }
113+
114+ OrtValue * RAI_OrtValueFromTensors (RAI_Tensor * * ts , size_t count , OrtAllocator * allocator , RAI_Error * error ) {
115+ if (count == 0 ) {
116+ return NULL ;
117+ }
118+
119+ size_t batch_size = 0 ;
120+ size_t batch_byte_size = 0 ;
121+
122+ for (size_t i = 0 ; i < count ; i ++ ) {
123+ batch_size += ts [i ]-> tensor .dl_tensor .shape [0 ];
124+ batch_byte_size += RAI_TensorByteSize (ts [i ]);
87125 }
88126
127+ RAI_Tensor * t0 = ts [0 ];
128+
129+ int ndim = t0 -> tensor .dl_tensor .ndim ;
130+ int64_t batched_shape [ndim ];
131+
132+ for (size_t i = 0 ; i < ndim ; i ++ ) {
133+ batched_shape [i ] = t0 -> tensor .dl_tensor .shape [i ];
134+ }
135+
136+ batched_shape [0 ] = batch_size ;
137+
138+ OrtStatus * status = NULL ;
139+
89140 OrtValue * out ;
90- status = OrtCreateTensorWithDataAsOrtValue (
91- allocator_info ,
92- t -> tensor .dl_tensor .data ,
93- RAI_TensorByteSize (t ),
94- t -> tensor .dl_tensor .shape ,
95- t -> tensor .dl_tensor .ndim ,
96- RAI_GetOrtDataTypeFromDL (t -> tensor .dl_tensor .dtype ),
141+ // status = OrtCreateTensorWithDataAsOrtValue(
142+ // allocator_info,
143+ // t->tensor.dl_tensor.data,
144+ // RAI_TensorByteSize(t),
145+ // batched_shape,
146+ // t->tensor.dl_tensor.ndim,
147+ // RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
148+ // &out);
149+ status = OrtCreateTensorAsOrtValue (
150+ allocator ,
151+ batched_shape ,
152+ t0 -> tensor .dl_tensor .ndim ,
153+ RAI_GetOrtDataTypeFromDL (t0 -> tensor .dl_tensor .dtype ),
97154 & out );
98-
99155 if (status != NULL ) {
100- OrtReleaseAllocatorInfo (allocator_info );
101156 goto error ;
102157 }
158+
159+ char * ort_data ;
160+ status = OrtGetTensorMutableData (out , (void * * )& ort_data );
161+ if (status != NULL ) {
162+ goto error ;
163+ }
164+
165+ for (size_t i = 0 ; i < count ; i ++ ) {
166+ memcpy (ort_data , RAI_TensorData (ts [i ]), RAI_TensorByteSize (ts [i ]));
167+ }
103168
104- OrtReleaseAllocatorInfo (allocator_info );
169+ if (status != NULL ) {
170+ goto error ;
171+ }
105172
106173 return out ;
107174
@@ -111,7 +178,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
111178 return NULL ;
112179}
113180
114- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , RAI_Error * error ) {
181+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
115182 OrtStatus * status = NULL ;
116183
117184 RAI_Tensor * ret = NULL ;
@@ -152,18 +219,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
152219 status = OrtGetTensorElementType (info , & ort_dtype );
153220 if (status != NULL ) goto error ;
154221
222+ int64_t total_batch_size = dims [0 ];
223+
155224 shape = RedisModule_Calloc (ndims , sizeof (* shape ));
156225 strides = RedisModule_Calloc (ndims , sizeof (* strides ));
157- for (int64_t i = 0 ; i < ndims ; ++ i )
226+ for (int64_t i = 0 ; i < ndims ; ++ i )
158227 {
159228 shape [i ] = dims [i ];
160229 strides [i ] = 1 ;
161230 }
231+ shape [0 ] = batch_size ;
162232 for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
163233 {
164234 strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
165235 }
166236
237+ // size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size;
238+
167239 DLDataType dtype = RAI_GetDLDataTypeFromORT (ort_dtype );
168240#ifdef RAI_COPY_RUN_OUTPUT
169241 char * ort_data ;
@@ -178,8 +250,13 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
178250 }
179251
180252 size_t len = dtype .bits * elem_count ;
181- char * data = RedisModule_Calloc (len , sizeof (* data ));
182- memcpy (data , ort_data , len );
253+
254+ size_t total_bytesize = len * sizeof (char );
255+ size_t sample_bytesize = total_bytesize / total_batch_size ;
256+ size_t batch_bytesize = sample_bytesize * batch_size ;
257+
258+ char * data = RedisModule_Calloc (batch_bytesize , sizeof (* data ));
259+ memcpy (data , ort_data + batch_offset , batch_bytesize );
183260#endif
184261
185262 OrtReleaseTensorTypeAndShapeInfo (info );
@@ -345,6 +422,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
345422 return 1 ;
346423 }
347424
425+ const size_t nbatches = array_len (mctx -> batches );
426+ if (nbatches == 0 ) {
427+ RAI_SetError (error , RAI_EMODELRUN , "No batches to run\n" );
428+ return 1 ;
429+ }
430+
431+ size_t batch_sizes [nbatches ];
432+ size_t batch_offsets [nbatches ];
433+ if (array_len (mctx -> batches [0 ].inputs ) > 0 ) {
434+ for (size_t b = 0 ; b < nbatches ; ++ b ) {
435+ batch_sizes [b ] = RAI_TensorDim (mctx -> batches [b ].inputs [0 ].tensor , 0 );
436+ }
437+ batch_offsets [0 ] = 0 ;
438+ for (size_t b = 1 ; b < nbatches ; ++ b ) {
439+ batch_offsets [b ] = batch_sizes [b - 1 ];
440+ }
441+ }
442+
348443 OrtStatus * status = NULL ;
349444
350445 OrtAllocator * allocator ;
@@ -374,8 +469,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
374469 OrtValue * inputs [n_input_nodes ];
375470 OrtValue * outputs [n_output_nodes ];
376471
377- size_t ninputs = array_len (mctx -> inputs );
378- size_t noutputs = array_len (mctx -> outputs );
472+ size_t ninputs = array_len (mctx -> batches [ 0 ]. inputs );
473+ size_t noutputs = array_len (mctx -> batches [ 0 ]. outputs );
379474
380475 if (ninputs != n_input_nodes ) {
381476 char msg [70 ];
@@ -403,7 +498,14 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
403498
404499 input_names [i ] = input_name ;
405500
406- inputs [i ] = RAI_OrtValueFromTensor (mctx -> inputs [i ].tensor , error );
501+ RAI_Tensor * batched_input_tensors [nbatches ];
502+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
503+ batched_input_tensors [b ] = mctx -> batches [b ].inputs [i ].tensor ;
504+ }
505+
506+ // TODO: batches
507+ // inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
508+ inputs [i ] = RAI_OrtValueFromTensors (batched_input_tensors , nbatches , allocator , error );
407509 if (error -> code != RAI_OK ) {
408510 OrtReleaseStatus (status );
409511 OrtReleaseAllocator (allocator );
@@ -456,20 +558,40 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
456558 }
457559
458560 for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
459- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], error );
460- if (error -> code != RAI_OK ) {
461- OrtReleaseStatus (status );
462- OrtReleaseAllocator (allocator );
463- return 1 ;
464- }
465- if (output_tensor ) {
466- mctx -> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
467- RAI_TensorFree (output_tensor );
468- }
469- else {
470- printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
561+ // TODO batched
562+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
563+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
564+ if (error -> code != RAI_OK ) {
565+ // TODO: check everything is deallocated here
566+ OrtReleaseStatus (status );
567+ OrtReleaseAllocator (allocator );
568+ return 1 ;
569+ }
570+ if (output_tensor ) {
571+ mctx -> batches [b ].outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
572+ RAI_TensorFree (output_tensor );
573+ }
574+ else {
575+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
576+ }
471577 }
578+
472579 OrtReleaseValue (outputs [i ]);
580+
581+ // // RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error);
582+ // if (error->code != RAI_OK) {
583+ // OrtReleaseStatus(status);
584+ // OrtReleaseAllocator(allocator);
585+ // return 1;
586+ // }
587+ // if (output_tensor) {
588+ // mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
589+ // RAI_TensorFree(output_tensor);
590+ // }
591+ // else {
592+ // printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
593+ // }
594+ // OrtReleaseValue(outputs[i]);
473595 }
474596
475597 for (size_t i = 0 ; i < n_input_nodes ; i ++ ) {
0 commit comments