Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit cd5b646

Browse files
author
Thomas Markovich
committed
Added the script to generate all PBG files from a SQL database
1 parent b181444 commit cd5b646

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
import itertools
2+
import h5py
3+
import json
4+
import logging
5+
import multiprocessing as mp
6+
import pandas as pd
7+
from pathlib import Path
8+
import sqlite3
9+
10+
from sklearn.multiclass import OutputCodeClassifier
11+
from sql_templates import (
12+
edges_partitioned,
13+
edgelist_cte_mapper,
14+
remap_relns,
15+
partitioned_mapped_entities,
16+
type_tmp_table
17+
)
18+
import sys
19+
import time
20+
21+
logging.basicConfig(
22+
format='%(process)d-%(levelname)s-%(message)s',
23+
stream = sys.stdout,
24+
level=logging.DEBUG
25+
)
26+
27+
def remap_relationships(conn):
28+
"""
29+
A function to remap relationships using SQL queries.
30+
"""
31+
logging.info("Remapping relationships")
32+
start = time.time()
33+
logging.debug(f"Running query: {remap_relns}\n")
34+
conn.executescript(remap_relns)
35+
36+
query = """
37+
select *
38+
from reln_map
39+
"""
40+
logging.debug(f"Running query: {query}\n")
41+
rels = pd.read_sql_query(query, conn)
42+
end = time.time()
43+
logging.info(f"Remapped relationships in {end - start}s")
44+
return rels
45+
46+
47+
def remap_entities(conn, entity2partitions):
48+
"""
49+
A function to remap entities with partitioning using SQL queries.
50+
51+
This function is complicated because the partitions have to be
52+
constructed first, and then we have to generate ordinal mappings of
53+
entity ids. These mappings will be used to generate buckets of edges
54+
for training and then for mapping our edges back to their original
55+
ids for use in downstream tasks.
56+
"""
57+
logging.info("Remapping entities")
58+
start=time.time()
59+
query = ""
60+
for entity, npartitions in entity2partitions.items():
61+
query = type_tmp_table.format(type=entity, nparts=npartitions)
62+
63+
for i in range(npartitions):
64+
query += partitioned_mapped_entities.format(type=entity, n=i)
65+
logging.debug(f"Running query: {query}")
66+
conn.executescript(query)
67+
end = time.time()
68+
logging.info(f"Remapped entities in {end - start}s")
69+
70+
71+
def generate_ctes(lhs_part, rhs_part, rels, entity2partitions):
72+
"""
73+
This function generates the sub-table CTEs that help us generate
74+
the completed edgelist.
75+
"""
76+
nctes = 0
77+
ctes = """
78+
with cte_0 as (
79+
"""
80+
first = True
81+
for _ , r in rels.iterrows():
82+
if lhs_part >= entity2partitions[r['source_type']]:
83+
continue
84+
if rhs_part >= entity2partitions[r['destination_type']]:
85+
continue
86+
if not first:
87+
ctes += f", cte_{nctes} as ("
88+
ctes += edgelist_cte_mapper.format(
89+
rel_name=r['id'],
90+
lhs_type=r['source_type'],
91+
rhs_type=r['destination_type'],
92+
i = lhs_part,
93+
j = rhs_part,
94+
)
95+
ctes += ")"
96+
97+
nctes += 1
98+
first = False
99+
return nctes, ctes
100+
101+
102+
def generate_unions(nctes):
103+
"""
104+
This function is just a helper function for
105+
generating the final edge list tables.
106+
"""
107+
subquery = ""
108+
first = True
109+
for i in range(nctes):
110+
if not first:
111+
subquery += "\tunion\n"
112+
subquery += f"\tselect * from cte_{i}\n"
113+
first = False
114+
return subquery
115+
116+
117+
def remap_edges(conn, rels, entity2partitions):
118+
"""
119+
A function to remap all edges to ordinal IDs
120+
according to their type.
121+
"""
122+
logging.info("Remapping edges")
123+
start = time.time()
124+
125+
nentities_premap = conn.execute("""
126+
select count(*) from edges
127+
;
128+
""").fetchall()[0][0]
129+
130+
query = ""
131+
NPARTS = max(entity2partitions.values())
132+
for lhs_part in range(NPARTS):
133+
for rhs_part in range(NPARTS):
134+
nctes, ctes = generate_ctes(lhs_part, rhs_part, rels, entity2partitions)
135+
subquery = generate_unions(nctes)
136+
query += edges_partitioned.format(
137+
i = lhs_part,
138+
j = rhs_part,
139+
ctes=ctes,
140+
tables=subquery
141+
)
142+
143+
logging.debug(f"Running query: {query}")
144+
conn.executescript(query)
145+
146+
logging.debug("Confirming that we didn't drop any edges.")
147+
nentities_postmap = 0
148+
for lhs_part in range(NPARTS):
149+
for rhs_part in range(NPARTS):
150+
nentities_postmap += conn.execute(f"""
151+
select count(*) from edges_{lhs_part}_{rhs_part}
152+
""").fetchall()[0][0]
153+
154+
if nentities_postmap != nentities_premap:
155+
logging.warning("DROPPED EDGES DURING REMAPPING.")
156+
logging.warning(f"We started with {nentities_premap} and finished with {nentities_postmap}")
157+
158+
end = time.time()
159+
logging.info(f"Remapped edges in {end - start}s")
160+
161+
162+
def load_edges(fname, conn):
163+
"""
164+
A simple function to load the edges into the SQL table. It is
165+
assumed that we will have a file of the form:
166+
| source_id | source_type | relationship_name | destination_id | destination_type |
167+
"""
168+
logging.info("Loading edges")
169+
start = time.time()
170+
cur = conn.cursor()
171+
cur.executescript("""
172+
DROP TABLE IF EXISTS edges
173+
;
174+
175+
CREATE TABLE edges (
176+
source_id INTEGER,
177+
source_type TEXT,
178+
destination_id INTEGER,
179+
destination_type TEXT,
180+
rel TEXT
181+
)
182+
""")
183+
184+
edges = pd.read_csv(fname)
185+
edges.to_sql('edges', conn, if_exists='append', index=False)
186+
end = time.time()
187+
logging.info(f"Loading edges in {end - start}s")
188+
189+
190+
def write_relations(outdir, rels, conn):
191+
"""
192+
A simple function to write the relevant relationship information out
193+
for training.
194+
"""
195+
logging.info("Writing relations for training")
196+
start = time.time()
197+
out = rels.sort_values('graph_id')['id'].to_list()
198+
with open(f'{outdir}/dynamic_rel_names.json', mode='w') as f:
199+
json.dump(out, f, indent=4)
200+
end = time.time()
201+
logging.info(f"Wrote relations in {end - start}s")
202+
203+
204+
def write_single_edge(work_packet):
205+
"""
206+
A function to write out a single edge-lists in the format that
207+
PyTorch BigGraph expects.
208+
209+
The work packet is expected to come contain information about
210+
the lhs and rhs partitions for these edges, the directory
211+
where we should put this information, and the database
212+
connection that we should use.
213+
"""
214+
lhs_part, rhs_part, outdir, conn = work_packet
215+
query = f"""
216+
select *
217+
from edges_{lhs_part}_{rhs_part}
218+
;
219+
"""
220+
df = pd.read_sql_query(query, conn)
221+
print(query)
222+
out_name = f'{outdir}/edges_{lhs_part}_{rhs_part}.h5'
223+
with h5py.File(out_name, mode='w') as f:
224+
# we need this for https://github.com/facebookresearch/PyTorch-BigGraph/blob/main/torchbiggraph/graph_storages.py#L400
225+
f.attrs['format_version'] = 1
226+
for dset, colname in [('lhs', 'source_id'), ('rhs', 'destination_id'), ('rel', 'rel_id')]:
227+
f.create_dataset(dset, dtype='i', shape=(len(df),), maxshape=(None, ))
228+
f[dset][0 : len(df)] = df[colname].tolist()
229+
230+
231+
def write_edges(outdir, LHS_PARTS, RHS_PARTS, conn):
232+
"""
233+
A function to write out all edge-lists in the format
234+
that PyTorch BigGraph expects.
235+
"""
236+
logging.info(f"Writing edges, {LHS_PARTS}, {RHS_PARTS}")
237+
start = time.time()
238+
239+
# I would write these using multiprocessing but SQLite connections
240+
# aren't pickelable, and I'd like to keep this simple
241+
worklist = list(itertools.product(range(LHS_PARTS), range(RHS_PARTS), ['training_data'], [conn]))
242+
for w in worklist:
243+
write_single_edge(w)
244+
245+
end = time.time()
246+
logging.info(f"Wrote edges in {end - start}s")
247+
248+
249+
def write_entities(outdir, entity2partitions, conn):
250+
"""
251+
A function to write out all of the training relevant
252+
entity information that PyTorch BigGraph expects
253+
"""
254+
logging.info("Writing entites for training")
255+
start = time.time()
256+
for entity_type, nparts in entity2partitions.items():
257+
for i in range(nparts):
258+
query = f"""
259+
select count(*)
260+
from {entity_type}_ids_map_{i}
261+
"""
262+
sz = conn.execute(query).fetchall()[0][0]
263+
with open(f'{outdir}/entity_count_{entity_type}_id_{i}.txt', mode='w') as f:
264+
f.write(f"{sz}\n")
265+
end = time.time()
266+
logging.info(f"Wrote entites in {end - start}s")
267+
268+
269+
def write_training_data(outdir, rels, entity2partitions, conn):
270+
"""
271+
A function to write out all of the training relevant
272+
information that PyTorch BigGraph expects
273+
"""
274+
LHS_PARTS = 1
275+
RHS_PARTS = 1
276+
for i, r in rels.iterrows():
277+
if entity2partitions[r['source_type']] > LHS_PARTS:
278+
LHS_PARTS = entity2partitions[r['source_type']]
279+
if entity2partitions[r['destination_type']] > RHS_PARTS:
280+
RHS_PARTS = entity2partitions[r['destination_type']]
281+
282+
write_relations(outdir, rels, conn)
283+
write_edges(rels, LHS_PARTS, RHS_PARTS, conn)
284+
write_entities(outdir, entity2partitions, conn)
285+
286+
287+
def main(NPARTS=2, edge_file_name='edges.csv', outdir='training_data/'):
288+
conn = sqlite3.connect("citationv2.db")
289+
load_edges(edge_file_name, conn)
290+
291+
entity2partitions = {
292+
'paper': NPARTS,
293+
'year': 1,
294+
}
295+
296+
rels = remap_relationships(conn)
297+
remap_entities(conn, entity2partitions)
298+
remap_edges(conn, rels, entity2partitions)
299+
out = Path(outdir).mkdir(parents=True, exist_ok=True)
300+
write_training_data(out, rels, entity2partitions, conn)
301+
302+
303+
if __name__ == '__main__':
304+
main()

0 commit comments

Comments
 (0)