66#include "util/arr.h"
77#include "backends/onnxruntime.h"
88#include "redis_ai_objects/tensor.h"
9+ #include "onnx_allocator/onnx_allocator.h"
910
1011#include "onnxruntime_c_api.h"
1112#include "backends_api.h"
@@ -21,63 +22,7 @@ OrtEnv *env = NULL;
2122// For model that run on GPU, onnx will not use the custom allocator (redis allocator), but
2223// the onnx allocator for GPU. But for the auxiliary allocations of the input and output names,
2324// we will use the custom global allocator for models that run on GPU as well.
24- OrtMemoryInfo * mem_info = NULL ;
2525OrtAllocator * global_allocator = NULL ;
26- unsigned long long OnnxMemory = 0 ;
27- unsigned long long OnnxMemoryAccessCounter = 0 ;
28-
29- const OrtMemoryInfo * AllocatorInfo (const OrtAllocator * allocator ) {
30- (void )allocator ;
31- const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
32- if (mem_info != NULL ) {
33- return mem_info ;
34- }
35- if (ort -> CreateCpuMemoryInfo (OrtDeviceAllocator , OrtMemTypeDefault , & mem_info ) != NULL ) {
36- return NULL ;
37- }
38- return mem_info ;
39- }
40-
41- // Allocate address with 64-byte alignment to cope with onnx optimizations.
42- void * AllocatorAlloc (OrtAllocator * ptr , size_t size ) {
43-
44- (void )ptr ;
45- // Allocate an additional 63 bytes to ensure that we can return an address which is
46- // 64-byte aligned, and an additional space in the size of a pointer to store
47- // the address that RedisModule_Alloc returns.
48- int offset = 63 + sizeof (void * );
49- void * allocated_address = (void * )RedisModule_Alloc (size + offset );
50- size_t allocated_size = RedisModule_MallocSize (allocated_address );
51- // Update the total number of bytes that onnx is using and the number of accesses
52- // that onnx made to the allocator.
53- atomic_fetch_add (& OnnxMemory , allocated_size );
54- atomic_fetch_add (& OnnxMemoryAccessCounter , 1 );
55- // This operation guarantees that p2 is the closest 64-aligned address to (p1+size_t).
56- void * * aligned_address = (void * * )(((size_t )(allocated_address ) + offset ) & (~63 ));
57- // This stores the address p1 right before p2 (so we can retrieve it when we free).
58- aligned_address [-1 ] = allocated_address ;
59- return aligned_address ;
60- }
61-
62- void AllocatorFree (OrtAllocator * ptr , void * aligned_address ) {
63- (void )ptr ;
64- if (aligned_address == NULL ) {
65- return ;
66- }
67- // Retrieve the address that we originally received from RedisModule_Alloc
68- // (this is the address that we need to sent to RedisModule_Free).
69- void * allocated_address = ((void * * )aligned_address )[-1 ];
70- size_t allocated_size = RedisModule_MallocSize (allocated_address );
71- // Update the total number of bytes that onnx is using and the number of accesses
72- // that onnx made to the allocator.
73- atomic_fetch_sub (& OnnxMemory , allocated_size );
74- atomic_fetch_add (& OnnxMemoryAccessCounter , 1 );
75- return RedisModule_Free (allocated_address );
76- }
77-
78- unsigned long long RAI_GetMemoryInfoORT () { return OnnxMemory ; }
79-
80- unsigned long long RAI_GetMemoryAccessORT () { return OnnxMemoryAccessCounter ; }
8126
8227int RAI_InitBackendORT (int (* get_api_fn )(const char * , void * * )) {
8328 // Export redis callbacks.
@@ -95,6 +40,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) {
9540 get_api_fn ("GetThreadId" , ((void * * )& RedisAI_GetThreadId ));
9641 get_api_fn ("GetNumThreadsPerQueue" , ((void * * )& RedisAI_GetNumThreadsPerQueue ));
9742 get_api_fn ("GetModelExecutionTimeout" , ((void * * )& RedisAI_GetModelExecutionTimeout ));
43+ get_api_fn ("GetBackendMemoryLimit" , ((void * * )& RedisAI_GetMemoryLimit ));
9844 get_api_fn ("GetThreadsCount" , ((void * * )& RedisAI_GetThreadsCount ));
9945
10046 // Create a global array of onnx runSessions, with an entry for every working thread.
@@ -389,8 +335,9 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
389335 // allocating buffers when creating and running models that run on CPU, and for allocations of
390336 // models inputs and outputs names (for both models that run on CPU and GPU)
391337 if (env == NULL ) {
392- ONNX_VALIDATE_STATUS (ort -> CreateEnv (ORT_LOGGING_LEVEL_WARNING , "test" , & env ))
393- ONNX_VALIDATE_STATUS (ort -> GetAllocatorWithDefaultOptions (& global_allocator ));
338+ ONNX_VALIDATE_STATUS (ort -> CreateEnv (ORT_LOGGING_LEVEL_WARNING , "RedisAI" , & env ))
339+ global_allocator = CreateCustomAllocator (RedisAI_GetMemoryLimit ());
340+ ONNX_VALIDATE_STATUS (ort -> RegisterAllocator (env , global_allocator ))
394341 }
395342
396343 ONNX_VALIDATE_STATUS (ort -> CreateSessionOptions (& session_options ))
0 commit comments