Skip to content

Commit 2f6b4be

Browse files
authored
Support bool type for tensors (#775)
* Support tensors of type bool. Add validation that a input value doesn't overflows than the tensor type in TENSORSET. * Support bool tensor in backends * Use forked dlpack, that contains the new type kDLBool, use v0.5_RAI branch in dlpack instead of main
1 parent 993f978 commit 2f6b4be

File tree

10 files changed

+114
-47
lines changed

10 files changed

+114
-47
lines changed

get_deps.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ MKL=mkl
7070
ONNXRUNTIME=onnxruntime
7171

7272
######################################################################################## DLPACK
73-
DLPACK_VERSION="v0.4"
73+
DLPACK_VERSION="v0.5_RAI"
7474
if [[ $WITH_DLPACK != 0 ]]; then
7575
[[ $FORCE == 1 ]] && rm -rf $DLPACK
7676

7777
if [[ ! -d $DLPACK ]]; then
7878
echo "Cloning dlpack ..."
79-
git clone --depth 1 --branch $DLPACK_VERSION https://github.com/dmlc/dlpack.git $DLPACK
79+
git clone --depth 1 --branch $DLPACK_VERSION https://github.com/RedisAI/dlpack.git $DLPACK
8080
echo "Done."
8181
else
8282
echo "dlpack is in place."

src/backends/libtflite_c/tflite_c.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ static DLDataType getDLDataType(const TfLiteTensor *tensor) {
4242
dtype.bits = 16;
4343
dtype.code = DLDataTypeCode::kDLFloat;
4444
break;
45+
case kTfLiteBool:
46+
dtype.bits = 8;
47+
dtype.code = DLDataTypeCode::kDLBool;
48+
break;
4549
default:
4650
break;
4751
}
@@ -55,23 +59,6 @@ static DLDevice getDLDevice(const TfLiteTensor *tensor, const int64_t &device_id
5559
return device;
5660
}
5761

58-
#if 0
59-
static at::DeviceType getATenDeviceType(DLDeviceType device_type) {
60-
switch (device_type) {
61-
case DLDeviceType::kDLCPU:
62-
return at::DeviceType::CPU;
63-
case DLDeviceType::kDLGPU:
64-
return at::DeviceType::CUDA;
65-
case DLDeviceType::kDLOpenCL:
66-
return at::DeviceType::OPENCL;
67-
case DLDeviceType::kDLROCM:
68-
return at::DeviceType::HIP;
69-
default:
70-
throw std::logic_error("Unsupported device_type: " + std::to_string(device_type));
71-
}
72-
return at::DeviceType::CPU; // impossible
73-
}
74-
#endif
7562

7663
size_t dltensorBytes(DLManagedTensor *t) {
7764
int64_t *shape = t->dl_tensor.shape;
@@ -110,9 +97,10 @@ void copyToTfLiteTensor(std::shared_ptr<tflite::Interpreter> interpreter, int tf
11097
case kTfLiteFloat32:
11198
memcpy(interpreter->typed_tensor<float>(tflite_input), input->dl_tensor.data, nbytes);
11299
break;
100+
case kTfLiteBool:
101+
memcpy(interpreter->typed_tensor<bool>(tflite_input), input->dl_tensor.data, nbytes);
113102
case kTfLiteFloat16:
114103
throw std::logic_error("Float16 not currently supported as input tensor data type");
115-
break;
116104
default:
117105
throw std::logic_error("Unsupported input data type");
118106
}
@@ -174,9 +162,11 @@ DLManagedTensor *toManagedDLPack(std::shared_ptr<tflite::Interpreter> interprete
174162
case kTfLiteFloat32:
175163
memcpy(dl_tensor.data, interpreter->typed_tensor<float>(tflite_output), tensor->bytes);
176164
break;
165+
case kTfLiteBool:
166+
memcpy(dl_tensor.data, interpreter->typed_tensor<bool>(tflite_output), tensor->bytes);
167+
break;
177168
case kTfLiteFloat16:
178169
throw std::logic_error("Float16 not currently supported as output tensor data type");
179-
break;
180170
default:
181171
throw std::logic_error("Unsupported output data type");
182172
}
@@ -231,7 +221,7 @@ extern "C" void *tfliteLoadModel(const char *graph, size_t graphlen, DLDeviceTyp
231221
}
232222

233223
#if RAI_TFLITE_USE_CUDA
234-
if (device == DLDeviceType::kDLGPU) {
224+
if (device == DLDeviceType::kDLCUDA) {
235225
tflite::Interpreter::TfLiteDelegatePtr delegate =
236226
tflite::evaluation::CreateGPUDelegate(model.get());
237227
if (interpreter_->ModifyGraphWithDelegate(std::move(delegate)) != kTfLiteOk) {

src/backends/libtorch_c/torch_c.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ static DLDataType getDLDataType(const at::Tensor &t) {
4141
dtype.code = DLDataTypeCode::kDLFloat;
4242
break;
4343
case at::ScalarType::Bool:
44-
throw std::logic_error("Bool is not supported by dlpack");
44+
dtype.code = DLDataTypeCode::kDLBool;
45+
break;
4546
case at::ScalarType::BFloat16:
4647
throw std::logic_error("BFloat16 is not supported by dlpack");
4748
case at::ScalarType::QInt8:
@@ -68,7 +69,7 @@ static DLDevice getDLDevice(const at::Tensor &tensor, const int64_t &device_id)
6869
DLDevice device;
6970
device.device_id = device_id;
7071
if (tensor.is_cuda()) {
71-
device.device_type = DLDeviceType::kDLGPU;
72+
device.device_type = DLDeviceType::kDLCUDA;
7273
} else {
7374
device.device_type = DLDeviceType::kDLCPU;
7475
}
@@ -79,7 +80,7 @@ static at::DeviceType getATenDeviceType(DLDeviceType device_type) {
7980
switch (device_type) {
8081
case DLDeviceType::kDLCPU:
8182
return at::DeviceType::CPU;
82-
case DLDeviceType::kDLGPU:
83+
case DLDeviceType::kDLCUDA:
8384
return at::DeviceType::CUDA;
8485
case DLDeviceType::kDLOpenCL:
8586
return at::DeviceType::OPENCL;
@@ -138,6 +139,15 @@ at::ScalarType toScalarType(const DLDataType &dtype) {
138139
throw std::logic_error("Unsupported kFloat bits " + std::to_string(dtype.bits));
139140
}
140141
break;
142+
case DLDataTypeCode::kDLBool:
143+
switch (dtype.bits) {
144+
case 8:
145+
stype = at::ScalarType::Bool;
146+
break;
147+
default:
148+
throw std::logic_error("Unsupported kOpaque bits " + std::to_string(dtype.bits));
149+
}
150+
break;
141151
default:
142152
throw std::logic_error("Unsupported code " + std::to_string(dtype.code));
143153
}
@@ -310,7 +320,7 @@ static torch::DeviceType getDeviceType(ModuleContext *ctx) {
310320
switch (ctx->device) {
311321
case kDLCPU:
312322
return torch::kCPU;
313-
case kDLGPU:
323+
case kDLCUDA:
314324
return torch::kCUDA;
315325
default:
316326
throw std::runtime_error(std::string("Unsupported device ") + std::to_string(ctx->device));

src/backends/onnxruntime.c

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <cuda_provider_factory.h>
33
#include "backends/util.h"
44
#include <stdatomic.h>
5+
#include <math.h>
56
#include "util/arr.h"
67
#include "backends/onnxruntime.h"
78
#include "redis_ai_objects/tensor.h"
@@ -152,6 +153,14 @@ ONNXTensorElementDataType RAI_GetOrtDataTypeFromDL(DLDataType dtype) {
152153
default:
153154
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
154155
}
156+
} else if (dtype.code == kDLBool) {
157+
switch (dtype.bits) {
158+
case 8:
159+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
160+
break;
161+
default:
162+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
163+
}
155164
}
156165
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
157166
}
@@ -174,6 +183,8 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
174183
return (DLDataType){.code = kDLUInt, .bits = 8, .lanes = 1};
175184
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
176185
return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1};
186+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
187+
return (DLDataType){.code = kDLBool, .bits = 8, .lanes = 1};
177188
default:
178189
return (DLDataType){.bits = 0};
179190
}
@@ -281,7 +292,7 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
281292
size_t elem_count;
282293
ONNX_VALIDATE_STATUS(ort->GetTensorShapeElementCount(info, &elem_count))
283294

284-
const size_t len = dtype.bits * elem_count / 8;
295+
const size_t len = ceil((double)dtype.bits * elem_count / 8);
285296
const size_t total_bytesize = len * sizeof(char);
286297
const size_t sample_bytesize = total_bytesize / total_batch_size;
287298
const size_t batch_bytesize = sample_bytesize * batch_size;

src/backends/tensorflow.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,37 @@ TF_DataType RAI_GetTFDataTypeFromDL(DLDataType dtype) {
2424
switch (dtype.bits) {
2525
case 32:
2626
return TF_FLOAT;
27-
break;
2827
case 64:
2928
return TF_DOUBLE;
30-
break;
3129
default:
3230
return 0;
3331
}
3432
} else if (dtype.code == kDLInt) {
3533
switch (dtype.bits) {
3634
case 8:
3735
return TF_INT8;
38-
break;
3936
case 16:
4037
return TF_INT16;
41-
break;
4238
case 32:
4339
return TF_INT32;
44-
break;
4540
case 64:
4641
return TF_INT64;
47-
break;
4842
default:
4943
return 0;
5044
}
5145
} else if (dtype.code == kDLUInt) {
5246
switch (dtype.bits) {
5347
case 8:
5448
return TF_UINT8;
55-
break;
5649
case 16:
5750
return TF_UINT16;
58-
break;
51+
default:
52+
return 0;
53+
}
54+
} else if (dtype.code == kDLBool) {
55+
switch (dtype.bits) {
56+
case 8:
57+
return TF_BOOL;
5958
default:
6059
return 0;
6160
}
@@ -81,6 +80,8 @@ DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) {
8180
return (DLDataType){.code = kDLUInt, .bits = 8, .lanes = 1};
8281
case TF_UINT16:
8382
return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1};
83+
case TF_BOOL:
84+
return (DLDataType){.code = kDLBool, .bits = 8, .lanes = 1};
8485
default:
8586
return (DLDataType){.bits = 0};
8687
}

src/backends/tflite.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
3232
dl_device = kDLCPU;
3333
break;
3434
case RAI_DEVICE_GPU:
35-
dl_device = kDLGPU;
35+
dl_device = kDLCUDA;
3636
break;
3737
default:
3838
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Error configuring model: unsupported device");

src/backends/torch.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
5353
dl_device = kDLCPU;
5454
break;
5555
case RAI_DEVICE_GPU:
56-
dl_device = kDLGPU;
56+
dl_device = kDLCUDA;
5757
break;
5858
default:
5959
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Error configuring model: unsupported device");
@@ -304,7 +304,7 @@ RAI_Script *RAI_ScriptCreateTorch(const char *devicestr, const char *scriptdef,
304304
dl_device = kDLCPU;
305305
break;
306306
case RAI_DEVICE_GPU:
307-
dl_device = kDLGPU;
307+
dl_device = kDLCUDA;
308308
break;
309309
default:
310310
RAI_SetError(error, RAI_ESCRIPTCONFIGURE,

src/redis_ai_objects/tensor.c

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "tensor.h"
1515
#include "err.h"
1616
#include "arr.h"
17+
#include "math.h"
1718
#include "redisai.h"
1819
#include "version.h"
1920
#include "tensor_struct.h"
@@ -24,6 +25,27 @@
2425

2526
extern RedisModuleType *RedisAI_TensorType;
2627

28+
// Check if the given value is in the range of the tensor type.
29+
bool _ValOverflow(long long val, RAI_Tensor *t) {
30+
DLDataType dtype = t->tensor.dl_tensor.dtype;
31+
if (dtype.code == kDLInt) {
32+
unsigned long long max_abs_val = ((unsigned long long)1 << (uint)(dtype.bits - 1));
33+
if ((unsigned long long)val >= max_abs_val || val < -1 * (long long)max_abs_val) {
34+
return true;
35+
}
36+
} else if (dtype.code == kDLUInt) {
37+
uint max_val = (uint)1 << dtype.bits;
38+
if (val >= max_val || val < 0) {
39+
return true;
40+
}
41+
} else if (dtype.code == kDLBool) {
42+
if (val < 0 || val > 1) {
43+
return true;
44+
}
45+
}
46+
return false;
47+
}
48+
2749
DLDataType RAI_TensorDataTypeFromString(const char *typestr) {
2850
if (strcasecmp(typestr, RAI_DATATYPE_STR_FLOAT) == 0) {
2951
return (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1};
@@ -55,6 +77,9 @@ DLDataType RAI_TensorDataTypeFromString(const char *typestr) {
5577
return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1};
5678
}
5779
}
80+
if (strcasecmp(typestr, "BOOL") == 0) {
81+
return (DLDataType){.code = kDLBool, .bits = 8, .lanes = 1};
82+
}
5883
return (DLDataType){.bits = 0};
5984
}
6085

@@ -93,6 +118,9 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) {
93118
strcpy(dtypestr, RAI_DATATYPE_STR_UINT16);
94119
result = REDISMODULE_OK;
95120
}
121+
} else if (dtype.code == kDLBool && dtype.bits == 8) {
122+
strcpy(dtypestr, RAI_DATATYPE_STR_BOOL);
123+
result = REDISMODULE_OK;
96124
}
97125
return result;
98126
}
@@ -129,7 +157,7 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
129157
DLDevice device = (DLDevice){.device_type = kDLCPU, .device_id = 0};
130158

131159
// If we return an empty tensor, we initialize the data with zeros to avoid security
132-
// issues. Otherwise, we only allocate without initializing (for better performance)
160+
// issues. Otherwise, we only allocate without initializing (for better performance).
133161
void *data;
134162
if (empty) {
135163
data = RedisModule_Calloc(len, dtypeSize);
@@ -429,8 +457,12 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) {
429457
default:
430458
return 0;
431459
}
432-
} else {
433-
return 0;
460+
} else if (dtype.code == kDLBool) {
461+
if (dtype.bits == 8) {
462+
((uint8_t *)data)[i] = val;
463+
} else {
464+
return 0;
465+
}
434466
}
435467
return 1;
436468
}
@@ -518,8 +550,12 @@ int RAI_TensorGetValueAsLongLong(RAI_Tensor *t, long long i, long long *val) {
518550
default:
519551
return 0;
520552
}
521-
} else {
522-
return 0;
553+
} else if (dtype.code == kDLBool) {
554+
if (dtype.bits == 8) {
555+
*val = ((uint8_t *)data)[i];
556+
} else {
557+
return 0;
558+
}
523559
}
524560
return 1;
525561
}
@@ -707,7 +743,7 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
707743
} else {
708744
long long val;
709745
const int retval = RedisModule_StringToLongLong(argv[argpos], &val);
710-
if (retval != REDISMODULE_OK) {
746+
if (retval != REDISMODULE_OK || _ValOverflow(val, *t)) {
711747
RAI_TensorFree(*t);
712748
array_free(dims);
713749
RAI_SetError(error, RAI_ETENSORSET, "ERR invalid value");

src/redis_ai_objects/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ static const char *RAI_DATATYPE_STR_INT32 = "INT32";
3131
static const char *RAI_DATATYPE_STR_INT64 = "INT64";
3232
static const char *RAI_DATATYPE_STR_UINT8 = "UINT8";
3333
static const char *RAI_DATATYPE_STR_UINT16 = "UINT16";
34+
static const char *RAI_DATATYPE_STR_BOOL = "BOOL";
3435

3536
#define TENSOR_NONE 0
3637
#define TENSOR_VALUES (1 << 0)

0 commit comments

Comments
 (0)