Skip to content

Commit 2b3f391

Browse files
committed
Small refactor of RDB encode/decode (so we won't copy blob on rdb load) + small fixes
1 parent 8a1d3b7 commit 2b3f391

File tree

7 files changed

+39
-26
lines changed

7 files changed

+39
-26
lines changed

src/redis_ai_objects/tensor.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ static int _RAI_TensorParseStringsBlob(const char *tensor_blob, size_t blob_len,
136136
}
137137
}
138138
if (tensor_blob[blob_len - 1] != '\0' || elements_counter != tensor_len) {
139-
RAI_SetError(err, RAI_ETENSORSET,
140-
"ERR Number of string elements in data blob does not match tensor length");
139+
if (err) {
140+
RAI_SetError(err, RAI_ETENSORSET,
141+
"ERR Number of string elements in data blob does not match tensor length");
142+
}
141143
return REDISMODULE_ERR;
142144
}
143145
return REDISMODULE_OK;
@@ -647,11 +649,9 @@ int RAI_TensorGetValueAsCString(RAI_Tensor *t, long long i, const char **val) {
647649
int RAI_TensorSetData(RAI_Tensor *t, const char *data, size_t len) {
648650
DLDataType data_type = RAI_TensorDataType(t);
649651
if (data_type.code == kDLString) {
650-
RAI_Error err = {0};
651652
if (_RAI_TensorParseStringsBlob(data, len, RAI_TensorLength(t),
652653
RAI_TensorStringElementsOffsets(t),
653-
&err) != REDISMODULE_OK) {
654-
RAI_ClearError(&err);
654+
NULL) != REDISMODULE_OK) {
655655
return 0;
656656
}
657657
RedisModule_Free(RAI_TensorData(t));

src/serialization/RDB/decoder/current/v4/decode_v4.c

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,36 @@ void *RAI_RDBLoadTensor_v4(RedisModuleIO *io) {
1515
DLDataType data_type = (DLDataType){.code = code, .bits = bits, .lanes = 1};
1616

1717
int ndims = (int)RedisModule_LoadSigned(io);
18-
int64_t shape[ndims];
18+
size_t shape[ndims];
1919
for (size_t i = 0; i < ndims; ++i) {
2020
shape[i] = RedisModule_LoadSigned(io);
2121
}
2222

23-
RAI_Error err = {0};
23+
RAI_Tensor *tensor = RAI_TensorNew(data_type, shape, ndims);
24+
2425
size_t blob_len;
2526
char *data = RedisModule_LoadStringBuffer(io, &blob_len);
2627
if (RedisModule_IsIOError(io))
2728
goto error;
28-
RAI_Tensor *t =
29-
RAI_TensorCreateFromBlob(data_type, (const size_t *)shape, ndims, data, blob_len, &err);
30-
RedisModule_Free(data);
31-
if (t == NULL)
32-
goto error;
3329

34-
return t;
30+
tensor->blobSize = blob_len;
31+
tensor->tensor.dl_tensor.data = data;
32+
33+
if (data_type.code == kDLString) {
34+
for (size_t i = 0; i < RAI_TensorLength(tensor); i++) {
35+
tensor->tensor.dl_tensor.elements_length[i] = RedisModule_LoadUnsigned(io);
36+
}
37+
}
38+
if (RedisModule_IsIOError(io))
39+
goto error;
40+
return tensor;
3541

3642
error:
37-
RAI_ClearError(&err);
3843
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a tensor from RDB");
44+
RAI_TensorFree(tensor);
45+
if (data) {
46+
RedisModule_Free(data);
47+
}
3948
return NULL;
4049
}
4150

src/serialization/RDB/decoder/previous/v0/decode_v0.c

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ void *RAI_RDBLoadTensor_v0(RedisModuleIO *io) {
1717
dtype.code = RedisModule_LoadUnsigned(io);
1818
dtype.lanes = RedisModule_LoadUnsigned(io);
1919

20-
size_t ndims = RedisModule_LoadUnsigned(io);
21-
int64_t shape[ndims];
20+
int ndims = RedisModule_LoadUnsigned(io);
21+
size_t shape[ndims];
2222
for (size_t i = 0; i < ndims; ++i) {
2323
shape[i] = RedisModule_LoadUnsigned(io);
2424
}
@@ -29,22 +29,18 @@ void *RAI_RDBLoadTensor_v0(RedisModuleIO *io) {
2929
}
3030
size_t byte_offset = RedisModule_LoadUnsigned(io);
3131

32-
RAI_Error err = {0};
3332
size_t blob_len;
3433
char *data = RedisModule_LoadStringBuffer(io, &blob_len);
3534
if (RedisModule_IsIOError(io))
3635
goto error;
3736

38-
RAI_Tensor *t =
39-
RAI_TensorCreateFromBlob(dtype, (const size_t *)shape, (int)ndims, data, blob_len, &err);
40-
RedisModule_Free(data);
41-
if (t == NULL)
42-
goto error;
37+
RAI_Tensor *tensor = RAI_TensorNew(dtype, shape, ndims);
38+
tensor->blobSize = blob_len;
39+
tensor->tensor.dl_tensor.data = data;
4340

44-
return t;
41+
return tensor;
4542

4643
error:
47-
RAI_ClearError(&err);
4844
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a tensor from RDB");
4945
return NULL;
5046
}

src/serialization/RDB/encoder/v4/encode_v4.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ void RAI_RDBSaveTensor_v4(RedisModuleIO *io, void *value) {
1414

1515
size_t size = RAI_TensorByteSize(tensor);
1616
RedisModule_SaveStringBuffer(io, tensor->tensor.dl_tensor.data, size);
17+
18+
if (tensor->tensor.dl_tensor.dtype.code == kDLString) {
19+
for (size_t i = 0; i < RAI_TensorLength(tensor); i++) {
20+
RedisModule_SaveUnsigned(io, tensor->tensor.dl_tensor.elements_length[i]);
21+
}
22+
}
1723
}
1824

1925
void RAI_RDBSaveModel_v4(RedisModuleIO *io, void *value) {

tests/flow/test_serializations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def test_v4_onnx_model(self):
322322
def test_v4_tensor(self):
323323
key_name = "tensor{1}"
324324
con = get_connection(self.env, key_name)
325-
tensor_rdb = b"\a\x81\x00\x8f\xd3\x10\xd4\x8eD\x04\x02\x00\x02 \x02\x02\x02\x02\x02\x01\x05\b\x01\x00\x00\x00\x02\x00\x00\x00\x00\t\x00Viy\xab4\xbe\xdd\x82"
325+
tensor_rdb = b'\x07\x81\x00\x8f\xd3\x10\xd4\x8eD\x04\x02\x00\x02 \x02\x02\x02\x02\x02\x01\x05\x08\x01\x00\x00\x00\x02\x00\x00\x00\x00\t\x00Viy\xab4\xbe\xdd\x82'
326326
self.env.assertEqual(con.execute_command('FLUSHALL'), True)
327327
con.restore(key_name, 0, tensor_rdb, True)
328328
_, tensor_type, _, tensor_shape = con.execute_command('AI.TENSORGET', key_name, 'META')
@@ -331,7 +331,7 @@ def test_v4_tensor(self):
331331
self.env.assertEqual(values, [1, 2])
332332

333333
# test RDB load of string tensor
334-
str_tensor_rdb = b"\a\x81\x00\x8f\xd3\x10\xd4\x8eD\x04\x02\a\x02\b\x02\x01\x02\x02\x05\x12str_val1\x00str_val2\x00\x00\t\x00\xf9;\xe0\xd12.\x06z"
334+
str_tensor_rdb = b'\x07\x81\x00\x8f\xd3\x10\xd4\x8eD\x04\x02\x07\x02\x08\x02\x01\x02\x02\x05\x12str_val1\x00str_val2\x00\x02\x00\x02\t\x00\t\x00\x8b\x05Z\x0f:\x877O'
335335
con.restore('string_tensor{1}', 0, str_tensor_rdb, True)
336336
_, tensor_type, _, tensor_shape = con.execute_command('AI.TENSORGET', 'string_tensor{1}', 'META')
337337
self.env.assertEqual([tensor_type, tensor_shape], [b"STRING", [2]])

tests/flow/tests_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_onnx_string_tensors(env):
9090
env.assertEqual(tensor_values, [b'input11', b'input12', b'input21', b'input22'])
9191

9292
if env.useSlaves:
93+
ensureSlaveSynced(con, env)
9394
slave_con = env.getSlaveConnection()
9495
slave_tensor_values = slave_con.execute_command('AI.TENSORGET', 'out_tensor{1}', 'VALUES')
9596
env.assertEqual(tensor_values, slave_tensor_values)

tests/flow/tests_tensorflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def test_tf_string_tensors(env):
726726
b'here we want to test a string longer than 24 chars, to force heap alloc in tf'])
727727

728728
if env.useSlaves:
729+
ensureSlaveSynced(con, env)
729730
slave_con = env.getSlaveConnection()
730731
slave_tensor_values = slave_con.execute_command('AI.TENSORGET', 'out_tensor{1}', 'VALUES')
731732
env.assertEqual(tensor_values, slave_tensor_values)

0 commit comments

Comments
 (0)