Skip to content

Commit 7121f5c

Browse files
committed
Bulk generator: Use the 'Project' type throughout the file.
1 parent 7c89d6d commit 7121f5c

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Project(TypedDict):
3939
"""
4040

4141
name: str
42-
git_repo: str
42+
git_repo: NotRequired[str]
4343
git_tag: NotRequired[str]
4444

4545

@@ -185,7 +185,7 @@ def build_database(
185185
return database_dir
186186

187187

188-
def generate_models(args, name: str, database_dir: str) -> None:
188+
def generate_models(args, project: Project, database_dir: str) -> None:
189189
"""
190190
Generate models for a project.
191191
@@ -194,6 +194,7 @@ def generate_models(args, name: str, database_dir: str) -> None:
194194
name: The name of the project.
195195
database_dir: Path to the CodeQL database.
196196
"""
197+
name = project["name"]
197198

198199
generator = mad.Generator(args.lang)
199200
generator.generateSinks = args.with_sinks
@@ -205,7 +206,7 @@ def generate_models(args, name: str, database_dir: str) -> None:
205206

206207
def build_databases_from_projects(
207208
language: str, extractor_options, projects: List[Project]
208-
) -> List[tuple[str, str | None]]:
209+
) -> List[tuple[Project, str | None]]:
209210
"""
210211
Build databases for all projects in parallel.
211212
@@ -225,7 +226,7 @@ def build_databases_from_projects(
225226
print("\n=== Building databases ===")
226227
database_results = [
227228
(
228-
project["name"],
229+
project,
229230
build_database(language, extractor_options, project, project_dir),
230231
)
231232
for project, project_dir in project_dirs
@@ -290,8 +291,8 @@ def pretty_name_from_artifact_name(artifact_name: str) -> str:
290291

291292

292293
def download_dca_databases(
293-
experiment_name: str, pat: str, projects
294-
) -> List[tuple[str, str | None]]:
294+
experiment_name: str, pat: str, projects: List[Project]
295+
) -> List[tuple[Project, str | None]]:
295296
"""
296297
Download databases from a DCA experiment.
297298
Args:
@@ -308,7 +309,7 @@ def download_dca_databases(
308309
pat,
309310
)
310311
targets = response["targets"]
311-
for target, data in targets.items():
312+
for data in targets.values():
312313
downloads = data["downloads"]
313314
analyzed_database = downloads["analyzed_database"]
314315
artifact_name = analyzed_database["artifact_name"]
@@ -349,20 +350,21 @@ def download_dca_databases(
349350
tar_ref.extractall(artifact_unzipped_location)
350351
database_results.append(
351352
(
352-
pretty_name,
353+
{"name": pretty_name},
353354
os.path.join(
354355
artifact_unzipped_location, remove_extension(entry)
355356
),
356357
)
357358
)
359+
358360
print(f"\n=== Extracted {len(database_results)} databases ===")
359361

360362
def compare(a, b):
361363
a_index = next(
362-
i for i, project in enumerate(projects) if project["name"] == a[0]
364+
i for i, project in enumerate(projects) if project["name"] == a[0]["name"]
363365
)
364366
b_index = next(
365-
i for i, project in enumerate(projects) if project["name"] == b[0]
367+
i for i, project in enumerate(projects) if project["name"] == b[0]["name"]
366368
)
367369
return a_index - b_index
368370

@@ -431,7 +433,9 @@ def main(config, args) -> None:
431433
# Generate models for all projects
432434
print("\n=== Generating models ===")
433435

434-
failed_builds = [project for project, db_dir in database_results if db_dir is None]
436+
failed_builds = [
437+
project["name"] for project, db_dir in database_results if db_dir is None
438+
]
435439
if failed_builds:
436440
print(
437441
f"ERROR: {len(failed_builds)} database builds failed: {', '.join(failed_builds)}"
@@ -440,7 +444,7 @@ def main(config, args) -> None:
440444

441445
# Delete the MaD directory for each project
442446
for project, database_dir in database_results:
443-
mad_dir = get_mad_destination_for_project(config, project)
447+
mad_dir = get_mad_destination_for_project(config, project["name"])
444448
if os.path.exists(mad_dir):
445449
print(f"Deleting existing MaD directory at {mad_dir}")
446450
subprocess.check_call(["rm", "-rf", mad_dir])

0 commit comments

Comments
 (0)