Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions fluster/fluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from shutil import rmtree
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple

from fluster import utils
from fluster.codec import Codec, Profile
from fluster.decoder import DECODERS, Decoder

Expand Down Expand Up @@ -849,11 +850,16 @@ def download_test_suites(
download_test_suites = self.test_suites
print(f"Test suites: {[ts.name for ts in download_test_suites]}")

for test_suite in download_test_suites:
test_suite.download(
jobs,
self.resources_dir,
verify=True,
keep_file=keep_file,
retries=retries,
)
manager = utils.DownloadManager(verify=True, keep_file=keep_file, retries=retries)
try:
for test_suite in download_test_suites:
test_suite.download(
jobs,
self.resources_dir,
verify=True,
keep_file=keep_file,
retries=retries,
download_manager=manager,
)
finally:
manager.cleanup()
191 changes: 136 additions & 55 deletions fluster/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import os.path
import sys
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import lru_cache
from multiprocessing import Pool
from shutil import rmtree
from time import perf_counter
from typing import Any, Dict, List, Optional, Set, Type, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Type, cast
from unittest.result import TestResult

from fluster import utils
Expand All @@ -47,13 +48,15 @@ def __init__(
keep_file: bool,
test_suite_name: str,
retries: int,
archive_path: Optional[str] = None,
):
self.out_dir = out_dir
self.verify = verify
self.extract_all = extract_all
self.keep_file = keep_file
self.test_suite_name = test_suite_name
self.retries = retries
self.archive_path = archive_path

# This is added to avoid having to create an extra ancestor class
def set_test_vector(self, test_vector: TestVector) -> None:
Expand All @@ -74,8 +77,9 @@ def __init__(
test_suite_name: str,
test_vectors: Dict[str, TestVector],
retries: int,
archive_path: Optional[str] = None,
):
super().__init__(out_dir, verify, extract_all, keep_file, test_suite_name, retries)
super().__init__(out_dir, verify, extract_all, keep_file, test_suite_name, retries, archive_path)
self.test_vectors = test_vectors


Expand Down Expand Up @@ -207,33 +211,37 @@ def _download_single_test_vector(ctx: DownloadWork) -> None:
dest_path = os.path.join(dest_dir, os.path.basename(ctx.test_vector.source))
os.makedirs(dest_dir, exist_ok=True)

if (
ctx.verify
and os.path.exists(dest_path)
and ctx.test_vector.source_checksum == utils.file_checksum(dest_path)
):
# Remove file only in case the input file was extractable.
# Otherwise, we'd be removing the original file we want to work
# with every even time we execute the download subcommand.
if utils.is_extractable(dest_path) and not ctx.keep_file:
os.remove(dest_path)
return

print(f"\tDownloading test vector {ctx.test_vector.name} from {ctx.test_vector.source}")
utils.download(ctx.test_vector.source, dest_dir, ctx.retries**ctx.retries)

if ctx.test_vector.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if ctx.test_vector.source_checksum != checksum:
raise Exception(
f"Checksum mismatch for {ctx.test_vector.name}: {checksum} instead of "
f"{ctx.test_vector.source_checksum}"
)
# When archive_path is provided, the archive was already downloaded
# by the DownloadManager — skip directly to extraction.
if ctx.archive_path and os.path.exists(ctx.archive_path):
source_path = ctx.archive_path
else:
source_path = dest_path

if (
ctx.verify
and os.path.exists(dest_path)
and ctx.test_vector.source_checksum == utils.file_checksum(dest_path)
):
if utils.is_extractable(dest_path) and not ctx.keep_file:
os.remove(dest_path)
return

print(f"\tDownloading test vector {ctx.test_vector.name} from {ctx.test_vector.source}")
utils.download(ctx.test_vector.source, dest_dir, ctx.retries**ctx.retries)

if ctx.test_vector.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if ctx.test_vector.source_checksum != checksum:
raise Exception(
f"Checksum mismatch for {ctx.test_vector.name}: {checksum} instead of "
f"{ctx.test_vector.source_checksum}"
)

if utils.is_extractable(dest_path):
if utils.is_extractable(source_path):
print(f"\tExtracting test vector {ctx.test_vector.name} to {dest_dir}")
utils.extract(dest_path, dest_dir, file=ctx.test_vector.input_file if not ctx.extract_all else None)
if not ctx.keep_file:
utils.extract(source_path, dest_dir, file=ctx.test_vector.input_file if not ctx.extract_all else None)
if not ctx.keep_file and not ctx.archive_path and os.path.exists(dest_path):
os.remove(dest_path)

@staticmethod
Expand All @@ -244,29 +252,48 @@ def _download_single_archive(ctx: DownloadWorkSingleArchive) -> None:
dest_path = os.path.join(dest_dir, os.path.basename(first_tv.source))
os.makedirs(dest_dir, exist_ok=True)

# Clean up existing corrupt source file
if (
ctx.verify
and os.path.exists(dest_path)
and utils.is_extractable(dest_path)
and first_tv.source_checksum != utils.file_checksum(dest_path)
):
os.remove(dest_path)

print(f"\tDownloading source file from {first_tv.source}")
utils.download(first_tv.source, dest_dir, ctx.retries**ctx.retries)
# When archive_path is provided, the archive was already downloaded
# by the DownloadManager — skip directly to extraction.
if ctx.archive_path and os.path.exists(ctx.archive_path):
archive_path = ctx.archive_path
print(f"\tUsing pre-downloaded archive {os.path.basename(first_tv.source)}")
else:
archive_path = dest_path

# Check that source file was downloaded correctly
if first_tv.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if first_tv.source_checksum != checksum:
raise Exception(
f"Checksum mismatch for source file {os.path.basename(first_tv.source)}: {checksum} "
f"instead of '{first_tv.source_checksum}'"
# Verify existing file: clean up corrupt, skip if valid
skip_download = False
if os.path.exists(dest_path):
if first_tv.source_checksum == "__skip__":
skip_download = True
else:
checksum = utils.file_checksum(dest_path)
if first_tv.source_checksum == checksum:
skip_download = True
elif ctx.verify and utils.is_extractable(dest_path):
os.remove(dest_path)

if skip_download:
print(f"\tSkipping download of {os.path.basename(first_tv.source)} (already exists)")
else:
print(f"\tDownloading source file from {first_tv.source}")
utils.download(
first_tv.source,
dest_dir,
ctx.retries**ctx.retries,
)

# Check downloaded file
if first_tv.source_checksum != "__skip__":
checksum = utils.file_checksum(dest_path)
if first_tv.source_checksum != checksum:
raise Exception(
f"Checksum mismatch for source file "
f"{os.path.basename(first_tv.source)}: "
f"{checksum} instead of '{first_tv.source_checksum}'"
)

try:
with zipfile.ZipFile(dest_path, "r") as zip_file:
with zipfile.ZipFile(archive_path, "r") as zip_file:
print(f"\tExtracting test vectors from {os.path.basename(first_tv.source)}")
for tv in ctx.test_vectors.values():
if tv.input_file in zip_file.namelist():
Expand All @@ -276,11 +303,18 @@ def _download_single_archive(ctx: DownloadWorkSingleArchive) -> None:
f"WARNING: test vector {tv.input_file} not found inside {os.path.basename(first_tv.source)}"
)
except zipfile.BadZipFile as bad_zip_error:
os.remove(dest_path)
raise Exception(f"{dest_path} could not be opened as zip file. File was deleted") from bad_zip_error

# Remove source file, if applicable
if not ctx.keep_file:
# Only delete the archive if we downloaded it locally, not when
# it is managed by the DownloadManager (another suite may need it).
if not ctx.archive_path and os.path.exists(archive_path):
os.remove(archive_path)
msg = f"{archive_path} could not be opened as zip file"
if not ctx.archive_path:
msg += ". File was deleted"
raise Exception(msg) from bad_zip_error

# Remove source file, if applicable (only if we downloaded it locally,
# not when the archive was provided by the DownloadManager)
if not ctx.keep_file and not ctx.archive_path and os.path.exists(dest_path):
os.remove(dest_path)

def download(
Expand All @@ -291,6 +325,7 @@ def download(
extract_all: bool = False,
keep_file: bool = False,
retries: int = 2,
download_manager: Optional["utils.DownloadManager"] = None,
) -> None:
"""Download the test suite"""
os.makedirs(out_dir, exist_ok=True)
Expand All @@ -303,19 +338,56 @@ def download(
):
# Download test suite of multiple test vectors from a single archive
print(f"Downloading test suite {self.name} using 1 job (single archive)")
first_tv = next(iter(self.test_vectors.values()))
archive_path = None
if download_manager:
dest_dir = os.path.join(out_dir, self.name)
archive_path = download_manager.get(first_tv.source, dest_dir, first_tv.source_checksum)
dwork_single = DownloadWorkSingleArchive(
out_dir, verify, extract_all, keep_file, self.name, self.test_vectors, retries
out_dir,
verify,
extract_all,
keep_file,
self.name,
self.test_vectors,
retries,
archive_path=archive_path,
)
self._download_single_archive(dwork_single)
elif len(unique_sources) == 1 and len(self.test_vectors) == 1:
# Download test suite of single test vector
print(f"Downloading test suite {self.name} using 1 job (single file)")
single_tv = next(iter(self.test_vectors.values()))
dwork = DownloadWork(out_dir, verify, extract_all, keep_file, self.name, retries)
archive_path = None
if download_manager:
dest_dir = os.path.join(out_dir, self.name, single_tv.name)
archive_path = download_manager.get(single_tv.source, dest_dir, single_tv.source_checksum)
dwork = DownloadWork(out_dir, verify, extract_all, keep_file, self.name, retries, archive_path)
dwork.set_test_vector(single_tv)
self._download_single_test_vector(dwork)
else:
# Download test suite of multiple test vectors
# Download test suite of multiple test vectors.
# When a download_manager is provided, pre-download all unique
# source URLs in parallel (deduplicating via the thread-safe
# manager), then dispatch parallel workers that only extract
# from the pre-downloaded archives.
source_paths: Dict[str, str] = {}
if download_manager:
unique_source_list = list(unique_sources)

def _pre_download(url: str) -> Tuple[str, str]:
rep_tv = next(tv for tv in self.test_vectors.values() if tv.source == url)
dest_dir = os.path.join(out_dir, self.name, rep_tv.name)
local_path = download_manager.get(url, dest_dir, rep_tv.source_checksum)
return (url, local_path)

max_workers = max(1, min(jobs, len(unique_source_list)))
with ThreadPoolExecutor(max_workers=max_workers) as dl_pool:
futures = {dl_pool.submit(_pre_download, url): url for url in unique_source_list}
for future in as_completed(futures):
url, local_path = future.result()
source_paths[url] = local_path

print(f"Downloading test suite {self.name} using {jobs} parallel jobs")
error_occurred = False
with Pool(jobs) as pool:
Expand All @@ -328,7 +400,16 @@ def _callback_error(err: Any) -> None:

downloads = []
for tv in self.test_vectors.values():
dwork = DownloadWork(out_dir, verify, extract_all, keep_file, self.name, retries)
archive_path = source_paths.get(tv.source)
dwork = DownloadWork(
out_dir,
verify,
extract_all,
keep_file,
self.name,
retries,
archive_path,
)
dwork.set_test_vector(tv)
downloads.append(
pool.apply_async(
Expand Down
Loading