Skip to content

Commit d434e4b

Browse files
committed
Fix torch API, still WIP
1 parent b4ae009 commit d434e4b

File tree

6 files changed

+149
-27
lines changed

6 files changed

+149
-27
lines changed

src/backends/torch.c

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,52 @@ void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) {
6565

6666
int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *error) {
6767

68-
size_t ninputs = array_len(mctx->inputs);
69-
size_t noutputs = array_len(mctx->outputs);
68+
const size_t nbatches = array_len(mctx->batches);
69+
if (nbatches == 0) {
70+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
71+
return 1;
72+
}
73+
74+
size_t total_batch_size = 0;
75+
size_t batch_sizes[nbatches];
76+
size_t batch_offsets[nbatches];
77+
if (array_len(mctx->batches[0].inputs) > 0) {
78+
for (size_t b=0; b<nbatches; ++b) {
79+
batch_sizes[b] = RAI_TensorDim(mctx->batches[b].inputs[0].tensor, 0);
80+
total_batch_size += batch_sizes[b];
81+
}
82+
batch_offsets[0] = 0;
83+
for (size_t b=1; b<nbatches; ++b) {
84+
batch_offsets[b] = batch_sizes[b-1];
85+
}
86+
}
87+
88+
size_t ninputs = array_len(mctx->batches[0].inputs);
89+
size_t noutputs = array_len(mctx->batches[0].outputs);
90+
91+
RAI_Tensor* inputs[ninputs];
7092

71-
DLManagedTensor* inputs[ninputs];
72-
DLManagedTensor* outputs[noutputs];
93+
DLManagedTensor* inputs_dl[ninputs];
94+
DLManagedTensor* outputs_dl[noutputs];
7395

7496
for (size_t i=0 ; i<ninputs; ++i) {
75-
inputs[i] = &mctx->inputs[i].tensor->tensor;
97+
RAI_Tensor* batch[nbatches];
98+
99+
for (size_t b=0; b<nbatches; b++) {
100+
batch[b] = mctx->batches[b].inputs[i].tensor;
101+
}
102+
103+
inputs[i] = RAI_TensorCreateByConcatenatingTensors(batch, nbatches);
104+
inputs_dl[i] = &inputs[i]->tensor;
76105
}
77106

78107
for (size_t i=0 ; i<noutputs; ++i) {
79-
outputs[i] = mctx->outputs[i].tensor ? &mctx->outputs[i].tensor->tensor : NULL;
108+
outputs_dl[i] = mctx->outputs[i].tensor ? &mctx->outputs[i].tensor->tensor : NULL;
80109
}
81110

82111
char* error_descr = NULL;
83112
torchRunModel(mctx->model->model,
84-
ninputs, inputs, noutputs, outputs,
113+
ninputs, inputs_dl, noutputs, outputs_dl,
85114
&error_descr, RedisModule_Alloc);
86115

87116
if (error_descr != NULL) {
@@ -90,16 +119,22 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *error) {
90119
return 1;
91120
}
92121

93-
for(size_t i=0 ; i<array_len(mctx->outputs) ; ++i) {
94-
if (outputs[i] == NULL) {
122+
for(size_t i=0 ; i<noutputs; ++i) {
123+
if (outputs_dl[i] == NULL) {
95124
RAI_SetError(error, RAI_EMODELRUN, "Model did not generate the expected number of outputs.");
96125
return 1;
97126
}
98-
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs[i]);
99-
mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
127+
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]);
128+
for (size_t b=0; b<nbatches; b++) {
129+
mctx->batches[b].outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]);
130+
}
100131
RAI_TensorFree(output_tensor);
101132
}
102133

134+
for (size_t i=0 ; i<ninputs; ++i) {
135+
RAI_TensorFree(inputs[i]);
136+
}
137+
103138
return 0;
104139
}
105140

src/model_struct.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ typedef struct RAI_ModelRunCtx {
3636
size_t ctxtype;
3737
RAI_Model* model;
3838
RAI_ModelCtxBatch* batches;
39-
// TODO: REMOVE THIS
40-
RAI_ModelCtxParam* inputs;
41-
RAI_ModelCtxParam* outputs;
42-
//
4339
} RAI_ModelRunCtx;
4440

4541
#endif /* SRC_MODEL_STRUCT_H_ */

src/redisai.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
234234
const char* typestr;
235235
AC_GetString(&ac, &typestr, NULL, 0);
236236

237-
size_t datasize = RAI_TensorGetDataSize(typestr);
237+
size_t datasize = RAI_TensorDataSizeFromString(typestr);
238238
if (!datasize){
239239
return RedisModule_ReplyWithError(ctx, "ERR invalid data type");
240240
}
@@ -1557,7 +1557,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) {
15571557
REGISTER_API(GetLLAPIVersion, ctx);
15581558

15591559
REGISTER_API(TensorCreate, ctx);
1560-
REGISTER_API(TensorGetDataSize, ctx);
1560+
REGISTER_API(TensorDataSize, ctx);
15611561
REGISTER_API(TensorFree, ctx);
15621562
REGISTER_API(TensorSetData, ctx);
15631563
REGISTER_API(TensorSetValueFromLongLong, ctx);

src/redisai.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ typedef struct RAI_Error RAI_Error;
3232
#define REDISAI_INFOMSG_THREADS_PER_QUEUE "Setting THREADS_PER_QUEUE parameter to"
3333

3434
RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreate)(const char* dataTypeStr, long long* dims, int ndims);
35+
RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreateByConcatenatingTensors)(RAI_Tensor** ts, long long n);
36+
RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreateBySlicingTensor)(RAI_Tensor* t, long long offset, long long len);
3537
size_t MODULE_API_FUNC(RedisAI_TensorLength)(RAI_Tensor* t);
36-
size_t MODULE_API_FUNC(RedisAI_TensorGetDataSize)(const char* dataTypeStr);
38+
size_t MODULE_API_FUNC(RedisAI_TensorDataSize)(RAI_Tensor* t);
3739
size_t MODULE_API_FUNC(RedisAI_TensorDataType)(RAI_Tensor* t);
3840
void MODULE_API_FUNC(RedisAI_TensorFree)(RAI_Tensor* t);
3941
int MODULE_API_FUNC(RedisAI_TensorSetData)(RAI_Tensor* tensor, const char* data, size_t len);
@@ -92,7 +94,9 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){
9294
REDISAI_MODULE_INIT_FUNCTION(ctx, GetLLAPIVersion);
9395

9496
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreate);
95-
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorGetDataSize);
97+
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreateByConcatenatingTensors);
98+
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreateBySlicingTensor);
99+
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorDataSize);
96100
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorFree);
97101
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorSetData);
98102
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorSetValueFromLongLong);

src/tensor.c

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
RedisModuleType *RedisAI_TensorType = NULL;
1010

11-
static DLDataType Tensor_GetDataType(const char* typestr){
11+
DLDataType RAI_TensorDataTypeFromString(const char* typestr){
1212
if (strcasecmp(typestr, "FLOAT") == 0){
1313
return (DLDataType){ .code = kDLFloat, .bits = 32, .lanes = 1};
1414
}
@@ -223,8 +223,7 @@ int RAI_TensorInit(RedisModuleCtx* ctx){
223223
return RedisAI_TensorType != NULL;
224224
}
225225

226-
RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims, int hasdata) {
227-
DLDataType dtype = Tensor_GetDataType(dataTypeStr);
226+
RAI_Tensor* RAI_TensorCreateWithDLDataType(DLDataType dtype, long long* dims, int ndims, int hasdata) {
228227
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
229228
if ( dtypeSize == 0){
230229
return NULL;
@@ -279,6 +278,11 @@ RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims
279278
return ret;
280279
}
281280

281+
RAI_Tensor* RAI_TensorCreate(const char* dataType, long long* dims, int ndims, int hasdata) {
282+
DLDataType dtype = RAI_TensorDataTypeFromString(dataType);
283+
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, hasdata);
284+
}
285+
282286
#if 0
283287
void RAI_TensorMoveFrom(RAI_Tensor* dst, RAI_Tensor* src) {
284288
if (--dst->refCount <= 0){
@@ -296,6 +300,76 @@ void RAI_TensorMoveFrom(RAI_Tensor* dst, RAI_Tensor* src) {
296300
}
297301
#endif
298302

303+
RAI_Tensor* RAI_TensorCreateByConcatenatingTensors(RAI_Tensor** ts, long long n) {
304+
305+
if (n == 0) {
306+
return NULL;
307+
}
308+
309+
long long total_batch_size = 0;
310+
long long batch_sizes[n];
311+
long long batch_offsets[n];
312+
313+
long long ndims = RAI_TensorNumDims(ts[0]);
314+
long long dims[ndims];
315+
316+
// TODO check that all tensors have compatible dims
317+
318+
for (long long i=0; i<n; i++) {
319+
batch_sizes[i] = RAI_TensorDim(ts[i], 0);
320+
total_batch_size += batch_sizes[i];
321+
}
322+
323+
batch_offsets[0] = 0;
324+
for (long long i=1; i<n; i++) {
325+
batch_offsets[i] = batch_sizes[i-1];
326+
}
327+
328+
long long sample_size = 0;
329+
330+
for (long long i=1; i<ndims; i++) {
331+
dims[i] = RAI_TensorDim(ts[0], i);
332+
sample_size *= dims[i];
333+
}
334+
dims[0] = total_batch_size;
335+
336+
long long dtype_size = RAI_TensorDataSize(ts[0]);
337+
338+
DLDataType dtype = RAI_TensorDataType(ts[0]);
339+
340+
RAI_Tensor* ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims, 1);
341+
342+
for (long long i=0; i<n; i++) {
343+
memcpy(RAI_TensorData(ret) + batch_offsets[i] * sample_size * dtype_size, RAI_TensorData(ts[i]), RAI_TensorByteSize(ts[i]));
344+
}
345+
346+
return ret;
347+
}
348+
349+
RAI_Tensor* RAI_TensorCreateBySlicingTensor(RAI_Tensor* t, long long offset, long long len) {
350+
351+
long long ndims = RAI_TensorNumDims(t);
352+
long long dims[ndims];
353+
354+
long long dtype_size = RAI_TensorDataSize(t);
355+
long long sample_size = 0;
356+
357+
for (long long i=1; i<ndims; i++) {
358+
dims[i] = RAI_TensorDim(t, i);
359+
sample_size *= dims[i];
360+
}
361+
362+
dims[0] = len;
363+
364+
DLDataType dtype = RAI_TensorDataType(t);
365+
366+
RAI_Tensor* ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims, 1);
367+
368+
memcpy(RAI_TensorData(ret), RAI_TensorData(t) + offset * sample_size * dtype_size, len * sample_size * dtype_size);
369+
370+
return ret;
371+
}
372+
299373
// Beware: this will take ownership of dltensor
300374
RAI_Tensor* RAI_TensorCreateFromDLTensor(DLManagedTensor* dl_tensor) {
301375

@@ -332,8 +406,16 @@ size_t RAI_TensorLength(RAI_Tensor* t) {
332406
return len;
333407
}
334408

335-
size_t RAI_TensorGetDataSize(const char* dataTypeStr) {
336-
DLDataType dtype = Tensor_GetDataType(dataTypeStr);
409+
size_t RAI_TensorDataSize(RAI_Tensor* t) {
410+
return Tensor_DataTypeSize(RAI_TensorDataType(t));
411+
}
412+
413+
size_t RAI_TensorDataSizeFromString(const char* dataTypeStr) {
414+
DLDataType dtype = RAI_TensorDataTypeFromString(dataTypeStr);
415+
return Tensor_DataTypeSize(dtype);
416+
}
417+
418+
size_t RAI_TensorDataSizeFromDLDataType(DLDataType dtype) {
337419
return Tensor_DataTypeSize(dtype);
338420
}
339421

src/tensor.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99
extern RedisModuleType *RedisAI_TensorType;
1010

1111
int RAI_TensorInit(RedisModuleCtx* ctx);
12-
RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims, int hasdata);
12+
RAI_Tensor* RAI_TensorCreate(const char* dataType, long long* dims, int ndims, int hasdata);
13+
RAI_Tensor* RAI_TensorCreateWithDLDataType(DLDataType dtype, long long* dims, int ndims, int hasdata);
1314
RAI_Tensor* RAI_TensorCreateFromDLTensor(DLManagedTensor* dl_tensor);
15+
RAI_Tensor* RAI_TensorCreateByConcatenatingTensors(RAI_Tensor** ts, long long n);
16+
RAI_Tensor* RAI_TensorCreateBySlicingTensor(RAI_Tensor* t, long long offset, long long len);
1417
size_t RAI_TensorLength(RAI_Tensor* t);
15-
size_t RAI_TensorGetDataSize(const char* dataTypeStr);
18+
size_t RAI_TensorDataSize(RAI_Tensor* t);
19+
size_t RAI_TensorDataSizeFromDLDataType(DLDataType dtype);
20+
size_t RAI_TensorDataSizeFromString(const char* dataType);
1621
DLDataType RAI_TensorDataType(RAI_Tensor* t);
17-
void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr);
22+
DLDataType RAI_TensorDataTypeFromString(const char* dataType);
1823
void RAI_TensorFree(RAI_Tensor* t);
1924
int RAI_TensorSetData(RAI_Tensor* t, const char* data, size_t len);
2025
int RAI_TensorSetValueFromLongLong(RAI_Tensor* t, long long i, long long val);

0 commit comments

Comments
 (0)