1+ import itertools
2+ import h5py
3+ import json
4+ import logging
5+ import multiprocessing as mp
6+ import pandas as pd
7+ from pathlib import Path
8+ import sqlite3
9+
10+ from sklearn .multiclass import OutputCodeClassifier
11+ from sql_templates import (
12+ edges_partitioned ,
13+ edgelist_cte_mapper ,
14+ remap_relns ,
15+ partitioned_mapped_entities ,
16+ type_tmp_table
17+ )
18+ import sys
19+ import time
20+
21+ logging .basicConfig (
22+ format = '%(process)d-%(levelname)s-%(message)s' ,
23+ stream = sys .stdout ,
24+ level = logging .DEBUG
25+ )
26+
27+ def remap_relationships (conn ):
28+ """
29+ A function to remap relationships using SQL queries.
30+ """
31+ logging .info ("Remapping relationships" )
32+ start = time .time ()
33+ logging .debug (f"Running query: { remap_relns } \n " )
34+ conn .executescript (remap_relns )
35+
36+ query = """
37+ select *
38+ from reln_map
39+ """
40+ logging .debug (f"Running query: { query } \n " )
41+ rels = pd .read_sql_query (query , conn )
42+ end = time .time ()
43+ logging .info (f"Remapped relationships in { end - start } s" )
44+ return rels
45+
46+
47+ def remap_entities (conn , entity2partitions ):
48+ """
49+ A function to remap entities with partitioning using SQL queries.
50+
51+ This function is complicated because the partitions have to be
52+ constructed first, and then we have to generate ordinal mappings of
53+ entity ids. These mappings will be used to generate buckets of edges
54+ for training and then for mapping our edges back to their original
55+ ids for use in downstream tasks.
56+ """
57+ logging .info ("Remapping entities" )
58+ start = time .time ()
59+ query = ""
60+ for entity , npartitions in entity2partitions .items ():
61+ query = type_tmp_table .format (type = entity , nparts = npartitions )
62+
63+ for i in range (npartitions ):
64+ query += partitioned_mapped_entities .format (type = entity , n = i )
65+ logging .debug (f"Running query: { query } " )
66+ conn .executescript (query )
67+ end = time .time ()
68+ logging .info (f"Remapped entities in { end - start } s" )
69+
70+
71+ def generate_ctes (lhs_part , rhs_part , rels , entity2partitions ):
72+ """
73+ This function generates the sub-table CTEs that help us generate
74+ the completed edgelist.
75+ """
76+ nctes = 0
77+ ctes = """
78+ with cte_0 as (
79+ """
80+ first = True
81+ for _ , r in rels .iterrows ():
82+ if lhs_part >= entity2partitions [r ['source_type' ]]:
83+ continue
84+ if rhs_part >= entity2partitions [r ['destination_type' ]]:
85+ continue
86+ if not first :
87+ ctes += f", cte_{ nctes } as ("
88+ ctes += edgelist_cte_mapper .format (
89+ rel_name = r ['id' ],
90+ lhs_type = r ['source_type' ],
91+ rhs_type = r ['destination_type' ],
92+ i = lhs_part ,
93+ j = rhs_part ,
94+ )
95+ ctes += ")"
96+
97+ nctes += 1
98+ first = False
99+ return nctes , ctes
100+
101+
102+ def generate_unions (nctes ):
103+ """
104+ This function is just a helper function for
105+ generating the final edge list tables.
106+ """
107+ subquery = ""
108+ first = True
109+ for i in range (nctes ):
110+ if not first :
111+ subquery += "\t union\n "
112+ subquery += f"\t select * from cte_{ i } \n "
113+ first = False
114+ return subquery
115+
116+
117+ def remap_edges (conn , rels , entity2partitions ):
118+ """
119+ A function to remap all edges to ordinal IDs
120+ according to their type.
121+ """
122+ logging .info ("Remapping edges" )
123+ start = time .time ()
124+
125+ nentities_premap = conn .execute ("""
126+ select count(*) from edges
127+ ;
128+ """ ).fetchall ()[0 ][0 ]
129+
130+ query = ""
131+ NPARTS = max (entity2partitions .values ())
132+ for lhs_part in range (NPARTS ):
133+ for rhs_part in range (NPARTS ):
134+ nctes , ctes = generate_ctes (lhs_part , rhs_part , rels , entity2partitions )
135+ subquery = generate_unions (nctes )
136+ query += edges_partitioned .format (
137+ i = lhs_part ,
138+ j = rhs_part ,
139+ ctes = ctes ,
140+ tables = subquery
141+ )
142+
143+ logging .debug (f"Running query: { query } " )
144+ conn .executescript (query )
145+
146+ logging .debug ("Confirming that we didn't drop any edges." )
147+ nentities_postmap = 0
148+ for lhs_part in range (NPARTS ):
149+ for rhs_part in range (NPARTS ):
150+ nentities_postmap += conn .execute (f"""
151+ select count(*) from edges_{ lhs_part } _{ rhs_part }
152+ """ ).fetchall ()[0 ][0 ]
153+
154+ if nentities_postmap != nentities_premap :
155+ logging .warning ("DROPPED EDGES DURING REMAPPING." )
156+ logging .warning (f"We started with { nentities_premap } and finished with { nentities_postmap } " )
157+
158+ end = time .time ()
159+ logging .info (f"Remapped edges in { end - start } s" )
160+
161+
162+ def load_edges (fname , conn ):
163+ """
164+ A simple function to load the edges into the SQL table. It is
165+ assumed that we will have a file of the form:
166+ | source_id | source_type | relationship_name | destination_id | destination_type |
167+ """
168+ logging .info ("Loading edges" )
169+ start = time .time ()
170+ cur = conn .cursor ()
171+ cur .executescript ("""
172+ DROP TABLE IF EXISTS edges
173+ ;
174+
175+ CREATE TABLE edges (
176+ source_id INTEGER,
177+ source_type TEXT,
178+ destination_id INTEGER,
179+ destination_type TEXT,
180+ rel TEXT
181+ )
182+ """ )
183+
184+ edges = pd .read_csv (fname )
185+ edges .to_sql ('edges' , conn , if_exists = 'append' , index = False )
186+ end = time .time ()
187+ logging .info (f"Loading edges in { end - start } s" )
188+
189+
190+ def write_relations (outdir , rels , conn ):
191+ """
192+ A simple function to write the relevant relationship information out
193+ for training.
194+ """
195+ logging .info ("Writing relations for training" )
196+ start = time .time ()
197+ out = rels .sort_values ('graph_id' )['id' ].to_list ()
198+ with open (f'{ outdir } /dynamic_rel_names.json' , mode = 'w' ) as f :
199+ json .dump (out , f , indent = 4 )
200+ end = time .time ()
201+ logging .info (f"Wrote relations in { end - start } s" )
202+
203+
204+ def write_single_edge (work_packet ):
205+ """
206+ A function to write out a single edge-lists in the format that
207+ PyTorch BigGraph expects.
208+
209+ The work packet is expected to come contain information about
210+ the lhs and rhs partitions for these edges, the directory
211+ where we should put this information, and the database
212+ connection that we should use.
213+ """
214+ lhs_part , rhs_part , outdir , conn = work_packet
215+ query = f"""
216+ select *
217+ from edges_{ lhs_part } _{ rhs_part }
218+ ;
219+ """
220+ df = pd .read_sql_query (query , conn )
221+ print (query )
222+ out_name = f'{ outdir } /edges_{ lhs_part } _{ rhs_part } .h5'
223+ with h5py .File (out_name , mode = 'w' ) as f :
224+ # we need this for https://github.com/facebookresearch/PyTorch-BigGraph/blob/main/torchbiggraph/graph_storages.py#L400
225+ f .attrs ['format_version' ] = 1
226+ for dset , colname in [('lhs' , 'source_id' ), ('rhs' , 'destination_id' ), ('rel' , 'rel_id' )]:
227+ f .create_dataset (dset , dtype = 'i' , shape = (len (df ),), maxshape = (None , ))
228+ f [dset ][0 : len (df )] = df [colname ].tolist ()
229+
230+
231+ def write_edges (outdir , LHS_PARTS , RHS_PARTS , conn ):
232+ """
233+ A function to write out all edge-lists in the format
234+ that PyTorch BigGraph expects.
235+ """
236+ logging .info (f"Writing edges, { LHS_PARTS } , { RHS_PARTS } " )
237+ start = time .time ()
238+
239+ # I would write these using multiprocessing but SQLite connections
240+ # aren't pickelable, and I'd like to keep this simple
241+ worklist = list (itertools .product (range (LHS_PARTS ), range (RHS_PARTS ), ['training_data' ], [conn ]))
242+ for w in worklist :
243+ write_single_edge (w )
244+
245+ end = time .time ()
246+ logging .info (f"Wrote edges in { end - start } s" )
247+
248+
249+ def write_entities (outdir , entity2partitions , conn ):
250+ """
251+ A function to write out all of the training relevant
252+ entity information that PyTorch BigGraph expects
253+ """
254+ logging .info ("Writing entites for training" )
255+ start = time .time ()
256+ for entity_type , nparts in entity2partitions .items ():
257+ for i in range (nparts ):
258+ query = f"""
259+ select count(*)
260+ from { entity_type } _ids_map_{ i }
261+ """
262+ sz = conn .execute (query ).fetchall ()[0 ][0 ]
263+ with open (f'{ outdir } /entity_count_{ entity_type } _id_{ i } .txt' , mode = 'w' ) as f :
264+ f .write (f"{ sz } \n " )
265+ end = time .time ()
266+ logging .info (f"Wrote entites in { end - start } s" )
267+
268+
269+ def write_training_data (outdir , rels , entity2partitions , conn ):
270+ """
271+ A function to write out all of the training relevant
272+ information that PyTorch BigGraph expects
273+ """
274+ LHS_PARTS = 1
275+ RHS_PARTS = 1
276+ for i , r in rels .iterrows ():
277+ if entity2partitions [r ['source_type' ]] > LHS_PARTS :
278+ LHS_PARTS = entity2partitions [r ['source_type' ]]
279+ if entity2partitions [r ['destination_type' ]] > RHS_PARTS :
280+ RHS_PARTS = entity2partitions [r ['destination_type' ]]
281+
282+ write_relations (outdir , rels , conn )
283+ write_edges (rels , LHS_PARTS , RHS_PARTS , conn )
284+ write_entities (outdir , entity2partitions , conn )
285+
286+
287+ def main (NPARTS = 2 , edge_file_name = 'edges.csv' , outdir = 'training_data/' ):
288+ conn = sqlite3 .connect ("citationv2.db" )
289+ load_edges (edge_file_name , conn )
290+
291+ entity2partitions = {
292+ 'paper' : NPARTS ,
293+ 'year' : 1 ,
294+ }
295+
296+ rels = remap_relationships (conn )
297+ remap_entities (conn , entity2partitions )
298+ remap_edges (conn , rels , entity2partitions )
299+ out = Path (outdir ).mkdir (parents = True , exist_ok = True )
300+ write_training_data (out , rels , entity2partitions , conn )
301+
302+
303+ if __name__ == '__main__' :
304+ main ()
0 commit comments