diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 9568669d..3d399dcc 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -22,7 +22,7 @@ { "data": { "text/html": [ - "
useragejobcredit_scoreoffice_locationuser_embedding
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='
" + "
useragejobcredit_scoreoffice_locationuser_embeddinglast_updated
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1741627789
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1741627789
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1710696589
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'1742232589
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'1739644189
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1742232589
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='1742232589
" ], "text/plain": [ "" @@ -58,6 +58,7 @@ " {\"name\": \"credit_score\", \"type\": \"tag\"},\n", " {\"name\": \"job\", \"type\": \"text\"},\n", " {\"name\": \"age\", \"type\": \"numeric\"},\n", + " {\"name\": \"last_updated\", \"type\": \"numeric\"},\n", " {\"name\": \"office_location\", \"type\": \"geo\"},\n", " {\n", " \"name\": \"user_embedding\",\n", @@ -83,7 +84,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "13:02:18 redisvl.index.index INFO Index already exists, overwriting.\n" + "11:40:25 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -99,28 +100,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float16_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float32_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. float32_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. bfloat_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. user_queries\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. student tutor\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 10. tutor\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 11. bfloat_session\n" - ] - } - ], + "outputs": [], "source": [ "# use the CLI to see the created index\n", "!rvl index listall" @@ -128,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -136,6 +118,26 @@ "keys = index.load(data)" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.info()['num_docs']" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -157,13 +159,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" ], "text/plain": [ "" @@ -182,7 +184,7 @@ "v = VectorQuery(\n", " vector=[0.1, 0.1, 0.5],\n", " vector_field_name=\"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\", \"last_updated\"],\n", " filter_expression=t\n", ")\n", "\n", @@ -192,13 +194,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" ], "text/plain": [ "" @@ -316,13 +318,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" ], "text/plain": [ "" @@ -335,7 +337,7 @@ "source": [ "from redisvl.query.filter import Num\n", "\n", - "numeric_filter = Num(\"age\") > 15\n", + "numeric_filter = Num(\"age\").between(15, 35)\n", "\n", "v.set_filter(numeric_filter)\n", "result_print(index.query(v))" @@ -393,6 +395,132 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Timestamp Filters\n", + "\n", + "In redis all times are stored as an epoch time numeric however, this class allows you to filter with python datetime for ease of use. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch comparison: 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "print(f'Epoch comparison: {dt.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\") > dt\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch comparison: 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0johnhigh18engineer-122.4194,37.77491741627789
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "\n", + "print(f'Epoch comparison: {dt.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\") < dt\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch between: 1736880339.132589 - 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0johnhigh18engineer-122.4194,37.77491741627789
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt_1 = datetime(2025, 1, 14, 13, 45, 39, 132589)\n", + "dt_2 = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "\n", + "print(f'Epoch between: {dt_1.timestamp()} - {dt_2.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\").between(dt_1, dt_2)\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -404,13 +532,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" ], "text/plain": [ "" @@ -771,13 +899,13 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -791,8 +919,9 @@ "t = Tag(\"credit_score\") == \"high\"\n", "low = Num(\"age\") >= 18\n", "high = Num(\"age\") <= 100\n", + "ts = Timestamp(\"last_updated\") > datetime(2025, 3, 16, 13, 45, 39, 132589)\n", "\n", - "combined = t & low & high\n", + "combined = t & low & high & ts\n", "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", @@ -814,13 +943,13 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -1325,7 +1454,7 @@ ], "metadata": { "kernelspec": { - "display_name": "env", + "display_name": "redisvl-Q9FZQJWe-py3.11", "language": "python", "name": "python3" }, @@ -1339,7 +1468,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.9" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/hybrid_example_data.pkl b/docs/user_guide/hybrid_example_data.pkl index b5928b91..2c8a92da 100644 Binary files a/docs/user_guide/hybrid_example_data.pkl and b/docs/user_guide/hybrid_example_data.pkl differ diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 41e6e214..c741ca45 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -310,7 +310,8 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) + result = self._vectorizer.embed(prompt) + return result # type: ignore async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -318,7 +319,8 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return await self._vectorizer.aembed(prompt) + result = await self._vectorizer.aembed(prompt) + return result # type: ignore def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 4a7e72c3..8aff7524 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -366,14 +366,14 @@ def __call__( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore aggregation_method = ( aggregation_method or self.routing_config.aggregation_method ) # perform route classification - top_route_match = self._classify_route(vector, aggregation_method) + top_route_match = self._classify_route(vector, aggregation_method) # type: ignore return top_route_match @deprecated_argument("distance_threshold") @@ -400,7 +400,7 @@ def route_many( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore max_k = max_k or self.routing_config.max_k aggregation_method = ( @@ -409,7 +409,7 @@ def route_many( # classify routes top_route_matches = self._classify_multi_route( - vector, max_k, aggregation_method + vector, max_k, aggregation_method # type: ignore ) return top_route_matches diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6825afa9..1aa15315 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -349,7 +349,7 @@ def add_messages( role=message[ROLE_FIELD_NAME], content=message[CONTENT_FIELD_NAME], session_tag=session_tag, - vector_field=content_vector, + vector_field=content_vector, # type: ignore ) if TOOL_FIELD_NAME in message: diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 1e8987ff..ced52520 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -1,3 +1,5 @@ +import datetime +import re from enum import Enum from functools import wraps from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -8,6 +10,19 @@ # mypy: disable-error-code="override" +class Inclusive(str, Enum): + """Enum for valid inclusive options""" + + BOTH = "both" + """Inclusive of both sides of range (default)""" + NEITHER = "neither" + """Inclusive of neither side of range""" + LEFT = "left" + """Inclusive of only left""" + RIGHT = "right" + """Inclusive of only right""" + + class FilterOperator(Enum): EQ = 1 NE = 2 @@ -19,6 +34,7 @@ class FilterOperator(Enum): AND = 8 LIKE = 9 IN = 10 + BETWEEN = 11 class FilterField: @@ -267,6 +283,7 @@ class Num(FilterField): FilterOperator.GT: ">", FilterOperator.LE: "<=", FilterOperator.GE: ">=", + FilterOperator.BETWEEN: "between", } OPERATOR_MAP: Dict[FilterOperator, str] = { FilterOperator.EQ: "@%s:[%s %s]", @@ -275,8 +292,10 @@ class Num(FilterField): FilterOperator.LT: "@%s:[-inf (%s]", FilterOperator.GE: "@%s:[%s +inf]", FilterOperator.LE: "@%s:[-inf %s]", + FilterOperator.BETWEEN: "@%s:[%s %s]", } - SUPPORTED_VAL_TYPES = (int, float, type(None)) + + SUPPORTED_VAL_TYPES = (int, float, tuple, type(None)) def __eq__(self, other: int) -> "FilterExpression": """Create a Numeric equality filter expression. @@ -373,10 +392,51 @@ def __le__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LE) return FilterExpression(str(self)) + @staticmethod + def _validate_inclusive_string(inclusive: str) -> Inclusive: + try: + return Inclusive(inclusive) + except: + raise ValueError( + f"Invalid inclusive value must be: {[i.value for i in Inclusive]}" + ) + + def _format_inclusive_between( + self, inclusive: Inclusive, start: int, end: int + ) -> str: + if inclusive.value == Inclusive.BOTH.value: + return f"@{self._field}:[{start} {end}]" + + if inclusive.value == Inclusive.NEITHER.value: + return f"@{self._field}:[({start} ({end}]" + + if inclusive.value == Inclusive.LEFT.value: + return f"@{self._field}:[{start} ({end}]" + + if inclusive.value == Inclusive.RIGHT.value: + return f"@{self._field}:[({start} {end}]" + + raise ValueError(f"Inclusive value not found") + + def between( + self, start: int, end: int, inclusive: str = "both" + ) -> "FilterExpression": + """Operator for searching values between two numeric values.""" + inclusive = self._validate_inclusive_string(inclusive) + expression = self._format_inclusive_between(inclusive, start, end) + + return FilterExpression(expression) + def __str__(self) -> str: """Return the Redis Query string for the Numeric filter""" if self._value is None: return "*" + if self._operator == FilterOperator.BETWEEN: + return self.OPERATOR_MAP[self._operator] % ( + self._field, + self._value[0], + self._value[1], + ) if self._operator == FilterOperator.EQ or self._operator == FilterOperator.NE: return self.OPERATOR_MAP[self._operator] % ( self._field, @@ -562,3 +622,213 @@ def __str__(self) -> str: if not self._filter: raise ValueError("Improperly initialized FilterExpression") return self._filter + + +class Timestamp(Num): + """ + A timestamp filter for querying date/time fields in Redis. + + This filter can handle various date and time formats, including: + - datetime objects (with or without timezone) + - date objects + - ISO-8601 formatted strings + - Unix timestamps (as integers or floats) + + All timestamps are converted to Unix timestamps in UTC for consistency. + """ + + SUPPORTED_TYPES = ( + datetime.datetime, + datetime.date, + tuple, # Date range + str, # ISO format + int, # Unix timestamp + float, # Unix timestamp with fractional seconds + type(None), + ) + + @staticmethod + def _is_date(value: Any) -> bool: + """Check if the value is a date object. Either ISO string or datetime.date.""" + return ( + isinstance(value, datetime.date) + and not isinstance(value, datetime.datetime) + ) or (isinstance(value, str) and Timestamp._is_date_only(value)) + + @staticmethod + def _is_date_only(iso_string: str) -> bool: + """Check if an ISO formatted string only includes date information using regex.""" + # Match YYYY-MM-DD format exactly + date_pattern = r"^\d{4}-\d{2}-\d{2}$" + return bool(re.match(date_pattern, iso_string)) + + def _convert_to_timestamp(self, value, end_date=False): + """ + Convert various inputs to a Unix timestamp (seconds since epoch in UTC). + + Args: + value: A datetime, date, string, int, or float + + Returns: + float: Unix timestamp + """ + if value is None: + return None + + if isinstance(value, (int, float)): + # Already a Unix timestamp + return float(value) + + if isinstance(value, str): + # Parse ISO format + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + raise ValueError(f"String timestamp must be in ISO format: {value}") + + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + # Convert to max or min if for dates based on end or not + if end_date: + value = datetime.datetime.combine(value, datetime.time.max) + else: + value = datetime.datetime.combine(value, datetime.time.min) + + # Ensure the datetime is timezone-aware (UTC) + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + else: + value = value.astimezone(datetime.timezone.utc) + + # Convert to Unix timestamp + return value.timestamp() + + raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") + + def __eq__(self, other) -> FilterExpression: + """ + Filter for timestamps equal to the specified value. + For date objects (without time), this matches the entire day. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + if self._is_date(other): + # For date objects, match the entire day + if isinstance(other, str): + other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + start = datetime.datetime.combine(other, datetime.time.min).astimezone( + datetime.timezone.utc + ) + end = datetime.datetime.combine(other, datetime.time.max).astimezone( + datetime.timezone.utc + ) + return self.between(start, end) + + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) + return FilterExpression(str(self)) + + def __ne__(self, other) -> FilterExpression: + """ + Filter for timestamps not equal to the specified value. + For date objects (without time), this excludes the entire day. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + if self._is_date(other): + # For date objects, exclude the entire day + if isinstance(other, str): + other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + start = datetime.datetime.combine(other, datetime.time.min) + end = datetime.datetime.combine(other, datetime.time.max) + return self.between(start, end) + + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.NE) + return FilterExpression(str(self)) + + def __gt__(self, other): + """ + Filter for timestamps greater than the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GT) + return FilterExpression(str(self)) + + def __lt__(self, other): + """ + Filter for timestamps less than the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LT) + return FilterExpression(str(self)) + + def __ge__(self, other): + """ + Filter for timestamps greater than or equal to the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GE) + return FilterExpression(str(self)) + + def __le__(self, other): + """ + Filter for timestamps less than or equal to the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LE) + return FilterExpression(str(self)) + + def between(self, start, end, inclusive: str = "both"): + """ + Filter for timestamps between start and end (inclusive). + + Args: + start: A datetime, date, ISO string, or Unix timestamp + end: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + inclusive = self._validate_inclusive_string(inclusive) + + start_ts = self._convert_to_timestamp(start) + end_ts = self._convert_to_timestamp(end, end_date=True) + + expression = self._format_inclusive_between(inclusive, start_ts, end_ts) + + return FilterExpression(expression) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index b3a63fa9..189b6e1a 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -49,34 +49,69 @@ def check_dims(cls, value): return value @abstractmethod - def embed_many( + def embed( self, - texts: List[str], + text: str, preprocess: Optional[Callable] = None, - batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[float], bytes]: + """Embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ raise NotImplementedError @abstractmethod - def embed( + def embed_many( self, - text: str, + texts: List[str], preprocess: Optional[Callable] = None, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[List[float]], List[bytes]]: + """Embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ raise NotImplementedError async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: + """Asynchronously embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs) @@ -86,7 +121,18 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: + """Asynchronously embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed(text, preprocess, as_buffer, **kwargs) diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 7b3b7d01..410280e5 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -178,7 +178,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the AzureOpenAI API. Args: @@ -191,7 +191,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -205,7 +206,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -224,7 +227,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the AzureOpenAI API. Args: @@ -235,7 +238,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -248,7 +252,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -261,10 +267,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the AzureOpenAI API. Args: @@ -277,7 +283,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -292,7 +299,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -312,7 +319,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -323,7 +330,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -336,7 +344,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index 5858aff8..2d40685d 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -135,8 +135,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using Amazon Bedrock. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the AWS Bedrock Embeddings API. Args: text (str): Text to embed. @@ -144,7 +144,8 @@ def embed( as_buffer (bool): Whether to return as byte buffer. Returns: - List[float]: The embedding vector. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If text is not a string. @@ -156,7 +157,7 @@ def embed( text = preprocess(text) response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) embedding = response_body["embedding"] @@ -177,17 +178,18 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed multiple texts using Amazon Bedrock. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the AWS Bedrock Embeddings API. Args: texts (List[str]): List of texts to embed. preprocess (Optional[Callable]): Optional preprocessing function. - batch_size (int): Size of batches for processing. + batch_size (int): Size of batches for processing. Defaults to 10. as_buffer (bool): Whether to return as byte buffers. Returns: - List[List[float]]: List of embedding vectors. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If texts is not a list of strings. @@ -206,7 +208,7 @@ def embed_many( batch_embeddings = [] for text in batch: response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) batch_embeddings.append(response_body["embedding"]) diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index bd6481fe..4e6192e2 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Optional +import warnings +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -64,7 +65,8 @@ def __init__( Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). - Defaults to 'float32'. + 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map + to Cohere's corresponding embedding types. Defaults to 'float32'. Raises: ImportError: If the cohere library is not installed. @@ -114,6 +116,15 @@ def _set_model_dims(self) -> int: raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) + def _get_cohere_embedding_type(self, dtype: str) -> List[str]: + """Map dtype to appropriate Cohere embedding_types value.""" + if dtype == "int8": + return ["int8"] + elif dtype == "uint8": + return ["uint8"] + else: + return ["float"] + @deprecated_argument("dtype") def embed( self, @@ -121,7 +132,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], List[int], bytes]: """Embed a chunk of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -150,13 +161,17 @@ def embed( Required for embedding models v3 and higher. Returns: - List[float]: Embedding. + Union[List[float], List[int], bytes]: + - If as_buffer=True: Returns a bytes object + - If as_buffer=False: + - For dtype="float32": Returns a list of floats + - For dtype="int8" or "uint8": Returns a list of integers Raises: TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") @@ -171,9 +186,34 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - embedding = self._client.embed( - texts=[text], model=self.model, input_type=input_type - ).embeddings[0] + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + + response = self._client.embed( + texts=[text], + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, + ) + + # Extract the appropriate embedding based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + embedding = getattr(response.embeddings, embed_type)[0] + else: + embedding = response.embeddings[0] # Fallback for older API versions + return self._process_embedding(embedding, as_buffer, dtype) @retry( @@ -189,7 +229,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[List[int]], List[bytes]]: """Embed many chunks of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -221,13 +261,17 @@ def embed_many( Required for embedding models v3 and higher. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[List[int]], List[bytes]]: + - If as_buffer=True: Returns a list of bytes objects + - If as_buffer=False: + - For dtype="float32": Returns a list of lists of floats + - For dtype="int8" or "uint8": Returns a list of lists of integers Raises: TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") @@ -241,14 +285,41 @@ def embed_many( dtype = kwargs.pop("dtype", self.dtype) + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, ) + + # Extract the appropriate embeddings based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + batch_embeddings = getattr(response.embeddings, embed_type) + else: + batch_embeddings = ( + response.embeddings + ) # Fallback for older API versions + embeddings += [ self._process_embedding(embedding, as_buffer, dtype) - for embedding in response.embeddings + for embedding in batch_embeddings ] return embeddings diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 4558d4d7..ed284d29 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic import PrivateAttr @@ -162,7 +162,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """ Generate an embedding for a single piece of text using your sync embed function. @@ -172,7 +172,7 @@ def embed( as_buffer (bool): If True, return the embedding as a byte buffer. Returns: - List[float]: The embedding of the input text. + Union[List[float], bytes]: The embedding of the input text. Raises: TypeError: If the input is not a string. @@ -200,7 +200,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Generate embeddings for multiple pieces of text in batches using your sync embed_many function. @@ -211,7 +211,7 @@ def embed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -226,7 +226,7 @@ def embed_many( raise NotImplementedError("No embed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): @@ -288,10 +288,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Asynchronously generate embeddings for multiple pieces of text in batches. @@ -302,7 +302,7 @@ async def aembed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -317,7 +317,7 @@ async def aembed_many( raise NotImplementedError("No aembed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index 8f81b85c..bafba41d 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic.v1 import PrivateAttr @@ -89,7 +89,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Hugging Face sentence transformer. Args: @@ -100,7 +100,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -121,10 +122,10 @@ def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the Hugging Face sentence transformer. @@ -138,7 +139,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index e930b3a4..05133b37 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -128,7 +128,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the Mistral API. Args: @@ -141,7 +141,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -155,7 +156,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(model=self.model, inputs=batch) + response = self._client.embeddings.create( + model=self.model, inputs=batch, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -174,7 +177,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Mistral API. Args: @@ -185,7 +188,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -198,7 +202,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(model=self.model, inputs=[text]) + result = self._client.embeddings.create( + model=self.model, inputs=[text], **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -211,7 +217,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> List[List[float]]: @@ -242,7 +248,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._client.embeddings.create_async( - model=self.model, inputs=batch + model=self.model, inputs=batch, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -287,7 +293,7 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) result = await self._client.embeddings.create_async( - model=self.model, inputs=[text] + model=self.model, inputs=[text], **kwargs ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 25b21c67..eee0764a 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -129,7 +129,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the OpenAI API. Args: @@ -142,7 +142,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -156,7 +157,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -175,7 +178,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the OpenAI API. Args: @@ -186,7 +189,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -199,7 +203,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -212,10 +218,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the OpenAI API. Args: @@ -228,7 +234,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -243,7 +250,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -263,7 +270,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -274,7 +281,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -287,7 +295,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 6d455c67..ebe2a625 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -141,8 +141,8 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed many chunks of texts using the VertexAI API. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the VertexAI Embeddings API. Args: texts (List[str]): List of text chunks to embed. @@ -154,7 +154,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -168,7 +169,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.get_embeddings(batch) + response = self._client.get_embeddings(batch, **kwargs) embeddings += [ self._process_embedding(r.values, as_buffer, dtype) for r in response ] @@ -186,8 +187,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using the VertexAI API. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the VertexAI Embeddings API. Args: text (str): Chunk of text to embed. @@ -197,7 +198,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -210,7 +212,7 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.get_embeddings([text]) + result = self._client.get_embeddings([text], **kwargs) return self._process_embedding(result[0].values, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index fbcbfd9e..9d015a81 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -124,7 +124,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -149,7 +149,8 @@ def embed( Check https://docs.voyageai.com/docs/embeddings Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If an invalid input_type is provided. @@ -171,7 +172,7 @@ def embed_many( batch_size: Optional[int] = None, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -198,14 +199,15 @@ def embed_many( Check https://docs.voyageai.com/docs/embeddings Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -235,7 +237,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -284,8 +286,8 @@ async def aembed_many( TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -315,7 +317,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -360,7 +362,6 @@ async def aembed( Raises: TypeError: In an invalid input_type is provided. """ - result = await self.aembed_many( texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs ) diff --git a/tests/conftest.py b/tests/conftest.py index 61c5de45..24da05e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from datetime import datetime, timezone import pytest from testcontainers.compose import DockerCompose @@ -68,12 +69,22 @@ def client(redis_url): @pytest.fixture -def sample_data(): +def sample_datetimes(): + return { + "low": datetime(2025, 1, 16, 13).astimezone(timezone.utc), + "mid": datetime(2025, 2, 16, 13).astimezone(timezone.utc), + "high": datetime(2025, 3, 16, 13).astimezone(timezone.utc), + } + + +@pytest.fixture +def sample_data(sample_datetimes): return [ { "user": "john", "age": 18, "job": "engineer", + "last_updated": sample_datetimes["low"].timestamp(), "credit_score": "high", "location": "-122.4194,37.7749", "user_embedding": [0.1, 0.1, 0.5], @@ -82,6 +93,7 @@ def sample_data(): "user": "mary", "age": 14, "job": "doctor", + "last_updated": sample_datetimes["low"].timestamp(), "credit_score": "low", "location": "-122.4194,37.7749", "user_embedding": [0.1, 0.1, 0.5], @@ -90,6 +102,7 @@ def sample_data(): "user": "nancy", "age": 94, "job": "doctor", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-122.4194,37.7749", "user_embedding": [0.7, 0.1, 0.5], @@ -98,6 +111,7 @@ def sample_data(): "user": "tyler", "age": 100, "job": "engineer", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-110.0839,37.3861", "user_embedding": [0.1, 0.4, 0.5], @@ -106,6 +120,7 @@ def sample_data(): "user": "tim", "age": 12, "job": "dermatologist", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-110.0839,37.3861", "user_embedding": [0.4, 0.4, 0.5], @@ -114,6 +129,7 @@ def sample_data(): "user": "taimur", "age": 15, "job": "CEO", + "last_updated": sample_datetimes["high"].timestamp(), "credit_score": "low", "location": "-110.0839,37.3861", "user_embedding": [0.6, 0.1, 0.5], @@ -122,6 +138,7 @@ def sample_data(): "user": "joe", "age": 35, "job": "dentist", + "last_updated": sample_datetimes["high"].timestamp(), "credit_score": "medium", "location": "-110.0839,37.3861", "user_embedding": [0.9, 0.9, 0.1], diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 271d36da..deb58cbc 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -1,9 +1,19 @@ +from datetime import timedelta + import pytest from redis.commands.search.result import Result from redisvl.index import SearchIndex from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery -from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text +from redisvl.query.filter import ( + FilterExpression, + Geo, + GeoRadius, + Num, + Tag, + Text, + Timestamp, +) from redisvl.redis.utils import array_to_buffer # TODO expand to multiple schema types and sync + async @@ -14,7 +24,14 @@ def vector_query(): return VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], ) @@ -23,7 +40,14 @@ def sorted_vector_query(): return VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], sort_by="age", ) @@ -31,7 +55,14 @@ def sorted_vector_query(): @pytest.fixture def filter_query(): return FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], filter_expression=Tag("credit_score") == "high", ) @@ -39,7 +70,14 @@ def filter_query(): @pytest.fixture def sorted_filter_query(): return FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], filter_expression=Tag("credit_score") == "high", sort_by="age", ) @@ -80,6 +118,7 @@ def index(sample_data, redis_url): {"name": "credit_score", "type": "tag"}, {"name": "job", "type": "text"}, {"name": "age", "type": "numeric"}, + {"name": "last_updated", "type": "numeric"}, {"name": "location", "type": "geo"}, { "name": "user_embedding", @@ -255,7 +294,7 @@ def query(request): return request.getfixturevalue(request.param) -def test_filters(index, query): +def test_filters(index, query, sample_datetimes): # Simple Tag Filter t = Tag("credit_score") == "high" search(query, index, t, 4, credit_check="high") @@ -310,6 +349,34 @@ def test_filters(index, query): t = Text("job") % "" search(query, index, t, 7) + # Timestamps + ts = Timestamp("last_updated") > sample_datetimes["mid"] + search(query, index, ts, 2) + + ts = Timestamp("last_updated") >= sample_datetimes["mid"] + search(query, index, ts, 5) + + ts = Timestamp("last_updated") < sample_datetimes["high"] + search(query, index, ts, 5) + + ts = Timestamp("last_updated") <= sample_datetimes["mid"] + search(query, index, ts, 5) + + ts = Timestamp("last_updated") == sample_datetimes["mid"] + search(query, index, ts, 3) + + ts = (Timestamp("last_updated") == sample_datetimes["low"]) | ( + Timestamp("last_updated") == sample_datetimes["high"] + ) + search(query, index, ts, 4) + + # could drop between if we prefer union syntax + ts = Timestamp("last_updated").between( + sample_datetimes["low"] + timedelta(seconds=1), + sample_datetimes["high"] - timedelta(seconds=1), + ) + search(query, index, ts, 3) + def test_manual_string_filters(index, query): # Simple Tag Filter diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index e1de4a46..36e444de 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from redisvl.utils.vectorize import ( @@ -287,7 +288,7 @@ def test_default_dtype(vectorizer_): VoyageAITextVectorizer, ], ) -def test_other_dtypes(vectorizer_): +def test_vectorizer_dtype_assignment(vectorizer_): # test initializing dtype in constructor for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: if issubclass(vectorizer_, CustomTextVectorizer): @@ -319,7 +320,7 @@ def test_other_dtypes(vectorizer_): VoyageAITextVectorizer, ], ) -def test_bad_dtypes(vectorizer_): +def test_non_supported_dtypes(vectorizer_): with pytest.raises(ValueError): vectorizer_(dtype="float25") @@ -392,3 +393,95 @@ async def test_avectorizer_bad_input(avectorizer): with pytest.raises(TypeError): avectorizer.embed_many(42) + + +@pytest.mark.requires_api_keys +@pytest.mark.parametrize( + "dtype,expected_type", + [ + ("float32", float), # Float dtype should return floats + ("int8", int), # Int8 dtype should return ints + ("uint8", int), # Uint8 dtype should return ints + ], +) +def test_cohere_dtype_support(dtype, expected_type): + """Test that CohereTextVectorizer properly handles different dtypes for embeddings.""" + text = "This is a test sentence." + texts = ["First test sentence.", "Second test sentence."] + + # Create vectorizer with specified dtype + vectorizer = CohereTextVectorizer(dtype=dtype) + + # Verify the correct mapping of dtype to Cohere embedding_types + if dtype == "int8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["int8"] + elif dtype == "uint8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["uint8"] + else: + # All other dtypes should map to float + assert vectorizer._get_cohere_embedding_type(dtype) == ["float"] + + # Test single embedding + embedding = vectorizer.embed(text, input_type="search_document") + assert isinstance(embedding, list) + assert len(embedding) == vectorizer.dims + + # Check that all elements are of the expected type + assert all( + isinstance(val, expected_type) for val in embedding + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" + + # Test multiple embeddings + embeddings = vectorizer.embed_many(texts, input_type="search_document") + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + assert all( + isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings + ) + + # Check that all elements in all embeddings are of the expected type + for emb in embeddings: + assert all( + isinstance(val, expected_type) for val in emb + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" + + # Test as_buffer output format + embedding_buffer = vectorizer.embed( + text, input_type="search_document", as_buffer=True + ) + assert isinstance(embedding_buffer, bytes) + + # Test embed_many with as_buffer=True + buffer_embeddings = vectorizer.embed_many( + texts, input_type="search_document", as_buffer=True + ) + assert all(isinstance(emb, bytes) for emb in buffer_embeddings) + + # Compare dimensions between buffer and list formats + assert len(np.frombuffer(embedding_buffer, dtype=dtype)) == len(embedding) + + +@pytest.mark.requires_api_keys +def test_cohere_embedding_types_warning(): + """Test that a warning is raised when embedding_types parameter is passed.""" + text = "This is a test sentence." + texts = ["First test sentence.", "Second test sentence."] + vectorizer = CohereTextVectorizer() + + # Test warning for single embedding + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embedding = vectorizer.embed( + text, + input_type="search_document", + embedding_types=["uint8"], # explicitly testing the anti-pattern here + ) + assert isinstance(embedding, list) + assert len(embedding) == vectorizer.dims + + # Test warning for multiple embeddings + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embeddings = vectorizer.embed_many( + texts, input_type="search_document", embedding_types=["uint8"] + ) + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 067402ea..dae74240 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -1,6 +1,8 @@ +from datetime import date, datetime, time, timedelta, timezone + import pytest -from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text +from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text, Timestamp # Test cases for various scenarios of tag usage, combinations, and their string representations. @@ -110,6 +112,18 @@ def test_numeric_filter(): nf = Num("numeric_field") != None assert str(nf) == "*" + nf = Num("numeric_field").between(2, 5) + assert str(nf) == "@numeric_field:[2 5]" + + nf = Num("numeric_field").between(2, 5, inclusive="neither") + assert str(nf) == "@numeric_field:[2 5]" + + nf = Num("numeric_field").between(2, 5, inclusive="left") + assert str(nf) == "@numeric_field:[2 (5]" + + nf = Num("numeric_field").between(2, 5, inclusive="right") + assert str(nf) == "@numeric_field:[(2 5]" + def test_text_filter(): txt_f = Text("text_field") == "text" @@ -292,3 +306,179 @@ def test_num_filter_zero(): assert ( str(num_filter) == "@chunk_number:[0 0]" ), "Num filter should handle zero correctly" + + +def test_timestamp_datetime(): + """Test Timestamp filter with datetime objects.""" + # Test with timezone-aware datetime + dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc) + ts = Timestamp("created_at") == dt + # Expected timestamp would be the Unix timestamp for the datetime + expected_ts = dt.timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + # Test with timezone-naive datetime (should convert to UTC) + dt = datetime(2023, 3, 17, 14, 30, 0) + ts = Timestamp("created_at") == dt + expected_ts = dt.replace(tzinfo=timezone.utc).timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + +def test_timestamp_date(): + """Test Timestamp filter with date objects (should match full day).""" + d = date(2023, 3, 17) + ts = Timestamp("created_at") == d + + expected_ts_start = ( + datetime.combine(d, time.min).astimezone(timezone.utc).timestamp() + ) + expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp() + + assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]" + + +def test_timestamp_iso_string(): + """Test Timestamp filter with ISO format strings.""" + # Date-only ISO string + ts = Timestamp("created_at") == "2023-03-17" + d = date(2023, 3, 17) + expected_ts_start = ( + datetime.combine(d, time.min).astimezone(timezone.utc).timestamp() + ) + expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp() + assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]" + + # Full ISO datetime string + dt_str = "2023-03-17T14:30:00+00:00" + ts = Timestamp("created_at") == dt_str + dt = datetime.fromisoformat(dt_str) + expected_ts = dt.timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + +def test_timestamp_unix(): + """Test Timestamp filter with Unix timestamps.""" + # Integer timestamp + ts = Timestamp("created_at") == 1679062200 # 2023-03-17T14:30:00+00:00 + assert str(ts) == "@created_at:[1679062200.0 1679062200.0]" + + # Float timestamp + ts = Timestamp("created_at") == 1679062200.5 + assert str(ts) == "@created_at:[1679062200.5 1679062200.5]" + + +def test_timestamp_operators(): + """Test all comparison operators for Timestamp filter.""" + dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc) + ts_value = dt.timestamp() + + # Equal + ts = Timestamp("created_at") == dt + assert str(ts) == f"@created_at:[{ts_value} {ts_value}]" + + # Not equal + ts = Timestamp("created_at") != dt + assert str(ts) == f"(-@created_at:[{ts_value} {ts_value}])" + + # Greater than + ts = Timestamp("created_at") > dt + assert str(ts) == f"@created_at:[({ts_value} +inf]" + + # Less than + ts = Timestamp("created_at") < dt + assert str(ts) == f"@created_at:[-inf ({ts_value}]" + + # Greater than or equal + ts = Timestamp("created_at") >= dt + assert str(ts) == f"@created_at:[{ts_value} +inf]" + + # Less than or equal + ts = Timestamp("created_at") <= dt + assert str(ts) == f"@created_at:[-inf {ts_value}]" + + td = timedelta(days=5) + dt2 = dt + td + ts_value2 = dt2.timestamp() + + ts = Timestamp("created_at").between(dt, dt2) + assert str(ts) == f"@created_at:[{ts_value} {ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="neither") + assert str(ts) == f"@created_at:[({ts_value} ({ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="left") + assert str(ts) == f"@created_at:[{ts_value} ({ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="right") + assert str(ts) == f"@created_at:[({ts_value} {ts_value2}]" + + +def test_timestamp_between(): + """Test the between method for date ranges.""" + start = datetime(2023, 3, 1, 0, 0, 0, tzinfo=timezone.utc) + end = datetime(2023, 3, 31, 23, 59, 59, tzinfo=timezone.utc) + + ts = Timestamp("created_at").between(start, end) + + start_ts = start.timestamp() + end_ts = end.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + # Test with dates (should expand to full days) + start_date = date(2023, 3, 1) + end_date = date(2023, 3, 31) + + ts = Timestamp("created_at").between(start_date, end_date) + + # Start should be beginning of day + expected_start = datetime.combine(start_date, datetime.min.time()) + expected_start = expected_start.replace(tzinfo=timezone.utc) + + # End should be end of day + expected_end = datetime.combine(end_date, datetime.max.time()) + expected_end = expected_end.replace(tzinfo=timezone.utc) + + expected_start_ts = expected_start.timestamp() + expected_end_ts = expected_end.timestamp() + + assert str(ts) == f"@created_at:[{expected_start_ts} {expected_end_ts}]" + + +def test_timestamp_none(): + """Test handling of None values.""" + ts = Timestamp("created_at") == None + assert str(ts) == "*" + + ts = Timestamp("created_at") != None + assert str(ts) == "*" + + ts = Timestamp("created_at") > None + assert str(ts) == "*" + + +def test_timestamp_invalid_input(): + """Test error handling for invalid inputs.""" + # Invalid ISO format + with pytest.raises(ValueError): + Timestamp("created_at") == "not-a-date" + + # Unsupported type + with pytest.raises(TypeError): + Timestamp("created_at") == object() + + +def test_timestamp_filter_combination(): + """Test combining timestamp filters with other filters.""" + from redisvl.query.filter import Num, Tag + + ts = Timestamp("created_at") > datetime(2023, 3, 1) + num = Num("age") > 30 + tag = Tag("status") == "active" + + combined = ts & num & tag + + # The exact string depends on the timestamp value, but we can check structure + assert str(combined).startswith("((@created_at:") + assert "@age:[(30 +inf]" in str(combined) + assert "@status:{active}" in str(combined)