@@ -78,7 +78,7 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
7878
7979/**
8080 * DAGRUN Building Block to parse [PERSIST <nkeys> key1 key2... ]
81- *
81+ * @param ctx Context in which Redis modules operate
8282 * @param argv Redis command arguments, as an array of strings
8383 * @param argc Redis command number of arguments
8484 * @param persistTensorsNames local hash table containing DAG's
@@ -87,8 +87,8 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
8787 * argument after the chaining operator is not considered
8888 * @return processed number of arguments on success, or -1 if the parsing failed
8989 */
90- static int _ParseDAGPersistArgs (RedisModuleString * * argv , int argc , AI_dict * persistTensorsNames ,
91- RAI_Error * err ) {
90+ static int _ParseDAGPersistArgs (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ,
91+ AI_dict * persistTensorsNames , RAI_Error * err ) {
9292 if (argc < 3 ) {
9393 RAI_SetError (err , RAI_EDAGBUILDER ,
9494 "ERR missing arguments after PERSIST keyword in DAG command" );
@@ -106,11 +106,16 @@ static int _ParseDAGPersistArgs(RedisModuleString **argv, int argc, AI_dict *per
106106 // Go over the given args and save the tensor key names to persist.
107107 int number_keys_to_persist = 0 ;
108108 for (size_t argpos = 2 ; (argpos < argc ) && (number_keys_to_persist < n_keys ); argpos ++ ) {
109- const char * arg_string = RedisModule_StringPtrLen (argv [argpos ], NULL );
110109 if (AI_dictFind (persistTensorsNames , (void * )argv [argpos ]) != NULL ) {
111110 RAI_SetError (err , RAI_EDAGBUILDER , "ERR PERSIST keys must be unique" );
112111 return -1 ;
113112 }
113+ if (!VerifyKeyInThisShard (ctx , argv [argpos ])) { // Relevant for enterprise cluster.
114+ RAI_SetError (
115+ err , RAI_EDAGBUILDER ,
116+ "ERR Found keys to persist in DAG command that don't hash to the local shard" );
117+ return -1 ;
118+ }
114119 AI_dictAdd (persistTensorsNames , (void * )argv [argpos ], NULL );
115120 number_keys_to_persist ++ ;
116121 }
@@ -267,7 +272,7 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
267272 /* Store the keys to persist in persistTensors dict, these keys will
268273 * be mapped later to the indices in the dagSharedTensors array in which the
269274 * tensors to persist will be found by the end of the DAG run. */
270- const int parse_result = _ParseDAGPersistArgs (& argv [arg_pos ], argc - arg_pos ,
275+ const int parse_result = _ParseDAGPersistArgs (ctx , & argv [arg_pos ], argc - arg_pos ,
271276 rinfo -> persistTensors , rinfo -> err );
272277 if (parse_result <= 0 )
273278 return REDISMODULE_ERR ;
0 commit comments