|
14 | 14 | #include "tensor.h" |
15 | 15 | #include "err.h" |
16 | 16 | #include "arr.h" |
| 17 | +#include "math.h" |
17 | 18 | #include "redisai.h" |
18 | 19 | #include "version.h" |
19 | 20 | #include "tensor_struct.h" |
|
24 | 25 |
|
25 | 26 | extern RedisModuleType *RedisAI_TensorType; |
26 | 27 |
|
| 28 | +// Check if the given value is in the range of the tensor type. |
| 29 | +bool _ValOverflow(long long val, RAI_Tensor *t) { |
| 30 | + DLDataType dtype = t->tensor.dl_tensor.dtype; |
| 31 | + if (dtype.code == kDLInt) { |
| 32 | + unsigned long long max_abs_val = ((unsigned long long)1 << (uint)(dtype.bits - 1)); |
| 33 | + if ((unsigned long long)val >= max_abs_val || val < -1 * (long long)max_abs_val) { |
| 34 | + return true; |
| 35 | + } |
| 36 | + } else if (dtype.code == kDLUInt) { |
| 37 | + uint max_val = (uint)1 << dtype.bits; |
| 38 | + if (val >= max_val || val < 0) { |
| 39 | + return true; |
| 40 | + } |
| 41 | + } else if (dtype.code == kDLBool) { |
| 42 | + if (val < 0 || val > 1) { |
| 43 | + return true; |
| 44 | + } |
| 45 | + } |
| 46 | + return false; |
| 47 | +} |
| 48 | + |
27 | 49 | DLDataType RAI_TensorDataTypeFromString(const char *typestr) { |
28 | 50 | if (strcasecmp(typestr, RAI_DATATYPE_STR_FLOAT) == 0) { |
29 | 51 | return (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1}; |
@@ -55,6 +77,9 @@ DLDataType RAI_TensorDataTypeFromString(const char *typestr) { |
55 | 77 | return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1}; |
56 | 78 | } |
57 | 79 | } |
| 80 | + if (strcasecmp(typestr, "BOOL") == 0) { |
| 81 | + return (DLDataType){.code = kDLBool, .bits = 8, .lanes = 1}; |
| 82 | + } |
58 | 83 | return (DLDataType){.bits = 0}; |
59 | 84 | } |
60 | 85 |
|
@@ -93,6 +118,9 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) { |
93 | 118 | strcpy(dtypestr, RAI_DATATYPE_STR_UINT16); |
94 | 119 | result = REDISMODULE_OK; |
95 | 120 | } |
| 121 | + } else if (dtype.code == kDLBool && dtype.bits == 8) { |
| 122 | + strcpy(dtypestr, RAI_DATATYPE_STR_BOOL); |
| 123 | + result = REDISMODULE_OK; |
96 | 124 | } |
97 | 125 | return result; |
98 | 126 | } |
@@ -129,7 +157,7 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in |
129 | 157 | DLDevice device = (DLDevice){.device_type = kDLCPU, .device_id = 0}; |
130 | 158 |
|
131 | 159 | // If we return an empty tensor, we initialize the data with zeros to avoid security |
132 | | - // issues. Otherwise, we only allocate without initializing (for better performance) |
| 160 | + // issues. Otherwise, we only allocate without initializing (for better performance). |
133 | 161 | void *data; |
134 | 162 | if (empty) { |
135 | 163 | data = RedisModule_Calloc(len, dtypeSize); |
@@ -429,8 +457,12 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) { |
429 | 457 | default: |
430 | 458 | return 0; |
431 | 459 | } |
432 | | - } else { |
433 | | - return 0; |
| 460 | + } else if (dtype.code == kDLBool) { |
| 461 | + if (dtype.bits == 8) { |
| 462 | + ((uint8_t *)data)[i] = val; |
| 463 | + } else { |
| 464 | + return 0; |
| 465 | + } |
434 | 466 | } |
435 | 467 | return 1; |
436 | 468 | } |
@@ -518,8 +550,12 @@ int RAI_TensorGetValueAsLongLong(RAI_Tensor *t, long long i, long long *val) { |
518 | 550 | default: |
519 | 551 | return 0; |
520 | 552 | } |
521 | | - } else { |
522 | | - return 0; |
| 553 | + } else if (dtype.code == kDLBool) { |
| 554 | + if (dtype.bits == 8) { |
| 555 | + *val = ((uint8_t *)data)[i]; |
| 556 | + } else { |
| 557 | + return 0; |
| 558 | + } |
523 | 559 | } |
524 | 560 | return 1; |
525 | 561 | } |
@@ -707,7 +743,7 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i |
707 | 743 | } else { |
708 | 744 | long long val; |
709 | 745 | const int retval = RedisModule_StringToLongLong(argv[argpos], &val); |
710 | | - if (retval != REDISMODULE_OK) { |
| 746 | + if (retval != REDISMODULE_OK || _ValOverflow(val, *t)) { |
711 | 747 | RAI_TensorFree(*t); |
712 | 748 | array_free(dims); |
713 | 749 | RAI_SetError(error, RAI_ETENSORSET, "ERR invalid value"); |
|
0 commit comments