Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions text_to_image/tools/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,31 @@
import tqdm
import requests
import urllib.request
import hashlib
import zipfile
import shutil
from pathlib import Path

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")

# SHA-256 of the pinned COCO 2014 annotations archive
COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = (
"031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009"
)


def _verify_sha256(path: Path, expected: str) -> None:
hasher = hashlib.sha256()
with open(str(path), "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
hasher.update(chunk)
actual = hasher.hexdigest()
if actual != expected:
raise RuntimeError(
f"SHA-256 mismatch for {path}: expected {expected}, got {actual}"
)


def get_args():
"""Parse commandline."""
Expand Down Expand Up @@ -71,13 +89,24 @@ def get_args():

def download_img(args):
img_url, target_folder, file_name = args
# Upgrade http://images.cocodataset.org/ URLs to the S3 origin
img_url = img_url.replace(
"http://images.cocodataset.org/",
"https://s3.amazonaws.com/images.cocodataset.org/",
1,
)
if (target_folder / file_name).exists():
log.warning(f"Image {file_name} found locally, skipping download")
else:
urllib.request.urlretrieve(img_url, str(target_folder / file_name))


def download_file(url: str, output_dir: Path, filename: str | None = None):
def download_file(
url: str,
output_dir: Path,
filename: str | None = None,
expected_sha256: str | None = None,
):
os.makedirs(str(output_dir), exist_ok=True)

if filename is None:
Expand All @@ -99,6 +128,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
f.write(chunk)
pbar.update(len(chunk))

if expected_sha256 is not None:
_verify_sha256(output_path, expected_sha256)

return output_path


Expand Down Expand Up @@ -146,8 +178,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
os.makedirs(str(dataset_dir / "raw"), exist_ok=True)
os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True)
download_file(
url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
url="https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip",
output_dir=dataset_dir / "download_aux",
expected_sha256=COCO_ANNOTATIONS_TRAINVAL2014_SHA256,
)
zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip"
# Unzip file
Expand Down
31 changes: 30 additions & 1 deletion text_to_image/tools/coco_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,29 @@
import os
import tqdm
import urllib.request
import hashlib
import zipfile

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")

# SHA-256 of the pinned COCO 2014 annotations archive
COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = (
"031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009"
)


def _verify_sha256(path: str, expected: str) -> None:
hasher = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
hasher.update(chunk)
actual = hasher.hexdigest()
if actual != expected:
raise RuntimeError(
f"SHA-256 mismatch for {path}: expected {expected}, got {actual}"
)


def get_args():
"""Parse commandline."""
Expand Down Expand Up @@ -44,6 +62,12 @@ def get_args():

def download_img(args):
img_url, target_folder, file_name = args
# Upgrade http://images.cocodataset.org/ URLs to the S3 origin
img_url = img_url.replace(
"http://images.cocodataset.org/",
"https://s3.amazonaws.com/images.cocodataset.org/",
1,
)
if os.path.exists(target_folder + file_name):
log.warning(f"Image {file_name} found locally, skipping download")
else:
Expand Down Expand Up @@ -80,7 +104,12 @@ def download_img(args):
os.makedirs(f"{dataset_dir}/download_aux/", exist_ok=True)
os.system(
f"cd {dataset_dir}/download_aux/ && \
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress"
wget https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress"
)
# Verify archive integrity before unzip
_verify_sha256(
f"{dataset_dir}/download_aux/annotations_trainval2014.zip",
COCO_ANNOTATIONS_TRAINVAL2014_SHA256,
)

# Unzip file
Expand Down
31 changes: 29 additions & 2 deletions text_to_image/tools/coco_generate_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import tqdm
import urllib.request
import hashlib
import zipfile
import shutil
from pathlib import Path
Expand All @@ -15,6 +16,23 @@
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")

# SHA-256 of the pinned COCO 2014 annotations archive
COCO_ANNOTATIONS_TRAINVAL2014_SHA256 = (
"031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009"
)


def _verify_sha256(path: Path, expected: str) -> None:
hasher = hashlib.sha256()
with open(str(path), "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
hasher.update(chunk)
actual = hasher.hexdigest()
if actual != expected:
raise RuntimeError(
f"SHA-256 mismatch for {path}: expected {expected}, got {actual}"
)


def get_args():
"""Parse commandline."""
Expand Down Expand Up @@ -43,7 +61,12 @@ def get_args():
return args


def download_file(url: str, output_dir: Path, filename: str | None = None):
def download_file(
url: str,
output_dir: Path,
filename: str | None = None,
expected_sha256: str | None = None,
):
os.makedirs(str(output_dir), exist_ok=True)

if filename is None:
Expand All @@ -65,6 +88,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
f.write(chunk)
pbar.update(len(chunk))

if expected_sha256 is not None:
_verify_sha256(output_path, expected_sha256)

return output_path


Expand All @@ -89,8 +115,9 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
os.makedirs(str(dataset_dir / "raw"), exist_ok=True)
os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True)
download_file(
url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
url="https://s3.amazonaws.com/images.cocodataset.org/annotations/annotations_trainval2014.zip",
output_dir=dataset_dir / "download_aux",
expected_sha256=COCO_ANNOTATIONS_TRAINVAL2014_SHA256,
)
# Unzip file
zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip"
Expand Down
Loading