88
99RedisModuleType * RedisAI_TensorType = NULL ;
1010
11- static DLDataType Tensor_GetDataType (const char * typestr ){
11+ DLDataType RAI_TensorDataTypeFromString (const char * typestr ){
1212 if (strcasecmp (typestr , "FLOAT" ) == 0 ){
1313 return (DLDataType ){ .code = kDLFloat , .bits = 32 , .lanes = 1 };
1414 }
@@ -223,8 +223,7 @@ int RAI_TensorInit(RedisModuleCtx* ctx){
223223 return RedisAI_TensorType != NULL ;
224224}
225225
226- RAI_Tensor * RAI_TensorCreate (const char * dataTypeStr , long long * dims , int ndims , int hasdata ) {
227- DLDataType dtype = Tensor_GetDataType (dataTypeStr );
226+ RAI_Tensor * RAI_TensorCreateWithDLDataType (DLDataType dtype , long long * dims , int ndims , int hasdata ) {
228227 const size_t dtypeSize = Tensor_DataTypeSize (dtype );
229228 if ( dtypeSize == 0 ){
230229 return NULL ;
@@ -279,6 +278,11 @@ RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims
279278 return ret ;
280279}
281280
281+ RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
282+ DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
283+ return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , hasdata );
284+ }
285+
282286#if 0
283287void RAI_TensorMoveFrom (RAI_Tensor * dst , RAI_Tensor * src ) {
284288 if (-- dst -> refCount <= 0 ){
@@ -296,6 +300,76 @@ void RAI_TensorMoveFrom(RAI_Tensor* dst, RAI_Tensor* src) {
296300}
297301#endif
298302
303+ RAI_Tensor * RAI_TensorCreateByConcatenatingTensors (RAI_Tensor * * ts , long long n ) {
304+
305+ if (n == 0 ) {
306+ return NULL ;
307+ }
308+
309+ long long total_batch_size = 0 ;
310+ long long batch_sizes [n ];
311+ long long batch_offsets [n ];
312+
313+ long long ndims = RAI_TensorNumDims (ts [0 ]);
314+ long long dims [ndims ];
315+
316+ // TODO check that all tensors have compatible dims
317+
318+ for (long long i = 0 ; i < n ; i ++ ) {
319+ batch_sizes [i ] = RAI_TensorDim (ts [i ], 0 );
320+ total_batch_size += batch_sizes [i ];
321+ }
322+
323+ batch_offsets [0 ] = 0 ;
324+ for (long long i = 1 ; i < n ; i ++ ) {
325+ batch_offsets [i ] = batch_sizes [i - 1 ];
326+ }
327+
328+ long long sample_size = 0 ;
329+
330+ for (long long i = 1 ; i < ndims ; i ++ ) {
331+ dims [i ] = RAI_TensorDim (ts [0 ], i );
332+ sample_size *= dims [i ];
333+ }
334+ dims [0 ] = total_batch_size ;
335+
336+ long long dtype_size = RAI_TensorDataSize (ts [0 ]);
337+
338+ DLDataType dtype = RAI_TensorDataType (ts [0 ]);
339+
340+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , 1 );
341+
342+ for (long long i = 0 ; i < n ; i ++ ) {
343+ memcpy (RAI_TensorData (ret ) + batch_offsets [i ] * sample_size * dtype_size , RAI_TensorData (ts [i ]), RAI_TensorByteSize (ts [i ]));
344+ }
345+
346+ return ret ;
347+ }
348+
349+ RAI_Tensor * RAI_TensorCreateBySlicingTensor (RAI_Tensor * t , long long offset , long long len ) {
350+
351+ long long ndims = RAI_TensorNumDims (t );
352+ long long dims [ndims ];
353+
354+ long long dtype_size = RAI_TensorDataSize (t );
355+ long long sample_size = 0 ;
356+
357+ for (long long i = 1 ; i < ndims ; i ++ ) {
358+ dims [i ] = RAI_TensorDim (t , i );
359+ sample_size *= dims [i ];
360+ }
361+
362+ dims [0 ] = len ;
363+
364+ DLDataType dtype = RAI_TensorDataType (t );
365+
366+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , 1 );
367+
368+ memcpy (RAI_TensorData (ret ), RAI_TensorData (t ) + offset * sample_size * dtype_size , len * sample_size * dtype_size );
369+
370+ return ret ;
371+ }
372+
299373// Beware: this will take ownership of dltensor
300374RAI_Tensor * RAI_TensorCreateFromDLTensor (DLManagedTensor * dl_tensor ) {
301375
@@ -332,8 +406,16 @@ size_t RAI_TensorLength(RAI_Tensor* t) {
332406 return len ;
333407}
334408
335- size_t RAI_TensorGetDataSize (const char * dataTypeStr ) {
336- DLDataType dtype = Tensor_GetDataType (dataTypeStr );
409+ size_t RAI_TensorDataSize (RAI_Tensor * t ) {
410+ return Tensor_DataTypeSize (RAI_TensorDataType (t ));
411+ }
412+
413+ size_t RAI_TensorDataSizeFromString (const char * dataTypeStr ) {
414+ DLDataType dtype = RAI_TensorDataTypeFromString (dataTypeStr );
415+ return Tensor_DataTypeSize (dtype );
416+ }
417+
418+ size_t RAI_TensorDataSizeFromDLDataType (DLDataType dtype ) {
337419 return Tensor_DataTypeSize (dtype );
338420}
339421
0 commit comments