diff --git a/.gitignore b/.gitignore index ba0430d2..fb8b672d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__pycache__/ \ No newline at end of file +__pycache__/ +redisvl.egg-info/ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..fb9e4c85 --- /dev/null +++ b/Makefile @@ -0,0 +1,73 @@ +MAKEFLAGS += --no-print-directory + +# Do not remove this block. It is used by the 'help' rule when +# constructing the help output. +# help: +# help: Developer Makefile +# help: + + +SHELL:=/bin/bash + +# help: help - display this makefile's help information +.PHONY: help +help: + @grep "^# help\:" Makefile | grep -v grep | sed 's/\# help\: //' | sed 's/\# help\://' + + +# help: +# help: Style +# help: ------- + +# help: style - Sort imports and format with black +.PHONY: style +style: sort-imports format + + +# help: check-style - check code style compliance +.PHONY: check-style +check-style: check-sort-imports check-format + + +# help: format - perform code style format +.PHONY: format +format: + @black ./redisvl ./tests/ + + +# help: sort-imports - apply import sort ordering +.PHONY: sort-imports +sort-imports: + @isort ./redisvl ./tests/ --profile black + + +# help: check-lint - run static analysis checks +.PHONY: check-lint +check-lint: + @pylint --rcfile=.pylintrc ./redisvl + + +# help: +# help: Test +# help: ------- + +# help: test - Run all tests +.PHONY: test +test: + @python -m pytest + +# help: test-verbose - Run all tests verbosely +.PHONY: test-verbose +test-verbose: + @python -m pytest -vv -s + +# help: test-cov - Run all tests with coverage +.PHONY: test-cov +test-cov: + @python -m pytest -vv --cov=./redisvl + +# help: cov - generate html coverage report +.PHONY: cov +cov: + @coverage html + @echo if data was present, coverage report is in ./htmlcov/index.html diff --git a/README.md b/README.md index 63ca59b7..1c014dac 100644 --- a/README.md +++ b/README.md @@ -1,65 +1,97 @@ -# RediSearch Data Loader -The purpose of this script is to assist in loading datasets to a RediSearch instance efficiently. +# RedisVL -The project is brand new and will undergo improvements over time. +A CLI and Library to help with loading data into Redis specifically for +usage with RediSearch and Redis Vector Search capabilities -## Getting Started +### Usage -### Requirements -Install the Python requirements listed in `requirements.txt`. - -```bash -$ pip install -r requirements.txt ``` +usage: redisvl [] -### Data -In order to run the script you need to have a dataset that contains your vectors and metadata. +Commands: + load Load vector data into redis + index Index manipulation (create, delete, etc.) + query Query an existing index ->Currently, the data file must be a pickled pandas dataframe. Support for more data types will be included in future iterations. +Redis Vector load CLI -### Schema -Along with the dataset, you must update the dataset schema for RediSearch in [`data/schema.py`](data/schema.py). +positional arguments: + command Subcommand to run -### Running -The `main.py` script provides an entrypoint with optional arguments to upload your dataset to a Redis server. +optional arguments: + -h, --help show this help message and exit -#### Usage -``` -python main.py - - -h, --help Show this help message and exit - --host Redis host - -p, --port Redis port - -a, --password Redis password - -c , --concurrency Amount of concurrency - -d , --data Path to data file - --prefix Key prefix for all hashes in the search index - -v , --vector Vector field name in df - -i , --index Index name ``` -#### Defaults +For any of the above commands, you will need to have an index schema written +into a yaml file for the cli to read. The format of the schema is as follows + +```yaml +index: + name: sample # index name used for querying + storage_type: hash + key_field: "id" # column name to use for key in redis + prefix: vector # prefix used for all loaded docs + +# all fields to create index with +# sub-items correspond to redis-py Field arguments +fields: + tag: + categories: # name of a tag field used for queries + separator: "|" + year: # name of a tag field used for queries + separator: "|" + vector: + vector: # name of the vector field used for queries + datatype: "float32" + algorithm: "flat" # flat or HSNW + dims: 768 + distance_metric: "cosine" # ip, L2, cosine +``` -| Argument | Default | -| ----------- | ----------- | -| Host | `localhost` | -| Port | `6379` | -| Password | "" | -| Concurrency | `50` | -| Data (Path) | `data/embeddings.pkl` | -| Prefix | `vector:` | -| Vector (Field Name) | `vector` | -| Index Name | `index` | +#### Example Usage +```bash +# load in a pickled dataframe with +redisvl load -s sample.yml -d embeddings.pkl +``` -#### Examples +```bash +# load in a pickled dataframe to a specific address and port +redisvl load -s sample.yml -d embeddings.pkl -h 127.0.0.1 -p 6379 +``` -Load to a local (default) redis server with a custom index name and with concurrency = 100: ```bash -$ python main.py -d data/embeddings.pkl -i myIndex -c 100 +# load in a pickled dataframe to a specific +# address and port and with password +redisvl load -s sample.yml -d embeddings.pkl -h 127.0.0.1 -p 6379 -p supersecret ``` -Load to a cloud redis server with all other defaults: +### Support + +#### Supported Index Fields + + - ``geo`` + - ``tag`` + - ``numeric`` + - ``vector`` + - ``text`` +#### Supported Data Types + - Pandas DataFrame (pickled) +#### Supported Redis Data Types + - Hash + - JSON (soon) + +### Install +Install the Python requirements listed in `requirements.txt`. + ```bash -$ python main.py -h {redis-host} -p {redis-port} -a {redis-password} -``` \ No newline at end of file +git clone https://github.com/RedisVentures/data-loader.git +cd redisvl +pip install . +``` + +### Creating Input Data +#### Pandas DataFrame + + more to come, see tests and sample-data for usage \ No newline at end of file diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..306bbe27 --- /dev/null +++ b/conftest.py @@ -0,0 +1,13 @@ +import os +import pytest + +from redisvl.utils.connection import get_async_redis_connection + +HOST = os.environ.get("REDIS_HOST", "localhost") +PORT = os.environ.get("REDIS_PORT", 6379) +USER = os.environ.get("REDIS_USER", "default") +PASS = os.environ.get("REDIS_PASSWORD", "") + +@pytest.fixture +def async_redis(): + return get_async_redis_connection(HOST, PORT, PASS) \ No newline at end of file diff --git a/data/schema.py b/data/schema.py deleted file mode 100644 index 2bc47868..00000000 --- a/data/schema.py +++ /dev/null @@ -1,23 +0,0 @@ -from redis.commands.search.field import ( - TagField, - VectorField -) - -# Build Schema -def get_schema(size: int): - return [ - # Tag fields - TagField("categories", separator = "|"), - TagField("year", separator = "|"), - # Vector field (FLAT index with COSINE similarity) - VectorField( - "vector", - "FLAT", { - "TYPE": "FLOAT32", - "DIM": 768, - "DISTANCE_METRIC": "COSINE", - "INITIAL_CAP": size, - "BLOCK_SIZE": size - } - ) - ] \ No newline at end of file diff --git a/main.py b/main.py deleted file mode 100644 index d38fd825..00000000 --- a/main.py +++ /dev/null @@ -1,143 +0,0 @@ -import asyncio -import warnings -import argparse -import logging -import pickle -import typing as t -import numpy as np - -from redis.asyncio import Redis -from data.schema import get_schema -from utils.search_index import SearchIndex - - -warnings.filterwarnings("error") - -logging.basicConfig( - level = logging.INFO, - format = "%(asctime)5s:%(filename)25s" - ":%(lineno)3s %(funcName)30s(): %(message)s", -) - -def read_data(data_file: str) -> t.List[dict]: - """ - Read dataset from a pickled dataframe (Pandas) file. - TODO -- add support for other input data types. - - Args: - data_file (str): Path to the destination - of the input data file. - - Returns: - t.List[dict]: List of Hash objects to insert to Redis. - """ - logging.info(f"Reading dataset from file: {data_file}") - with open(data_file, "rb") as f: - df = pickle.load(f) - return df.to_dict("records") - -async def gather_with_concurrency( - *data, - n: int, - vector_field_name: str, - prefix: str, - redis_conn: Redis -): - """ - Gather and load the hashes into Redis using - async connections. - - Args: - n (int): Max number of "concurrent" async connections. - vector_field_name (str): Vector field name in the dataframe. - prefix (str): Redis key prefix for all hashes in the search index. - redis_conn (Redis): Redis connection. - """ - logging.info("Loading dataset into Redis") - semaphore = asyncio.Semaphore(n) - async def load(d: dict): - async with semaphore: - d[vector_field_name] = np.array(d[vector_field_name], dtype = np.float32).tobytes() - key = prefix + str(d["id"]) - await redis_conn.hset(key, mapping = d) - # gather with concurrency - await asyncio.gather(*[load(d) for d in data]) - -async def load_all_data( - redis_conn: Redis, - concurrency: int, - prefix: str, - vector_field_name: str, - data_file: str, - index_name: str -): - """ - Load all data. - - Args: - redis_conn (Redis): Redis connection. - concurrency (int): Max number of "concurrent" async connections. - prefix (str): Redis key prefix for all hashes in the search index. - vector_field_name (str): Vector field name in the dataframe. - data_file (str): Path to the destination of the input data file. - index_name (str): Name of the RediSearch Index. - """ - search_index = SearchIndex( - index_name = index_name, - redis_conn = redis_conn - ) - - # Load from pickled dataframe file - data = read_data(data_file) - - # Gather async - await gather_with_concurrency( - *data, - n = concurrency, - prefix = prefix, - vector_field_name= vector_field_name, - redis_conn = redis_conn - ) - - # Load schema - logging.info("Processing RediSearch schema") - schema = get_schema(len(data)) - await search_index.create(*schema, prefix=prefix) - logging.info("All done. Data uploaded and RediSearch index created.") - - -async def main(): - # Parse script arguments - parser = argparse.ArgumentParser() - parser.add_argument("--host", help="Redis host", type=str, default="localhost") - parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379) - parser.add_argument("-a", "--password", help="Redis password", type=str, default="") - parser.add_argument("-c", "--concurrency", type=int, default=50) - parser.add_argument("-d", "--data", help="Path to data file", type=str, default="data/embeddings.pkl") - parser.add_argument("--prefix", help="Key prefix for all hashes in the search index", type=str, default="vector:") - parser.add_argument("-v", "--vector", help="Vector field name in df", type=str, default="vector") - parser.add_argument("-i", "--index", help="Index name", type=str, default="index") - args = parser.parse_args() - - # Create Redis Connection - connection_args = { - "host": args.host, - "port": args.port - } - if args.password: - connection_args.update({"password": args.password}) - redis_conn = Redis(**connection_args) - - # Perform data loading - await load_all_data( - redis_conn=redis_conn, - concurrency=args.concurrency, - prefix=args.prefix, - vector_field_name=args.vector, - data_file=args.data, - index_name=args.index - ) - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..9c502d63 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[tool.black] +target-version = ['py37', 'py38', 'py39', 'py310'] +exclude = ''' +( + | \.egg + | \.git + | \.hg + | \.mypy_cache + | \.nox + | \.tox + | \.venv + | _build + | build + | dist + | setup.py +) +''' + +[tool.pytest.ini_options] +log_cli = true + +[tool.coverage.run] +source = ["redisvl"] + +[tool.coverage.report] +ignore_errors = true + +[tool.coverage.html] +directory = "htmlcov" \ No newline at end of file diff --git a/data/__init__.py b/redisvl/__init__.py similarity index 100% rename from data/__init__.py rename to redisvl/__init__.py diff --git a/redisvl/cli/__init__.py b/redisvl/cli/__init__.py new file mode 100644 index 00000000..becbae5e --- /dev/null +++ b/redisvl/cli/__init__.py @@ -0,0 +1,3 @@ +from redisvl.cli.main import RedisVlCLI + +RedisVlCLI() diff --git a/utils/__init__.py b/redisvl/cli/index.py similarity index 100% rename from utils/__init__.py rename to redisvl/cli/index.py diff --git a/redisvl/cli/load.py b/redisvl/cli/load.py new file mode 100644 index 00000000..bd7d4ea5 --- /dev/null +++ b/redisvl/cli/load.py @@ -0,0 +1,86 @@ +import argparse +import asyncio +import sys +import typing as t + +from redisvl import readers +from redisvl.index import SearchIndex +from redisvl.load import concurrent_store_as_hash +from redisvl.utils.connection import get_async_redis_connection +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +class Load: + def __init__(self): + parser = argparse.ArgumentParser(description="Load vector data into redis") + parser.add_argument( + "-d", "--data", help="Path to data file", type=str, required=True + ) + parser.add_argument( + "-s", "--schema", help="Path to schema file", type=str, required=True + ) + parser.add_argument("--host", help="Redis host", type=str, default="localhost") + parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379) + parser.add_argument( + "-a", "--password", help="Redis password", type=str, default="" + ) + parser.add_argument("-c", "--concurrency", type=int, default=50) + # TODO add argument to optionally not create index + args = parser.parse_args(sys.argv[2:]) + if not args.data: + parser.print_help() + exit(0) + + # Create Redis Connection + try: + logger.info(f"Connecting to {args.host}:{str(args.port)}") + redis_conn = get_async_redis_connection(args.host, args.port, args.password) + logger.info("Connected.") + except: + # TODO: be more specific about the exception + logger.error("Could not connect to redis.") + exit(1) + + # validate schema + index = SearchIndex.from_yaml(redis_conn, args.schema) + + # read in data + logger.info("Reading data...") + data = self.read_data(args) # TODO add other readers and formats + logger.info("Data read.") + + # load data and create the index + asyncio.run(self.load_and_create_index(args.concurrency, data, index)) + + def read_data( + self, args: t.List[str], reader: str = "pandas", format: str = "pickle" + ) -> dict: + if reader == "pandas": + if format == "pickle": + return readers.pandas.from_pickle(args.data) + else: + raise NotImplementedError( + "Only pickle format is supported for pandas reader." + ) + else: + raise NotImplementedError("Only pandas reader is supported.") + + async def load_and_create_index( + self, concurrency: int, data: dict, index: SearchIndex + ): + + logger.info("Loading data...") + if index.storage_type == "hash": + await concurrent_store_as_hash( + data, concurrency, index.key_field, index.prefix, index.redis_conn + ) + else: + raise NotImplementedError("Only hash storage type is supported.") + logger.info("Data loaded.") + + # create index + logger.info("Creating index...") + await index.create() + logger.info("Index created.") diff --git a/redisvl/cli/main.py b/redisvl/cli/main.py new file mode 100644 index 00000000..f9b633de --- /dev/null +++ b/redisvl/cli/main.py @@ -0,0 +1,43 @@ +import argparse +import sys + +from redisvl.cli.load import Load +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +def _usage(): + usage = [ + "redisvl []\n", + "Commands:", + "\tload Load vector data into redis", + "\tindex Index manipulation (create, delete, etc.)", + "\tquery Query an existing index", + ] + return "\n".join(usage) + + +class RedisVlCLI: + def __init__(self): + parser = argparse.ArgumentParser( + description="Redis Vector load CLI", usage=_usage() + ) + + parser.add_argument("command", help="Subcommand to run") + + if len(sys.argv) < 2: + parser.print_help() + exit(0) + + args = parser.parse_args(sys.argv[1:2]) + if not hasattr(self, args.command): + parser.print_help() + exit(0) + getattr(self, args.command)() + + def load(self): + Load() + exit(0) + + # TODO index and query functions diff --git a/redisvl/cli/query.py b/redisvl/cli/query.py new file mode 100644 index 00000000..e69de29b diff --git a/redisvl/index.py b/redisvl/index.py new file mode 100644 index 00000000..46baa14e --- /dev/null +++ b/redisvl/index.py @@ -0,0 +1,88 @@ +import re +import typing as t +from typing import Optional, Pattern + +from redis.asyncio import Redis +from redis.commands.search.field import Field +from redis.commands.search.indexDefinition import IndexDefinition, IndexType + +from redisvl.schema import read_field_spec, read_schema + + +class TokenEscaper: + """ + Escape punctuation within an input string. Taken from RedisOM Python. + """ + + # Characters that RediSearch requires us to escape during queries. + # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization + DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" + + def __init__(self, escape_chars_re: Optional[Pattern] = None): + if escape_chars_re: + self.escaped_chars_re = escape_chars_re + else: + self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) + + def escape(self, value: str) -> str: + def escape_symbol(match): + value = match.group(0) + return f"\\{value}" + + return self.escaped_chars_re.sub(escape_symbol, value) + + +class SearchIndex: + """ + SearchIndex is used to wrap and capture all information + and actions applied to a RediSearch index including creation, + manegement, and query construction. + """ + + escaper = TokenEscaper() + + # TODO think about, should this have a redis connection? SearchIndexManupulator? + def __init__( + self, + redis_conn: Redis, + name: str, + storage_type: str = "hash", + key_field: str = "id", + prefix: str = "", + fields: t.List[Field] = None, + ): + self.index_name = name + self.key_field = key_field + self.redis_conn = redis_conn + self.storage_type = storage_type + self.prefix = prefix + self.fields = fields + + @classmethod + def from_yaml(cls, redis_conn: Redis, schema_path: str): + index_attrs, fields = read_schema(schema_path) + return cls(redis_conn, fields=fields, **index_attrs) + + @classmethod + def from_dict(cls, redis_conn: Redis, schema_dict: t.Dict[str, t.Any]): + # TODO error handling + fields = read_field_spec(schema_dict["fields"]) + index_attrs = schema_dict["index"] + return cls(redis_conn, fields=fields, **index_attrs) + + async def create( + self, + ): + # set storage_type, default to hash + storage_type = IndexType.HASH + if self.storage_type.lower() == "json": + self.storage_type = IndexType.JSON + + # Create Index + await self.redis_conn.ft(self.index_name).create_index( + fields=self.fields, + definition=IndexDefinition(prefix=[self.prefix], index_type=storage_type), + ) + + async def delete(self): + await self.redis_conn.ft(self.index_name).dropindex(delete_documents=True) diff --git a/redisvl/load.py b/redisvl/load.py new file mode 100644 index 00000000..1dcdda25 --- /dev/null +++ b/redisvl/load.py @@ -0,0 +1,35 @@ +import asyncio +import typing as t + +import numpy as np +from redis.asyncio import Redis + +# TODO Add conncurrent_store_as_json + + +async def concurrent_store_as_hash( + data: t.List[t.Dict[str, t.Any]], # TODO: be stricter about the type + concurrency: int, + key_field: str, + prefix: str, + redis_conn: Redis, +): + """ + Gather and load the hashes into Redis using + async connections. + + Args: + concurrency (int): Max number of "concurrent" async connections. + key_field: name of the key in each dict to use as top level key in Redis. + prefix (str): Redis key prefix for all hashes in the search index. + redis_conn (Redis): Redis connection. + """ + semaphore = asyncio.Semaphore(concurrency) + + async def load(d: dict): + async with semaphore: + key = prefix + str(d[key_field]) + await redis_conn.hset(key, mapping=d) + + # gather with concurrency + await asyncio.gather(*[load(d) for d in data]) diff --git a/redisvl/query.py b/redisvl/query.py new file mode 100644 index 00000000..a9dfa569 --- /dev/null +++ b/redisvl/query.py @@ -0,0 +1,20 @@ +import typing as t + +from redis.commands.search.query import Query + + +def create_vector_query( + return_fields: t.List[str], + search_type: str = "KNN", + number_of_results: int = 20, + vector_field_name: str = "vector", + tags: str = "*", +): + base_query = f"{tags}=>[{search_type} {number_of_results} @{vector_field_name} $vector AS vector_score]" + return ( + Query(base_query) + .sort_by("vector_score") + .paging(0, number_of_results) + .return_fields(*return_fields) + .dialect(2) + ) diff --git a/redisvl/readers/__init__.py b/redisvl/readers/__init__.py new file mode 100644 index 00000000..8ea67cc2 --- /dev/null +++ b/redisvl/readers/__init__.py @@ -0,0 +1 @@ +from . import pandas diff --git a/redisvl/readers/pandas.py b/redisvl/readers/pandas.py new file mode 100644 index 00000000..0ca3878b --- /dev/null +++ b/redisvl/readers/pandas.py @@ -0,0 +1,20 @@ +import pickle +import typing as t + +import pandas as pd + + +def from_pickle(data_file: str) -> t.List[dict]: + """ + Read dataset from a pickled dataframe (Pandas) file. + + Args: + data_file (str): Path to the destination + of the input data file. + + Returns: + t.List[dict]: List of Hash objects to insert to Redis. + """ + with open(data_file, "rb") as f: + df = pickle.load(f) + return df.to_dict("records") diff --git a/redisvl/schema.py b/redisvl/schema.py new file mode 100644 index 00000000..7dfc9173 --- /dev/null +++ b/redisvl/schema.py @@ -0,0 +1,142 @@ +import typing as t +from pathlib import Path + +import yaml +from redis.commands.search.field import ( + GeoField, + NumericField, + TagField, + TextField, + VectorField, +) + +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +def read_schema(file_path: str): + fp = Path(file_path).resolve() + if not fp.exists(): + logger.error(f"Schema file {file_path} does not exist") + raise FileNotFoundError(f"Schema file {file_path} does not exist") + + with open(fp, "r") as f: + schema = yaml.safe_load(f) + + try: + index_schema = schema["index"] + fields_schema = schema["fields"] + except KeyError: + logger.error("Schema file must contain both a 'fields' and 'index' key") + raise + + index_attrs = read_index_spec(index_schema) + fields = read_field_spec(fields_schema) + return index_attrs, fields + + +def read_index_spec(index_spec: t.Dict[str, t.Any]): + """Read index specification and return the fields + + Args: + index_schema (dict): Index specification from schema file. + + Returns: + index_fields (dict): List of index fields. + """ + # TODO parsing and validation here + return index_spec + + +def read_field_spec(field_spec: t.Dict[str, t.Any]): + """ + Read a schema file and return a list of RediSearch fields. + + Args: + field_schema (dict): Field specification from schema file. + + Returns: + fields: list of RediSearch fields. + """ + fields = [] + for key, field in field_spec.items(): + if key.upper() == "TAG": + for name, attrs in field.items(): + fields.append(TagField(name, **attrs)) + elif key.upper() == "VECTOR": + for name, attrs in field.items(): + fields.append(_create_vector_field(name, **attrs)) + elif key.upper() == "GEO": + for name, attrs in field.items(): + fields.append(GeoField(name, **attrs)) + elif key.upper() == "TEXT": + for name, attrs in field.items(): + fields.append(TextField(name, **attrs)) + elif key.upper() == "NUMERIC": + for name, attrs in field.items(): + fields.append(NumericField(name, **attrs)) + else: + logger.error(f"Invalid field type: {key}") + raise ValueError(f"Invalid field type: {key}") + return fields + + +def _create_vector_field( + name: str, + dims: int, + algorithm: str = "FLAT", + datatype: str = "FLOAT32", + distance_metric: str = "COSINE", + initial_cap: int = 1000000, + block_size: int = 1000, + m: int = 16, + ef_construction: int = 200, + ef_runtime: int = 10, + epsilon: float = 0.8, +): + """Create a RediSearch VectorField. + + Args: + name: The name of the field. + algorithm: The algorithm used to index the vector. + dims: The dimensionality of the vector. + datatype: The type of the vector. default: FLOAT32 + distance_metric: The distance metric used to compare vectors. + initial_cap: The initial capacity of the index. + block_size: The block size of the index. + m: The number of outgoing edges in the HNSW graph. + ef_construction: Number of maximum allowed potential outgoing edges + candidates for each node in the graph, during the graph building. + ef_runtime: The umber of maximum top candidates to hold during the KNN search + + returns: + A RediSearch VectorField. + """ + if algorithm.upper() == "HNSW": + return VectorField( + name, + "HNSW", + { + "TYPE": datatype.upper(), + "DIM": dims, + "DISTANCE_METRIC": distance_metric.upper(), + "INITIAL_CAP": initial_cap, + "M": m, + "EF_CONSTRUCTION": ef_construction, + "EF_RUNTIME": ef_runtime, + "EPSILON": epsilon, + }, + ) + else: + return VectorField( + name, + "FLAT", + { + "TYPE": datatype.upper(), + "DIM": dims, + "DISTANCE_METRIC": distance_metric.upper(), + "INITIAL_CAP": initial_cap, + "BLOCK_SIZE": block_size, + }, + ) diff --git a/redisvl/utils/__init__.py b/redisvl/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/redisvl/utils/connection.py b/redisvl/utils/connection.py new file mode 100644 index 00000000..e15575c6 --- /dev/null +++ b/redisvl/utils/connection.py @@ -0,0 +1,8 @@ +def get_async_redis_connection(host: str, port: int, password: str = None): + # TODO add username and ACL/TCL support + from redis.asyncio import Redis + + connection_args = {"host": host, "port": port} + if password: + connection_args.update({"password": password}) + return Redis(**connection_args) diff --git a/redisvl/utils/log.py b/redisvl/utils/log.py new file mode 100644 index 00000000..ec574b49 --- /dev/null +++ b/redisvl/utils/log.py @@ -0,0 +1,21 @@ +import logging +import sys + +import coloredlogs + +# constants for logging +coloredlogs.DEFAULT_DATE_FORMAT = "%H:%M:%S" +coloredlogs.DEFAULT_LOG_FORMAT = ( + "%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s" +) + + +def get_logger(name, log_level="info", fmt=None): + """Return a logger instance""" + + # Use file name if logger is in debug mode + name = "RedisVL" if log_level == "debug" else name + + logger = logging.getLogger(name) + coloredlogs.install(level=log_level, logger=logger, fmt=fmt, stream=sys.stdout) + return logger diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..f052f902 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +black>=20.8b1 +isort>=5.6.4 +pylint>=2.6.0 +pytest>=6.0.0 +pytest-cov>=2.10.1 +pytest-asyncio \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b4a9bb1a..fa992b1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ -numpy==1.23.2 -pandas==1.5.0 -redis==4.3.4 +numpy>=1.23.2 +pandas>=1.5.0 +redis>=4.3.4 +pyyaml +coloredlogs diff --git a/data/embeddings.pkl b/sample-data/pandas-sample.pkl similarity index 52% rename from data/embeddings.pkl rename to sample-data/pandas-sample.pkl index 13f2856a..f86a4774 100644 Binary files a/data/embeddings.pkl and b/sample-data/pandas-sample.pkl differ diff --git a/sample-data/sample.yml b/sample-data/sample.yml new file mode 100644 index 00000000..5f2f7958 --- /dev/null +++ b/sample-data/sample.yml @@ -0,0 +1,20 @@ + + +index: + name: sample + storage_type: hash + prefix: "vector:" + key_field: "id" + +fields: + tag: + categories: + separator: "|" + year: + separator: "|" + vector: + vector: + datatype: "float32" + algorithm: "flat" + dims: 768 + distance_metric: "cosine" diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..e03d0033 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup + +# function to read in requirements.txt to into a python list +def read_requirements(): + with open("requirements.txt") as f: + requirements = f.read().splitlines() + return requirements + +def read_dev_requirements(): + with open("requirements-dev.txt") as f: + requirements = f.read().splitlines() + return requirements + +setup( + name="redisvl", + description="Vector loading utility for Redis vector search", + license="BSD-3-Clause", + version="0.1.0", + python_requires=">=3.6", + install_requires=read_requirements(), + extras_require={"dev": read_dev_requirements()}, + packages=["redisvl"], + zip_safe=False, + entry_points={ + "console_scripts": [ + "redisvl = redisvl.cli.__init__:main" + ] + } +) \ No newline at end of file diff --git a/tests/test_simple.py b/tests/test_simple.py new file mode 100644 index 00000000..cf77b77b --- /dev/null +++ b/tests/test_simple.py @@ -0,0 +1,98 @@ +import asyncio +import time +from pprint import pprint + +import numpy as np +import pandas as pd +import pytest + +from redisvl.index import SearchIndex +from redisvl.load import concurrent_store_as_hash +from redisvl.query import create_vector_query +from redisvl.utils.connection import get_async_redis_connection + +data = pd.DataFrame( + { + "users": ["john", "mary", "joe"], + "age": [1, 2, 3], + "job": ["engineer", "doctor", "dentist"], + "credit_score": ["high", "low", "medium"], + "user_embedding": [ + np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), + ], + } +) + +schema = { + "index": { + "name": "user_index", + "prefix": "user:", + "key_field": "users", + "storage_type": "hash", + }, + "fields": { + "tag": {"credit_score": {}}, + "text": {"job": {}}, + "numeric": {"age": {}}, + "vector": { + "user_embedding": { + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + } + }, + }, +} + + +@pytest.mark.asyncio +async def test_simple(async_redis): + index = SearchIndex.from_dict(async_redis, schema) + + await concurrent_store_as_hash( + data.to_dict(orient="records"), + 5, + index.key_field, + index.prefix, + index.redis_conn, + ) + await index.create() + # add assertions here + + # wait for indexing to happen on server side + time.sleep(1) + + query = create_vector_query( + ["users", "age", "job", "credit_score", "vector_score"], + number_of_results=3, + vector_field_name="user_embedding", + ) + + query_vector = np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes() + results = await async_redis.ft(index.index_name).search( + query, query_params={"vector": query_vector} + ) + + # make sure correct users returned + # users = list(results.docs) + # print(len(users)) + users = [doc for doc in results.docs] + assert users[0].users in ["john", "mary"] + assert users[1].users in ["john", "mary"] + + # make sure vector scores are correct + # query vector and first two are the same vector. + # third is different (hence should be positive difference) + assert float(users[0].vector_score) == 0.0 + assert float(users[1].vector_score) == 0.0 + assert float(users[2].vector_score) > 0 + + print() + for doc in results.docs: + print("Score:", doc.vector_score) + pprint(doc) + + await index.delete() diff --git a/utils/search_index.py b/utils/search_index.py deleted file mode 100644 index 17ee9f37..00000000 --- a/utils/search_index.py +++ /dev/null @@ -1,133 +0,0 @@ -import re - -from redis.asyncio import Redis -from redis.commands.search.query import Query -from redis.commands.search.indexDefinition import IndexDefinition, IndexType -from typing import Optional, Pattern - - -class TokenEscaper: - """ - Escape punctuation within an input string. Taken from RedisOM Python. - """ - # Characters that RediSearch requires us to escape during queries. - # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization - DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" - - def __init__(self, escape_chars_re: Optional[Pattern] = None): - if escape_chars_re: - self.escaped_chars_re = escape_chars_re - else: - self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) - - def escape(self, value: str) -> str: - def escape_symbol(match): - value = match.group(0) - return f"\\{value}" - - return self.escaped_chars_re.sub(escape_symbol, value) - -class SearchIndex: - """ - SearchIndex is used to wrap and capture all information - and actions applied to a RediSearch index including creation, - manegement, and query construction. - """ - escaper = TokenEscaper() - - def __init__(self, index_name: str, redis_conn: Redis): - self.index_name = index_name - self.redis_conn = redis_conn - - async def create( - self, - *fields, - prefix: str - ): - # Create Index - await self.redis_conn.ft(self.index_name).create_index( - fields = fields, - definition= IndexDefinition(prefix=[prefix], index_type=IndexType.HASH) - ) - - async def delete(self): - await self.redis_conn.ft(self.index_name).dropindex(delete_documents=True) - - def process_tags(self, categories: list, years: list) -> str: - """ - Helper function to process tags data. TODO - factor this - out so it's agnostic to the name of the field. - - Args: - categories (list): List of categories. - years (list): List of years. - - Returns: - str: RediSearch tag query string. - """ - tag = "(" - if years: - years = "|".join([self.escaper.escape(year) for year in years]) - tag += f"(@year:{{{years}}})" - if categories: - categories = "|".join([self.escaper.escape(cat) for cat in categories]) - if tag: - tag += f" (@categories:{{{categories}}})" - else: - tag += f"(@categories:{{{categories}}})" - tag += ")" - # if no tags are selected - if len(tag) < 3: - tag = "*" - return tag - - def vector_query( - self, - categories: list, - years: list, - search_type: str="KNN", - number_of_results: int=20 - ) -> Query: - """ - Create a RediSearch query to perform hybrid vector and tag based searches. - - - Args: - categories (list): List of categories. - years (list): List of years. - search_type (str, optional): Style of search. Defaults to "KNN". - number_of_results (int, optional): How many results to fetch. Defaults to 20. - - Returns: - Query: RediSearch Query - - """ - # Parse tags to create query - tag_query = self.process_tags(categories, years) - base_query = f'{tag_query}=>[{search_type} {number_of_results} @vector $vec_param AS vector_score]' - return Query(base_query)\ - .sort_by("vector_score")\ - .paging(0, number_of_results)\ - .return_fields("paper_id", "paper_pk", "vector_score")\ - .dialect(2) - - def count_query( - self, - years: list, - categories: list - ) -> Query: - """ - Create a RediSearch query to count available documents. - - Args: - categories (list): List of categories. - years (list): List of years. - - Returns: - Query: RediSearch Query - """ - # Parse tags to create query - tag_query = self.process_tags(categories, years) - return Query(f'{tag_query}')\ - .no_content()\ - .dialect(2)