Skip to content

Commit 1a3f394

Browse files
authored
Add support for variadic arguments to SCRIPT (#395)
* Add support for variadic arguments to SCRIPT * Add negative errors
1 parent 0557118 commit 1a3f394

File tree

8 files changed

+174
-10
lines changed

8 files changed

+174
-10
lines changed

docs/commands.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,16 @@ The **`AI.SCRIPTRUN`** command runs a script stored as a key's value on its spec
458458
**Redis API**
459459

460460
```
461-
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] OUTPUTS <output> [output ...]
461+
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] [$ input ...] OUTPUTS <output> [output ...]
462462
```
463463

464464
_Arguments_
465465

466466
* **key**: the script's key name
467467
* **function**: the name of the function to run
468-
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names
468+
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names;
469+
variadic arguments are supported by prepending the list with `$`, in this case the
470+
script is expected an argument of type `List[Tensor]` as its last argument
469471
* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by one or more key names
470472

471473
_Return_
@@ -489,6 +491,29 @@ redis> AI.TENSORGET result VALUES
489491
3) 1) "42"
490492
```
491493

494+
If 'myscript' supports variadic arguments:
495+
```python
496+
def addn(a, args : List[Tensor]):
497+
return a + torch.stack(args).sum()
498+
```
499+
500+
then one can provide an arbitrary number of inputs after the `$` sign:
501+
502+
```
503+
redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40
504+
OK
505+
redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 1
506+
OK
507+
redis> AI.TENSORSET mytensor3 FLOAT 1 VALUES 1
508+
OK
509+
redis> AI.SCRIPTRUN myscript addn INPUTS mytensor1 $ mytensor2 mytensor3 OUTPUTS result
510+
OK
511+
redis> AI.TENSORGET result VALUES
512+
1) FLOAT
513+
2) 1) (integer) 1
514+
3) 1) "42"
515+
```
516+
492517
!!! warning "Intermediate memory overhead"
493518
The execution of scripts may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis.
494519

src/backends/torch.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* error) {
252252

253253
char* error_descr = NULL;
254254
torchRunScript(sctx->script->script, sctx->fnname,
255+
sctx->variadic,
255256
nInputs, inputs, nOutputs, outputs,
256257
&error_descr, RedisModule_Alloc);
257258

src/libtorch_c/torch_c.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ struct ModuleContext {
190190
int64_t device_id;
191191
};
192192

193-
void torchRunModule(ModuleContext* ctx, const char* fnName,
193+
void torchRunModule(ModuleContext* ctx, const char* fnName, int variadic,
194194
long nInputs, DLManagedTensor** inputs,
195195
long nOutputs, DLManagedTensor** outputs) {
196196
// Checks device, if GPU then move input to GPU before running
@@ -214,11 +214,25 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
214214
torch::jit::Stack stack;
215215

216216
for (int i=0; i<nInputs; i++) {
217+
if (i == variadic) {
218+
break;
219+
}
217220
DLTensor* input = &(inputs[i]->dl_tensor);
218221
torch::Tensor tensor = fromDLPack(input);
219222
stack.push_back(tensor.to(device));
220223
}
221224

225+
if (variadic != -1 ) {
226+
std::vector<torch::Tensor> args;
227+
for (int i=variadic; i<nInputs; i++) {
228+
DLTensor* input = &(inputs[i]->dl_tensor);
229+
torch::Tensor tensor = fromDLPack(input);
230+
tensor.to(device);
231+
args.emplace_back(tensor);
232+
}
233+
stack.push_back(args);
234+
}
235+
222236
if (ctx->module) {
223237
torch::NoGradGuard guard;
224238
torch::jit::script::Method method = ctx->module->get_method(fnName);
@@ -351,14 +365,14 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
351365
return ctx;
352366
}
353367

354-
extern "C" void torchRunScript(void* scriptCtx, const char* fnName,
368+
extern "C" void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
355369
long nInputs, DLManagedTensor** inputs,
356370
long nOutputs, DLManagedTensor** outputs,
357371
char **error, void* (*alloc)(size_t))
358372
{
359373
ModuleContext* ctx = (ModuleContext*)scriptCtx;
360374
try {
361-
torchRunModule(ctx, fnName, nInputs, inputs, nOutputs, outputs);
375+
torchRunModule(ctx, fnName, variadic, nInputs, inputs, nOutputs, outputs);
362376
}
363377
catch(std::exception& e) {
364378
const size_t len = strlen(e.what());
@@ -376,7 +390,7 @@ extern "C" void torchRunModel(void* modelCtx,
376390
{
377391
ModuleContext* ctx = (ModuleContext*)modelCtx;
378392
try {
379-
torchRunModule(ctx, "forward", nInputs, inputs, nOutputs, outputs);
393+
torchRunModule(ctx, "forward", -1, nInputs, inputs, nOutputs, outputs);
380394
}
381395
catch(std::exception& e) {
382396
const size_t len = strlen(e.what());

src/libtorch_c/torch_c.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void* torchCompileScript(const char* script, DLDeviceType device, int64_t device
1919
void* torchLoadModel(const char* model, size_t modellen, DLDeviceType device, int64_t device_id,
2020
char **error, void* (*alloc)(size_t));
2121

22-
void torchRunScript(void* scriptCtx, const char* fnName,
22+
void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
2323
long nInputs, DLManagedTensor** inputs,
2424
long nOutputs, DLManagedTensor** outputs,
2525
char **error, void* (*alloc)(size_t));

src/script.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
150150
sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
151151
sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
152152
sctx->fnname = RedisModule_Strdup(fnname);
153+
sctx->variadic = -1;
153154
return sctx;
154155
}
155156

@@ -285,6 +286,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
285286
is_input = 1;
286287
outputs_flag_count = 1;
287288
} else {
289+
if (!strcasecmp(arg_string, "$")) {
290+
(*sctx)->variadic = argpos - 4;
291+
continue;
292+
}
288293
RedisModule_RetainString(ctx, argv[argpos]);
289294
if (is_input == 0) {
290295
RAI_Tensor *inputTensor;
@@ -299,18 +304,18 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
299304
RedisModule_CloseKey(tensorKey);
300305
} else {
301306
const int get_result = RAI_getTensorFromLocalContext(
302-
ctx, *localContextDict, arg_string, &inputTensor,error);
307+
ctx, *localContextDict, arg_string, &inputTensor, error);
303308
if (get_result == REDISMODULE_ERR) {
304309
return -1;
305310
}
306311
}
307312
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) {
308-
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Input key not found");
313+
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found");
309314
return -1;
310315
}
311316
} else {
312317
if (!RAI_ScriptRunCtxAddOutput(*sctx)) {
313-
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Output key not found");
318+
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found");
314319
return -1;
315320
}
316321
*outkeys=array_append(*outkeys,argv[argpos]);

src/script_struct.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ typedef struct RAI_ScriptRunCtx {
2727
char* fnname;
2828
RAI_ScriptCtxParam* inputs;
2929
RAI_ScriptCtxParam* outputs;
30+
int variadic;
3031
} RAI_ScriptRunCtx;
3132

3233
#endif /* SRC_SCRIPT_STRUCT_H_ */

test/test_data/script.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
def bar(a, b):
22
return a + b
3+
4+
def bar_variadic(a, args : List[Tensor]):
5+
return args[0] + args[1]

test/tests_pytorch.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,61 @@ def test_pytorch_scriptrun(env):
426426
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
427427
env.assertEqual(values2, values)
428428

429+
430+
def test_pytorch_scriptrun_variadic(env):
431+
if not TEST_PT:
432+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
433+
return
434+
435+
con = env.getConnection()
436+
437+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
438+
script_filename = os.path.join(test_data_path, 'script.txt')
439+
440+
with open(script_filename, 'rb') as f:
441+
script = f.read()
442+
443+
ret = con.execute_command('AI.SCRIPTSET', 'myscript', DEVICE, 'TAG', 'version1', 'SOURCE', script)
444+
env.assertEqual(ret, b'OK')
445+
446+
ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
447+
env.assertEqual(ret, b'OK')
448+
ret = con.execute_command('AI.TENSORSET', 'b1', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
449+
env.assertEqual(ret, b'OK')
450+
ret = con.execute_command('AI.TENSORSET', 'b2', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
451+
env.assertEqual(ret, b'OK')
452+
453+
ensureSlaveSynced(con, env)
454+
455+
for _ in range( 0,100):
456+
ret = con.execute_command('AI.SCRIPTRUN', 'myscript', 'bar_variadic', 'INPUTS', 'a', '$', 'b1', 'b2', 'OUTPUTS', 'c')
457+
env.assertEqual(ret, b'OK')
458+
459+
ensureSlaveSynced(con, env)
460+
461+
info = con.execute_command('AI.INFO', 'myscript')
462+
info_dict_0 = info_to_dict(info)
463+
464+
env.assertEqual(info_dict_0['key'], 'myscript')
465+
env.assertEqual(info_dict_0['type'], 'SCRIPT')
466+
env.assertEqual(info_dict_0['backend'], 'TORCH')
467+
env.assertEqual(info_dict_0['tag'], 'version1')
468+
env.assertTrue(info_dict_0['duration'] > 0)
469+
env.assertEqual(info_dict_0['samples'], -1)
470+
env.assertEqual(info_dict_0['calls'], 100)
471+
env.assertEqual(info_dict_0['errors'], 0)
472+
473+
values = con.execute_command('AI.TENSORGET', 'c', 'VALUES')
474+
env.assertEqual(values, [b'4', b'6', b'4', b'6'])
475+
476+
ensureSlaveSynced(con, env)
477+
478+
if env.useSlaves:
479+
con2 = env.getSlaveConnection()
480+
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
481+
env.assertEqual(values2, values)
482+
483+
429484
def test_pytorch_scriptrun_errors(env):
430485
if not TEST_PT:
431486
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
@@ -528,6 +583,66 @@ def test_pytorch_scriptrun_errors(env):
528583
env.assertEqual(type(exception), redis.exceptions.ResponseError)
529584

530585

586+
def test_pytorch_scriptrun_errors(env):
587+
if not TEST_PT:
588+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
589+
return
590+
591+
con = env.getConnection()
592+
593+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
594+
script_filename = os.path.join(test_data_path, 'script.txt')
595+
596+
with open(script_filename, 'rb') as f:
597+
script = f.read()
598+
599+
ret = con.execute_command('AI.SCRIPTSET', 'ket', DEVICE, 'TAG', 'asdf', 'SOURCE', script)
600+
env.assertEqual(ret, b'OK')
601+
602+
ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
603+
env.assertEqual(ret, b'OK')
604+
ret = con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
605+
env.assertEqual(ret, b'OK')
606+
607+
ensureSlaveSynced(con, env)
608+
609+
# ERR Variadic input key is empty
610+
try:
611+
con.execute_command('DEL', 'EMPTY')
612+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$', 'EMPTY', 'b', 'OUTPUTS', 'c')
613+
except Exception as e:
614+
exception = e
615+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
616+
env.assertEqual("tensor key is empty", exception.__str__())
617+
618+
# ERR Variadic input key not tensor
619+
try:
620+
con.execute_command('SET', 'NOT_TENSOR', 'BAR')
621+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$' , 'NOT_TENSOR', 'b', 'OUTPUTS', 'c')
622+
except Exception as e:
623+
exception = e
624+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
625+
env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__())
626+
627+
try:
628+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS', 'c')
629+
except Exception as e:
630+
exception = e
631+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
632+
633+
try:
634+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS')
635+
except Exception as e:
636+
exception = e
637+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
638+
639+
try:
640+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS')
641+
except Exception as e:
642+
exception = e
643+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
644+
645+
531646
def test_pytorch_scriptinfo(env):
532647
if not TEST_PT:
533648
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)

0 commit comments

Comments
 (0)