From 38c0a6041287875d7c8609863fb41f412045b8e7 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 4 Mar 2025 20:42:00 -0500 Subject: [PATCH 1/9] Improve vectorizer kwargs and typing (#291) ## Changes Made 1. **Expanded Type Support**: - Updated return type signatures across all vectorizers to properly reflect the ability to return either data lists (`List[float]`) or binary buffers (`bytes`) - Added special handling for Cohere's integer embedding types (`List[int]`) 2. **Standardized Interface**: - Uniform type annotations and docstrings across all vectorizer implementations - Consistent default batch sizes (10) for better predictability 3. **Improved Provider-Specific Support**: - Enhanced kwargs forwarding to allow passing provider-specific parameters - Better warnings for deprecated parameters (like Cohere's `embedding_types`) 4. **Fixed Type Checking**: - Added strategic type ignores to resolve MyPy errors - Made minimal changes to consumer code to handle the expanded return types ## Motivation These changes create a more consistent and flexible vectorizer interface that: - Accurately represents what the methods can return - Accommodates provider-specific features (like Cohere's integer embeddings) - Provides clearer documentation for users - Maintains backward compatibility ## Future Improvements For future consideration: - Introduce helper methods (like `embed_as_list()`) that guarantee specific return types when needed - Add more robust type conversion in consumer code that relies on specific types - Develop a cleaner separation between the base vectorizer interface and provider-specific extensions - Consider a more structured approach to provider-specific parameters --- redisvl/extensions/llmcache/semantic.py | 6 +- redisvl/extensions/router/semantic.py | 8 +- .../session_manager/semantic_session.py | 2 +- redisvl/utils/vectorize/base.py | 68 ++++++++++--- redisvl/utils/vectorize/text/azureopenai.py | 38 +++++--- redisvl/utils/vectorize/text/bedrock.py | 22 +++-- redisvl/utils/vectorize/text/cohere.py | 97 ++++++++++++++++--- redisvl/utils/vectorize/text/custom.py | 20 ++-- redisvl/utils/vectorize/text/huggingface.py | 14 +-- redisvl/utils/vectorize/text/mistral.py | 26 +++-- redisvl/utils/vectorize/text/openai.py | 38 +++++--- redisvl/utils/vectorize/text/vertexai.py | 20 ++-- redisvl/utils/vectorize/text/voyageai.py | 25 ++--- tests/integration/test_vectorizers.py | 97 ++++++++++++++++++- 14 files changed, 363 insertions(+), 118 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 41e6e214..c741ca45 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -310,7 +310,8 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) + result = self._vectorizer.embed(prompt) + return result # type: ignore async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -318,7 +319,8 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return await self._vectorizer.aembed(prompt) + result = await self._vectorizer.aembed(prompt) + return result # type: ignore def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 4a7e72c3..8aff7524 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -366,14 +366,14 @@ def __call__( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore aggregation_method = ( aggregation_method or self.routing_config.aggregation_method ) # perform route classification - top_route_match = self._classify_route(vector, aggregation_method) + top_route_match = self._classify_route(vector, aggregation_method) # type: ignore return top_route_match @deprecated_argument("distance_threshold") @@ -400,7 +400,7 @@ def route_many( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore max_k = max_k or self.routing_config.max_k aggregation_method = ( @@ -409,7 +409,7 @@ def route_many( # classify routes top_route_matches = self._classify_multi_route( - vector, max_k, aggregation_method + vector, max_k, aggregation_method # type: ignore ) return top_route_matches diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6825afa9..1aa15315 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -349,7 +349,7 @@ def add_messages( role=message[ROLE_FIELD_NAME], content=message[CONTENT_FIELD_NAME], session_tag=session_tag, - vector_field=content_vector, + vector_field=content_vector, # type: ignore ) if TOOL_FIELD_NAME in message: diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index b3a63fa9..189b6e1a 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -49,34 +49,69 @@ def check_dims(cls, value): return value @abstractmethod - def embed_many( + def embed( self, - texts: List[str], + text: str, preprocess: Optional[Callable] = None, - batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[float], bytes]: + """Embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ raise NotImplementedError @abstractmethod - def embed( + def embed_many( self, - text: str, + texts: List[str], preprocess: Optional[Callable] = None, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[List[float]], List[bytes]]: + """Embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ raise NotImplementedError async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: + """Asynchronously embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs) @@ -86,7 +121,18 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: + """Asynchronously embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed(text, preprocess, as_buffer, **kwargs) diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 7b3b7d01..410280e5 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -178,7 +178,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the AzureOpenAI API. Args: @@ -191,7 +191,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -205,7 +206,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -224,7 +227,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the AzureOpenAI API. Args: @@ -235,7 +238,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -248,7 +252,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -261,10 +267,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the AzureOpenAI API. Args: @@ -277,7 +283,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -292,7 +299,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -312,7 +319,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -323,7 +330,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -336,7 +344,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index 5858aff8..2d40685d 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -135,8 +135,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using Amazon Bedrock. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the AWS Bedrock Embeddings API. Args: text (str): Text to embed. @@ -144,7 +144,8 @@ def embed( as_buffer (bool): Whether to return as byte buffer. Returns: - List[float]: The embedding vector. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If text is not a string. @@ -156,7 +157,7 @@ def embed( text = preprocess(text) response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) embedding = response_body["embedding"] @@ -177,17 +178,18 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed multiple texts using Amazon Bedrock. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the AWS Bedrock Embeddings API. Args: texts (List[str]): List of texts to embed. preprocess (Optional[Callable]): Optional preprocessing function. - batch_size (int): Size of batches for processing. + batch_size (int): Size of batches for processing. Defaults to 10. as_buffer (bool): Whether to return as byte buffers. Returns: - List[List[float]]: List of embedding vectors. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If texts is not a list of strings. @@ -206,7 +208,7 @@ def embed_many( batch_embeddings = [] for text in batch: response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) batch_embeddings.append(response_body["embedding"]) diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index bd6481fe..4e6192e2 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Optional +import warnings +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -64,7 +65,8 @@ def __init__( Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). - Defaults to 'float32'. + 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map + to Cohere's corresponding embedding types. Defaults to 'float32'. Raises: ImportError: If the cohere library is not installed. @@ -114,6 +116,15 @@ def _set_model_dims(self) -> int: raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) + def _get_cohere_embedding_type(self, dtype: str) -> List[str]: + """Map dtype to appropriate Cohere embedding_types value.""" + if dtype == "int8": + return ["int8"] + elif dtype == "uint8": + return ["uint8"] + else: + return ["float"] + @deprecated_argument("dtype") def embed( self, @@ -121,7 +132,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], List[int], bytes]: """Embed a chunk of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -150,13 +161,17 @@ def embed( Required for embedding models v3 and higher. Returns: - List[float]: Embedding. + Union[List[float], List[int], bytes]: + - If as_buffer=True: Returns a bytes object + - If as_buffer=False: + - For dtype="float32": Returns a list of floats + - For dtype="int8" or "uint8": Returns a list of integers Raises: TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") @@ -171,9 +186,34 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - embedding = self._client.embed( - texts=[text], model=self.model, input_type=input_type - ).embeddings[0] + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + + response = self._client.embed( + texts=[text], + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, + ) + + # Extract the appropriate embedding based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + embedding = getattr(response.embeddings, embed_type)[0] + else: + embedding = response.embeddings[0] # Fallback for older API versions + return self._process_embedding(embedding, as_buffer, dtype) @retry( @@ -189,7 +229,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[List[int]], List[bytes]]: """Embed many chunks of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -221,13 +261,17 @@ def embed_many( Required for embedding models v3 and higher. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[List[int]], List[bytes]]: + - If as_buffer=True: Returns a list of bytes objects + - If as_buffer=False: + - For dtype="float32": Returns a list of lists of floats + - For dtype="int8" or "uint8": Returns a list of lists of integers Raises: TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") @@ -241,14 +285,41 @@ def embed_many( dtype = kwargs.pop("dtype", self.dtype) + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, ) + + # Extract the appropriate embeddings based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + batch_embeddings = getattr(response.embeddings, embed_type) + else: + batch_embeddings = ( + response.embeddings + ) # Fallback for older API versions + embeddings += [ self._process_embedding(embedding, as_buffer, dtype) - for embedding in response.embeddings + for embedding in batch_embeddings ] return embeddings diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 4558d4d7..ed284d29 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic import PrivateAttr @@ -162,7 +162,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """ Generate an embedding for a single piece of text using your sync embed function. @@ -172,7 +172,7 @@ def embed( as_buffer (bool): If True, return the embedding as a byte buffer. Returns: - List[float]: The embedding of the input text. + Union[List[float], bytes]: The embedding of the input text. Raises: TypeError: If the input is not a string. @@ -200,7 +200,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Generate embeddings for multiple pieces of text in batches using your sync embed_many function. @@ -211,7 +211,7 @@ def embed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -226,7 +226,7 @@ def embed_many( raise NotImplementedError("No embed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): @@ -288,10 +288,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Asynchronously generate embeddings for multiple pieces of text in batches. @@ -302,7 +302,7 @@ async def aembed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -317,7 +317,7 @@ async def aembed_many( raise NotImplementedError("No aembed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index 8f81b85c..bafba41d 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic.v1 import PrivateAttr @@ -89,7 +89,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Hugging Face sentence transformer. Args: @@ -100,7 +100,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -121,10 +122,10 @@ def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the Hugging Face sentence transformer. @@ -138,7 +139,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index e930b3a4..05133b37 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -128,7 +128,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the Mistral API. Args: @@ -141,7 +141,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -155,7 +156,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(model=self.model, inputs=batch) + response = self._client.embeddings.create( + model=self.model, inputs=batch, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -174,7 +177,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Mistral API. Args: @@ -185,7 +188,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -198,7 +202,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(model=self.model, inputs=[text]) + result = self._client.embeddings.create( + model=self.model, inputs=[text], **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -211,7 +217,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> List[List[float]]: @@ -242,7 +248,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._client.embeddings.create_async( - model=self.model, inputs=batch + model=self.model, inputs=batch, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -287,7 +293,7 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) result = await self._client.embeddings.create_async( - model=self.model, inputs=[text] + model=self.model, inputs=[text], **kwargs ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 25b21c67..eee0764a 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -129,7 +129,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the OpenAI API. Args: @@ -142,7 +142,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -156,7 +157,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -175,7 +178,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the OpenAI API. Args: @@ -186,7 +189,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -199,7 +203,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -212,10 +218,10 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the OpenAI API. Args: @@ -228,7 +234,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -243,7 +250,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -263,7 +270,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -274,7 +281,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -287,7 +295,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 6d455c67..ebe2a625 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -141,8 +141,8 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed many chunks of texts using the VertexAI API. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the VertexAI Embeddings API. Args: texts (List[str]): List of text chunks to embed. @@ -154,7 +154,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -168,7 +169,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.get_embeddings(batch) + response = self._client.get_embeddings(batch, **kwargs) embeddings += [ self._process_embedding(r.values, as_buffer, dtype) for r in response ] @@ -186,8 +187,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using the VertexAI API. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the VertexAI Embeddings API. Args: text (str): Chunk of text to embed. @@ -197,7 +198,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -210,7 +212,7 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.get_embeddings([text]) + result = self._client.get_embeddings([text], **kwargs) return self._process_embedding(result[0].values, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index fbcbfd9e..9d015a81 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -124,7 +124,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -149,7 +149,8 @@ def embed( Check https://docs.voyageai.com/docs/embeddings Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If an invalid input_type is provided. @@ -171,7 +172,7 @@ def embed_many( batch_size: Optional[int] = None, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -198,14 +199,15 @@ def embed_many( Check https://docs.voyageai.com/docs/embeddings Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -235,7 +237,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -284,8 +286,8 @@ async def aembed_many( TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -315,7 +317,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -360,7 +362,6 @@ async def aembed( Raises: TypeError: In an invalid input_type is provided. """ - result = await self.aembed_many( texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs ) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index e1de4a46..36e444de 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from redisvl.utils.vectorize import ( @@ -287,7 +288,7 @@ def test_default_dtype(vectorizer_): VoyageAITextVectorizer, ], ) -def test_other_dtypes(vectorizer_): +def test_vectorizer_dtype_assignment(vectorizer_): # test initializing dtype in constructor for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: if issubclass(vectorizer_, CustomTextVectorizer): @@ -319,7 +320,7 @@ def test_other_dtypes(vectorizer_): VoyageAITextVectorizer, ], ) -def test_bad_dtypes(vectorizer_): +def test_non_supported_dtypes(vectorizer_): with pytest.raises(ValueError): vectorizer_(dtype="float25") @@ -392,3 +393,95 @@ async def test_avectorizer_bad_input(avectorizer): with pytest.raises(TypeError): avectorizer.embed_many(42) + + +@pytest.mark.requires_api_keys +@pytest.mark.parametrize( + "dtype,expected_type", + [ + ("float32", float), # Float dtype should return floats + ("int8", int), # Int8 dtype should return ints + ("uint8", int), # Uint8 dtype should return ints + ], +) +def test_cohere_dtype_support(dtype, expected_type): + """Test that CohereTextVectorizer properly handles different dtypes for embeddings.""" + text = "This is a test sentence." + texts = ["First test sentence.", "Second test sentence."] + + # Create vectorizer with specified dtype + vectorizer = CohereTextVectorizer(dtype=dtype) + + # Verify the correct mapping of dtype to Cohere embedding_types + if dtype == "int8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["int8"] + elif dtype == "uint8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["uint8"] + else: + # All other dtypes should map to float + assert vectorizer._get_cohere_embedding_type(dtype) == ["float"] + + # Test single embedding + embedding = vectorizer.embed(text, input_type="search_document") + assert isinstance(embedding, list) + assert len(embedding) == vectorizer.dims + + # Check that all elements are of the expected type + assert all( + isinstance(val, expected_type) for val in embedding + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" + + # Test multiple embeddings + embeddings = vectorizer.embed_many(texts, input_type="search_document") + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + assert all( + isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings + ) + + # Check that all elements in all embeddings are of the expected type + for emb in embeddings: + assert all( + isinstance(val, expected_type) for val in emb + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" + + # Test as_buffer output format + embedding_buffer = vectorizer.embed( + text, input_type="search_document", as_buffer=True + ) + assert isinstance(embedding_buffer, bytes) + + # Test embed_many with as_buffer=True + buffer_embeddings = vectorizer.embed_many( + texts, input_type="search_document", as_buffer=True + ) + assert all(isinstance(emb, bytes) for emb in buffer_embeddings) + + # Compare dimensions between buffer and list formats + assert len(np.frombuffer(embedding_buffer, dtype=dtype)) == len(embedding) + + +@pytest.mark.requires_api_keys +def test_cohere_embedding_types_warning(): + """Test that a warning is raised when embedding_types parameter is passed.""" + text = "This is a test sentence." + texts = ["First test sentence.", "Second test sentence."] + vectorizer = CohereTextVectorizer() + + # Test warning for single embedding + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embedding = vectorizer.embed( + text, + input_type="search_document", + embedding_types=["uint8"], # explicitly testing the anti-pattern here + ) + assert isinstance(embedding, list) + assert len(embedding) == vectorizer.dims + + # Test warning for multiple embeddings + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embeddings = vectorizer.embed_many( + texts, input_type="search_document", embedding_types=["uint8"] + ) + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) From 22a701ba85fd66ea542ce71d65dc794522b95602 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Mon, 17 Mar 2025 13:14:02 -0400 Subject: [PATCH 2/9] first working unit tests --- redisvl/query/filter.py | 394 ++++++++++++++++++++++++++++++++++++++ tests/unit/test_filter.py | 271 +++++++++++++++++++++++++- 2 files changed, 664 insertions(+), 1 deletion(-) diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 1e8987ff..47af27d8 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -1,3 +1,5 @@ +import datetime +import re from enum import Enum from functools import wraps from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -19,6 +21,7 @@ class FilterOperator(Enum): AND = 8 LIKE = 9 IN = 10 + BETWEEN = 11 class FilterField: @@ -562,3 +565,394 @@ def __str__(self) -> str: if not self._filter: raise ValueError("Improperly initialized FilterExpression") return self._filter + + +class Timestamp(FilterField): + """ + A timestamp filter for querying date/time fields in Redis. + + This filter can handle various date and time formats, including: + - datetime objects (with or without timezone) + - date objects + - ISO-8601 formatted strings + - Unix timestamps (as integers or floats) + + All timestamps are converted to Unix timestamps in UTC for consistency. + """ + + OPERATORS = { + FilterOperator.EQ: "==", + FilterOperator.NE: "!=", + FilterOperator.LT: "<", + FilterOperator.GT: ">", + FilterOperator.LE: "<=", + FilterOperator.GE: ">=", + FilterOperator.BETWEEN: "between", + } + + OPERATOR_MAP = { + FilterOperator.EQ: "@%s:[%s %s]", # For exact timestamp match (converted to range for dates) + FilterOperator.NE: "(-@%s:[%s %s])", + FilterOperator.GT: "@%s:[(%s +inf]", + FilterOperator.LT: "@%s:[-inf (%s]", + FilterOperator.GE: "@%s:[%s +inf]", + FilterOperator.LE: "@%s:[-inf %s]", + FilterOperator.BETWEEN: "@%s:[%s %s]", + } + + SUPPORTED_TYPES = ( + datetime.datetime, + datetime.date, + tuple, # Date range + str, # ISO format + int, # Unix timestamp + float, # Unix timestamp with fractional seconds + type(None), + ) + + def __init__(self, field: str): + """ + Initialize a timestamp filter for the specified field. + + Args: + field: The name of the field to filter on + """ + super().__init__(field) + self._start_value = None + self._end_value = None + + def __eq__(self, other): + """ + Filter for timestamps equal to the specified value. + For date objects (without time), this matches the entire day. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + # TODO make this a function + if ( + isinstance(other, datetime.date) + and not isinstance(other, datetime.datetime) + ) or (isinstance(other, str) and self._is_date_only(other)): + # For date objects, match the entire day + if isinstance(other, str): + other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + start = datetime.datetime.combine(other, datetime.time.min) + end = datetime.datetime.combine(other, datetime.time.max) + return self.between(start, end) + + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) + self._start_value = timestamp + self._end_value = timestamp + return FilterExpression(str(self)) + + def __ne__(self, other): + """ + Filter for timestamps not equal to the specified value. + For date objects (without time), this excludes the entire day. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + # TODO: not sure if we need to support not equal for dates + if ( + isinstance(other, datetime.date) + and not isinstance(other, datetime.datetime) + ) or (isinstance(other, str) and self._is_date_only(other)): + # For date objects, exclude the entire day + if isinstance(other, str): + other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + start = datetime.datetime.combine(other, datetime.time.min) + end = datetime.datetime.combine(other, datetime.time.max) + start_ts = self._convert_to_timestamp(start) + end_ts = self._convert_to_timestamp(end) + self._set_value((start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.NE) + self._start_value = start_ts + self._end_value = end_ts + return FilterExpression(str(self)) + + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.NE) + self._start_value = timestamp + self._end_value = timestamp + return FilterExpression(str(self)) + + def __gt__(self, other): + """ + Filter for timestamps greater than the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GT) + return FilterExpression(str(self)) + + def __lt__(self, other): + """ + Filter for timestamps less than the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LT) + return self + + def __ge__(self, other): + """ + Filter for timestamps greater than or equal to the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.GE) + return FilterExpression(str(self)) + + def __le__(self, other): + """ + Filter for timestamps less than or equal to the specified value. + + Args: + other: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + timestamp = self._convert_to_timestamp(other) + self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LE) + return FilterExpression(str(self)) + + def between(self, start, end): + """ + Filter for timestamps between start and end (inclusive). + + Args: + start: A datetime, date, ISO string, or Unix timestamp + end: A datetime, date, ISO string, or Unix timestamp + + Returns: + self: The filter object for method chaining + """ + start_ts = self._convert_to_timestamp(start) + end_ts = self._convert_to_timestamp(end) + + # Handle date objects by expanding to full day + # TODO: confirm if this is checked twice. Seems we do the date max thing twice + if isinstance(start, datetime.date) and not isinstance( + start, datetime.datetime + ): + start_ts = self._convert_to_timestamp( + datetime.datetime.combine(start, datetime.time.min) + ) + + if isinstance(end, datetime.date) and not isinstance(end, datetime.datetime): + end_ts = self._convert_to_timestamp( + datetime.datetime.combine(end, datetime.time.max) + ) + + self._set_value( + (start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.BETWEEN + ) + self._start_value = start_ts + self._end_value = end_ts + return FilterExpression(str(self)) + + def day_of(self, date_value): + """ + Match exactly on the specified date (entire day, from 00:00:00 to 23:59:59). + + Args: + date_value: A date or datetime object, or a string date representation + + Returns: + self: The filter object for method chaining + """ + if isinstance(date_value, str): + # Try to parse the string as a date + date_value = datetime.datetime.fromisoformat(date_value).date() + elif isinstance(date_value, datetime.datetime): + date_value = date_value.date() + + start = datetime.datetime.combine(date_value, datetime.time.min) + end = datetime.datetime.combine(date_value, datetime.time.max) + return self.between(start, end) + + def week_of(self, date_value): + """ + Match the week containing the specified date. + Weeks start on Monday (ISO week). + + Args: + date_value: A date or datetime object, or a string date representation + + Returns: + self: The filter object for method chaining + """ + if isinstance(date_value, str): + date_value = datetime.datetime.fromisoformat(date_value).date() + elif isinstance(date_value, datetime.datetime): + date_value = date_value.date() + + # Calculate the Monday of the week + start_date = date_value - datetime.timedelta(days=date_value.weekday()) + # Calculate the Sunday of the week + end_date = start_date + datetime.timedelta(days=6) + + start = datetime.datetime.combine(start_date, datetime.time.min) + end = datetime.datetime.combine(end_date, datetime.time.max) + return self.between(start, end) + + def month_of(self, year, month): + """ + Match the specified month. + + Args: + year: The year as an integer + month: The month as an integer (1-12) + + Returns: + self: The filter object for method chaining + """ + if not (1 <= month <= 12): + raise ValueError("Month must be between 1 and 12") + + start_date = datetime.date(year, month, 1) + + # Calculate the last day of the month + if month == 12: + end_date = datetime.date(year + 1, 1, 1) - datetime.timedelta(days=1) + else: + end_date = datetime.date(year, month + 1, 1) - datetime.timedelta(days=1) + + start = datetime.datetime.combine(start_date, datetime.time.min) + end = datetime.datetime.combine(end_date, datetime.time.max) + return self.between(start, end) + + def year_of(self, year): + """ + Match the specified year. + + Args: + year: The year as an integer + + Returns: + self: The filter object for method chaining + """ + start = datetime.datetime(year, 1, 1, 0, 0, 0) + end = datetime.datetime(year, 12, 31, 23, 59, 59) + return self.between(start, end) + + def last_days(self, days): + """ + Match timestamps from the last N days up to now. + + Args: + days: Number of days to look back + + Returns: + self: The filter object for method chaining + """ + end = datetime.datetime.now(datetime.timezone.utc) + start = end - datetime.timedelta(days=days) + return self.between(start, end) + + @staticmethod + def _is_date_only(iso_string: str) -> bool: + """Check if an ISO formatted string only includes date information using regex.""" + # Match YYYY-MM-DD format exactly + date_pattern = r"^\d{4}-\d{2}-\d{2}$" + return bool(re.match(date_pattern, iso_string)) + + def _convert_to_timestamp(self, value): + """ + Convert various inputs to a Unix timestamp (seconds since epoch in UTC). + + Args: + value: A datetime, date, string, int, or float + + Returns: + float: Unix timestamp + """ + if value is None: + return None + + if isinstance(value, (int, float)): + # Already a Unix timestamp + return float(value) + + if isinstance(value, str): + # Parse ISO format + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + raise ValueError(f"String timestamp must be in ISO format: {value}") + + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + # Convert date to datetime at midnight + value = datetime.datetime.combine(value, datetime.time.min) + + # Ensure the datetime is timezone-aware (UTC) + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + else: + value = value.astimezone(datetime.timezone.utc) + + # Convert to Unix timestamp + return value.timestamp() + + raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") + + def __str__(self): + """Generate the Redis query string for this filter.""" + if self._value is None: + return "*" + if ( + self._operator == FilterOperator.BETWEEN + or self._operator == FilterOperator.EQ + or ( + self._operator == FilterOperator.NE + and self._start_value == self._end_value + ) + ): + # For between and exact matches with range + return self.OPERATOR_MAP[self._operator] % ( + self.escaper.escape(self._field), + self._start_value, + self._end_value, + ) + elif ( + self._operator == FilterOperator.NE and self._start_value != self._end_value + ): + # For not equal with date range + return self.OPERATOR_MAP[self._operator] % ( + self.escaper.escape(self._field), + self._start_value, + self._end_value, + ) + else: + # For other operators + return self.OPERATOR_MAP[self._operator] % ( + self.escaper.escape(self._field), + self._value, + ) diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 067402ea..56a52c07 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -1,6 +1,8 @@ +from datetime import date, datetime, timezone + import pytest -from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text +from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text, Timestamp # Test cases for various scenarios of tag usage, combinations, and their string representations. @@ -292,3 +294,270 @@ def test_num_filter_zero(): assert ( str(num_filter) == "@chunk_number:[0 0]" ), "Num filter should handle zero correctly" + + +from datetime import date, datetime, timedelta, timezone + +import pytest + +from redisvl.query.filter import Timestamp + + +def test_timestamp_datetime(): + """Test Timestamp filter with datetime objects.""" + # Test with timezone-aware datetime + dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc) + ts = Timestamp("created_at") == dt + # Expected timestamp would be the Unix timestamp for the datetime + expected_ts = dt.timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + # Test with timezone-naive datetime (should convert to UTC) + dt = datetime(2023, 3, 17, 14, 30, 0) + ts = Timestamp("created_at") == dt + expected_ts = dt.replace(tzinfo=timezone.utc).timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + +def test_timestamp_date(): + """Test Timestamp filter with date objects (should match full day).""" + d = date(2023, 3, 17) + ts = Timestamp("created_at") == d + + # Expected start is midnight UTC + start_dt = datetime(2023, 3, 17, 0, 0, 0, tzinfo=timezone.utc) + # Expected end is end of day UTC + end_dt = datetime(2023, 3, 17, 23, 59, 59, 999999, tzinfo=timezone.utc) + + expected_start_ts = start_dt.timestamp() + expected_end_ts = end_dt.timestamp() + + # The filter should create a range query for the entire day + assert str(ts).startswith(f"@created_at:[") + # We can't easily test the exact values due to potential timezone issues + # so we'll check that the values are within the expected day + + # Alternative approach: use the day_of method directly + ts2 = Timestamp("created_at").day_of(d) + assert str(ts) == str(ts2) + + +def test_timestamp_iso_string(): + """Test Timestamp filter with ISO format strings.""" + # Date-only ISO string + ts = Timestamp("created_at") == "2023-03-17" + d = date(2023, 3, 17) + expected_ts = Timestamp("created_at").day_of(d) + assert str(ts) == str(expected_ts) + + # Full ISO datetime string + dt_str = "2023-03-17T14:30:00+00:00" + ts = Timestamp("created_at") == dt_str + dt = datetime.fromisoformat(dt_str) + expected_ts = dt.timestamp() + assert str(ts) == f"@created_at:[{expected_ts} {expected_ts}]" + + +def test_timestamp_unix(): + """Test Timestamp filter with Unix timestamps.""" + # Integer timestamp + ts = Timestamp("created_at") == 1679062200 # 2023-03-17T14:30:00+00:00 + assert str(ts) == "@created_at:[1679062200.0 1679062200.0]" + + # Float timestamp + ts = Timestamp("created_at") == 1679062200.5 + assert str(ts) == "@created_at:[1679062200.5 1679062200.5]" + + +def test_timestamp_operators(): + """Test all comparison operators for Timestamp filter.""" + dt = datetime(2023, 3, 17, 14, 30, 0, tzinfo=timezone.utc) + ts_value = dt.timestamp() + + # Equal + ts = Timestamp("created_at") == dt + assert str(ts) == f"@created_at:[{ts_value} {ts_value}]" + + # Not equal + ts = Timestamp("created_at") != dt + assert str(ts) == f"(-@created_at:[{ts_value} {ts_value}])" + + # Greater than + ts = Timestamp("created_at") > dt + assert str(ts) == f"@created_at:[({ts_value} +inf]" + + # Less than + ts = Timestamp("created_at") < dt + assert str(ts) == f"@created_at:[-inf ({ts_value}]" + + # Greater than or equal + ts = Timestamp("created_at") >= dt + assert str(ts) == f"@created_at:[{ts_value} +inf]" + + # Less than or equal + ts = Timestamp("created_at") <= dt + assert str(ts) == f"@created_at:[-inf {ts_value}]" + + +def test_timestamp_between(): + """Test the between method for date ranges.""" + start = datetime(2023, 3, 1, 0, 0, 0, tzinfo=timezone.utc) + end = datetime(2023, 3, 31, 23, 59, 59, tzinfo=timezone.utc) + + ts = Timestamp("created_at").between(start, end) + + start_ts = start.timestamp() + end_ts = end.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + # Test with dates (should expand to full days) + start_date = date(2023, 3, 1) + end_date = date(2023, 3, 31) + + ts = Timestamp("created_at").between(start_date, end_date) + + # Start should be beginning of day + expected_start = datetime.combine(start_date, datetime.min.time()) + expected_start = expected_start.replace(tzinfo=timezone.utc) + + # End should be end of day + expected_end = datetime.combine(end_date, datetime.max.time()) + expected_end = expected_end.replace(tzinfo=timezone.utc) + + expected_start_ts = expected_start.timestamp() + expected_end_ts = expected_end.timestamp() + + assert str(ts) == f"@created_at:[{expected_start_ts} {expected_end_ts}]" + + +def test_timestamp_day_of(): + """Test the day_of helper method.""" + d = date(2023, 3, 17) + ts = Timestamp("created_at").day_of(d) + + # Expected start is midnight UTC + start_dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc) + # Expected end is end of day UTC + end_dt = datetime.combine(d, datetime.max.time()).replace(tzinfo=timezone.utc) + + start_ts = start_dt.timestamp() + end_ts = end_dt.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + # Test with string date + ts = Timestamp("created_at").day_of("2023-03-17") + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + +def test_timestamp_week_of(): + """Test the week_of helper method.""" + # March 17, 2023 was a Friday + d = date(2023, 3, 17) + ts = Timestamp("created_at").week_of(d) + + # Monday of that week is March 13 + monday = date(2023, 3, 13) + # Sunday of that week is March 19 + sunday = date(2023, 3, 19) + + start_dt = datetime.combine(monday, datetime.min.time()).replace( + tzinfo=timezone.utc + ) + end_dt = datetime.combine(sunday, datetime.max.time()).replace(tzinfo=timezone.utc) + + start_ts = start_dt.timestamp() + end_ts = end_dt.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + +def test_timestamp_month_of(): + """Test the month_of helper method.""" + ts = Timestamp("created_at").month_of(2023, 3) + + # First day of March + start_date = date(2023, 3, 1) + # Last day of March + end_date = date(2023, 3, 31) + + start_dt = datetime.combine(start_date, datetime.min.time()).replace( + tzinfo=timezone.utc + ) + end_dt = datetime.combine(end_date, datetime.max.time()).replace( + tzinfo=timezone.utc + ) + + start_ts = start_dt.timestamp() + end_ts = end_dt.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + # Test with invalid month + with pytest.raises(ValueError): + Timestamp("created_at").month_of(2023, 13) + + +def test_timestamp_year_of(): + """Test the year_of helper method.""" + ts = Timestamp("created_at").year_of(2023) + + start_dt = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + end_dt = datetime(2023, 12, 31, 23, 59, 59, tzinfo=timezone.utc) + + start_ts = start_dt.timestamp() + end_ts = end_dt.timestamp() + + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + + +def test_timestamp_last_days(): + """Test the last_days helper method.""" + ts = Timestamp("created_at").last_days(7) + + # This test is tricky because it depends on the current time + # We'll just verify that it generates a valid filter string + assert "@created_at:[" in str(ts) + + # We can mock datetime.now for more precise testing in a real test suite + # but for simplicity, we'll just check the format here + + +def test_timestamp_none(): + """Test handling of None values.""" + ts = Timestamp("created_at") == None + assert str(ts) == "*" + + ts = Timestamp("created_at") != None + assert str(ts) == "*" + + ts = Timestamp("created_at") > None + assert str(ts) == "*" + + +def test_timestamp_invalid_input(): + """Test error handling for invalid inputs.""" + # Invalid ISO format + with pytest.raises(ValueError): + Timestamp("created_at") == "not-a-date" + + # Unsupported type + with pytest.raises(TypeError): + Timestamp("created_at") == object() + + +def test_timestamp_filter_combination(): + """Test combining timestamp filters with other filters.""" + from redisvl.query.filter import Num, Tag + + ts = Timestamp("created_at") > datetime(2023, 3, 1) + num = Num("age") > 30 + tag = Tag("status") == "active" + + combined = ts & num & tag + + # The exact string depends on the timestamp value, but we can check structure + assert str(combined).startswith("((@created_at:") + assert "@age:[(30 +inf]" in str(combined) + assert "@status:{active}" in str(combined) From a60a8b0e33488df168b144abf80c624cc254c6ce Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Mon, 17 Mar 2025 15:39:08 -0400 Subject: [PATCH 3/9] add integration tests --- docs/user_guide/02_hybrid_queries.ipynb | 200 ++++++++++++++++++++---- docs/user_guide/hybrid_example_data.pkl | Bin 494 -> 556 bytes redisvl/query/filter.py | 4 +- tests/conftest.py | 19 ++- tests/integration/test_query.py | 70 ++++++++- 5 files changed, 253 insertions(+), 40 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 9568669d..90b9d78c 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -16,13 +16,13 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
useragejobcredit_scoreoffice_locationuser_embedding
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='
" + "
useragejobcredit_scoreoffice_locationuser_embeddinglast_updated
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1741627789
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1741627789
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1710696589
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'1742232589
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'1739644189
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'1742232589
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='1742232589
" ], "text/plain": [ "" @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -58,6 +58,7 @@ " {\"name\": \"credit_score\", \"type\": \"tag\"},\n", " {\"name\": \"job\", \"type\": \"text\"},\n", " {\"name\": \"age\", \"type\": \"numeric\"},\n", + " {\"name\": \"last_updated\", \"type\": \"numeric\"},\n", " {\"name\": \"office_location\", \"type\": \"geo\"},\n", " {\n", " \"name\": \"user_embedding\",\n", @@ -76,14 +77,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "13:02:18 redisvl.index.index INFO Index already exists, overwriting.\n" + "13:53:37 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -99,25 +100,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float16_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float32_session\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. float32_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. bfloat_cache\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. user_queries\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. student tutor\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 10. tutor\n", - "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 11. bfloat_session\n" + "\u001b[32m13:53:42\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m13:53:42\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -128,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -136,6 +127,26 @@ "keys = index.load(data)" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.info()['num_docs']" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -157,13 +168,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" ], "text/plain": [ "" @@ -182,7 +193,7 @@ "v = VectorQuery(\n", " vector=[0.1, 0.1, 0.5],\n", " vector_field_name=\"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\", \"last_updated\"],\n", " filter_expression=t\n", ")\n", "\n", @@ -316,13 +327,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" ], "text/plain": [ "" @@ -393,6 +404,132 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Timestamp Filters\n", + "\n", + "In redis all times are stored as an epoch time numeric however, this class allows you to filter with python datetime for ease of use. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch comparison: 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "print(f'Epoch comparison: {dt.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\") > dt\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch comparison: 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0johnhigh18engineer-122.4194,37.77491741627789
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "\n", + "print(f'Epoch comparison: {dt.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\") < dt\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch between: 1736880339.132589 - 1742147139.132589\n" + ] + }, + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0johnhigh18engineer-122.4194,37.77491741627789
0.158808946609timhigh12dermatologist-122.0839,37.38611739644189
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query.filter import Timestamp\n", + "from datetime import datetime\n", + "\n", + "dt_1 = datetime(2025, 1, 14, 13, 45, 39, 132589)\n", + "dt_2 = datetime(2025, 3, 16, 13, 45, 39, 132589)\n", + "\n", + "print(f'Epoch between: {dt_1.timestamp()} - {dt_2.timestamp()}')\n", + "\n", + "timestamp_filter = Timestamp(\"last_updated\").between(dt_1, dt_2)\n", + "\n", + "v.set_filter(timestamp_filter)\n", + "result_print(index.query(v))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -404,13 +541,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
" ], "text/plain": [ "" @@ -771,13 +908,13 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -791,8 +928,9 @@ "t = Tag(\"credit_score\") == \"high\"\n", "low = Num(\"age\") >= 18\n", "high = Num(\"age\") <= 100\n", + "ts = Timestamp(\"last_updated\") > datetime(2025, 3, 16, 13, 45, 39, 132589)\n", "\n", - "combined = t & low & high\n", + "combined = t & low & high & ts\n", "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", @@ -814,13 +952,13 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" diff --git a/docs/user_guide/hybrid_example_data.pkl b/docs/user_guide/hybrid_example_data.pkl index b5928b917e3263b461fe62779d384c1ea63d0b61..2c8a92daf5da8a232b3c515428c4254514ce79e3 100644 GIT binary patch delta 190 zcmaFIyoQCffpw}P(?r&Oksh9$#Nv|p(t?!4lGKzbUcHLv(Um>sq)&EW)RdF}DFiCA17Vm# v*}EIkCs#Ad$qE9E0;`xcQxe4}u>SRoS~6hSw6rvP5C&NWvIrzwnyLo?GfY7{ delta 123 zcmZ3(@{XCcfo1Ar#)+)`6Ia-XmDWzt$YAPW%}dNnuAGv=;vJX4n!!Jrkx_E8BBSHv zOh!#1@eFQfp0j7p*nuzu1B3mP$*URV!~`-pGr&@_W=h%vrEDQm<5t!NO^2Y4#ut KlA2PQss{k)1Swbm diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 47af27d8..9e8d7b95 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -597,7 +597,7 @@ class Timestamp(FilterField): FilterOperator.LT: "@%s:[-inf (%s]", FilterOperator.GE: "@%s:[%s +inf]", FilterOperator.LE: "@%s:[-inf %s]", - FilterOperator.BETWEEN: "@%s:[%s %s]", + FilterOperator.BETWEEN: "@%s:[(%s (%s]", # should between be inclusive? } SUPPORTED_TYPES = ( @@ -710,7 +710,7 @@ def __lt__(self, other): """ timestamp = self._convert_to_timestamp(other) self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LT) - return self + return FilterExpression(str(self)) def __ge__(self, other): """ diff --git a/tests/conftest.py b/tests/conftest.py index 61c5de45..24da05e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from datetime import datetime, timezone import pytest from testcontainers.compose import DockerCompose @@ -68,12 +69,22 @@ def client(redis_url): @pytest.fixture -def sample_data(): +def sample_datetimes(): + return { + "low": datetime(2025, 1, 16, 13).astimezone(timezone.utc), + "mid": datetime(2025, 2, 16, 13).astimezone(timezone.utc), + "high": datetime(2025, 3, 16, 13).astimezone(timezone.utc), + } + + +@pytest.fixture +def sample_data(sample_datetimes): return [ { "user": "john", "age": 18, "job": "engineer", + "last_updated": sample_datetimes["low"].timestamp(), "credit_score": "high", "location": "-122.4194,37.7749", "user_embedding": [0.1, 0.1, 0.5], @@ -82,6 +93,7 @@ def sample_data(): "user": "mary", "age": 14, "job": "doctor", + "last_updated": sample_datetimes["low"].timestamp(), "credit_score": "low", "location": "-122.4194,37.7749", "user_embedding": [0.1, 0.1, 0.5], @@ -90,6 +102,7 @@ def sample_data(): "user": "nancy", "age": 94, "job": "doctor", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-122.4194,37.7749", "user_embedding": [0.7, 0.1, 0.5], @@ -98,6 +111,7 @@ def sample_data(): "user": "tyler", "age": 100, "job": "engineer", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-110.0839,37.3861", "user_embedding": [0.1, 0.4, 0.5], @@ -106,6 +120,7 @@ def sample_data(): "user": "tim", "age": 12, "job": "dermatologist", + "last_updated": sample_datetimes["mid"].timestamp(), "credit_score": "high", "location": "-110.0839,37.3861", "user_embedding": [0.4, 0.4, 0.5], @@ -114,6 +129,7 @@ def sample_data(): "user": "taimur", "age": 15, "job": "CEO", + "last_updated": sample_datetimes["high"].timestamp(), "credit_score": "low", "location": "-110.0839,37.3861", "user_embedding": [0.6, 0.1, 0.5], @@ -122,6 +138,7 @@ def sample_data(): "user": "joe", "age": 35, "job": "dentist", + "last_updated": sample_datetimes["high"].timestamp(), "credit_score": "medium", "location": "-110.0839,37.3861", "user_embedding": [0.9, 0.9, 0.1], diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 271d36da..cbdb7634 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -3,7 +3,15 @@ from redisvl.index import SearchIndex from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery -from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text +from redisvl.query.filter import ( + FilterExpression, + Geo, + GeoRadius, + Num, + Tag, + Text, + Timestamp, +) from redisvl.redis.utils import array_to_buffer # TODO expand to multiple schema types and sync + async @@ -14,7 +22,14 @@ def vector_query(): return VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], ) @@ -23,7 +38,14 @@ def sorted_vector_query(): return VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], sort_by="age", ) @@ -31,7 +53,14 @@ def sorted_vector_query(): @pytest.fixture def filter_query(): return FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], filter_expression=Tag("credit_score") == "high", ) @@ -39,7 +68,14 @@ def filter_query(): @pytest.fixture def sorted_filter_query(): return FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], filter_expression=Tag("credit_score") == "high", sort_by="age", ) @@ -80,6 +116,7 @@ def index(sample_data, redis_url): {"name": "credit_score", "type": "tag"}, {"name": "job", "type": "text"}, {"name": "age", "type": "numeric"}, + {"name": "last_updated", "type": "numeric"}, {"name": "location", "type": "geo"}, { "name": "user_embedding", @@ -255,7 +292,7 @@ def query(request): return request.getfixturevalue(request.param) -def test_filters(index, query): +def test_filters(index, query, sample_datetimes): # Simple Tag Filter t = Tag("credit_score") == "high" search(query, index, t, 4, credit_check="high") @@ -310,6 +347,27 @@ def test_filters(index, query): t = Text("job") % "" search(query, index, t, 7) + # Timestamps + t = Timestamp("last_updated") > sample_datetimes["mid"] + search(query, index, t, 2) + + t = Timestamp("last_updated") >= sample_datetimes["mid"] + search(query, index, t, 5) + + t = Timestamp("last_updated") < sample_datetimes["high"] + search(query, index, t, 5) + + t = Timestamp("last_updated") <= sample_datetimes["mid"] + search(query, index, t, 5) + + t = Timestamp("last_updated") == sample_datetimes["mid"] + search(query, index, t, 3) + + t = Timestamp("last_updated").between( + sample_datetimes["low"], sample_datetimes["high"] + ) + search(query, index, t, 3) + def test_manual_string_filters(index, query): # Simple Tag Filter From 8b020fc66af1d5899c4a56892d62220b2ff17eb4 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 18 Mar 2025 09:18:48 -0400 Subject: [PATCH 4/9] test suite running --- tests/unit/test_filter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 56a52c07..1992d3af 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -409,7 +409,7 @@ def test_timestamp_between(): start_ts = start.timestamp() end_ts = end.timestamp() - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" # Test with dates (should expand to full days) start_date = date(2023, 3, 1) @@ -428,7 +428,7 @@ def test_timestamp_between(): expected_start_ts = expected_start.timestamp() expected_end_ts = expected_end.timestamp() - assert str(ts) == f"@created_at:[{expected_start_ts} {expected_end_ts}]" + assert str(ts) == f"@created_at:[({expected_start_ts} ({expected_end_ts}]" def test_timestamp_day_of(): @@ -444,11 +444,11 @@ def test_timestamp_day_of(): start_ts = start_dt.timestamp() end_ts = end_dt.timestamp() - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" # Test with string date ts = Timestamp("created_at").day_of("2023-03-17") - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" def test_timestamp_week_of(): @@ -470,7 +470,7 @@ def test_timestamp_week_of(): start_ts = start_dt.timestamp() end_ts = end_dt.timestamp() - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" def test_timestamp_month_of(): @@ -492,7 +492,7 @@ def test_timestamp_month_of(): start_ts = start_dt.timestamp() end_ts = end_dt.timestamp() - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" # Test with invalid month with pytest.raises(ValueError): @@ -509,7 +509,7 @@ def test_timestamp_year_of(): start_ts = start_dt.timestamp() end_ts = end_dt.timestamp() - assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" + assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" def test_timestamp_last_days(): From 1ac4894f1fa0b005647933d52ca06b371e4ec978 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 18 Mar 2025 12:37:44 -0400 Subject: [PATCH 5/9] refactor timestamp to extend num --- docs/user_guide/02_hybrid_queries.ipynb | 36 ++- redisvl/query/filter.py | 372 ++++++++---------------- tests/integration/test_query.py | 35 ++- tests/unit/test_filter.py | 126 +------- 4 files changed, 181 insertions(+), 388 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 90b9d78c..00868b0d 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -77,14 +77,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "13:53:37 redisvl.index.index INFO Index already exists, overwriting.\n" + "11:40:25 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -100,15 +100,23 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m13:53:42\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m13:53:42\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_session\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_cache\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float32_session\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float16_session\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. bfloat_session\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. float32_cache\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. bfloat_cache\n", + "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. user_queries\n" ] } ], @@ -119,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -168,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -203,13 +211,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0derricklow14doctor-122.4194,37.77491741627789
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" ], "text/plain": [ "" @@ -327,13 +335,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.109129190445tylerhigh100engineer-122.0839,37.38611742232589
0.266666650772nancyhigh94doctor-122.4194,37.77491710696589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" + "
vector_distanceusercredit_scoreagejoboffice_locationlast_updated
0johnhigh18engineer-122.4194,37.77491741627789
0.217882037163taimurlow15CEO-122.0839,37.38611742232589
0.653301358223joemedium35dentist-122.0839,37.38611742232589
" ], "text/plain": [ "" @@ -346,7 +354,7 @@ "source": [ "from redisvl.query.filter import Num\n", "\n", - "numeric_filter = Num(\"age\") > 15\n", + "numeric_filter = Num(\"age\").between(15, 35)\n", "\n", "v.set_filter(numeric_filter)\n", "result_print(index.query(v))" diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 9e8d7b95..68cb6978 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -270,6 +270,7 @@ class Num(FilterField): FilterOperator.GT: ">", FilterOperator.LE: "<=", FilterOperator.GE: ">=", + FilterOperator.BETWEEN: "between", } OPERATOR_MAP: Dict[FilterOperator, str] = { FilterOperator.EQ: "@%s:[%s %s]", @@ -278,8 +279,10 @@ class Num(FilterField): FilterOperator.LT: "@%s:[-inf (%s]", FilterOperator.GE: "@%s:[%s +inf]", FilterOperator.LE: "@%s:[-inf %s]", + FilterOperator.BETWEEN: "@%s:[%s %s]", } - SUPPORTED_VAL_TYPES = (int, float, type(None)) + + SUPPORTED_VAL_TYPES = (int, float, tuple, type(None)) def __eq__(self, other: int) -> "FilterExpression": """Create a Numeric equality filter expression. @@ -376,10 +379,31 @@ def __le__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LE) return FilterExpression(str(self)) + def between(self, start: int, end: int) -> "FilterExpression": + """Create a Numeric equality filter expression. + + Args: + other (int): The value to filter on. + + .. code-block:: python + + from redisvl.query.filter import Num + f = Num("zipcode") == 90210 + + """ + self._set_value((start, end), self.SUPPORTED_VAL_TYPES, FilterOperator.BETWEEN) + return FilterExpression(str(self)) + def __str__(self) -> str: """Return the Redis Query string for the Numeric filter""" if self._value is None: return "*" + if self._operator == FilterOperator.BETWEEN: + return self.OPERATOR_MAP[self._operator] % ( + self._field, + self._value[0], + self._value[1], + ) if self._operator == FilterOperator.EQ or self._operator == FilterOperator.NE: return self.OPERATOR_MAP[self._operator] % ( self._field, @@ -567,7 +591,7 @@ def __str__(self) -> str: return self._filter -class Timestamp(FilterField): +class Timestamp(Num): """ A timestamp filter for querying date/time fields in Redis. @@ -580,26 +604,6 @@ class Timestamp(FilterField): All timestamps are converted to Unix timestamps in UTC for consistency. """ - OPERATORS = { - FilterOperator.EQ: "==", - FilterOperator.NE: "!=", - FilterOperator.LT: "<", - FilterOperator.GT: ">", - FilterOperator.LE: "<=", - FilterOperator.GE: ">=", - FilterOperator.BETWEEN: "between", - } - - OPERATOR_MAP = { - FilterOperator.EQ: "@%s:[%s %s]", # For exact timestamp match (converted to range for dates) - FilterOperator.NE: "(-@%s:[%s %s])", - FilterOperator.GT: "@%s:[(%s +inf]", - FilterOperator.LT: "@%s:[-inf (%s]", - FilterOperator.GE: "@%s:[%s +inf]", - FilterOperator.LE: "@%s:[-inf %s]", - FilterOperator.BETWEEN: "@%s:[(%s (%s]", # should between be inclusive? - } - SUPPORTED_TYPES = ( datetime.datetime, datetime.date, @@ -610,16 +614,65 @@ class Timestamp(FilterField): type(None), ) - def __init__(self, field: str): + @staticmethod + def _is_date(value: Any) -> bool: + """Check if the value is a date object. Either ISO string or datetime.date.""" + return ( + isinstance(value, datetime.date) + and not isinstance(value, datetime.datetime) + ) or (isinstance(value, str) and Timestamp._is_date_only(value)) + + @staticmethod + def _is_date_only(iso_string: str) -> bool: + """Check if an ISO formatted string only includes date information using regex.""" + # Match YYYY-MM-DD format exactly + date_pattern = r"^\d{4}-\d{2}-\d{2}$" + return bool(re.match(date_pattern, iso_string)) + + def _convert_to_timestamp(self, value, end_date=False): """ - Initialize a timestamp filter for the specified field. + Convert various inputs to a Unix timestamp (seconds since epoch in UTC). Args: - field: The name of the field to filter on + value: A datetime, date, string, int, or float + + Returns: + float: Unix timestamp """ - super().__init__(field) - self._start_value = None - self._end_value = None + if value is None: + return None + + if isinstance(value, (int, float)): + # Already a Unix timestamp + return float(value) + + if isinstance(value, str): + # Parse ISO format + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + raise ValueError(f"String timestamp must be in ISO format: {value}") + + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + # Convert to max or min if for dates based on end or not + if end_date: + value = datetime.datetime.combine(value, datetime.time.max) + else: + value = datetime.datetime.combine(value, datetime.time.min) + + # Ensure the datetime is timezone-aware (UTC) + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + else: + value = value.astimezone(datetime.timezone.utc) + + # Convert to Unix timestamp + return value.timestamp() + + raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") def __eq__(self, other): """ @@ -632,22 +685,20 @@ def __eq__(self, other): Returns: self: The filter object for method chaining """ - # TODO make this a function - if ( - isinstance(other, datetime.date) - and not isinstance(other, datetime.datetime) - ) or (isinstance(other, str) and self._is_date_only(other)): + if self._is_date(other): # For date objects, match the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() - start = datetime.datetime.combine(other, datetime.time.min) - end = datetime.datetime.combine(other, datetime.time.max) + start = datetime.datetime.combine(other, datetime.time.min).astimezone( + datetime.timezone.utc + ) + end = datetime.datetime.combine(other, datetime.time.max).astimezone( + datetime.timezone.utc + ) return self.between(start, end) timestamp = self._convert_to_timestamp(other) self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) - self._start_value = timestamp - self._end_value = timestamp return FilterExpression(str(self)) def __ne__(self, other): @@ -661,27 +712,16 @@ def __ne__(self, other): Returns: self: The filter object for method chaining """ - # TODO: not sure if we need to support not equal for dates - if ( - isinstance(other, datetime.date) - and not isinstance(other, datetime.datetime) - ) or (isinstance(other, str) and self._is_date_only(other)): + if self._is_date(other): # For date objects, exclude the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() start = datetime.datetime.combine(other, datetime.time.min) end = datetime.datetime.combine(other, datetime.time.max) - start_ts = self._convert_to_timestamp(start) - end_ts = self._convert_to_timestamp(end) - self._set_value((start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.NE) - self._start_value = start_ts - self._end_value = end_ts - return FilterExpression(str(self)) + return self.between(start, end) timestamp = self._convert_to_timestamp(other) self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.NE) - self._start_value = timestamp - self._end_value = timestamp return FilterExpression(str(self)) def __gt__(self, other): @@ -752,207 +792,43 @@ def between(self, start, end): self: The filter object for method chaining """ start_ts = self._convert_to_timestamp(start) - end_ts = self._convert_to_timestamp(end) - - # Handle date objects by expanding to full day - # TODO: confirm if this is checked twice. Seems we do the date max thing twice - if isinstance(start, datetime.date) and not isinstance( - start, datetime.datetime - ): - start_ts = self._convert_to_timestamp( - datetime.datetime.combine(start, datetime.time.min) - ) - - if isinstance(end, datetime.date) and not isinstance(end, datetime.datetime): - end_ts = self._convert_to_timestamp( - datetime.datetime.combine(end, datetime.time.max) - ) + end_ts = self._convert_to_timestamp(end, end_date=True) self._set_value( (start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.BETWEEN ) - self._start_value = start_ts - self._end_value = end_ts return FilterExpression(str(self)) - def day_of(self, date_value): - """ - Match exactly on the specified date (entire day, from 00:00:00 to 23:59:59). - - Args: - date_value: A date or datetime object, or a string date representation - - Returns: - self: The filter object for method chaining - """ - if isinstance(date_value, str): - # Try to parse the string as a date - date_value = datetime.datetime.fromisoformat(date_value).date() - elif isinstance(date_value, datetime.datetime): - date_value = date_value.date() - - start = datetime.datetime.combine(date_value, datetime.time.min) - end = datetime.datetime.combine(date_value, datetime.time.max) - return self.between(start, end) - - def week_of(self, date_value): - """ - Match the week containing the specified date. - Weeks start on Monday (ISO week). - - Args: - date_value: A date or datetime object, or a string date representation - - Returns: - self: The filter object for method chaining - """ - if isinstance(date_value, str): - date_value = datetime.datetime.fromisoformat(date_value).date() - elif isinstance(date_value, datetime.datetime): - date_value = date_value.date() - - # Calculate the Monday of the week - start_date = date_value - datetime.timedelta(days=date_value.weekday()) - # Calculate the Sunday of the week - end_date = start_date + datetime.timedelta(days=6) - - start = datetime.datetime.combine(start_date, datetime.time.min) - end = datetime.datetime.combine(end_date, datetime.time.max) - return self.between(start, end) - - def month_of(self, year, month): - """ - Match the specified month. - - Args: - year: The year as an integer - month: The month as an integer (1-12) - - Returns: - self: The filter object for method chaining - """ - if not (1 <= month <= 12): - raise ValueError("Month must be between 1 and 12") - - start_date = datetime.date(year, month, 1) - - # Calculate the last day of the month - if month == 12: - end_date = datetime.date(year + 1, 1, 1) - datetime.timedelta(days=1) - else: - end_date = datetime.date(year, month + 1, 1) - datetime.timedelta(days=1) - - start = datetime.datetime.combine(start_date, datetime.time.min) - end = datetime.datetime.combine(end_date, datetime.time.max) - return self.between(start, end) - - def year_of(self, year): - """ - Match the specified year. - - Args: - year: The year as an integer - - Returns: - self: The filter object for method chaining - """ - start = datetime.datetime(year, 1, 1, 0, 0, 0) - end = datetime.datetime(year, 12, 31, 23, 59, 59) - return self.between(start, end) - - def last_days(self, days): - """ - Match timestamps from the last N days up to now. - - Args: - days: Number of days to look back - - Returns: - self: The filter object for method chaining - """ - end = datetime.datetime.now(datetime.timezone.utc) - start = end - datetime.timedelta(days=days) - return self.between(start, end) - - @staticmethod - def _is_date_only(iso_string: str) -> bool: - """Check if an ISO formatted string only includes date information using regex.""" - # Match YYYY-MM-DD format exactly - date_pattern = r"^\d{4}-\d{2}-\d{2}$" - return bool(re.match(date_pattern, iso_string)) - - def _convert_to_timestamp(self, value): - """ - Convert various inputs to a Unix timestamp (seconds since epoch in UTC). - - Args: - value: A datetime, date, string, int, or float - - Returns: - float: Unix timestamp - """ - if value is None: - return None - - if isinstance(value, (int, float)): - # Already a Unix timestamp - return float(value) - - if isinstance(value, str): - # Parse ISO format - try: - value = datetime.datetime.fromisoformat(value) - except ValueError: - raise ValueError(f"String timestamp must be in ISO format: {value}") - - if isinstance(value, datetime.date) and not isinstance( - value, datetime.datetime - ): - # Convert date to datetime at midnight - value = datetime.datetime.combine(value, datetime.time.min) - - # Ensure the datetime is timezone-aware (UTC) - if isinstance(value, datetime.datetime): - if value.tzinfo is None: - value = value.replace(tzinfo=datetime.timezone.utc) - else: - value = value.astimezone(datetime.timezone.utc) - - # Convert to Unix timestamp - return value.timestamp() - - raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") - - def __str__(self): - """Generate the Redis query string for this filter.""" - if self._value is None: - return "*" - if ( - self._operator == FilterOperator.BETWEEN - or self._operator == FilterOperator.EQ - or ( - self._operator == FilterOperator.NE - and self._start_value == self._end_value - ) - ): - # For between and exact matches with range - return self.OPERATOR_MAP[self._operator] % ( - self.escaper.escape(self._field), - self._start_value, - self._end_value, - ) - elif ( - self._operator == FilterOperator.NE and self._start_value != self._end_value - ): - # For not equal with date range - return self.OPERATOR_MAP[self._operator] % ( - self.escaper.escape(self._field), - self._start_value, - self._end_value, - ) - else: - # For other operators - return self.OPERATOR_MAP[self._operator] % ( - self.escaper.escape(self._field), - self._value, - ) + # def __str__(self): + # """Generate the Redis query string for this filter.""" + # if self._value is None: + # return "*" + # if ( + # self._operator == FilterOperator.BETWEEN + # or self._operator == FilterOperator.EQ + # or ( + # self._operator == FilterOperator.NE + # and self._start_value == self._end_value + # ) + # ): + # # For between and exact matches with range + # return self.OPERATOR_MAP[self._operator] % ( + # self.escaper.escape(self._field), + # self._start_value, + # self._end_value, + # ) + # elif ( + # self._operator == FilterOperator.NE and self._start_value != self._end_value + # ): + # # For not equal with date range + # return self.OPERATOR_MAP[self._operator] % ( + # self.escaper.escape(self._field), + # self._start_value, + # self._end_value, + # ) + # else: + # # For other operators + # return self.OPERATOR_MAP[self._operator] % ( + # self.escaper.escape(self._field), + # self._value, + # ) diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index cbdb7634..deb58cbc 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -1,3 +1,5 @@ +from datetime import timedelta + import pytest from redis.commands.search.result import Result @@ -348,25 +350,32 @@ def test_filters(index, query, sample_datetimes): search(query, index, t, 7) # Timestamps - t = Timestamp("last_updated") > sample_datetimes["mid"] - search(query, index, t, 2) + ts = Timestamp("last_updated") > sample_datetimes["mid"] + search(query, index, ts, 2) - t = Timestamp("last_updated") >= sample_datetimes["mid"] - search(query, index, t, 5) + ts = Timestamp("last_updated") >= sample_datetimes["mid"] + search(query, index, ts, 5) - t = Timestamp("last_updated") < sample_datetimes["high"] - search(query, index, t, 5) + ts = Timestamp("last_updated") < sample_datetimes["high"] + search(query, index, ts, 5) - t = Timestamp("last_updated") <= sample_datetimes["mid"] - search(query, index, t, 5) + ts = Timestamp("last_updated") <= sample_datetimes["mid"] + search(query, index, ts, 5) + + ts = Timestamp("last_updated") == sample_datetimes["mid"] + search(query, index, ts, 3) - t = Timestamp("last_updated") == sample_datetimes["mid"] - search(query, index, t, 3) + ts = (Timestamp("last_updated") == sample_datetimes["low"]) | ( + Timestamp("last_updated") == sample_datetimes["high"] + ) + search(query, index, ts, 4) - t = Timestamp("last_updated").between( - sample_datetimes["low"], sample_datetimes["high"] + # could drop between if we prefer union syntax + ts = Timestamp("last_updated").between( + sample_datetimes["low"] + timedelta(seconds=1), + sample_datetimes["high"] - timedelta(seconds=1), ) - search(query, index, t, 3) + search(query, index, ts, 3) def test_manual_string_filters(index, query): diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 1992d3af..6d82de02 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -1,4 +1,4 @@ -from datetime import date, datetime, timezone +from datetime import date, datetime, time, timezone import pytest @@ -324,22 +324,12 @@ def test_timestamp_date(): d = date(2023, 3, 17) ts = Timestamp("created_at") == d - # Expected start is midnight UTC - start_dt = datetime(2023, 3, 17, 0, 0, 0, tzinfo=timezone.utc) - # Expected end is end of day UTC - end_dt = datetime(2023, 3, 17, 23, 59, 59, 999999, tzinfo=timezone.utc) - - expected_start_ts = start_dt.timestamp() - expected_end_ts = end_dt.timestamp() - - # The filter should create a range query for the entire day - assert str(ts).startswith(f"@created_at:[") - # We can't easily test the exact values due to potential timezone issues - # so we'll check that the values are within the expected day + expected_ts_start = ( + datetime.combine(d, time.min).astimezone(timezone.utc).timestamp() + ) + expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp() - # Alternative approach: use the day_of method directly - ts2 = Timestamp("created_at").day_of(d) - assert str(ts) == str(ts2) + assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]" def test_timestamp_iso_string(): @@ -347,8 +337,11 @@ def test_timestamp_iso_string(): # Date-only ISO string ts = Timestamp("created_at") == "2023-03-17" d = date(2023, 3, 17) - expected_ts = Timestamp("created_at").day_of(d) - assert str(ts) == str(expected_ts) + expected_ts_start = ( + datetime.combine(d, time.min).astimezone(timezone.utc).timestamp() + ) + expected_ts_end = datetime.combine(d, time.max).astimezone(timezone.utc).timestamp() + assert str(ts) == f"@created_at:[{expected_ts_start} {expected_ts_end}]" # Full ISO datetime string dt_str = "2023-03-17T14:30:00+00:00" @@ -409,7 +402,7 @@ def test_timestamp_between(): start_ts = start.timestamp() end_ts = end.timestamp() - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" + assert str(ts) == f"@created_at:[{start_ts} {end_ts}]" # Test with dates (should expand to full days) start_date = date(2023, 3, 1) @@ -428,100 +421,7 @@ def test_timestamp_between(): expected_start_ts = expected_start.timestamp() expected_end_ts = expected_end.timestamp() - assert str(ts) == f"@created_at:[({expected_start_ts} ({expected_end_ts}]" - - -def test_timestamp_day_of(): - """Test the day_of helper method.""" - d = date(2023, 3, 17) - ts = Timestamp("created_at").day_of(d) - - # Expected start is midnight UTC - start_dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc) - # Expected end is end of day UTC - end_dt = datetime.combine(d, datetime.max.time()).replace(tzinfo=timezone.utc) - - start_ts = start_dt.timestamp() - end_ts = end_dt.timestamp() - - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" - - # Test with string date - ts = Timestamp("created_at").day_of("2023-03-17") - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" - - -def test_timestamp_week_of(): - """Test the week_of helper method.""" - # March 17, 2023 was a Friday - d = date(2023, 3, 17) - ts = Timestamp("created_at").week_of(d) - - # Monday of that week is March 13 - monday = date(2023, 3, 13) - # Sunday of that week is March 19 - sunday = date(2023, 3, 19) - - start_dt = datetime.combine(monday, datetime.min.time()).replace( - tzinfo=timezone.utc - ) - end_dt = datetime.combine(sunday, datetime.max.time()).replace(tzinfo=timezone.utc) - - start_ts = start_dt.timestamp() - end_ts = end_dt.timestamp() - - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" - - -def test_timestamp_month_of(): - """Test the month_of helper method.""" - ts = Timestamp("created_at").month_of(2023, 3) - - # First day of March - start_date = date(2023, 3, 1) - # Last day of March - end_date = date(2023, 3, 31) - - start_dt = datetime.combine(start_date, datetime.min.time()).replace( - tzinfo=timezone.utc - ) - end_dt = datetime.combine(end_date, datetime.max.time()).replace( - tzinfo=timezone.utc - ) - - start_ts = start_dt.timestamp() - end_ts = end_dt.timestamp() - - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" - - # Test with invalid month - with pytest.raises(ValueError): - Timestamp("created_at").month_of(2023, 13) - - -def test_timestamp_year_of(): - """Test the year_of helper method.""" - ts = Timestamp("created_at").year_of(2023) - - start_dt = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - end_dt = datetime(2023, 12, 31, 23, 59, 59, tzinfo=timezone.utc) - - start_ts = start_dt.timestamp() - end_ts = end_dt.timestamp() - - assert str(ts) == f"@created_at:[({start_ts} ({end_ts}]" - - -def test_timestamp_last_days(): - """Test the last_days helper method.""" - ts = Timestamp("created_at").last_days(7) - - # This test is tricky because it depends on the current time - # We'll just verify that it generates a valid filter string - assert "@created_at:[" in str(ts) - - # We can mock datetime.now for more precise testing in a real test suite - # but for simplicity, we'll just check the format here + assert str(ts) == f"@created_at:[{expected_start_ts} {expected_end_ts}]" def test_timestamp_none(): From bc161bca9c935f1d339815878721968bb30fde86 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 18 Mar 2025 13:29:36 -0400 Subject: [PATCH 6/9] remove old code --- redisvl/query/filter.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 68cb6978..82cc8afb 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -798,37 +798,3 @@ def between(self, start, end): (start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.BETWEEN ) return FilterExpression(str(self)) - - # def __str__(self): - # """Generate the Redis query string for this filter.""" - # if self._value is None: - # return "*" - # if ( - # self._operator == FilterOperator.BETWEEN - # or self._operator == FilterOperator.EQ - # or ( - # self._operator == FilterOperator.NE - # and self._start_value == self._end_value - # ) - # ): - # # For between and exact matches with range - # return self.OPERATOR_MAP[self._operator] % ( - # self.escaper.escape(self._field), - # self._start_value, - # self._end_value, - # ) - # elif ( - # self._operator == FilterOperator.NE and self._start_value != self._end_value - # ): - # # For not equal with date range - # return self.OPERATOR_MAP[self._operator] % ( - # self.escaper.escape(self._field), - # self._start_value, - # self._end_value, - # ) - # else: - # # For other operators - # return self.OPERATOR_MAP[self._operator] % ( - # self.escaper.escape(self._field), - # self._value, - # ) From ba79474cb0ce57349e3ea4a4d36753b157b4e15f Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 19 Mar 2025 09:48:36 -0400 Subject: [PATCH 7/9] add inclusive filter --- docs/user_guide/02_hybrid_queries.ipynb | 101 +++++++++++++++++++++++- redisvl/query/filter.py | 64 ++++++++++++--- tests/unit/test_filter.py | 37 +++++++-- 3 files changed, 182 insertions(+), 20 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 00868b0d..0c868dd1 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -1467,11 +1467,108 @@ "# Cleanup\n", "index.delete()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from enum import Enum\n", + "\n", + "class Inclusive(str, Enum):\n", + " \"\"\"Enumeration for distance aggregation methods.\"\"\"\n", + "\n", + " BOTH = \"both\"\n", + " \"\"\"Inclusive of both sides of range (default)\"\"\"\n", + " NEITHER = \"neither\"\n", + " \"\"\"Inclusive of neither side of range\"\"\"\n", + " LEFT = \"left\"\n", + " \"\"\"Inclusive of only left\"\"\"\n", + " RIGHT = \"right\"\n", + " \"\"\"Inclusive of only right\"\"\"\n", + "\n", + "def my_fn(value: Inclusive) -> str:\n", + " return Inclusive(value).value" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['both', 'neither', 'left', '']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(e.value for e in Inclusive)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "mappingproxy({'BOTH': ,\n", + " 'NEITHER': ,\n", + " 'LEFT': ,\n", + " 'RIGHT': })" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Inclusive." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "'not' is not a valid Inclusive", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mInclusive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnot\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/enum.py:714\u001b[0m, in \u001b[0;36mEnumType.__call__\u001b[0;34m(cls, value, names, module, qualname, type, start, boundary)\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 690\u001b[0m \u001b[38;5;124;03mEither returns an existing member, or creates a new enum class.\u001b[39;00m\n\u001b[1;32m 691\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 711\u001b[0m \u001b[38;5;124;03m`type`, if set, will be mixed in as the first base class.\u001b[39;00m\n\u001b[1;32m 712\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 713\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# simple value lookup\u001b[39;00m\n\u001b[0;32m--> 714\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__new__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;66;03m# otherwise, functional API: we're creating a new Enum type\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_create_(\n\u001b[1;32m 717\u001b[0m value,\n\u001b[1;32m 718\u001b[0m names,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 723\u001b[0m boundary\u001b[38;5;241m=\u001b[39mboundary,\n\u001b[1;32m 724\u001b[0m )\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/enum.py:1137\u001b[0m, in \u001b[0;36mEnum.__new__\u001b[0;34m(cls, value)\u001b[0m\n\u001b[1;32m 1135\u001b[0m ve_exc \u001b[38;5;241m=\u001b[39m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m is not a valid \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (value, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m))\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ve_exc\n\u001b[1;32m 1138\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1139\u001b[0m exc \u001b[38;5;241m=\u001b[39m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1140\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124merror in \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m._missing_: returned \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m instead of None or a valid member\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1141\u001b[0m \u001b[38;5;241m%\u001b[39m (\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, result)\n\u001b[1;32m 1142\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: 'not' is not a valid Inclusive" + ] + } + ], + "source": [ + "Inclusive(\"not\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "env", + "display_name": "redisvl-Q9FZQJWe-py3.11", "language": "python", "name": "python3" }, @@ -1485,7 +1582,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.9" }, "orig_nbformat": 4 }, diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 82cc8afb..e0207852 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -10,6 +10,19 @@ # mypy: disable-error-code="override" +class Inclusive(str, Enum): + """Enumeration for distance aggregation methods.""" + + BOTH = "both" + """Inclusive of both sides of range (default)""" + NEITHER = "neither" + """Inclusive of neither side of range""" + LEFT = "left" + """Inclusive of only left""" + RIGHT = "right" + """Inclusive of only right""" + + class FilterOperator(Enum): EQ = 1 NE = 2 @@ -379,7 +392,35 @@ def __le__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LE) return FilterExpression(str(self)) - def between(self, start: int, end: int) -> "FilterExpression": + @staticmethod + def _validate_inclusive_string(inclusive: str) -> Inclusive: + try: + return Inclusive(inclusive) + except: + raise ValueError( + f"Invalid inclusive value must be: {[i.value for i in Inclusive]}" + ) + + def _format_inclusive_between( + self, inclusive: Inclusive, start: int, end: int + ) -> str: + if inclusive.value == Inclusive.BOTH.value: + return f"@{self._field}:[{start} {end}]" + + if inclusive.value == Inclusive.NEITHER.value: + return f"@{self._field}:[({start} ({end}]" + + if inclusive.value == Inclusive.LEFT.value: + return f"@{self._field}:[{start} ({end}]" + + if inclusive.value == Inclusive.RIGHT.value: + return f"@{self._field}:[({start} {end}]" + + raise ValueError(f"Inclusive value not found") + + def between( + self, start: int, end: int, inclusive: str = "both" + ) -> "FilterExpression": """Create a Numeric equality filter expression. Args: @@ -391,8 +432,10 @@ def between(self, start: int, end: int) -> "FilterExpression": f = Num("zipcode") == 90210 """ - self._set_value((start, end), self.SUPPORTED_VAL_TYPES, FilterOperator.BETWEEN) - return FilterExpression(str(self)) + inclusive = self._validate_inclusive_string(inclusive) + expression = self._format_inclusive_between(inclusive, start, end) + + return FilterExpression(expression) def __str__(self) -> str: """Return the Redis Query string for the Numeric filter""" @@ -674,7 +717,7 @@ def _convert_to_timestamp(self, value, end_date=False): raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") - def __eq__(self, other): + def __eq__(self, other) -> FilterExpression: """ Filter for timestamps equal to the specified value. For date objects (without time), this matches the entire day. @@ -701,7 +744,7 @@ def __eq__(self, other): self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) return FilterExpression(str(self)) - def __ne__(self, other): + def __ne__(self, other) -> FilterExpression: """ Filter for timestamps not equal to the specified value. For date objects (without time), this excludes the entire day. @@ -780,7 +823,7 @@ def __le__(self, other): self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.LE) return FilterExpression(str(self)) - def between(self, start, end): + def between(self, start, end, inclusive: str = "both"): """ Filter for timestamps between start and end (inclusive). @@ -791,10 +834,11 @@ def between(self, start, end): Returns: self: The filter object for method chaining """ + inclusive = self._validate_inclusive_string(inclusive) + start_ts = self._convert_to_timestamp(start) end_ts = self._convert_to_timestamp(end, end_date=True) - self._set_value( - (start_ts, end_ts), self.SUPPORTED_TYPES, FilterOperator.BETWEEN - ) - return FilterExpression(str(self)) + expression = self._format_inclusive_between(inclusive, start_ts, end_ts) + + return FilterExpression(expression) diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index 6d82de02..dae74240 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -1,4 +1,4 @@ -from datetime import date, datetime, time, timezone +from datetime import date, datetime, time, timedelta, timezone import pytest @@ -112,6 +112,18 @@ def test_numeric_filter(): nf = Num("numeric_field") != None assert str(nf) == "*" + nf = Num("numeric_field").between(2, 5) + assert str(nf) == "@numeric_field:[2 5]" + + nf = Num("numeric_field").between(2, 5, inclusive="neither") + assert str(nf) == "@numeric_field:[2 5]" + + nf = Num("numeric_field").between(2, 5, inclusive="left") + assert str(nf) == "@numeric_field:[2 (5]" + + nf = Num("numeric_field").between(2, 5, inclusive="right") + assert str(nf) == "@numeric_field:[(2 5]" + def test_text_filter(): txt_f = Text("text_field") == "text" @@ -296,13 +308,6 @@ def test_num_filter_zero(): ), "Num filter should handle zero correctly" -from datetime import date, datetime, timedelta, timezone - -import pytest - -from redisvl.query.filter import Timestamp - - def test_timestamp_datetime(): """Test Timestamp filter with datetime objects.""" # Test with timezone-aware datetime @@ -391,6 +396,22 @@ def test_timestamp_operators(): ts = Timestamp("created_at") <= dt assert str(ts) == f"@created_at:[-inf {ts_value}]" + td = timedelta(days=5) + dt2 = dt + td + ts_value2 = dt2.timestamp() + + ts = Timestamp("created_at").between(dt, dt2) + assert str(ts) == f"@created_at:[{ts_value} {ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="neither") + assert str(ts) == f"@created_at:[({ts_value} ({ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="left") + assert str(ts) == f"@created_at:[{ts_value} ({ts_value2}]" + + ts = Timestamp("created_at").between(dt, dt2, inclusive="right") + assert str(ts) == f"@created_at:[({ts_value} {ts_value2}]" + def test_timestamp_between(): """Test the between method for date ranges.""" From 407ff72f6764e3eaa53d17e84a7afa5b7f7bfa3a Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 19 Mar 2025 09:56:01 -0400 Subject: [PATCH 8/9] remove test code --- docs/user_guide/02_hybrid_queries.ipynb | 97 ------------------------- 1 file changed, 97 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 0c868dd1..7b426c23 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -1467,103 +1467,6 @@ "# Cleanup\n", "index.delete()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from enum import Enum\n", - "\n", - "class Inclusive(str, Enum):\n", - " \"\"\"Enumeration for distance aggregation methods.\"\"\"\n", - "\n", - " BOTH = \"both\"\n", - " \"\"\"Inclusive of both sides of range (default)\"\"\"\n", - " NEITHER = \"neither\"\n", - " \"\"\"Inclusive of neither side of range\"\"\"\n", - " LEFT = \"left\"\n", - " \"\"\"Inclusive of only left\"\"\"\n", - " RIGHT = \"right\"\n", - " \"\"\"Inclusive of only right\"\"\"\n", - "\n", - "def my_fn(value: Inclusive) -> str:\n", - " return Inclusive(value).value" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['both', 'neither', 'left', '']" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(e.value for e in Inclusive)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "mappingproxy({'BOTH': ,\n", - " 'NEITHER': ,\n", - " 'LEFT': ,\n", - " 'RIGHT': })" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Inclusive." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "'not' is not a valid Inclusive", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mInclusive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnot\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/enum.py:714\u001b[0m, in \u001b[0;36mEnumType.__call__\u001b[0;34m(cls, value, names, module, qualname, type, start, boundary)\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 690\u001b[0m \u001b[38;5;124;03mEither returns an existing member, or creates a new enum class.\u001b[39;00m\n\u001b[1;32m 691\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 711\u001b[0m \u001b[38;5;124;03m`type`, if set, will be mixed in as the first base class.\u001b[39;00m\n\u001b[1;32m 712\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 713\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# simple value lookup\u001b[39;00m\n\u001b[0;32m--> 714\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__new__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;66;03m# otherwise, functional API: we're creating a new Enum type\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_create_(\n\u001b[1;32m 717\u001b[0m value,\n\u001b[1;32m 718\u001b[0m names,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 723\u001b[0m boundary\u001b[38;5;241m=\u001b[39mboundary,\n\u001b[1;32m 724\u001b[0m )\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/enum.py:1137\u001b[0m, in \u001b[0;36mEnum.__new__\u001b[0;34m(cls, value)\u001b[0m\n\u001b[1;32m 1135\u001b[0m ve_exc \u001b[38;5;241m=\u001b[39m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m is not a valid \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (value, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m))\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ve_exc\n\u001b[1;32m 1138\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1139\u001b[0m exc \u001b[38;5;241m=\u001b[39m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1140\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124merror in \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m._missing_: returned \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m instead of None or a valid member\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1141\u001b[0m \u001b[38;5;241m%\u001b[39m (\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, result)\n\u001b[1;32m 1142\u001b[0m )\n", - "\u001b[0;31mValueError\u001b[0m: 'not' is not a valid Inclusive" - ] - } - ], - "source": [ - "Inclusive(\"not\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 1e9e76f3fb98066916bfd3382ae07d98131df474 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Thu, 20 Mar 2025 11:06:34 -0400 Subject: [PATCH 9/9] quick tweaks --- docs/user_guide/02_hybrid_queries.ipynb | 21 ++------------------- redisvl/query/filter.py | 14 ++------------ 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 7b426c23..3d399dcc 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -100,26 +100,9 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_session\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_cache\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float32_session\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float16_session\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. bfloat_session\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. float32_cache\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. bfloat_cache\n", - "\u001b[32m11:03:03\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. user_queries\n" - ] - } - ], + "outputs": [], "source": [ "# use the CLI to see the created index\n", "!rvl index listall" diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index e0207852..ced52520 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -11,7 +11,7 @@ class Inclusive(str, Enum): - """Enumeration for distance aggregation methods.""" + """Enum for valid inclusive options""" BOTH = "both" """Inclusive of both sides of range (default)""" @@ -421,17 +421,7 @@ def _format_inclusive_between( def between( self, start: int, end: int, inclusive: str = "both" ) -> "FilterExpression": - """Create a Numeric equality filter expression. - - Args: - other (int): The value to filter on. - - .. code-block:: python - - from redisvl.query.filter import Num - f = Num("zipcode") == 90210 - - """ + """Operator for searching values between two numeric values.""" inclusive = self._validate_inclusive_string(inclusive) expression = self._format_inclusive_between(inclusive, start, end)