Skip to content

Commit f37e501

Browse files
committed
Merge branch 'master' into Support_BOOL_type_for_tensors
2 parents 2962804 + 993f978 commit f37e501

32 files changed

+474
-318
lines changed

src/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ file (GLOB BACKEND_COMMON_SRC
1717
util/dict.c
1818
util/dictionaries.c
1919
redis_ai_objects/tensor.c
20+
redis_ai_objects/model.c
21+
redis_ai_objects/stats.c
22+
redis_ai_objects/script.c
2023
util/string_utils.c
2124
execution/utils.c
25+
execution/execution_contexts/execution_ctx.c
2226
serialization/ai_datatypes.c)
2327

2428
ADD_LIBRARY(redisai_obj OBJECT
@@ -41,6 +45,7 @@ ADD_LIBRARY(redisai_obj OBJECT
4145
execution/DAG/dag_builder.c
4246
execution/DAG/dag_execute.c
4347
execution/DAG/dag_op.c
48+
execution/execution_contexts/execution_ctx.c
4449
execution/execution_contexts/modelRun_ctx.c
4550
execution/execution_contexts/scriptRun_ctx.c
4651
backends/backends.c

src/backends/backends.c

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) {
111111
return REDISMODULE_ERR;
112112
}
113113

114-
backend.model_run =
115-
(int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunTF");
114+
backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))(
115+
unsigned long)dlsym(handle, "RAI_ModelRunTF");
116116
if (backend.model_run == NULL) {
117117
dlclose(handle);
118118
RedisModule_Log(ctx, "warning",
@@ -202,8 +202,8 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
202202
return REDISMODULE_ERR;
203203
}
204204

205-
backend.model_run = (int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(
206-
handle, "RAI_ModelRunTFLite");
205+
backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))(
206+
unsigned long)dlsym(handle, "RAI_ModelRunTFLite");
207207
if (backend.model_run == NULL) {
208208
dlclose(handle);
209209
RedisModule_Log(ctx, "warning",
@@ -294,8 +294,8 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
294294
return REDISMODULE_ERR;
295295
}
296296

297-
backend.model_run =
298-
(int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunTorch");
297+
backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))(
298+
unsigned long)dlsym(handle, "RAI_ModelRunTorch");
299299
if (backend.model_run == NULL) {
300300
dlclose(handle);
301301
RedisModule_Log(ctx, "warning",
@@ -338,8 +338,8 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
338338
return REDISMODULE_ERR;
339339
}
340340

341-
backend.script_run = (int (*)(RAI_ScriptRunCtx *, RAI_Error *))(unsigned long)dlsym(
342-
handle, "RAI_ScriptRunTorch");
341+
backend.script_run = (int (*)(RAI_Script *, const char *, RAI_ExecutionCtx *, RAI_Error *))(
342+
unsigned long)dlsym(handle, "RAI_ScriptRunTorch");
343343
if (backend.script_run == NULL) {
344344
dlclose(handle);
345345
RedisModule_Log(ctx, "warning",
@@ -418,8 +418,8 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) {
418418
return REDISMODULE_ERR;
419419
}
420420

421-
backend.model_run =
422-
(int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunORT");
421+
backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))(
422+
unsigned long)dlsym(handle, "RAI_ModelRunORT");
423423
if (backend.model_run == NULL) {
424424
dlclose(handle);
425425
RedisModule_Log(ctx, "warning",

src/backends/backends.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
#include "config/config.h"
1212
#include "redis_ai_objects/err.h"
1313
#include "redis_ai_objects/tensor.h"
14-
#include "redis_ai_objects/model_struct.h"
15-
#include "redis_ai_objects/script_struct.h"
14+
#include "redis_ai_objects/model.h"
15+
#include "redis_ai_objects/script.h"
16+
#include "execution/execution_contexts/execution_ctx.h"
1617

1718
/*
1819
* To register a new backend to be loaded by the module, the backend needs to
@@ -25,7 +26,7 @@
2526
* the RAI_ModelOpts.
2627
*
2728
* * ** model_run **: A callback function pointer that runs a model given the
28-
* RAI_ModelRunCtx pointer.
29+
* RAI_Model pointer and an array of RAI_ExecutionCtx pointers.
2930
*
3031
* * ** model_serialize **: A callback function pointer that serializes a model
3132
* given the RAI_Model pointer.
@@ -36,7 +37,7 @@
3637
* the RAI_Script pointer.
3738
*
3839
* * ** script_run **: A callback function pointer that runs a model given the
39-
* RAI_ScriptRunCtx pointer.
40+
* RAI_Script pointer and .
4041
*/
4142
typedef struct RAI_LoadedBackend {
4243
// ** model_create_with_nodes **: A callback function pointer that creates a
@@ -55,8 +56,8 @@ typedef struct RAI_LoadedBackend {
5556
void (*model_free)(RAI_Model *, RAI_Error *);
5657

5758
// ** model_run **: A callback function pointer that runs a model given the
58-
// RAI_ModelRunCtx pointer
59-
int (*model_run)(RAI_ModelRunCtx **, RAI_Error *);
59+
// RAI_Model pointer and an array of RAI_ExecutionCtx pointers
60+
int (*model_run)(RAI_Model *, RAI_ExecutionCtx **, RAI_Error *);
6061

6162
// ** model_serialize **: A callback function pointer that serializes a model
6263
// given the RAI_Model pointer
@@ -71,7 +72,7 @@ typedef struct RAI_LoadedBackend {
7172

7273
// ** script_run **: A callback function pointer that runs a model given the
7374
// RAI_ScriptRunCtx pointer
74-
int (*script_run)(RAI_ScriptRunCtx *, RAI_Error *);
75+
int (*script_run)(RAI_Script *, const char *function, RAI_ExecutionCtx *, RAI_Error *);
7576

7677
// Returns the backend version.
7778
const char *(*get_version)(void);

src/backends/onnxruntime.c

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,16 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
482482
ort->ReleaseStatus(status);
483483
}
484484

485-
int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
485+
int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error) {
486486
const OrtApi *ort = OrtGetApiBase()->GetApi(1);
487487

488-
OrtSession *session = mctxs[0]->model->session;
488+
OrtSession *session = RAI_ModelGetSession(model);
489489
if (session == NULL) {
490490
RAI_SetError(error, RAI_EMODELRUN, "ERR ONNXRuntime session was not allocated");
491491
return REDISMODULE_ERR;
492492
}
493493

494-
const size_t nbatches = array_len(mctxs);
494+
const size_t nbatches = array_len(ectxs);
495495
if (nbatches == 0) {
496496
RAI_SetError(error, RAI_EMODELRUN, "ERR No batches to run");
497497
return REDISMODULE_ERR;
@@ -500,9 +500,11 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
500500
size_t batch_sizes[nbatches];
501501
size_t batch_offsets[nbatches];
502502
size_t total_batch_size = 0;
503-
if (array_len(mctxs[0]->inputs) > 0) {
503+
const size_t ninputs = RAI_ExecutionCtx_NumInputs(ectxs[0]);
504+
const size_t noutputs = RAI_ExecutionCtx_NumOutputs(ectxs[0]);
505+
if (ninputs > 0) {
504506
for (size_t b = 0; b < nbatches; ++b) {
505-
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
507+
batch_sizes[b] = RAI_TensorDim(RAI_ExecutionCtx_GetInput(ectxs[b], 0), 0);
506508
total_batch_size += batch_sizes[b];
507509
}
508510
batch_offsets[0] = 0;
@@ -512,8 +514,6 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
512514
}
513515

514516
OrtStatus *status = NULL;
515-
const size_t ninputs = array_len(mctxs[0]->inputs);
516-
const size_t noutputs = array_len(mctxs[0]->outputs);
517517
array_new_on_stack(const char *, 5, input_names);
518518
array_new_on_stack(const char *, 5, output_names);
519519
array_new_on_stack(OrtValue *, 5, inputs);
@@ -547,7 +547,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
547547

548548
RAI_Tensor *batched_input_tensors[nbatches];
549549
for (size_t b = 0; b < nbatches; b++) {
550-
batched_input_tensors[b] = mctxs[b]->inputs[i].tensor;
550+
batched_input_tensors[b] = RAI_ExecutionCtx_GetInput(ectxs[b], i);
551551
}
552552
OrtValue *input;
553553
if (RAI_OrtValueFromTensors(batched_input_tensors, nbatches, &input, &status) !=
@@ -600,8 +600,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
600600
goto error;
601601
}
602602
if (output_tensor) {
603-
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
604-
RAI_TensorFree(output_tensor);
603+
RAI_ExecutionCtx_SetOutput(ectxs[b], output_tensor, i);
605604
} else {
606605
RedisModule_Log(NULL, "warning",
607606
"non-tensor output from ONNX models, ignoring (currently "
@@ -614,8 +613,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
614613
goto error;
615614
}
616615
if (output_tensor) {
617-
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
618-
RAI_TensorFree(output_tensor);
616+
RAI_ExecutionCtx_SetOutput(ectxs[0], output_tensor, i);
619617
} else {
620618
RedisModule_Log(NULL, "warning",
621619
"non-tensor output from ONNX models, ignoring (currently "

src/backends/onnxruntime.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
#include "config/config.h"
44
#include "redis_ai_objects/err.h"
5-
#include "redis_ai_objects/tensor_struct.h"
6-
#include "redis_ai_objects/model_struct.h"
5+
#include "redis_ai_objects/model.h"
6+
#include "execution/execution_contexts/execution_ctx.h"
77

88
unsigned long long RAI_GetMemoryInfoORT(void);
99

@@ -16,7 +16,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
1616

1717
void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error);
1818

19-
int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error);
19+
int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error);
2020

2121
int RAI_ModelSerializeORT(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error);
2222

src/backends/tensorflow.c

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "backends/util.h"
33
#include "backends/tensorflow.h"
44
#include "util/arr.h"
5+
#include "execution/execution_contexts/modelRun_ctx.h"
56
#include "redis_ai_objects/model.h"
67
#include "redis_ai_objects/tensor.h"
78

@@ -461,17 +462,17 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
461462
TF_DeleteStatus(status);
462463
}
463464

464-
int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
465+
int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error) {
465466
TF_Status *status = TF_NewStatus();
466467

467-
const size_t nbatches = array_len(mctxs);
468+
const size_t nbatches = array_len(ectxs);
468469
if (nbatches == 0) {
469470
RAI_SetError(error, RAI_EMODELRUN, "ERR No batches to run");
470471
return 1;
471472
}
472473

473-
const size_t ninputs = array_len(mctxs[0]->inputs);
474-
const size_t noutputs = array_len(mctxs[0]->outputs);
474+
const size_t ninputs = RAI_ExecutionCtx_NumInputs(ectxs[0]);
475+
const size_t noutputs = RAI_ExecutionCtx_NumOutputs(ectxs[0]);
475476
TF_Tensor *inputTensorsValues[ninputs];
476477
TF_Output inputs[ninputs];
477478
TF_Tensor *outputTensorsValues[noutputs];
@@ -482,7 +483,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
482483
size_t total_batch_size = 0;
483484
if (ninputs > 0) {
484485
for (size_t b = 0; b < nbatches; ++b) {
485-
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
486+
batch_sizes[b] = RAI_TensorDim(RAI_ExecutionCtx_GetInput(ectxs[b], 0), 0);
486487
total_batch_size += batch_sizes[b];
487488
}
488489
batch_offsets[0] = 0;
@@ -491,15 +492,18 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
491492
}
492493
}
493494

495+
void *tfGraph = RAI_ModelGetModel(model);
496+
void *tfSession = RAI_ModelGetSession(model);
497+
494498
for (size_t i = 0; i < ninputs; ++i) {
495499
RAI_Tensor *batched_input_tensors[nbatches];
496500

497501
for (size_t b = 0; b < nbatches; ++b) {
498-
batched_input_tensors[b] = mctxs[b]->inputs[i].tensor;
502+
batched_input_tensors[b] = RAI_ExecutionCtx_GetInput(ectxs[b], i);
499503
}
500504
inputTensorsValues[i] = RAI_TFTensorFromTensors(batched_input_tensors, nbatches);
501505
TF_Output port;
502-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->inputs[i].name);
506+
port.oper = TF_GraphOperationByName(tfGraph, RAI_ModelGetInputName(model, i));
503507
port.index = 0;
504508
if (port.oper == NULL) {
505509
return 1;
@@ -509,17 +513,17 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
509513

510514
for (size_t i = 0; i < noutputs; ++i) {
511515
TF_Output port;
512-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->outputs[i].name);
516+
port.oper = TF_GraphOperationByName(tfGraph, RAI_ModelGetOutputName(model, i));
513517
port.index = 0;
514518
if (port.oper == NULL) {
515519
return 1;
516520
}
517521
outputs[i] = port;
518522
}
519523

520-
TF_SessionRun(mctxs[0]->model->session, NULL /* run_options */, inputs, inputTensorsValues,
521-
ninputs, outputs, outputTensorsValues, noutputs, NULL /* target_opers */,
522-
0 /* ntargets */, NULL /* run_Metadata */, status);
524+
TF_SessionRun(tfSession, NULL /* run_options */, inputs, inputTensorsValues, ninputs, outputs,
525+
outputTensorsValues, noutputs, NULL /* target_opers */, 0 /* ntargets */,
526+
NULL /* run_Metadata */, status);
523527

524528
for (size_t i = 0; i < ninputs; ++i) {
525529
TF_DeleteTensor(inputTensorsValues[i]);
@@ -547,12 +551,15 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
547551
}
548552

549553
for (size_t b = 0; b < nbatches; b++) {
550-
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(
551-
outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
554+
RAI_ExecutionCtx_SetOutput(ectxs[b],
555+
RAI_TensorCreateFromTFTensor(outputTensorsValues[i],
556+
batch_offsets[b],
557+
batch_sizes[b]),
558+
i);
552559
}
553560
} else {
554-
mctxs[0]->outputs[i].tensor =
555-
RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
561+
RAI_ExecutionCtx_SetOutput(
562+
ectxs[0], RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1), i);
556563
}
557564
TF_DeleteTensor(outputTensorsValues[i]);
558565
}

src/backends/tensorflow.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
#include "config/config.h"
44
#include "redis_ai_objects/err.h"
5-
#include "redis_ai_objects/tensor_struct.h"
6-
#include "redis_ai_objects/model_struct.h"
5+
#include "redis_ai_objects/model.h"
6+
#include "execution/execution_contexts/execution_ctx.h"
77

88
int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *));
99

@@ -14,7 +14,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
1414

1515
void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error);
1616

17-
int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error);
17+
int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error);
1818

1919
int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error);
2020

0 commit comments

Comments
 (0)