diff --git a/text_to_image/tools/coco_calibration.py b/text_to_image/tools/coco_calibration.py index dc4f49009d..21f8839ccc 100644 --- a/text_to_image/tools/coco_calibration.py +++ b/text_to_image/tools/coco_calibration.py @@ -8,6 +8,9 @@ import tqdm import urllib.request import zipfile +import shutil +from pathlib import Path +import requests logging.basicConfig(level=logging.INFO) log = logging.getLogger("coco") @@ -44,15 +47,42 @@ def get_args(): def download_img(args): img_url, target_folder, file_name = args - if os.path.exists(target_folder + file_name): + target_folder = Path(target_folder) + if (target_folder / file_name).exists(): log.warning(f"Image {file_name} found locally, skipping download") else: - urllib.request.urlretrieve(img_url, target_folder + file_name) + urllib.request.urlretrieve(img_url, str(target_folder / file_name)) + +def download_file(url: str, output_dir: Path, filename: str | None = None): + os.makedirs(str(output_dir), exist_ok=True) + + if filename is None: + filename = os.path.basename(url) + + output_path = output_dir / filename + + with requests.get(url, stream=True) as r: + r.raise_for_status() + total_size = int(r.headers.get("Content-Length", 0)) + with open(str(output_path), "wb") as f, tqdm.tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=filename, + ) as pbar: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + return output_path if __name__ == "__main__": args = get_args() dataset_dir = os.path.abspath(args.dataset_dir) + dataset_dir = Path(dataset_dir) + calibration_dir = ( args.calibration_dir if args.calibration_dir is not None @@ -60,45 +90,52 @@ def download_img(args): os.path.dirname(__file__), "..", "..", "calibration", "COCO-2014" ) ) + calibration_dir = Path(calibration_dir) + tsv_path = Path(args.tsv_path) if args.tsv_path is not None else None + # Check if the annotation dataframe is there - if os.path.exists(f"{dataset_dir}/calibration/captions.tsv"): + calibration_captions_path = dataset_dir / "calibration" / "captions.tsv" + if calibration_captions_path.exists(): df_annotations = pd.read_csv( - f"{dataset_dir}/calibration/captions.tsv", sep="\t" + str(calibration_captions_path), sep="\t" ) - elif args.tsv_path is not None and os.path.exists(f"{args.tsv_path}"): + elif tsv_path is not None and tsv_path.exists(): os.makedirs(f"{dataset_dir}/calibration/", exist_ok=True) - os.system(f"cp {args.tsv_path} {dataset_dir}/calibration/") + shutil.copy(tsv_path, str(calibration_captions_path)) df_annotations = pd.read_csv( - f"{dataset_dir}/calibration/captions.tsv", sep="\t" + str(calibration_captions_path), sep="\t" ) else: + + # Check if raw annotations file already exist - if not os.path.exists( - f"{dataset_dir}/raw/annotations/captions_train2014.json"): + if not (dataset_dir / "raw" / "annotations" / "captions_train2014.json").exists(): # Download annotations - os.makedirs(f"{dataset_dir}/raw/", exist_ok=True) - os.makedirs(f"{dataset_dir}/download_aux/", exist_ok=True) - os.system( - f"cd {dataset_dir}/download_aux/ && \ - wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress" + os.makedirs(str(dataset_dir / "raw"), exist_ok=True) + os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True) + download_file( + url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip", + output_dir=dataset_dir / "download_aux", ) - + # Unzip file + zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip" # Unzip file with zipfile.ZipFile( - f"{dataset_dir}/download_aux/annotations_trainval2014.zip", "r" + str(zipfile_path), "r" ) as zip_ref: - zip_ref.extractall(f"{dataset_dir}/raw/") + zip_ref.extractall(str(dataset_dir / "raw/")) # Move captions to target folder - os.makedirs(f"{dataset_dir}/captions/", exist_ok=True) - os.system( - f"mv {dataset_dir}/raw/annotations/captions_train2014.json {dataset_dir}/captions/" - ) + os.makedirs(str(dataset_dir / "captions"), exist_ok=True) + shutil.move( + str(dataset_dir / "raw" / "annotations" / "captions_train2014.json"), + str(dataset_dir / "captions" / "captions_train2014.json")) if not args.keep_raw: - os.system(f"rm -rf {dataset_dir}/raw") - os.system(f"rm -rf {dataset_dir}/download_aux") + shutil.rmtree(str(dataset_dir / "raw")) + shutil.rmtree(str(dataset_dir / "download_aux")) + # Convert to dataframe format and extract the relevant fields - with open(f"{dataset_dir}/captions/captions_train2014.json") as f: + with open(dataset_dir / "captions" / "captions_train2014.json") as f: captions = json.load(f) annotations = captions["annotations"] images = captions["images"] @@ -106,7 +143,7 @@ def download_img(args): df_images = pd.DataFrame(images) # Calibration images - with open(f"{calibration_dir}/coco_cal_captions_list.txt") as f: + with open(calibration_dir / "coco_cal_captions_list.txt") as f: calibration_ids = f.readlines() calibration_ids = [int(id.replace("\n", "")) for id in calibration_ids] @@ -128,12 +165,12 @@ def download_img(args): .reset_index(drop=True) ) # Download images - os.makedirs(f"{dataset_dir}/calibration/", exist_ok=True) + os.makedirs(str(dataset_dir / "calibration"), exist_ok=True) if args.download_images: - os.makedirs(f"{dataset_dir}/calibration/data/", exist_ok=True) + os.makedirs(str(dataset_dir / "calibration" / "data"), exist_ok=True) tasks = [ (row["coco_url"], - f"{dataset_dir}/calibration/data/", + str(dataset_dir / "calibration" / "data" / ""), row["file_name"]) for i, row in df_annotations.iterrows() ] @@ -147,4 +184,4 @@ def download_img(args): # Finalize annotations df_annotations[ ["id", "image_id", "caption", "height", "width", "file_name", "coco_url"] - ].to_csv(f"{dataset_dir}/calibration/captions.tsv", sep="\t", index=False) + ].to_csv(str(dataset_dir / "calibration" / "captions.tsv"), sep="\t", index=False)