diff --git a/text_to_image/tools/coco.py b/text_to_image/tools/coco.py index 5831173a1d..1a23d826d1 100644 --- a/text_to_image/tools/coco.py +++ b/text_to_image/tools/coco.py @@ -8,6 +8,7 @@ import tqdm import requests import urllib.request +import hashlib import zipfile import shutil from pathlib import Path @@ -15,6 +16,23 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger("coco") +# SHA-256 of the pinned COCO 2014 annotations archive +COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = ( + "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009" +) + + +def _verify_sha256(path: Path, expected: str) -> None: + hasher = hashlib.sha256() + with open(str(path), "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + hasher.update(chunk) + actual = hasher.hexdigest() + if actual != expected: + raise RuntimeError( + f"SHA-256 mismatch for {path}: expected {expected}, got {actual}" + ) + def get_args(): """Parse commandline.""" @@ -71,13 +89,24 @@ def get_args(): def download_img(args): img_url, target_folder, file_name = args + # Upgrade http://images.cocodataset.org/ URLs to the S3 origin + img_url = img_url.replace( + "http://images.cocodataset.org/", + "https://s3.amazonaws.com/images.cocodataset.org/", + 1, + ) if (target_folder / file_name).exists(): log.warning(f"Image {file_name} found locally, skipping download") else: urllib.request.urlretrieve(img_url, str(target_folder / file_name)) -def download_file(url: str, output_dir: Path, filename: str | None = None): +def download_file( + url: str, + output_dir: Path, + filename: str | None = None, + expected_sha256: str | None = None, +): os.makedirs(str(output_dir), exist_ok=True) if filename is None: @@ -99,6 +128,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None): f.write(chunk) pbar.update(len(chunk)) + if expected_sha256 is not None: + _verify_sha256(output_path, expected_sha256) + return output_path @@ -146,8 +178,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None): os.makedirs(str(dataset_dir / "raw"), exist_ok=True) os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True) download_file( - url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip", + url="https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip", output_dir=dataset_dir / "download_aux", + expected_sha256=COCO_ANNOTATIONS_TRAINVAL2014_SHA256, ) zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip" # Unzip file diff --git a/text_to_image/tools/coco_calibration.py b/text_to_image/tools/coco_calibration.py index dc4f49009d..1c35be8e37 100644 --- a/text_to_image/tools/coco_calibration.py +++ b/text_to_image/tools/coco_calibration.py @@ -7,11 +7,29 @@ import os import tqdm import urllib.request +import hashlib import zipfile logging.basicConfig(level=logging.INFO) log = logging.getLogger("coco") +# SHA-256 of the pinned COCO 2014 annotations archive +COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = ( + "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009" +) + + +def _verify_sha256(path: str, expected: str) -> None: + hasher = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + hasher.update(chunk) + actual = hasher.hexdigest() + if actual != expected: + raise RuntimeError( + f"SHA-256 mismatch for {path}: expected {expected}, got {actual}" + ) + def get_args(): """Parse commandline.""" @@ -44,6 +62,12 @@ def get_args(): def download_img(args): img_url, target_folder, file_name = args + # Upgrade http://images.cocodataset.org/ URLs to the S3 origin + img_url = img_url.replace( + "http://images.cocodataset.org/", + "https://s3.amazonaws.com/images.cocodataset.org/", + 1, + ) if os.path.exists(target_folder + file_name): log.warning(f"Image {file_name} found locally, skipping download") else: @@ -80,7 +104,12 @@ def download_img(args): os.makedirs(f"{dataset_dir}/download_aux/", exist_ok=True) os.system( f"cd {dataset_dir}/download_aux/ && \ - wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress" + wget https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress" + ) + # Verify archive integrity before unzip + _verify_sha256( + f"{dataset_dir}/download_aux/annotations_trainval2014.zip", + COCO_ANNOTATIONS_TRAINVAL2014_SHA256, ) # Unzip file diff --git a/text_to_image/tools/coco_generate_calibration.py b/text_to_image/tools/coco_generate_calibration.py index 096f5e4079..b5a8dcab88 100644 --- a/text_to_image/tools/coco_generate_calibration.py +++ b/text_to_image/tools/coco_generate_calibration.py @@ -7,6 +7,7 @@ import os import tqdm import urllib.request +import hashlib import zipfile import shutil from pathlib import Path @@ -15,6 +16,23 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger("coco") +# SHA-256 of the pinned COCO 2014 annotations archive +COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = ( + "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009" +) + + +def _verify_sha256(path: Path, expected: str) -> None: + hasher = hashlib.sha256() + with open(str(path), "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + hasher.update(chunk) + actual = hasher.hexdigest() + if actual != expected: + raise RuntimeError( + f"SHA-256 mismatch for {path}: expected {expected}, got {actual}" + ) + def get_args(): """Parse commandline.""" @@ -43,7 +61,12 @@ def get_args(): return args -def download_file(url: str, output_dir: Path, filename: str | None = None): +def download_file( + url: str, + output_dir: Path, + filename: str | None = None, + expected_sha256: str | None = None, +): os.makedirs(str(output_dir), exist_ok=True) if filename is None: @@ -65,6 +88,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None): f.write(chunk) pbar.update(len(chunk)) + if expected_sha256 is not None: + _verify_sha256(output_path, expected_sha256) + return output_path @@ -89,8 +115,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None): os.makedirs(str(dataset_dir / "raw"), exist_ok=True) os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True) download_file( - url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip", + url="https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip", output_dir=dataset_dir / "download_aux", + expected_sha256=COCO_ANNOTATIONS_TRAINVAL2014_SHA256, ) # Unzip file zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip"