1
+ import asyncio
2
+ import warnings
3
+ import argparse
4
+ import logging
5
+ import pickle
6
+ import typing as t
7
+ import numpy as np
8
+
9
+ from redis .asyncio import Redis
10
+ from data .schema import get_schema
11
+ from utils .search_index import SearchIndex
12
+
13
+
14
+ warnings .filterwarnings ("error" )
15
+
16
+ logging .basicConfig (
17
+ level = logging .INFO ,
18
+ format = "%(asctime)5s:%(filename)25s"
19
+ ":%(lineno)3s %(funcName)30s(): %(message)s" ,
20
+ )
21
+
22
+ def read_data (data_file : str ) -> t .List [dict ]:
23
+ """
24
+ Read dataset from a pickled dataframe (Pandas) file.
25
+ TODO -- add support for other input data types.
26
+
27
+ Args:
28
+ data_file (str): Path to the destination
29
+ of the input data file.
30
+
31
+ Returns:
32
+ t.List[dict]: List of Hash objects to insert to Redis.
33
+ """
34
+ logging .info (f"Reading dataset from file: { data_file } " )
35
+ with open (data_file , "rb" ) as f :
36
+ df = pickle .load (f )
37
+ return df .to_dict ("records" )
38
+
39
+ async def gather_with_concurrency (
40
+ * data ,
41
+ n : int ,
42
+ vector_field_name : str ,
43
+ prefix : str ,
44
+ redis_conn : Redis
45
+ ):
46
+ """
47
+ Gather and load the hashes into Redis using
48
+ async connections.
49
+
50
+ Args:
51
+ n (int): Max number of "concurrent" async connections.
52
+ vector_field_name (str): Vector field name in the dataframe.
53
+ prefix (str): Redis key prefix for all hashes in the search index.
54
+ redis_conn (Redis): Redis connection.
55
+ """
56
+ logging .info ("Loading dataset into Redis" )
57
+ semaphore = asyncio .Semaphore (n )
58
+ async def load (d : dict ):
59
+ async with semaphore :
60
+ d [vector_field_name ] = np .array (d [vector_field_name ], dtype = np .float32 ).tobytes ()
61
+ key = prefix + str (d ["id" ])
62
+ await redis_conn .hset (key , mapping = d )
63
+ # gather with concurrency
64
+ await asyncio .gather (* [load (d ) for d in data ])
65
+
66
+ async def load_all_data (
67
+ redis_conn : Redis ,
68
+ concurrency : int ,
69
+ prefix : str ,
70
+ vector_field_name : str ,
71
+ data_file : str ,
72
+ index_name : str
73
+ ):
74
+ """
75
+ Load all data.
76
+
77
+ Args:
78
+ redis_conn (Redis): Redis connection.
79
+ concurrency (int): Max number of "concurrent" async connections.
80
+ prefix (str): Redis key prefix for all hashes in the search index.
81
+ vector_field_name (str): Vector field name in the dataframe.
82
+ data_file (str): Path to the destination of the input data file.
83
+ index_name (str): Name of the RediSearch Index.
84
+ """
85
+ search_index = SearchIndex (
86
+ index_name = index_name ,
87
+ redis_conn = redis_conn
88
+ )
89
+
90
+ # Load from pickled dataframe file
91
+ data = read_data (data_file )
92
+
93
+ # Gather async
94
+ await gather_with_concurrency (
95
+ * data ,
96
+ n = concurrency ,
97
+ prefix = prefix ,
98
+ vector_field_name = vector_field_name ,
99
+ redis_conn = redis_conn
100
+ )
101
+
102
+ # Load schema
103
+ logging .info ("Processing RediSearch schema" )
104
+ schema = get_schema (len (data ))
105
+ await search_index .create (* schema , prefix = prefix )
106
+ logging .info ("All done. Data uploaded and RediSearch index created." )
107
+
108
+
109
+ async def main ():
110
+ # Parse script arguments
111
+ parser = argparse .ArgumentParser ()
112
+ parser .add_argument ("--host" , help = "Redis host" , type = str , default = "localhost" )
113
+ parser .add_argument ("-p" , "--port" , help = "Redis port" , type = int , default = 6379 )
114
+ parser .add_argument ("-a" , "--password" , help = "Redis password" , type = str , default = "" )
115
+ parser .add_argument ("-c" , "--concurrency" , type = int , default = 50 )
116
+ parser .add_argument ("-d" , "--data" , help = "Path to data file" , type = str , default = "data/embeddings.pkl" )
117
+ parser .add_argument ("--prefix" , help = "Key prefix for all hashes in the search index" , type = str , default = "vector:" )
118
+ parser .add_argument ("-v" , "--vector" , help = "Vector field name in df" , type = str , default = "vector" )
119
+ parser .add_argument ("-i" , "--index" , help = "Index name" , type = str , default = "index" )
120
+ args = parser .parse_args ()
121
+
122
+ # Create Redis Connection
123
+ connection_args = {
124
+ "host" : args .host ,
125
+ "port" : args .port
126
+ }
127
+ if args .password :
128
+ connection_args .update ({"password" : args .password })
129
+ redis_conn = Redis (** connection_args )
130
+
131
+ # Perform data loading
132
+ await load_all_data (
133
+ redis_conn = redis_conn ,
134
+ concurrency = args .concurrency ,
135
+ prefix = args .prefix ,
136
+ vector_field_name = args .vector ,
137
+ data_file = args .data ,
138
+ index_name = args .index
139
+ )
140
+
141
+
142
+ if __name__ == "__main__" :
143
+ asyncio .run (main ())
0 commit comments