1616#include "util/arr_rm_alloc.h"
1717#include "util/dict.h"
1818
19+
20+ static uint64_t RAI_TensorDictKeyHashFunction (const void * key ){
21+ return AI_dictGenHashFunction (key , strlen ((char * )key ));
22+ }
23+
24+ static int RAI_TensorDictKeyStrcmp (void * privdata , const void * key1 , const void * key2 ){
25+ const char * strKey1 = key1 ;
26+ const char * strKey2 = key2 ;
27+ return strcmp (strKey1 , strKey2 ) == 0 ;
28+ }
29+
30+ static void RAI_TensorDictKeyFree (void * privdata , void * key ){
31+ RedisModule_Free (key );
32+ }
33+
34+ static void * RAI_TensorDictKeyDup (void * privdata , const void * key ){
35+ return RedisModule_Strdup ((char * )key );
36+ }
37+
38+ static void RAI_TensorDictValFree (void * privdata , const void * obj ){
39+ return RAI_TensorFree ((RAI_Tensor * )obj );
40+ }
41+
42+
43+ AI_dictType AI_dictTypeTensorVals = {
44+ .hashFunction = RAI_TensorDictKeyHashFunction ,
45+ .keyDup = RAI_TensorDictKeyDup ,
46+ .valDup = NULL ,
47+ .keyCompare = RAI_TensorDictKeyStrcmp ,
48+ .keyDestructor = RAI_TensorDictKeyFree ,
49+ .valDestructor = RAI_TensorDictValFree ,
50+ };
51+
52+
1953/**
2054 * Allocate the memory and initialise the RAI_DagOp.
2155 * @param result Output parameter to capture allocated RAI_DagOp.
@@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
76110 return REDISMODULE_ERR ;
77111 }
78112 rinfo -> use_local_context = 0 ;
79- rinfo -> dagTensorsContext = AI_dictCreate (& AI_dictTypeHeapStrings , NULL );
113+ rinfo -> dagTensorsContext = AI_dictCreate (& AI_dictTypeTensorVals , NULL );
80114 if (!(rinfo -> dagTensorsContext )) {
81115 return REDISMODULE_ERR ;
82116 }
@@ -116,6 +150,13 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) {
116150 }
117151 array_free (dagOp -> outTensors );
118152
153+ if (dagOp -> mctx ) {
154+ RAI_ModelRunCtxFree (dagOp -> mctx , false);
155+ }
156+ if (dagOp -> sctx ) {
157+ RAI_ScriptRunCtxFree (dagOp -> sctx , false);
158+ }
159+
119160 RedisModule_Free (dagOp );
120161 }
121162}
@@ -125,37 +166,48 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
125166 return ;
126167 }
127168 if (rinfo -> mctx ) {
128- RAI_ModelRunCtxFree (rinfo -> mctx );
169+ RAI_ModelRunCtxFree (rinfo -> mctx , true );
129170 }
130171 if (rinfo -> sctx ) {
131- RAI_ScriptRunCtxFree (rinfo -> sctx );
172+ RAI_ScriptRunCtxFree (rinfo -> sctx , true );
132173 }
133174 RAI_FreeError (rinfo -> err );
134175
135176 if (rinfo -> dagTensorsContext ) {
136177 AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
137- AI_dictEntry * stats_entry = AI_dictNext (iter );
178+ AI_dictEntry * entry = AI_dictNext (iter );
138179 RAI_Tensor * tensor = NULL ;
139180
140- while (stats_entry ) {
141- tensor = AI_dictGetVal (stats_entry );
142- char * key = (char * )AI_dictGetKey (stats_entry );
181+ while (entry ) {
182+ tensor = AI_dictGetVal (entry );
183+ char * key = (char * )AI_dictGetKey (entry );
143184
144- if (tensor && key != NULL ) {
185+ if (tensor && key != NULL ) {
145186 // if the key is persistent then we should not delete it
146187 AI_dictEntry * persistent_entry =
147188 AI_dictFind (rinfo -> dagTensorsPersistentContext , key );
148- // if the key was loaded from the keyspace then we should not delete
149- // it
189+ // if the key was loaded from the keyspace then we should not delete it
150190 AI_dictEntry * loaded_entry =
151191 AI_dictFind (rinfo -> dagTensorsLoadedContext , key );
192+
152193 if (persistent_entry == NULL && loaded_entry == NULL ) {
153- RAI_TensorFree (tensor );
194+ AI_dictDelete (rinfo -> dagTensorsContext , key );
195+ }
196+
197+ if (persistent_entry ) {
198+ AI_dictDelete (rinfo -> dagTensorsPersistentContext , key );
199+ }
200+ if (loaded_entry ) {
201+ AI_dictDelete (rinfo -> dagTensorsLoadedContext , key );
154202 }
155203 }
156- stats_entry = AI_dictNext (iter );
204+ entry = AI_dictNext (iter );
157205 }
158206 AI_dictReleaseIterator (iter );
207+
208+ RedisModule_Free (rinfo -> dagTensorsContext );
209+ RedisModule_Free (rinfo -> dagTensorsLoadedContext );
210+ RedisModule_Free (rinfo -> dagTensorsPersistentContext );
159211 }
160212
161213 if (rinfo -> dagOps ) {
0 commit comments