diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb
index 9568669d..3d399dcc 100644
--- a/docs/user_guide/02_hybrid_queries.ipynb
+++ b/docs/user_guide/02_hybrid_queries.ipynb
@@ -22,7 +22,7 @@
{
"data": {
"text/html": [
- "
user | age | job | credit_score | office_location | user_embedding |
---|
john | 18 | engineer | high | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
derrick | 14 | doctor | low | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
nancy | 94 | doctor | high | -122.4194,37.7749 | b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
tyler | 100 | engineer | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' |
tim | 12 | dermatologist | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' |
taimur | 15 | CEO | low | -122.0839,37.3861 | b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
joe | 35 | dentist | medium | -122.0839,37.3861 | b'fff?fff?\\xcd\\xcc\\xcc=' |
"
+ "user | age | job | credit_score | office_location | user_embedding | last_updated |
---|
john | 18 | engineer | high | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' | 1741627789 |
derrick | 14 | doctor | low | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' | 1741627789 |
nancy | 94 | doctor | high | -122.4194,37.7749 | b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' | 1710696589 |
tyler | 100 | engineer | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' | 1742232589 |
tim | 12 | dermatologist | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' | 1739644189 |
taimur | 15 | CEO | low | -122.0839,37.3861 | b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' | 1742232589 |
joe | 35 | dentist | medium | -122.0839,37.3861 | b'fff?fff?\\xcd\\xcc\\xcc=' | 1742232589 |
"
],
"text/plain": [
""
@@ -58,6 +58,7 @@
" {\"name\": \"credit_score\", \"type\": \"tag\"},\n",
" {\"name\": \"job\", \"type\": \"text\"},\n",
" {\"name\": \"age\", \"type\": \"numeric\"},\n",
+ " {\"name\": \"last_updated\", \"type\": \"numeric\"},\n",
" {\"name\": \"office_location\", \"type\": \"geo\"},\n",
" {\n",
" \"name\": \"user_embedding\",\n",
@@ -83,7 +84,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "13:02:18 redisvl.index.index INFO Index already exists, overwriting.\n"
+ "11:40:25 redisvl.index.index INFO Index already exists, overwriting.\n"
]
}
],
@@ -99,28 +100,9 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_cache\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_session\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float16_session\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float32_session\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. float32_cache\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. bfloat_cache\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. user_queries\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. student tutor\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 10. tutor\n",
- "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 11. bfloat_session\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# use the CLI to see the created index\n",
"!rvl index listall"
@@ -128,7 +110,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@@ -136,6 +118,26 @@
"keys = index.load(data)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "7"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "index.info()['num_docs']"
+ ]
+ },
{
"attachments": {},
"cell_type": "markdown",
@@ -157,13 +159,13 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 | 1741627789 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 | 1742232589 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 | 1739644189 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 | 1710696589 |
"
],
"text/plain": [
""
@@ -182,7 +184,7 @@
"v = VectorQuery(\n",
" vector=[0.1, 0.1, 0.5],\n",
" vector_field_name=\"user_embedding\",\n",
- " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n",
+ " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\", \"last_updated\"],\n",
" filter_expression=t\n",
")\n",
"\n",
@@ -192,13 +194,13 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 | 1741627789 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 | 1742232589 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 | 1742232589 |
"
],
"text/plain": [
""
@@ -316,13 +318,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 | 1741627789 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 | 1742232589 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 | 1742232589 |
"
],
"text/plain": [
""
@@ -335,7 +337,7 @@
"source": [
"from redisvl.query.filter import Num\n",
"\n",
- "numeric_filter = Num(\"age\") > 15\n",
+ "numeric_filter = Num(\"age\").between(15, 35)\n",
"\n",
"v.set_filter(numeric_filter)\n",
"result_print(index.query(v))"
@@ -393,6 +395,132 @@
"result_print(index.query(v))"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Timestamp Filters\n",
+ "\n",
+ "In redis all times are stored as an epoch time numeric however, this class allows you to filter with python datetime for ease of use. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch comparison: 1742147139.132589\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 | 1742232589 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 | 1742232589 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 | 1742232589 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query.filter import Timestamp\n",
+ "from datetime import datetime\n",
+ "\n",
+ "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n",
+ "print(f'Epoch comparison: {dt.timestamp()}')\n",
+ "\n",
+ "timestamp_filter = Timestamp(\"last_updated\") > dt\n",
+ "\n",
+ "v.set_filter(timestamp_filter)\n",
+ "result_print(index.query(v))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch comparison: 1742147139.132589\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 | 1741627789 |
0 | john | high | 18 | engineer | -122.4194,37.7749 | 1741627789 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 | 1739644189 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 | 1710696589 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query.filter import Timestamp\n",
+ "from datetime import datetime\n",
+ "\n",
+ "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n",
+ "\n",
+ "print(f'Epoch comparison: {dt.timestamp()}')\n",
+ "\n",
+ "timestamp_filter = Timestamp(\"last_updated\") < dt\n",
+ "\n",
+ "v.set_filter(timestamp_filter)\n",
+ "result_print(index.query(v))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch between: 1736880339.132589 - 1742147139.132589\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 | 1741627789 |
0 | john | high | 18 | engineer | -122.4194,37.7749 | 1741627789 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 | 1739644189 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query.filter import Timestamp\n",
+ "from datetime import datetime\n",
+ "\n",
+ "dt_1 = datetime(2025, 1, 14, 13, 45, 39, 132589)\n",
+ "dt_2 = datetime(2025, 3, 16, 13, 45, 39, 132589)\n",
+ "\n",
+ "print(f'Epoch between: {dt_1.timestamp()} - {dt_2.timestamp()}')\n",
+ "\n",
+ "timestamp_filter = Timestamp(\"last_updated\").between(dt_1, dt_2)\n",
+ "\n",
+ "v.set_filter(timestamp_filter)\n",
+ "result_print(index.query(v))"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -404,13 +532,13 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
+ "vector_distance | user | credit_score | age | job | office_location | last_updated |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 | 1741627789 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 | 1710696589 |
"
],
"text/plain": [
""
@@ -771,13 +899,13 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
+ "vector_distance | user | credit_score | age | job | office_location |
---|
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
"
],
"text/plain": [
""
@@ -791,8 +919,9 @@
"t = Tag(\"credit_score\") == \"high\"\n",
"low = Num(\"age\") >= 18\n",
"high = Num(\"age\") <= 100\n",
+ "ts = Timestamp(\"last_updated\") > datetime(2025, 3, 16, 13, 45, 39, 132589)\n",
"\n",
- "combined = t & low & high\n",
+ "combined = t & low & high & ts\n",
"\n",
"v = VectorQuery([0.1, 0.1, 0.5],\n",
" \"user_embedding\",\n",
@@ -814,13 +943,13 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
+ "vector_distance | user | credit_score | age | job | office_location |
---|
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
],
"text/plain": [
""
@@ -1325,7 +1454,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "env",
+ "display_name": "redisvl-Q9FZQJWe-py3.11",
"language": "python",
"name": "python3"
},
@@ -1339,7 +1468,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.11.9"
},
"orig_nbformat": 4
},
diff --git a/docs/user_guide/hybrid_example_data.pkl b/docs/user_guide/hybrid_example_data.pkl
index b5928b91..2c8a92da 100644
Binary files a/docs/user_guide/hybrid_example_data.pkl and b/docs/user_guide/hybrid_example_data.pkl differ
diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py
index 41e6e214..c741ca45 100644
--- a/redisvl/extensions/llmcache/semantic.py
+++ b/redisvl/extensions/llmcache/semantic.py
@@ -310,7 +310,8 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")
- return self._vectorizer.embed(prompt)
+ result = self._vectorizer.embed(prompt)
+ return result # type: ignore
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
@@ -318,7 +319,8 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")
- return await self._vectorizer.aembed(prompt)
+ result = await self._vectorizer.aembed(prompt)
+ return result # type: ignore
def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py
index 4a7e72c3..8aff7524 100644
--- a/redisvl/extensions/router/semantic.py
+++ b/redisvl/extensions/router/semantic.py
@@ -366,14 +366,14 @@ def __call__(
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
- vector = self.vectorizer.embed(statement)
+ vector = self.vectorizer.embed(statement) # type: ignore
aggregation_method = (
aggregation_method or self.routing_config.aggregation_method
)
# perform route classification
- top_route_match = self._classify_route(vector, aggregation_method)
+ top_route_match = self._classify_route(vector, aggregation_method) # type: ignore
return top_route_match
@deprecated_argument("distance_threshold")
@@ -400,7 +400,7 @@ def route_many(
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
- vector = self.vectorizer.embed(statement)
+ vector = self.vectorizer.embed(statement) # type: ignore
max_k = max_k or self.routing_config.max_k
aggregation_method = (
@@ -409,7 +409,7 @@ def route_many(
# classify routes
top_route_matches = self._classify_multi_route(
- vector, max_k, aggregation_method
+ vector, max_k, aggregation_method # type: ignore
)
return top_route_matches
diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py
index 6825afa9..1aa15315 100644
--- a/redisvl/extensions/session_manager/semantic_session.py
+++ b/redisvl/extensions/session_manager/semantic_session.py
@@ -349,7 +349,7 @@ def add_messages(
role=message[ROLE_FIELD_NAME],
content=message[CONTENT_FIELD_NAME],
session_tag=session_tag,
- vector_field=content_vector,
+ vector_field=content_vector, # type: ignore
)
if TOOL_FIELD_NAME in message:
diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py
index 1e8987ff..ced52520 100644
--- a/redisvl/query/filter.py
+++ b/redisvl/query/filter.py
@@ -1,3 +1,5 @@
+import datetime
+import re
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -8,6 +10,19 @@
# mypy: disable-error-code="override"
+class Inclusive(str, Enum):
+ """Enum for valid inclusive options"""
+
+ BOTH = "both"
+ """Inclusive of both sides of range (default)"""
+ NEITHER = "neither"
+ """Inclusive of neither side of range"""
+ LEFT = "left"
+ """Inclusive of only left"""
+ RIGHT = "right"
+ """Inclusive of only right"""
+
+
class FilterOperator(Enum):
EQ = 1
NE = 2
@@ -19,6 +34,7 @@ class FilterOperator(Enum):
AND = 8
LIKE = 9
IN = 10
+ BETWEEN = 11
class FilterField:
@@ -267,6 +283,7 @@ class Num(FilterField):
FilterOperator.GT: ">",
FilterOperator.LE: "<=",
FilterOperator.GE: ">=",
+ FilterOperator.BETWEEN: "between",
}
OPERATOR_MAP: Dict[FilterOperator, str] = {
FilterOperator.EQ: "@%s:[%s %s]",
@@ -275,8 +292,10 @@ class Num(FilterField):
FilterOperator.LT: "@%s:[-inf (%s]",
FilterOperator.GE: "@%s:[%s +inf]",
FilterOperator.LE: "@%s:[-inf %s]",
+ FilterOperator.BETWEEN: "@%s:[%s %s]",
}
- SUPPORTED_VAL_TYPES = (int, float, type(None))
+
+ SUPPORTED_VAL_TYPES = (int, float, tuple, type(None))
def __eq__(self, other: int) -> "FilterExpression":
"""Create a Numeric equality filter expression.
@@ -373,10 +392,51 @@ def __le__(self, other: int) -> "FilterExpression":
self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LE)
return FilterExpression(str(self))
+ @staticmethod
+ def _validate_inclusive_string(inclusive: str) -> Inclusive:
+ try:
+ return Inclusive(inclusive)
+ except:
+ raise ValueError(
+ f"Invalid inclusive value must be: {[i.value for i in Inclusive]}"
+ )
+
+ def _format_inclusive_between(
+ self, inclusive: Inclusive, start: int, end: int
+ ) -> str:
+ if inclusive.value == Inclusive.BOTH.value:
+ return f"@{self._field}:[{start} {end}]"
+
+ if inclusive.value == Inclusive.NEITHER.value:
+ return f"@{self._field}:[({start} ({end}]"
+
+ if inclusive.value == Inclusive.LEFT.value:
+ return f"@{self._field}:[{start} ({end}]"
+
+ if inclusive.value == Inclusive.RIGHT.value:
+ return f"@{self._field}:[({start} {end}]"
+
+ raise ValueError(f"Inclusive value not found")
+
+ def between(
+ self, start: int, end: int, inclusive: str = "both"
+ ) -> "FilterExpression":
+ """Operator for searching values between two numeric values."""
+ inclusive = self._validate_inclusive_string(inclusive)
+ expression = self._format_inclusive_between(inclusive, start, end)
+
+ return FilterExpression(expression)
+
def __str__(self) -> str:
"""Return the Redis Query string for the Numeric filter"""
if self._value is None:
return "*"
+ if self._operator == FilterOperator.BETWEEN:
+ return self.OPERATOR_MAP[self._operator] % (
+ self._field,
+ self._value[0],
+ self._value[1],
+ )
if self._operator == FilterOperator.EQ or self._operator == FilterOperator.NE:
return self.OPERATOR_MAP[self._operator] % (
self._field,
@@ -562,3 +622,213 @@ def __str__(self) -> str:
if not self._filter:
raise ValueError("Improperly initialized FilterExpression")
return self._filter
+
+
+class Timestamp(Num):
+ """
+ A timestamp filter for querying date/time fields in Redis.
+
+ This filter can handle various date and time formats, including:
+ - datetime objects (with or without timezone)
+ - date objects
+ - ISO-8601 formatted strings
+ - Unix timestamps (as integers or floats)
+
+ All timestamps are converted to Unix timestamps in UTC for consistency.
+ """
+
+ SUPPORTED_TYPES = (
+ datetime.datetime,
+ datetime.date,
+ tuple, # Date range
+ str, # ISO format
+ int, # Unix timestamp
+ float, # Unix timestamp with fractional seconds
+ type(None),
+ )
+
+ @staticmethod
+ def _is_date(value: Any) -> bool:
+ """Check if the value is a date object. Either ISO string or datetime.date."""
+ return (
+ isinstance(value, datetime.date)
+ and not isinstance(value, datetime.datetime)
+ ) or (isinstance(value, str) and Timestamp._is_date_only(value))
+
+ @staticmethod
+ def _is_date_only(iso_string: str) -> bool:
+ """Check if an ISO formatted string only includes date information using regex."""
+ # Match YYYY-MM-DD format exactly
+ date_pattern = r"^\d{4}-\d{2}-\d{2}$"
+ return bool(re.match(date_pattern, iso_string))
+
+ def _convert_to_timestamp(self, value, end_date=False):
+ """
+ Convert various inputs to a Unix timestamp (seconds since epoch in UTC).
+
+ Args:
+ value: A datetime, date, string, int, or float
+
+ Returns:
+ float: Unix timestamp
+ """
+ if value is None:
+ return None
+
+ if isinstance(value, (int, float)):
+ # Already a Unix timestamp
+ return float(value)
+
+ if isinstance(value, str):
+ # Parse ISO format
+ try:
+ value = datetime.datetime.fromisoformat(value)
+ except ValueError:
+ raise ValueError(f"String timestamp must be in ISO format: {value}")
+
+ if isinstance(value, datetime.date) and not isinstance(
+ value, datetime.datetime
+ ):
+ # Convert to max or min if for dates based on end or not
+ if end_date:
+ value = datetime.datetime.combine(value, datetime.time.max)
+ else:
+ value = datetime.datetime.combine(value, datetime.time.min)
+
+ # Ensure the datetime is timezone-aware (UTC)
+ if isinstance(value, datetime.datetime):
+ if value.tzinfo is None:
+ value = value.replace(tzinfo=datetime.timezone.utc)
+ else:
+ value = value.astimezone(datetime.timezone.utc)
+
+ # Convert to Unix timestamp
+ return value.timestamp()
+
+ raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}")
+
+ def __eq__(self, other) -> FilterExpression:
+ """
+ Filter for timestamps equal to the specified value.
+ For date objects (without time), this matches the entire day.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ if self._is_date(other):
+ # For date objects, match the entire day
+ if isinstance(other, str):
+ other = datetime.datetime.strptime(other, "%Y-%m-%d").date()
+ start = datetime.datetime.combine(other, datetime.time.min).astimezone(
+ datetime.timezone.utc
+ )
+ end = datetime.datetime.combine(other, datetime.time.max).astimezone(
+ datetime.timezone.utc
+ )
+ return self.between(start, end)
+
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ)
+ return FilterExpression(str(self))
+
+ def __ne__(self, other) -> FilterExpression:
+ """
+ Filter for timestamps not equal to the specified value.
+ For date objects (without time), this excludes the entire day.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ if self._is_date(other):
+ # For date objects, exclude the entire day
+ if isinstance(other, str):
+ other = datetime.datetime.strptime(other, "%Y-%m-%d").date()
+ start = datetime.datetime.combine(other, datetime.time.min)
+ end = datetime.datetime.combine(other, datetime.time.max)
+ return self.between(start, end)
+
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.NE)
+ return FilterExpression(str(self))
+
+ def __gt__(self, other):
+ """
+ Filter for timestamps greater than the specified value.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GT)
+ return FilterExpression(str(self))
+
+ def __lt__(self, other):
+ """
+ Filter for timestamps less than the specified value.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LT)
+ return FilterExpression(str(self))
+
+ def __ge__(self, other):
+ """
+ Filter for timestamps greater than or equal to the specified value.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GE)
+ return FilterExpression(str(self))
+
+ def __le__(self, other):
+ """
+ Filter for timestamps less than or equal to the specified value.
+
+ Args:
+ other: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ timestamp = self._convert_to_timestamp(other)
+ self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LE)
+ return FilterExpression(str(self))
+
+ def between(self, start, end, inclusive: str = "both"):
+ """
+ Filter for timestamps between start and end (inclusive).
+
+ Args:
+ start: A datetime, date, ISO string, or Unix timestamp
+ end: A datetime, date, ISO string, or Unix timestamp
+
+ Returns:
+ self: The filter object for method chaining
+ """
+ inclusive = self._validate_inclusive_string(inclusive)
+
+ start_ts = self._convert_to_timestamp(start)
+ end_ts = self._convert_to_timestamp(end, end_date=True)
+
+ expression = self._format_inclusive_between(inclusive, start_ts, end_ts)
+
+ return FilterExpression(expression)
diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py
index b3a63fa9..189b6e1a 100644
--- a/redisvl/utils/vectorize/base.py
+++ b/redisvl/utils/vectorize/base.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Callable, List, Optional
+from typing import Callable, List, Optional, Union
from pydantic import BaseModel, Field, field_validator
@@ -49,34 +49,69 @@ def check_dims(cls, value):
return value
@abstractmethod
- def embed_many(
+ def embed(
self,
- texts: List[str],
+ text: str,
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[float], bytes]:
+ """Embed a chunk of text.
+
+ Args:
+ text: Text to embed
+ preprocess: Optional function to preprocess text
+ as_buffer: If True, returns a bytes object instead of a list
+
+ Returns:
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
+ """
raise NotImplementedError
@abstractmethod
- def embed(
+ def embed_many(
self,
- text: str,
+ texts: List[str],
preprocess: Optional[Callable] = None,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[List[float]], List[bytes]]:
+ """Embed multiple chunks of text.
+
+ Args:
+ texts: List of texts to embed
+ preprocess: Optional function to preprocess text
+ batch_size: Number of texts to process in each batch
+ as_buffer: If True, returns each embedding as a bytes object
+
+ Returns:
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
+ """
raise NotImplementedError
async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
+ """Asynchronously embed multiple chunks of text.
+
+ Args:
+ texts: List of texts to embed
+ preprocess: Optional function to preprocess text
+ batch_size: Number of texts to process in each batch
+ as_buffer: If True, returns each embedding as a bytes object
+
+ Returns:
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
+ """
# Fallback to standard embedding call if no async support
return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs)
@@ -86,7 +121,18 @@ async def aembed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
+ """Asynchronously embed a chunk of text.
+
+ Args:
+ text: Text to embed
+ preprocess: Optional function to preprocess text
+ as_buffer: If True, returns a bytes object instead of a list
+
+ Returns:
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
+ """
# Fallback to standard embedding call if no async support
return self.embed(text, preprocess, as_buffer, **kwargs)
diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py
index 7b3b7d01..410280e5 100644
--- a/redisvl/utils/vectorize/text/azureopenai.py
+++ b/redisvl/utils/vectorize/text/azureopenai.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -178,7 +178,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of texts using the AzureOpenAI API.
Args:
@@ -191,7 +191,8 @@ def embed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -205,7 +206,9 @@ def embed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
- response = self._client.embeddings.create(input=batch, model=self.model)
+ response = self._client.embeddings.create(
+ input=batch, model=self.model, **kwargs
+ )
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
@@ -224,7 +227,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Embed a chunk of text using the AzureOpenAI API.
Args:
@@ -235,7 +238,8 @@ def embed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -248,7 +252,9 @@ def embed(
dtype = kwargs.pop("dtype", self.dtype)
- result = self._client.embeddings.create(input=[text], model=self.model)
+ result = self._client.embeddings.create(
+ input=[text], model=self.model, **kwargs
+ )
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@retry(
@@ -261,10 +267,10 @@ async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Asynchronously embed many chunks of texts using the AzureOpenAI API.
Args:
@@ -277,7 +283,8 @@ async def aembed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -292,7 +299,7 @@ async def aembed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._aclient.embeddings.create(
- input=batch, model=self.model
+ input=batch, model=self.model, **kwargs
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -312,7 +319,7 @@ async def aembed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Asynchronously embed a chunk of text using the OpenAI API.
Args:
@@ -323,7 +330,8 @@ async def aembed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -336,7 +344,9 @@ async def aembed(
dtype = kwargs.pop("dtype", self.dtype)
- result = await self._aclient.embeddings.create(input=[text], model=self.model)
+ result = await self._aclient.embeddings.create(
+ input=[text], model=self.model, **kwargs
+ )
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@property
diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py
index 5858aff8..2d40685d 100644
--- a/redisvl/utils/vectorize/text/bedrock.py
+++ b/redisvl/utils/vectorize/text/bedrock.py
@@ -1,6 +1,6 @@
import json
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -135,8 +135,8 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
- """Embed a chunk of text using Amazon Bedrock.
+ ) -> Union[List[float], bytes]:
+ """Embed a chunk of text using the AWS Bedrock Embeddings API.
Args:
text (str): Text to embed.
@@ -144,7 +144,8 @@ def embed(
as_buffer (bool): Whether to return as byte buffer.
Returns:
- List[float]: The embedding vector.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If text is not a string.
@@ -156,7 +157,7 @@ def embed(
text = preprocess(text)
response = self._client.invoke_model(
- modelId=self.model, body=json.dumps({"inputText": text})
+ modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
)
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]
@@ -177,17 +178,18 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
- """Embed multiple texts using Amazon Bedrock.
+ ) -> Union[List[List[float]], List[bytes]]:
+ """Embed many chunks of text using the AWS Bedrock Embeddings API.
Args:
texts (List[str]): List of texts to embed.
preprocess (Optional[Callable]): Optional preprocessing function.
- batch_size (int): Size of batches for processing.
+ batch_size (int): Size of batches for processing. Defaults to 10.
as_buffer (bool): Whether to return as byte buffers.
Returns:
- List[List[float]]: List of embedding vectors.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If texts is not a list of strings.
@@ -206,7 +208,7 @@ def embed_many(
batch_embeddings = []
for text in batch:
response = self._client.invoke_model(
- modelId=self.model, body=json.dumps({"inputText": text})
+ modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
)
response_body = json.loads(response["body"].read())
batch_embeddings.append(response_body["embedding"])
diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py
index bd6481fe..4e6192e2 100644
--- a/redisvl/utils/vectorize/text/cohere.py
+++ b/redisvl/utils/vectorize/text/cohere.py
@@ -1,5 +1,6 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -64,7 +65,8 @@ def __init__(
Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
- Defaults to 'float32'.
+ 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map
+ to Cohere's corresponding embedding types. Defaults to 'float32'.
Raises:
ImportError: If the cohere library is not installed.
@@ -114,6 +116,15 @@ def _set_model_dims(self) -> int:
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)
+ def _get_cohere_embedding_type(self, dtype: str) -> List[str]:
+ """Map dtype to appropriate Cohere embedding_types value."""
+ if dtype == "int8":
+ return ["int8"]
+ elif dtype == "uint8":
+ return ["uint8"]
+ else:
+ return ["float"]
+
@deprecated_argument("dtype")
def embed(
self,
@@ -121,7 +132,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], List[int], bytes]:
"""Embed a chunk of text using the Cohere Embeddings API.
Must provide the embedding `input_type` as a `kwarg` to this method
@@ -150,13 +161,17 @@ def embed(
Required for embedding models v3 and higher.
Returns:
- List[float]: Embedding.
+ Union[List[float], List[int], bytes]:
+ - If as_buffer=True: Returns a bytes object
+ - If as_buffer=False:
+ - For dtype="float32": Returns a list of floats
+ - For dtype="int8" or "uint8": Returns a list of integers
Raises:
TypeError: In an invalid input_type is provided.
"""
- input_type = kwargs.get("input_type")
+ input_type = kwargs.pop("input_type", None)
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
@@ -171,9 +186,34 @@ def embed(
dtype = kwargs.pop("dtype", self.dtype)
- embedding = self._client.embed(
- texts=[text], model=self.model, input_type=input_type
- ).embeddings[0]
+ # Check if embedding_types was provided and warn user
+ if "embedding_types" in kwargs:
+ warnings.warn(
+ "The 'embedding_types' parameter is not supported in CohereTextVectorizer. "
+ "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.",
+ UserWarning,
+ stacklevel=2,
+ )
+ kwargs.pop("embedding_types")
+
+ # Map dtype to appropriate embedding_type
+ embedding_types = self._get_cohere_embedding_type(dtype)
+
+ response = self._client.embed(
+ texts=[text],
+ model=self.model,
+ input_type=input_type,
+ embedding_types=embedding_types,
+ **kwargs,
+ )
+
+ # Extract the appropriate embedding based on embedding_types
+ embed_type = embedding_types[0]
+ if hasattr(response.embeddings, embed_type):
+ embedding = getattr(response.embeddings, embed_type)[0]
+ else:
+ embedding = response.embeddings[0] # Fallback for older API versions
+
return self._process_embedding(embedding, as_buffer, dtype)
@retry(
@@ -189,7 +229,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[List[int]], List[bytes]]:
"""Embed many chunks of text using the Cohere Embeddings API.
Must provide the embedding `input_type` as a `kwarg` to this method
@@ -221,13 +261,17 @@ def embed_many(
Required for embedding models v3 and higher.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[List[int]], List[bytes]]:
+ - If as_buffer=True: Returns a list of bytes objects
+ - If as_buffer=False:
+ - For dtype="float32": Returns a list of lists of floats
+ - For dtype="int8" or "uint8": Returns a list of lists of integers
Raises:
TypeError: In an invalid input_type is provided.
"""
- input_type = kwargs.get("input_type")
+ input_type = kwargs.pop("input_type", None)
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
@@ -241,14 +285,41 @@ def embed_many(
dtype = kwargs.pop("dtype", self.dtype)
+ # Check if embedding_types was provided and warn user
+ if "embedding_types" in kwargs:
+ warnings.warn(
+ "The 'embedding_types' parameter is not supported in CohereTextVectorizer. "
+ "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.",
+ UserWarning,
+ stacklevel=2,
+ )
+ kwargs.pop("embedding_types")
+
+ # Map dtype to appropriate embedding_type
+ embedding_types = self._get_cohere_embedding_type(dtype)
+
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._client.embed(
- texts=batch, model=self.model, input_type=input_type
+ texts=batch,
+ model=self.model,
+ input_type=input_type,
+ embedding_types=embedding_types,
+ **kwargs,
)
+
+ # Extract the appropriate embeddings based on embedding_types
+ embed_type = embedding_types[0]
+ if hasattr(response.embeddings, embed_type):
+ batch_embeddings = getattr(response.embeddings, embed_type)
+ else:
+ batch_embeddings = (
+ response.embeddings
+ ) # Fallback for older API versions
+
embeddings += [
self._process_embedding(embedding, as_buffer, dtype)
- for embedding in response.embeddings
+ for embedding in batch_embeddings
]
return embeddings
diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py
index 4558d4d7..ed284d29 100644
--- a/redisvl/utils/vectorize/text/custom.py
+++ b/redisvl/utils/vectorize/text/custom.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List, Optional, Union
from pydantic import PrivateAttr
@@ -162,7 +162,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""
Generate an embedding for a single piece of text using your sync embed function.
@@ -172,7 +172,7 @@ def embed(
as_buffer (bool): If True, return the embedding as a byte buffer.
Returns:
- List[float]: The embedding of the input text.
+ Union[List[float], bytes]: The embedding of the input text.
Raises:
TypeError: If the input is not a string.
@@ -200,7 +200,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""
Generate embeddings for multiple pieces of text in batches using your sync embed_many function.
@@ -211,7 +211,7 @@ def embed_many(
as_buffer (bool): If True, convert each embedding to a byte buffer.
Returns:
- List[List[float]]: A list of embeddings, where each embedding is a list of floats.
+ Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes.
Raises:
TypeError: If the input is not a list of strings.
@@ -226,7 +226,7 @@ def embed_many(
raise NotImplementedError("No embed_many function was provided.")
dtype = kwargs.pop("dtype", self.dtype)
- embeddings: List[List[float]] = []
+ embeddings: Union[List[List[float]], List[bytes]] = []
try:
for batch in self.batchify(texts, batch_size, preprocess):
@@ -288,10 +288,10 @@ async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""
Asynchronously generate embeddings for multiple pieces of text in batches.
@@ -302,7 +302,7 @@ async def aembed_many(
as_buffer (bool): If True, convert each embedding to a byte buffer.
Returns:
- List[List[float]]: A list of embeddings, where each embedding is a list of floats.
+ Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes.
Raises:
TypeError: If the input is not a list of strings.
@@ -317,7 +317,7 @@ async def aembed_many(
raise NotImplementedError("No aembed_many function was provided.")
dtype = kwargs.pop("dtype", self.dtype)
- embeddings: List[List[float]] = []
+ embeddings: Union[List[List[float]], List[bytes]] = []
try:
for batch in self.batchify(texts, batch_size, preprocess):
diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py
index 8f81b85c..bafba41d 100644
--- a/redisvl/utils/vectorize/text/huggingface.py
+++ b/redisvl/utils/vectorize/text/huggingface.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List, Optional, Union
from pydantic.v1 import PrivateAttr
@@ -89,7 +89,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Embed a chunk of text using the Hugging Face sentence transformer.
Args:
@@ -100,7 +100,8 @@ def embed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the text.
@@ -121,10 +122,10 @@ def embed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Asynchronously embed many chunks of texts using the Hugging Face
sentence transformer.
@@ -138,7 +139,8 @@ def embed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py
index e930b3a4..05133b37 100644
--- a/redisvl/utils/vectorize/text/mistral.py
+++ b/redisvl/utils/vectorize/text/mistral.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -128,7 +128,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of texts using the Mistral API.
Args:
@@ -141,7 +141,8 @@ def embed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -155,7 +156,9 @@ def embed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
- response = self._client.embeddings.create(model=self.model, inputs=batch)
+ response = self._client.embeddings.create(
+ model=self.model, inputs=batch, **kwargs
+ )
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
@@ -174,7 +177,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Embed a chunk of text using the Mistral API.
Args:
@@ -185,7 +188,8 @@ def embed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -198,7 +202,9 @@ def embed(
dtype = kwargs.pop("dtype", self.dtype)
- result = self._client.embeddings.create(model=self.model, inputs=[text])
+ result = self._client.embeddings.create(
+ model=self.model, inputs=[text], **kwargs
+ )
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@retry(
@@ -211,7 +217,7 @@ async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
@@ -242,7 +248,7 @@ async def aembed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._client.embeddings.create_async(
- model=self.model, inputs=batch
+ model=self.model, inputs=batch, **kwargs
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -287,7 +293,7 @@ async def aembed(
dtype = kwargs.pop("dtype", self.dtype)
result = await self._client.embeddings.create_async(
- model=self.model, inputs=[text]
+ model=self.model, inputs=[text], **kwargs
)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py
index 25b21c67..eee0764a 100644
--- a/redisvl/utils/vectorize/text/openai.py
+++ b/redisvl/utils/vectorize/text/openai.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -129,7 +129,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of texts using the OpenAI API.
Args:
@@ -142,7 +142,8 @@ def embed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the text.
@@ -156,7 +157,9 @@ def embed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
- response = self._client.embeddings.create(input=batch, model=self.model)
+ response = self._client.embeddings.create(
+ input=batch, model=self.model, **kwargs
+ )
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
@@ -175,7 +178,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Embed a chunk of text using the OpenAI API.
Args:
@@ -186,7 +189,8 @@ def embed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the text.
@@ -199,7 +203,9 @@ def embed(
dtype = kwargs.pop("dtype", self.dtype)
- result = self._client.embeddings.create(input=[text], model=self.model)
+ result = self._client.embeddings.create(
+ input=[text], model=self.model, **kwargs
+ )
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@retry(
@@ -212,10 +218,10 @@ async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
- batch_size: int = 1000,
+ batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Asynchronously embed many chunks of texts using the OpenAI API.
Args:
@@ -228,7 +234,8 @@ async def aembed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the text.
@@ -243,7 +250,7 @@ async def aembed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._aclient.embeddings.create(
- input=batch, model=self.model
+ input=batch, model=self.model, **kwargs
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -263,7 +270,7 @@ async def aembed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Asynchronously embed a chunk of text using the OpenAI API.
Args:
@@ -274,7 +281,8 @@ async def aembed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the text.
@@ -287,7 +295,9 @@ async def aembed(
dtype = kwargs.pop("dtype", self.dtype)
- result = await self._aclient.embeddings.create(input=[text], model=self.model)
+ result = await self._aclient.embeddings.create(
+ input=[text], model=self.model, **kwargs
+ )
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@property
diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py
index 6d455c67..ebe2a625 100644
--- a/redisvl/utils/vectorize/text/vertexai.py
+++ b/redisvl/utils/vectorize/text/vertexai.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -141,8 +141,8 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
- """Embed many chunks of texts using the VertexAI API.
+ ) -> Union[List[List[float]], List[bytes]]:
+ """Embed many chunks of text using the VertexAI Embeddings API.
Args:
texts (List[str]): List of text chunks to embed.
@@ -154,7 +154,8 @@ def embed_many(
to a byte string. Defaults to False.
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -168,7 +169,7 @@ def embed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
- response = self._client.get_embeddings(batch)
+ response = self._client.get_embeddings(batch, **kwargs)
embeddings += [
self._process_embedding(r.values, as_buffer, dtype) for r in response
]
@@ -186,8 +187,8 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
- """Embed a chunk of text using the VertexAI API.
+ ) -> Union[List[float], bytes]:
+ """Embed a chunk of text using the VertexAI Embeddings API.
Args:
text (str): Chunk of text to embed.
@@ -197,7 +198,8 @@ def embed(
to a byte string. Defaults to False.
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If the wrong input type is passed in for the test.
@@ -210,7 +212,7 @@ def embed(
dtype = kwargs.pop("dtype", self.dtype)
- result = self._client.get_embeddings([text])
+ result = self._client.get_embeddings([text], **kwargs)
return self._process_embedding(result[0].values, as_buffer, dtype)
@property
diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py
index fbcbfd9e..9d015a81 100644
--- a/redisvl/utils/vectorize/text/voyageai.py
+++ b/redisvl/utils/vectorize/text/voyageai.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -124,7 +124,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[float]:
+ ) -> Union[List[float], bytes]:
"""Embed a chunk of text using the VoyageAI Embeddings API.
Can provide the embedding `input_type` as a `kwarg` to this method
@@ -149,7 +149,8 @@ def embed(
Check https://docs.voyageai.com/docs/embeddings
Returns:
- List[float]: Embedding.
+ Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
+ object if as_buffer=True
Raises:
TypeError: If an invalid input_type is provided.
@@ -171,7 +172,7 @@ def embed_many(
batch_size: Optional[int] = None,
as_buffer: bool = False,
**kwargs,
- ) -> List[List[float]]:
+ ) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of text using the VoyageAI Embeddings API.
Can provide the embedding `input_type` as a `kwarg` to this method
@@ -198,14 +199,15 @@ def embed_many(
Check https://docs.voyageai.com/docs/embeddings
Returns:
- List[List[float]]: List of embeddings.
+ Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
+ or as bytes objects if as_buffer=True
Raises:
TypeError: If an invalid input_type is provided.
"""
- input_type = kwargs.get("input_type")
- truncation = kwargs.get("truncation")
+ input_type = kwargs.pop("input_type", None)
+ truncation = kwargs.pop("truncation", None)
dtype = kwargs.pop("dtype", self.dtype)
if not isinstance(texts, list):
@@ -235,7 +237,7 @@ def embed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._client.embed(
- texts=batch, model=self.model, input_type=input_type
+ texts=batch, model=self.model, input_type=input_type, **kwargs
)
embeddings += [
self._process_embedding(embedding, as_buffer, dtype)
@@ -284,8 +286,8 @@ async def aembed_many(
TypeError: In an invalid input_type is provided.
"""
- input_type = kwargs.get("input_type")
- truncation = kwargs.get("truncation")
+ input_type = kwargs.pop("input_type", None)
+ truncation = kwargs.pop("truncation", None)
dtype = kwargs.pop("dtype", self.dtype)
if not isinstance(texts, list):
@@ -315,7 +317,7 @@ async def aembed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._aclient.embed(
- texts=batch, model=self.model, input_type=input_type
+ texts=batch, model=self.model, input_type=input_type, **kwargs
)
embeddings += [
self._process_embedding(embedding, as_buffer, dtype)
@@ -360,7 +362,6 @@ async def aembed(
Raises:
TypeError: In an invalid input_type is provided.
"""
-
result = await self.aembed_many(
texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 61c5de45..24da05e5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,5 @@
import os
+from datetime import datetime, timezone
import pytest
from testcontainers.compose import DockerCompose
@@ -68,12 +69,22 @@ def client(redis_url):
@pytest.fixture
-def sample_data():
+def sample_datetimes():
+ return {
+ "low": datetime(2025, 1, 16, 13).astimezone(timezone.utc),
+ "mid": datetime(2025, 2, 16, 13).astimezone(timezone.utc),
+ "high": datetime(2025, 3, 16, 13).astimezone(timezone.utc),
+ }
+
+
+@pytest.fixture
+def sample_data(sample_datetimes):
return [
{
"user": "john",
"age": 18,
"job": "engineer",
+ "last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
@@ -82,6 +93,7 @@ def sample_data():
"user": "mary",
"age": 14,
"job": "doctor",
+ "last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "low",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
@@ -90,6 +102,7 @@ def sample_data():
"user": "nancy",
"age": 94,
"job": "doctor",
+ "last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.7, 0.1, 0.5],
@@ -98,6 +111,7 @@ def sample_data():
"user": "tyler",
"age": 100,
"job": "engineer",
+ "last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.1, 0.4, 0.5],
@@ -106,6 +120,7 @@ def sample_data():
"user": "tim",
"age": 12,
"job": "dermatologist",
+ "last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.4, 0.4, 0.5],
@@ -114,6 +129,7 @@ def sample_data():
"user": "taimur",
"age": 15,
"job": "CEO",
+ "last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "low",
"location": "-110.0839,37.3861",
"user_embedding": [0.6, 0.1, 0.5],
@@ -122,6 +138,7 @@ def sample_data():
"user": "joe",
"age": 35,
"job": "dentist",
+ "last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "medium",
"location": "-110.0839,37.3861",
"user_embedding": [0.9, 0.9, 0.1],
diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py
index 271d36da..deb58cbc 100644
--- a/tests/integration/test_query.py
+++ b/tests/integration/test_query.py
@@ -1,9 +1,19 @@
+from datetime import timedelta
+
import pytest
from redis.commands.search.result import Result
from redisvl.index import SearchIndex
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
-from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
+from redisvl.query.filter import (
+ FilterExpression,
+ Geo,
+ GeoRadius,
+ Num,
+ Tag,
+ Text,
+ Timestamp,
+)
from redisvl.redis.utils import array_to_buffer
# TODO expand to multiple schema types and sync + async
@@ -14,7 +24,14 @@ def vector_query():
return VectorQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
- return_fields=["user", "credit_score", "age", "job", "location"],
+ return_fields=[
+ "user",
+ "credit_score",
+ "age",
+ "job",
+ "location",
+ "last_updated",
+ ],
)
@@ -23,7 +40,14 @@ def sorted_vector_query():
return VectorQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
- return_fields=["user", "credit_score", "age", "job", "location"],
+ return_fields=[
+ "user",
+ "credit_score",
+ "age",
+ "job",
+ "location",
+ "last_updated",
+ ],
sort_by="age",
)
@@ -31,7 +55,14 @@ def sorted_vector_query():
@pytest.fixture
def filter_query():
return FilterQuery(
- return_fields=["user", "credit_score", "age", "job", "location"],
+ return_fields=[
+ "user",
+ "credit_score",
+ "age",
+ "job",
+ "location",
+ "last_updated",
+ ],
filter_expression=Tag("credit_score") == "high",
)
@@ -39,7 +70,14 @@ def filter_query():
@pytest.fixture
def sorted_filter_query():
return FilterQuery(
- return_fields=["user", "credit_score", "age", "job", "location"],
+ return_fields=[
+ "user",
+ "credit_score",
+ "age",
+ "job",
+ "location",
+ "last_updated",
+ ],
filter_expression=Tag("credit_score") == "high",
sort_by="age",
)
@@ -80,6 +118,7 @@ def index(sample_data, redis_url):
{"name": "credit_score", "type": "tag"},
{"name": "job", "type": "text"},
{"name": "age", "type": "numeric"},
+ {"name": "last_updated", "type": "numeric"},
{"name": "location", "type": "geo"},
{
"name": "user_embedding",
@@ -255,7 +294,7 @@ def query(request):
return request.getfixturevalue(request.param)
-def test_filters(index, query):
+def test_filters(index, query, sample_datetimes):
# Simple Tag Filter
t = Tag("credit_score") == "high"
search(query, index, t, 4, credit_check="high")
@@ -310,6 +349,34 @@ def test_filters(index, query):
t = Text("job") % ""
search(query, index, t, 7)
+ # Timestamps
+ ts = Timestamp("last_updated") > sample_datetimes["mid"]
+ search(query, index, ts, 2)
+
+ ts = Timestamp("last_updated") >= sample_datetimes["mid"]
+ search(query, index, ts, 5)
+
+ ts = Timestamp("last_updated") < sample_datetimes["high"]
+ search(query, index, ts, 5)
+
+ ts = Timestamp("last_updated") <= sample_datetimes["mid"]
+ search(query, index, ts, 5)
+
+ ts = Timestamp("last_updated") == sample_datetimes["mid"]
+ search(query, index, ts, 3)
+
+ ts = (Timestamp("last_updated") == sample_datetimes["low"]) | (
+ Timestamp("last_updated") == sample_datetimes["high"]
+ )
+ search(query, index, ts, 4)
+
+ # could drop between if we prefer union syntax
+ ts = Timestamp("last_updated").between(
+ sample_datetimes["low"] + timedelta(seconds=1),
+ sample_datetimes["high"] - timedelta(seconds=1),
+ )
+ search(query, index, ts, 3)
+
def test_manual_string_filters(index, query):
# Simple Tag Filter
diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py
index e1de4a46..36e444de 100644
--- a/tests/integration/test_vectorizers.py
+++ b/tests/integration/test_vectorizers.py
@@ -1,5 +1,6 @@
import os
+import numpy as np
import pytest
from redisvl.utils.vectorize import (
@@ -287,7 +288,7 @@ def test_default_dtype(vectorizer_):
VoyageAITextVectorizer,
],
)
-def test_other_dtypes(vectorizer_):
+def test_vectorizer_dtype_assignment(vectorizer_):
# test initializing dtype in constructor
for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]:
if issubclass(vectorizer_, CustomTextVectorizer):
@@ -319,7 +320,7 @@ def test_other_dtypes(vectorizer_):
VoyageAITextVectorizer,
],
)
-def test_bad_dtypes(vectorizer_):
+def test_non_supported_dtypes(vectorizer_):
with pytest.raises(ValueError):
vectorizer_(dtype="float25")
@@ -392,3 +393,95 @@ async def test_avectorizer_bad_input(avectorizer):
with pytest.raises(TypeError):
avectorizer.embed_many(42)
+
+
+@pytest.mark.requires_api_keys
+@pytest.mark.parametrize(
+ "dtype,expected_type",
+ [
+ ("float32", float), # Float dtype should return floats
+ ("int8", int), # Int8 dtype should return ints
+ ("uint8", int), # Uint8 dtype should return ints
+ ],
+)
+def test_cohere_dtype_support(dtype, expected_type):
+ """Test that CohereTextVectorizer properly handles different dtypes for embeddings."""
+ text = "This is a test sentence."
+ texts = ["First test sentence.", "Second test sentence."]
+
+ # Create vectorizer with specified dtype
+ vectorizer = CohereTextVectorizer(dtype=dtype)
+
+ # Verify the correct mapping of dtype to Cohere embedding_types
+ if dtype == "int8":
+ assert vectorizer._get_cohere_embedding_type(dtype) == ["int8"]
+ elif dtype == "uint8":
+ assert vectorizer._get_cohere_embedding_type(dtype) == ["uint8"]
+ else:
+ # All other dtypes should map to float
+ assert vectorizer._get_cohere_embedding_type(dtype) == ["float"]
+
+ # Test single embedding
+ embedding = vectorizer.embed(text, input_type="search_document")
+ assert isinstance(embedding, list)
+ assert len(embedding) == vectorizer.dims
+
+ # Check that all elements are of the expected type
+ assert all(
+ isinstance(val, expected_type) for val in embedding
+ ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}"
+
+ # Test multiple embeddings
+ embeddings = vectorizer.embed_many(texts, input_type="search_document")
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == len(texts)
+ assert all(
+ isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings
+ )
+
+ # Check that all elements in all embeddings are of the expected type
+ for emb in embeddings:
+ assert all(
+ isinstance(val, expected_type) for val in emb
+ ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}"
+
+ # Test as_buffer output format
+ embedding_buffer = vectorizer.embed(
+ text, input_type="search_document", as_buffer=True
+ )
+ assert isinstance(embedding_buffer, bytes)
+
+ # Test embed_many with as_buffer=True
+ buffer_embeddings = vectorizer.embed_many(
+ texts, input_type="search_document", as_buffer=True
+ )
+ assert all(isinstance(emb, bytes) for emb in buffer_embeddings)
+
+ # Compare dimensions between buffer and list formats
+ assert len(np.frombuffer(embedding_buffer, dtype=dtype)) == len(embedding)
+
+
+@pytest.mark.requires_api_keys
+def test_cohere_embedding_types_warning():
+ """Test that a warning is raised when embedding_types parameter is passed."""
+ text = "This is a test sentence."
+ texts = ["First test sentence.", "Second test sentence."]
+ vectorizer = CohereTextVectorizer()
+
+ # Test warning for single embedding
+ with pytest.warns(UserWarning, match="embedding_types.*not supported"):
+ embedding = vectorizer.embed(
+ text,
+ input_type="search_document",
+ embedding_types=["uint8"], # explicitly testing the anti-pattern here
+ )
+ assert isinstance(embedding, list)
+ assert len(embedding) == vectorizer.dims
+
+ # Test warning for multiple embeddings
+ with pytest.warns(UserWarning, match="embedding_types.*not supported"):
+ embeddings = vectorizer.embed_many(
+ texts, input_type="search_document", embedding_types=["uint8"]
+ )
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == len(texts)
diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py
index 067402ea..dae74240 100644
--- a/tests/unit/test_filter.py
+++ b/tests/unit/test_filter.py
@@ -1,6 +1,8 @@
+from datetime import date, datetime, time, timedelta, timezone
+
import pytest
-from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
+from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text, Timestamp
# Test cases for various scenarios of tag usage, combinations, and their string representations.
@@ -110,6 +112,18 @@ def test_numeric_filter():
nf = Num("numeric_field") != None
assert str(nf) == "*"
+ nf = Num("numeric_field").between(2, 5)
+ assert str(nf) == "@numeric_field:[2 5]"
+
+ nf = Num("numeric_field").between(2, 5, inclusive="neither")
+ assert str(nf) == "@numeric_field:[2 5]"
+
+ nf = Num("numeric_field").between(2, 5, inclusive="left")
+ assert str(nf) == "@numeric_field:[2 (5]"
+
+ nf = Num("numeric_field").between(2, 5, inclusive="right")
+ assert str(nf) == "@numeric_field:[(2 5]"
+
def test_text_filter():
txt_f = Text("text_field") == "text"
@@ -292,3 +306,179 @@ def test_num_filter_zero():
assert (
str(num_filter) == "@chunk_number:[0 0]"
), "Num filter should handle zero correctly"
+
+
+def test_timestamp_datetime():
+ """Test Timestamp filter with datetime objects."""
+ # Test with timezone-aware datetime
+ dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc)
+ ts = Timestamp("created_at") == dt
+ # Expected timestamp would be the Unix timestamp for the datetime
+ expected_ts = dt.timestamp()
+ assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]"
+
+ # Test with timezone-naive datetime (should convert to UTC)
+ dt = datetime(2023, 3, 17, 14, 30, 0)
+ ts = Timestamp("created_at") == dt
+ expected_ts = dt.replace(tzinfo=timezone.utc).timestamp()
+ assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]"
+
+
+def test_timestamp_date():
+ """Test Timestamp filter with date objects (should match full day)."""
+ d = date(2023, 3, 17)
+ ts = Timestamp("created_at") == d
+
+ expected_ts_start = (
+ datetime.combine(d, time.min).astimezone(timezone.utc).timestamp()
+ )
+ expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp()
+
+ assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]"
+
+
+def test_timestamp_iso_string():
+ """Test Timestamp filter with ISO format strings."""
+ # Date-only ISO string
+ ts = Timestamp("created_at") == "2023-03-17"
+ d = date(2023, 3, 17)
+ expected_ts_start = (
+ datetime.combine(d, time.min).astimezone(timezone.utc).timestamp()
+ )
+ expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp()
+ assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]"
+
+ # Full ISO datetime string
+ dt_str = "2023-03-17T14:30:00+00:00"
+ ts = Timestamp("created_at") == dt_str
+ dt = datetime.fromisoformat(dt_str)
+ expected_ts = dt.timestamp()
+ assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]"
+
+
+def test_timestamp_unix():
+ """Test Timestamp filter with Unix timestamps."""
+ # Integer timestamp
+ ts = Timestamp("created_at") == 1679062200 # 2023-03-17T14:30:00+00:00
+ assert str(ts) == "@created_at:[1679062200.0 1679062200.0]"
+
+ # Float timestamp
+ ts = Timestamp("created_at") == 1679062200.5
+ assert str(ts) == "@created_at:[1679062200.5 1679062200.5]"
+
+
+def test_timestamp_operators():
+ """Test all comparison operators for Timestamp filter."""
+ dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc)
+ ts_value = dt.timestamp()
+
+ # Equal
+ ts = Timestamp("created_at") == dt
+ assert str(ts) == f"@created_at:[{ts_value} {ts_value}]"
+
+ # Not equal
+ ts = Timestamp("created_at") != dt
+ assert str(ts) == f"(-@created_at:[{ts_value} {ts_value}])"
+
+ # Greater than
+ ts = Timestamp("created_at") > dt
+ assert str(ts) == f"@created_at:[({ts_value} +inf]"
+
+ # Less than
+ ts = Timestamp("created_at") < dt
+ assert str(ts) == f"@created_at:[-inf ({ts_value}]"
+
+ # Greater than or equal
+ ts = Timestamp("created_at") >= dt
+ assert str(ts) == f"@created_at:[{ts_value} +inf]"
+
+ # Less than or equal
+ ts = Timestamp("created_at") <= dt
+ assert str(ts) == f"@created_at:[-inf {ts_value}]"
+
+ td = timedelta(days=5)
+ dt2 = dt + td
+ ts_value2 = dt2.timestamp()
+
+ ts = Timestamp("created_at").between(dt, dt2)
+ assert str(ts) == f"@created_at:[{ts_value} {ts_value2}]"
+
+ ts = Timestamp("created_at").between(dt, dt2, inclusive="neither")
+ assert str(ts) == f"@created_at:[({ts_value} ({ts_value2}]"
+
+ ts = Timestamp("created_at").between(dt, dt2, inclusive="left")
+ assert str(ts) == f"@created_at:[{ts_value} ({ts_value2}]"
+
+ ts = Timestamp("created_at").between(dt, dt2, inclusive="right")
+ assert str(ts) == f"@created_at:[({ts_value} {ts_value2}]"
+
+
+def test_timestamp_between():
+ """Test the between method for date ranges."""
+ start = datetime(2023, 3, 1, 0, 0, 0, tzinfo=timezone.utc)
+ end = datetime(2023, 3, 31, 23, 59, 59, tzinfo=timezone.utc)
+
+ ts = Timestamp("created_at").between(start, end)
+
+ start_ts = start.timestamp()
+ end_ts = end.timestamp()
+
+ assert str(ts) == f"@created_at:[{start_ts} {end_ts}]"
+
+ # Test with dates (should expand to full days)
+ start_date = date(2023, 3, 1)
+ end_date = date(2023, 3, 31)
+
+ ts = Timestamp("created_at").between(start_date, end_date)
+
+ # Start should be beginning of day
+ expected_start = datetime.combine(start_date, datetime.min.time())
+ expected_start = expected_start.replace(tzinfo=timezone.utc)
+
+ # End should be end of day
+ expected_end = datetime.combine(end_date, datetime.max.time())
+ expected_end = expected_end.replace(tzinfo=timezone.utc)
+
+ expected_start_ts = expected_start.timestamp()
+ expected_end_ts = expected_end.timestamp()
+
+ assert str(ts) == f"@created_at:[{expected_start_ts} {expected_end_ts}]"
+
+
+def test_timestamp_none():
+ """Test handling of None values."""
+ ts = Timestamp("created_at") == None
+ assert str(ts) == "*"
+
+ ts = Timestamp("created_at") != None
+ assert str(ts) == "*"
+
+ ts = Timestamp("created_at") > None
+ assert str(ts) == "*"
+
+
+def test_timestamp_invalid_input():
+ """Test error handling for invalid inputs."""
+ # Invalid ISO format
+ with pytest.raises(ValueError):
+ Timestamp("created_at") == "not-a-date"
+
+ # Unsupported type
+ with pytest.raises(TypeError):
+ Timestamp("created_at") == object()
+
+
+def test_timestamp_filter_combination():
+ """Test combining timestamp filters with other filters."""
+ from redisvl.query.filter import Num, Tag
+
+ ts = Timestamp("created_at") > datetime(2023, 3, 1)
+ num = Num("age") > 30
+ tag = Tag("status") == "active"
+
+ combined = ts & num & tag
+
+ # The exact string depends on the timestamp value, but we can check structure
+ assert str(combined).startswith("((@created_at:")
+ assert "@age:[(30 +inf]" in str(combined)
+ assert "@status:{active}" in str(combined)