Skip to content

Commit 9c0ee05

Browse files
first commit
0 parents  commit 9c0ee05

File tree

11 files changed

+368
-0
lines changed

11 files changed

+368
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

CONTRIBUTING.md

Whitespace-only changes.

LICENSE

Whitespace-only changes.

README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# RediSearch Data Loader
2+
The purpose of this script is to assist in loading datasets to a RediSearch instance efficiently.
3+
4+
The project is brand new and will undergo improvements over time.
5+
6+
## Getting Started
7+
8+
### Requirements
9+
Install the Python requirements listed in `requirements.txt`.
10+
11+
```bash
12+
$ pip install -r requirements.txt
13+
```
14+
15+
### Data
16+
In order to run the script you need to have a dataset that contains your vectors and metadata.
17+
18+
>Currently, the data file must be a pickled pandas dataframe. Support for more data types will be included in future iterations.
19+
20+
### Schema
21+
Along with the dataset, you must update the dataset schema for RediSearch in [`data/schema.py`](data/schema.py).
22+
23+
### Running
24+
The `main.py` script provides an entrypoint with optional arguments to upload your dataset to a Redis server.
25+
26+
#### Usage
27+
```
28+
python main.py
29+
30+
-h, --help Show this help message and exit
31+
--host Redis host
32+
-p, --port Redis port
33+
-a, --password Redis password
34+
-c , --concurrency Amount of concurrency
35+
-d , --data Path to data file
36+
--prefix Key prefix for all hashes in the search index
37+
-v , --vector Vector field name in df
38+
-i , --index Index name
39+
```
40+
41+
#### Defaults
42+
43+
| Argument | Default |
44+
| ----------- | ----------- |
45+
| Host | `localhost` |
46+
| Port | `6379` |
47+
| Password | "" |
48+
| Concurrency | `50` |
49+
| Data (Path) | `data/embeddings.pkl` |
50+
| Prefix | `vector:` |
51+
| Vector (Field Name) | `vector` |
52+
| Index Name | `index` |
53+
54+
55+
#### Examples
56+
57+
Load to a local (default) redis server with a custom index name and with concurrency = 100:
58+
```bash
59+
$ python main.py -d data/embeddings.pkl -i myIndex -c 100
60+
```
61+
62+
Load to a cloud redis server with all other defaults:
63+
```bash
64+
$ python main.py -h {redis-host} -p {redis-port} -a {redis-password}
65+
```

data/__init__.py

Whitespace-only changes.

data/embeddings.pkl

85.6 MB
Binary file not shown.

data/schema.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from redis.commands.search.field import (
2+
TagField,
3+
VectorField
4+
)
5+
6+
# Build Schema
7+
def get_schema(size: int):
8+
return [
9+
# Tag fields
10+
TagField("categories", separator = "|"),
11+
TagField("year", separator = "|"),
12+
# Vector field (FLAT index with COSINE similarity)
13+
VectorField(
14+
"vector",
15+
"FLAT", {
16+
"TYPE": "FLOAT32",
17+
"DIM": 768,
18+
"DISTANCE_METRIC": "COSINE",
19+
"INITIAL_CAP": size,
20+
"BLOCK_SIZE": size
21+
}
22+
)
23+
]

main.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import asyncio
2+
import warnings
3+
import argparse
4+
import logging
5+
import pickle
6+
import typing as t
7+
import numpy as np
8+
9+
from redis.asyncio import Redis
10+
from data.schema import get_schema
11+
from utils.search_index import SearchIndex
12+
13+
14+
warnings.filterwarnings("error")
15+
16+
logging.basicConfig(
17+
level = logging.INFO,
18+
format = "%(asctime)5s:%(filename)25s"
19+
":%(lineno)3s %(funcName)30s(): %(message)s",
20+
)
21+
22+
def read_data(data_file: str) -> t.List[dict]:
23+
"""
24+
Read dataset from a pickled dataframe (Pandas) file.
25+
TODO -- add support for other input data types.
26+
27+
Args:
28+
data_file (str): Path to the destination
29+
of the input data file.
30+
31+
Returns:
32+
t.List[dict]: List of Hash objects to insert to Redis.
33+
"""
34+
logging.info(f"Reading dataset from file: {data_file}")
35+
with open(data_file, "rb") as f:
36+
df = pickle.load(f)
37+
return df.to_dict("records")
38+
39+
async def gather_with_concurrency(
40+
*data,
41+
n: int,
42+
vector_field_name: str,
43+
prefix: str,
44+
redis_conn: Redis
45+
):
46+
"""
47+
Gather and load the hashes into Redis using
48+
async connections.
49+
50+
Args:
51+
n (int): Max number of "concurrent" async connections.
52+
vector_field_name (str): Vector field name in the dataframe.
53+
prefix (str): Redis key prefix for all hashes in the search index.
54+
redis_conn (Redis): Redis connection.
55+
"""
56+
logging.info("Loading dataset into Redis")
57+
semaphore = asyncio.Semaphore(n)
58+
async def load(d: dict):
59+
async with semaphore:
60+
d[vector_field_name] = np.array(d[vector_field_name], dtype = np.float32).tobytes()
61+
key = prefix + str(d["id"])
62+
await redis_conn.hset(key, mapping = d)
63+
# gather with concurrency
64+
await asyncio.gather(*[load(d) for d in data])
65+
66+
async def load_all_data(
67+
redis_conn: Redis,
68+
concurrency: int,
69+
prefix: str,
70+
vector_field_name: str,
71+
data_file: str,
72+
index_name: str
73+
):
74+
"""
75+
Load all data.
76+
77+
Args:
78+
redis_conn (Redis): Redis connection.
79+
concurrency (int): Max number of "concurrent" async connections.
80+
prefix (str): Redis key prefix for all hashes in the search index.
81+
vector_field_name (str): Vector field name in the dataframe.
82+
data_file (str): Path to the destination of the input data file.
83+
index_name (str): Name of the RediSearch Index.
84+
"""
85+
search_index = SearchIndex(
86+
index_name = index_name,
87+
redis_conn = redis_conn
88+
)
89+
90+
# Load from pickled dataframe file
91+
data = read_data(data_file)
92+
93+
# Gather async
94+
await gather_with_concurrency(
95+
*data,
96+
n = concurrency,
97+
prefix = prefix,
98+
vector_field_name= vector_field_name,
99+
redis_conn = redis_conn
100+
)
101+
102+
# Load schema
103+
logging.info("Processing RediSearch schema")
104+
schema = get_schema(len(data))
105+
await search_index.create(*schema, prefix=prefix)
106+
logging.info("All done. Data uploaded and RediSearch index created.")
107+
108+
109+
async def main():
110+
# Parse script arguments
111+
parser = argparse.ArgumentParser()
112+
parser.add_argument("--host", help="Redis host", type=str, default="localhost")
113+
parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379)
114+
parser.add_argument("-a", "--password", help="Redis password", type=str, default="")
115+
parser.add_argument("-c", "--concurrency", type=int, default=50)
116+
parser.add_argument("-d", "--data", help="Path to data file", type=str, default="data/embeddings.pkl")
117+
parser.add_argument("--prefix", help="Key prefix for all hashes in the search index", type=str, default="vector:")
118+
parser.add_argument("-v", "--vector", help="Vector field name in df", type=str, default="vector")
119+
parser.add_argument("-i", "--index", help="Index name", type=str, default="index")
120+
args = parser.parse_args()
121+
122+
# Create Redis Connection
123+
connection_args = {
124+
"host": args.host,
125+
"port": args.port
126+
}
127+
if args.password:
128+
connection_args.update({"password": args.password})
129+
redis_conn = Redis(**connection_args)
130+
131+
# Perform data loading
132+
await load_all_data(
133+
redis_conn=redis_conn,
134+
concurrency=args.concurrency,
135+
prefix=args.prefix,
136+
vector_field_name=args.vector,
137+
data_file=args.data,
138+
index_name=args.index
139+
)
140+
141+
142+
if __name__ == "__main__":
143+
asyncio.run(main())

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy==1.23.2
2+
pandas==1.5.0
3+
redis==4.3.4

utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)