1+ #!/usr/bin/env python3
12"""
23Experimental script for bulk generation of MaD models based on a list of projects.
34
78import os.path
89import subprocess
910import sys
10- from typing import NotRequired , TypedDict, List
11+ from typing import Required , TypedDict, List, Callable, Optional
1112from concurrent.futures import ThreadPoolExecutor, as_completed
1213import time
1314import argparse
14- import json
15- import requests
1615import zipfile
1716import tarfile
18- from functools import cmp_to_key
17+ import shutil
18+
19+
20+ def missing_module(module_name: str) -> None:
21+ print(
22+ f"ERROR: {module_name} is not installed. Please install it with 'pip install {module_name}'."
23+ )
24+ sys.exit(1)
25+
26+
27+ try:
28+ import yaml
29+ except ImportError:
30+ missing_module("pyyaml")
31+
32+ try:
33+ import requests
34+ except ImportError:
35+ missing_module("requests")
1936
2037import generate_mad as mad
2138
2845
2946
3047# A project to generate models for
31- class Project(TypedDict):
32- """
33- Type definition for projects (acquired via a GitHub repo) to model.
34-
35- Attributes:
36- name: The name of the project
37- git_repo: URL to the git repository
38- git_tag: Optional Git tag to check out
39- """
40-
41- name: str
42- git_repo: NotRequired[str]
43- git_tag: NotRequired[str]
44- with_sinks: NotRequired[bool]
45- with_sinks: NotRequired[bool]
46- with_summaries: NotRequired[bool]
48+ Project = TypedDict(
49+ "Project",
50+ {
51+ "name": Required[str],
52+ "git-repo": str,
53+ "git-tag": str,
54+ "with-sinks": bool,
55+ "with-sources": bool,
56+ "with-summaries": bool,
57+ },
58+ total=False,
59+ )
4760
4861
4962def should_generate_sinks(project: Project) -> bool:
@@ -63,14 +76,14 @@ def clone_project(project: Project) -> str:
6376 Shallow clone a project into the build directory.
6477
6578 Args:
66- project: A dictionary containing project information with 'name', 'git_repo ', and optional 'git_tag ' keys.
79+ project: A dictionary containing project information with 'name', 'git-repo ', and optional 'git-tag ' keys.
6780
6881 Returns:
6982 The path to the cloned project directory.
7083 """
7184 name = project["name"]
72- repo_url = project["git_repo "]
73- git_tag = project.get("git_tag ")
85+ repo_url = project["git-repo "]
86+ git_tag = project.get("git-tag ")
7487
7588 # Determine target directory
7689 target_dir = os.path.join(build_dir, name)
@@ -103,6 +116,39 @@ def clone_project(project: Project) -> str:
103116 return target_dir
104117
105118
119+ def run_in_parallel[
120+ T, U
121+ ](
122+ func: Callable[[T], U],
123+ items: List[T],
124+ *,
125+ on_error=lambda item, exc: None,
126+ error_summary=lambda failures: None,
127+ max_workers=8,
128+ ) -> List[Optional[U]]:
129+ if not items:
130+ return []
131+ max_workers = min(max_workers, len(items))
132+ results = [None for _ in range(len(items))]
133+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
134+ # Start cloning tasks and keep track of them
135+ futures = {
136+ executor.submit(func, item): index for index, item in enumerate(items)
137+ }
138+ # Process results as they complete
139+ for future in as_completed(futures):
140+ index = futures[future]
141+ try:
142+ results[index] = future.result()
143+ except Exception as e:
144+ on_error(items[index], e)
145+ failed = [item for item, result in zip(items, results) if result is None]
146+ if failed:
147+ error_summary(failed)
148+ sys.exit(1)
149+ return results
150+
151+
106152def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
107153 """
108154 Clone all projects in parallel.
@@ -114,40 +160,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
114160 List of (project, project_dir) pairs in the same order as the input projects
115161 """
116162 start_time = time.time()
117- max_workers = min(8, len(projects)) # Use at most 8 threads
118- project_dirs_map = {} # Map to store results by project name
119-
120- with ThreadPoolExecutor(max_workers=max_workers) as executor:
121- # Start cloning tasks and keep track of them
122- future_to_project = {
123- executor.submit(clone_project, project): project for project in projects
124- }
125-
126- # Process results as they complete
127- for future in as_completed(future_to_project):
128- project = future_to_project[future]
129- try:
130- project_dir = future.result()
131- project_dirs_map[project["name"]] = (project, project_dir)
132- except Exception as e:
133- print(f"ERROR: Failed to clone {project['name']}: {e}")
134-
135- if len(project_dirs_map) != len(projects):
136- failed_projects = [
137- project["name"]
138- for project in projects
139- if project["name"] not in project_dirs_map
140- ]
141- print(
142- f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
143- )
144- sys.exit(1)
145-
146- project_dirs = [project_dirs_map[project["name"]] for project in projects]
147-
163+ dirs = run_in_parallel(
164+ clone_project,
165+ projects,
166+ on_error=lambda project, exc: print(
167+ f"ERROR: Failed to clone project {project['name']}: {exc}"
168+ ),
169+ error_summary=lambda failures: print(
170+ f"ERROR: Failed to clone {len(failures)} projects: {', '.join(p['name'] for p in failures)}"
171+ ),
172+ )
148173 clone_time = time.time() - start_time
149174 print(f"Cloning completed in {clone_time:.2f} seconds")
150- return project_dirs
175+ return list(zip(projects, dirs))
151176
152177
153178def build_database(
@@ -159,7 +184,7 @@ def build_database(
159184 Args:
160185 language: The language for which to build the database (e.g., "rust").
161186 extractor_options: Additional options for the extractor.
162- project: A dictionary containing project information with 'name' and 'git_repo ' keys.
187+ project: A dictionary containing project information with 'name' and 'git-repo ' keys.
163188 project_dir: Path to the CodeQL database.
164189
165190 Returns:
@@ -307,7 +332,10 @@ def pretty_name_from_artifact_name(artifact_name: str) -> str:
307332
308333
309334def download_dca_databases(
310- experiment_name: str, pat: str, projects: List[Project]
335+ language: str,
336+ experiment_name: str,
337+ pat: str,
338+ projects: List[Project],
311339) -> List[tuple[Project, str | None]]:
312340 """
313341 Download databases from a DCA experiment.
@@ -318,14 +346,14 @@ def download_dca_databases(
318346 Returns:
319347 List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
320348 """
321- database_results = {}
322349 print("\n=== Finding projects ===")
323350 response = get_json_from_github(
324351 f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
325352 pat,
326353 )
327354 targets = response["targets"]
328355 project_map = {project["name"]: project for project in projects}
356+ analyzed_databases = {}
329357 for data in targets.values():
330358 downloads = data["downloads"]
331359 analyzed_database = downloads["analyzed_database"]
@@ -336,6 +364,15 @@ def download_dca_databases(
336364 print(f"Skipping {pretty_name} as it is not in the list of projects")
337365 continue
338366
367+ if pretty_name in analyzed_databases:
368+ print(
369+ f"Skipping previous database {analyzed_databases[pretty_name]['artifact_name']} for {pretty_name}"
370+ )
371+
372+ analyzed_databases[pretty_name] = analyzed_database
373+
374+ def download_and_decompress(analyzed_database: dict) -> str:
375+ artifact_name = analyzed_database["artifact_name"]
339376 repository = analyzed_database["repository"]
340377 run_id = analyzed_database["run_id"]
341378 print(f"=== Finding artifact: {artifact_name} ===")
@@ -351,27 +388,40 @@ def download_dca_databases(
351388 artifact_zip_location = download_artifact(
352389 archive_download_url, artifact_name, pat
353390 )
354- print(f"=== Extracting artifact: {artifact_name} ===")
391+ print(f"=== Decompressing artifact: {artifact_name} ===")
355392 # The database is in a zip file, which contains a tar.gz file with the DB
356393 # First we open the zip file
357394 with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
358395 artifact_unzipped_location = os.path.join(build_dir, artifact_name)
396+ # clean up any remnants of previous runs
397+ shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
359398 # And then we extract it to build_dir/artifact_name
360399 zip_ref.extractall(artifact_unzipped_location)
361- # And then we iterate over the contents of the extracted directory
362- # and extract the tar.gz files inside it
363- for entry in os.listdir(artifact_unzipped_location):
364- artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
365- with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
366- # And we just untar it to the same directory as the zip file
367- tar_ref.extractall(artifact_unzipped_location)
368- database_results[pretty_name] = os.path.join(
369- artifact_unzipped_location, remove_extension(entry)
370- )
400+ # And then we extract the language tar.gz file inside it
401+ artifact_tar_location = os.path.join(
402+ artifact_unzipped_location, f"{language}.tar.gz"
403+ )
404+ with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
405+ # And we just untar it to the same directory as the zip file
406+ tar_ref.extractall(artifact_unzipped_location)
407+ ret = os.path.join(artifact_unzipped_location, language)
408+ print(f"Decompression complete: {ret}")
409+ return ret
410+
411+ results = run_in_parallel(
412+ download_and_decompress,
413+ list(analyzed_databases.values()),
414+ on_error=lambda db, exc: print(
415+ f"ERROR: Failed to download and decompress {db["artifact_name"]}: {exc}"
416+ ),
417+ error_summary=lambda failures: print(
418+ f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
419+ ),
420+ )
371421
372- print(f"\n=== Extracted {len(database_results )} databases ===")
422+ print(f"\n=== Fetched {len(results )} databases ===")
373423
374- return [(project, database_results[project["name"]] ) for project in projects ]
424+ return [(project_map[n], r ) for n, r in zip(analyzed_databases, results) ]
375425
376426
377427def get_mad_destination_for_project(config, name: str) -> str:
@@ -422,7 +472,9 @@ def main(config, args) -> None:
422472 case "repo":
423473 extractor_options = config.get("extractor_options", [])
424474 database_results = build_databases_from_projects(
425- language, extractor_options, projects
475+ language,
476+ extractor_options,
477+ projects,
426478 )
427479 case "dca":
428480 experiment_name = args.dca
@@ -439,7 +491,10 @@ def main(config, args) -> None:
439491 with open(args.pat, "r") as f:
440492 pat = f.read().strip()
441493 database_results = download_dca_databases(
442- experiment_name, pat, projects
494+ language,
495+ experiment_name,
496+ pat,
497+ projects,
443498 )
444499
445500 # Generate models for all projects
@@ -492,9 +547,9 @@ def main(config, args) -> None:
492547 sys.exit(1)
493548 try:
494549 with open(args.config, "r") as f:
495- config = json.load (f)
496- except json.JSONDecodeError as e:
497- print(f"ERROR: Failed to parse JSON file {args.config}: {e}")
550+ config = yaml.safe_load (f)
551+ except yaml.YAMLError as e:
552+ print(f"ERROR: Failed to parse YAML file {args.config}: {e}")
498553 sys.exit(1)
499554
500555 main(config, args)
0 commit comments