Skip to content

Commit d06d0d9

Browse files
committed
Add auto-batching to TFLite backend
1 parent dd589fa commit d06d0d9

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

src/backends/tflite.c

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,52 @@ void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) {
8383

8484
int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) {
8585

86-
size_t ninputs = array_len(mctx->inputs);
87-
size_t noutputs = array_len(mctx->outputs);
86+
const size_t nbatches = array_len(mctx->batches);
87+
if (nbatches == 0) {
88+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
89+
return 1;
90+
}
91+
92+
size_t total_batch_size = 0;
93+
size_t batch_sizes[nbatches];
94+
size_t batch_offsets[nbatches];
95+
if (array_len(mctx->batches[0].inputs) > 0) {
96+
for (size_t b=0; b<nbatches; ++b) {
97+
batch_sizes[b] = RAI_TensorDim(mctx->batches[b].inputs[0].tensor, 0);
98+
total_batch_size += batch_sizes[b];
99+
}
100+
batch_offsets[0] = 0;
101+
for (size_t b=1; b<nbatches; ++b) {
102+
batch_offsets[b] = batch_sizes[b-1];
103+
}
104+
}
105+
106+
size_t ninputs = array_len(mctx->batches[0].inputs);
107+
size_t noutputs = array_len(mctx->batches[0].outputs);
108+
109+
RAI_Tensor* inputs[ninputs];
88110

89-
DLManagedTensor* inputs[ninputs];
90-
DLManagedTensor* outputs[noutputs];
111+
DLManagedTensor* inputs_dl[ninputs];
112+
DLManagedTensor* outputs_dl[noutputs];
91113

92114
for (size_t i=0 ; i<ninputs; ++i) {
93-
inputs[i] = &mctx->inputs[i].tensor->tensor;
115+
RAI_Tensor* batch[nbatches];
116+
117+
for (size_t b=0; b<nbatches; b++) {
118+
batch[b] = mctx->batches[b].inputs[i].tensor;
119+
}
120+
121+
inputs[i] = RAI_TensorCreateByConcatenatingTensors(batch, nbatches);
122+
inputs_dl[i] = &inputs[i]->tensor;
94123
}
95124

96125
for (size_t i=0 ; i<noutputs; ++i) {
97-
outputs[i] = mctx->outputs[i].tensor ? &mctx->outputs[i].tensor->tensor : NULL;
126+
outputs_dl[i] = NULL;
98127
}
99128

100129
char* error_descr = NULL;
101130
tfliteRunModel(mctx->model->model,
102-
ninputs, inputs, noutputs, outputs,
131+
ninputs, inputs_dl, noutputs, outputs_dl,
103132
&error_descr, RedisModule_Alloc);
104133

105134
if (error_descr != NULL) {
@@ -108,16 +137,22 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) {
108137
return 1;
109138
}
110139

111-
for(size_t i=0 ; i<array_len(mctx->outputs) ; ++i) {
112-
if (outputs[i] == NULL) {
140+
for(size_t i=0 ; i<noutputs; ++i) {
141+
if (outputs_dl[i] == NULL) {
113142
RAI_SetError(error, RAI_EMODELRUN, "Model did not generate the expected number of outputs.");
114143
return 1;
115144
}
116-
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs[i]);
117-
mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
145+
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]);
146+
for (size_t b=0; b<nbatches; b++) {
147+
mctx->batches[b].outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]);
148+
}
118149
RAI_TensorFree(output_tensor);
119150
}
120151

152+
for (size_t i=0 ; i<ninputs; ++i) {
153+
RAI_TensorFree(inputs[i]);
154+
}
155+
121156
return 0;
122157
}
123158

src/redisai.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,13 +888,16 @@ int RedisAI_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
888888
REDISMODULE_NOT_USED(argc);
889889
struct RedisAI_RunInfo *rinfo = RedisModule_GetBlockedClientPrivateData(ctx);
890890

891+
printf("A\n");
891892
if (rinfo->status) {
892893
RedisModule_Log(ctx, "warning", "ERR %s", rinfo->err->detail);
894+
printf("A1\n");
893895
int ret = RedisModule_ReplyWithError(ctx, rinfo->err->detail_oneline);
894896
RedisAI_FreeRunInfo(ctx, rinfo);
895897
return ret;
896898
}
897899

900+
printf("B\n");
898901
size_t num_outputs = 0;
899902
if (rinfo->mctx) {
900903
(rinfo->mctx->model->backend_calls)++;

0 commit comments

Comments
 (0)