Skip to content

Commit 9611b33

Browse files
committed
triplet-ext-script
1 parent 9b4a44d commit 9611b33

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed

llvm/docs/CommandGuide/llvm-ir2vec.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ embedding training (see
5050
<https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch?tab=readme-ov-file#data-format>
5151
for details).
5252

53+
See `llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py` for more details on how
54+
these two modes are used to generate the triplets and entity mappings.
55+
5356
Triplet Generation Mode
5457
~~~~~~~~~~~~~~~~~~~~~~~
5558

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
"""IR2Vec Triplet Generator
5+
6+
Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
7+
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
8+
files: entity2id.txt, relation2id.txt, and train2id.txt.
9+
10+
Usage:
11+
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
12+
"""
13+
14+
import argparse
15+
import logging
16+
import os
17+
import random
18+
import subprocess
19+
import sys
20+
from concurrent.futures import ThreadPoolExecutor, as_completed
21+
from pathlib import Path
22+
from typing import List, Set, Tuple
23+
24+
# Configuration
25+
OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"]
26+
DEFAULT_MAX_WORKERS = 100
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class TripletResult:
32+
"""Result from processing a single LLVM IR file"""
33+
34+
__slots__ = ["triplets", "max_relation"]
35+
36+
def __init__(self, triplets: Set[str], max_relation: int):
37+
self.triplets = triplets
38+
self.max_relation = max_relation
39+
40+
41+
class IR2VecTripletGenerator:
42+
"""Main class for generating IR2Vec triplets"""
43+
44+
def __init__(
45+
self,
46+
llvm_build_dir: Path,
47+
num_optimizations: int,
48+
output_dir: Path,
49+
max_workers: int = DEFAULT_MAX_WORKERS,
50+
):
51+
self.llvm_build_dir = llvm_build_dir
52+
self.num_optimizations = num_optimizations
53+
self.output_dir = output_dir
54+
self.max_workers = max_workers
55+
56+
# Tool paths
57+
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
58+
self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec")
59+
60+
self._validate_setup()
61+
62+
# Create output directory if it doesn't exist
63+
self.output_dir.mkdir(parents=True, exist_ok=True)
64+
65+
def _validate_setup(self):
66+
"""Validate that all required tools and paths exist"""
67+
if not self.llvm_build_dir.exists():
68+
raise FileNotFoundError(
69+
f"LLVM build directory not found: {self.llvm_build_dir}"
70+
)
71+
72+
if not os.path.isfile(self.opt_binary) or not os.access(
73+
self.opt_binary, os.X_OK
74+
):
75+
raise FileNotFoundError(
76+
f"opt binary not found or not executable: {self.opt_binary}"
77+
)
78+
79+
if not os.path.isfile(self.ir2vec_binary) or not os.access(
80+
self.ir2vec_binary, os.X_OK
81+
):
82+
raise FileNotFoundError(
83+
f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}"
84+
)
85+
86+
if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
87+
raise ValueError(
88+
f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
89+
)
90+
91+
def _select_optimization_levels(self) -> List[str]:
92+
"""Select unique random optimization levels"""
93+
return random.sample(OPT_LEVELS, self.num_optimizations)
94+
95+
def _process_single_file(self, input_file: Path) -> TripletResult:
96+
"""Process a single LLVM IR file with multiple optimization levels"""
97+
all_triplets = set()
98+
max_relation = 1
99+
opt_levels = self._select_optimization_levels()
100+
101+
for opt_level in opt_levels:
102+
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
103+
if triplets:
104+
all_triplets.update(triplets)
105+
max_relation = max(max_relation, file_max_relation)
106+
logger.debug(
107+
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
108+
)
109+
110+
return TripletResult(all_triplets, max_relation)
111+
112+
def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]:
113+
"""Run opt | llvm-ir2vec pipeline using subprocess pipes."""
114+
try:
115+
# Run opt first
116+
opt_proc = subprocess.Popen(
117+
[self.opt_binary, f"-{opt_level}", str(input_file), "-o", "-"],
118+
stdout=subprocess.PIPE,
119+
stderr=subprocess.PIPE,
120+
text=True
121+
)
122+
123+
# Run llvm-ir2vec with opt's output as input
124+
ir2vec_proc = subprocess.Popen(
125+
[self.ir2vec_binary, "--mode=triplets", "-", "-o", "-"],
126+
stdin=opt_proc.stdout,
127+
stdout=subprocess.PIPE,
128+
stderr=subprocess.PIPE,
129+
text=True
130+
)
131+
132+
opt_proc.stdout.close()
133+
stdout, _ = ir2vec_proc.communicate()
134+
opt_proc.wait()
135+
136+
# Check if either process failed
137+
if opt_proc.returncode != 0 or ir2vec_proc.returncode != 0:
138+
return set(), 1
139+
140+
return self._parse_triplet_output(stdout)
141+
except (subprocess.SubprocessError, OSError):
142+
return set(), 1
143+
144+
def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
145+
"""Parse triplet output and extract max relation"""
146+
if not output.strip():
147+
return set(), 1
148+
149+
lines = output.strip().split("\n")
150+
max_relation = 1
151+
152+
# Extract max relation from metadata line
153+
if lines and lines[0].startswith("MAX_RELATION="):
154+
max_relation = int(lines[0].split("=")[1])
155+
lines = lines[1:]
156+
157+
# Remove duplicate triplets by converting to a set
158+
return set(lines), max_relation
159+
160+
def generate_triplets(self, file_list: Path) -> None:
161+
"""Main method to generate triplets from a list of LLVM IR files"""
162+
input_files = self._read_file_list(file_list)
163+
logger.info(
164+
f"Processing {len(input_files)} files with {self.num_optimizations} "
165+
f"optimization levels using {self.max_workers} workers"
166+
)
167+
168+
all_triplets = set()
169+
global_max_relation = 1
170+
171+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
172+
future_to_file = {
173+
executor.submit(self._process_single_file, file): file
174+
for file in input_files
175+
}
176+
177+
for future in as_completed(future_to_file):
178+
try:
179+
result = future.result()
180+
all_triplets.update(result.triplets)
181+
global_max_relation = max(global_max_relation, result.max_relation)
182+
except (subprocess.SubprocessError, OSError, ValueError) as e:
183+
file_path = future_to_file[future]
184+
logger.error(f"Error processing {file_path}: {e}")
185+
186+
self._generate_output_files(all_triplets, global_max_relation)
187+
logger.info("Processing completed successfully")
188+
189+
def _read_file_list(self, file_list: Path) -> List[Path]:
190+
"""Read and validate the list of input files"""
191+
input_files = []
192+
with open(file_list, "r") as f:
193+
for line_num, line in enumerate(f, 1):
194+
if line := line.strip():
195+
file_path = Path(line)
196+
if file_path.exists():
197+
input_files.append(file_path)
198+
else:
199+
logger.warning(f"File not found (line {line_num}): {file_path}")
200+
201+
if not input_files:
202+
raise ValueError("No valid input files found")
203+
return input_files
204+
205+
def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None:
206+
"""Generate the final output files"""
207+
logger.info(f"Generating output files with {len(all_triplets)} unique triplets")
208+
209+
# Write all output files -- train2id.txt, entity2id.txt, relation2id.txt
210+
train2id_file = os.path.join(self.output_dir, "train2id.txt")
211+
entity2id_file = os.path.join(self.output_dir, "entity2id.txt")
212+
relation2id_file = os.path.join(self.output_dir, "relation2id.txt")
213+
214+
with open(train2id_file, "w") as f:
215+
f.write(f"{len(all_triplets)}\n")
216+
f.writelines(f"{triplet}\n" for triplet in all_triplets)
217+
218+
self._generate_entity2id(entity2id_file)
219+
self._generate_relation2id(relation2id_file, max_relation)
220+
221+
def _generate_entity2id(self, output_file: Path) -> None:
222+
"""Generate entity2id.txt using llvm-ir2vec"""
223+
subprocess.run(
224+
[str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)],
225+
check=True,
226+
capture_output=True,
227+
)
228+
229+
def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
230+
"""Generate relation2id.txt from max relation"""
231+
max_relation = max(max_relation, 1) # At least Type and Next relations
232+
num_relations = max_relation + 1
233+
234+
with open(output_file, "w") as f:
235+
f.write(f"{num_relations}\n")
236+
f.write("Type\t0\n")
237+
f.write("Next\t1\n")
238+
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
239+
240+
241+
def main():
242+
"""Main entry point"""
243+
parser = argparse.ArgumentParser(
244+
description="Generate IR2Vec triplets from LLVM IR files",
245+
formatter_class=argparse.RawDescriptionHelpFormatter,
246+
)
247+
248+
parser.add_argument(
249+
"llvm_build_dir", type=Path, help="Path to LLVM build directory"
250+
)
251+
parser.add_argument(
252+
"num_optimizations",
253+
type=int,
254+
help="Number of optimization levels to apply (1-6)",
255+
)
256+
parser.add_argument(
257+
"ll_file_list",
258+
type=Path,
259+
help="File containing list of LLVM IR files to process",
260+
)
261+
parser.add_argument(
262+
"output_dir", type=Path, help="Output directory for generated files"
263+
)
264+
parser.add_argument(
265+
"-j",
266+
"--max-workers",
267+
type=int,
268+
default=DEFAULT_MAX_WORKERS,
269+
help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})",
270+
)
271+
parser.add_argument(
272+
"-v", "--verbose", action="store_true", help="Enable debug logging"
273+
)
274+
parser.add_argument(
275+
"-q", "--quiet", action="store_true", help="Suppress all output except errors"
276+
)
277+
278+
args = parser.parse_args()
279+
280+
# Configure logging
281+
level = (
282+
logging.ERROR
283+
if args.quiet
284+
else (logging.DEBUG if args.verbose else logging.INFO)
285+
)
286+
logging.basicConfig(
287+
level=level,
288+
format="[%(asctime)s] %(levelname)s: %(message)s",
289+
datefmt="%H:%M:%S",
290+
)
291+
292+
generator = IR2VecTripletGenerator(
293+
args.llvm_build_dir,
294+
args.num_optimizations,
295+
args.output_dir,
296+
args.max_workers,
297+
)
298+
generator.generate_triplets(args.ll_file_list)
299+
300+
301+
if __name__ == "__main__":
302+
main()

0 commit comments

Comments
 (0)