Skip to content

Commit 94b51e1

Browse files
Update download_osl_hf.py
Add the description and dense description
1 parent 4d2a545 commit 94b51e1

1 file changed

Lines changed: 98 additions & 42 deletions

File tree

test_data/download_osl_hf.py

Lines changed: 98 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from huggingface_hub import hf_hub_download, snapshot_download, HfApi
66

77

8-
def human_size(num):
8+
def human_size(num: int) -> str:
99
"""Convert a file size in bytes to a human-readable string (B, KB, MB, GB, TB)."""
1010
for unit in ["B", "KB", "MB", "GB", "TB"]:
1111
if num < 1024.0:
@@ -14,25 +14,30 @@ def human_size(num):
1414
return f"{num:.1f} PB"
1515

1616

17-
def fix_hf_url(hf_url):
17+
def fix_hf_url(hf_url: str) -> str:
1818
"""Convert a HuggingFace 'blob' URL to a 'resolve' URL for direct download."""
1919
return hf_url.replace("/blob/", "/resolve/")
2020

2121

22-
def parse_hf_url(hf_url):
22+
def parse_hf_url(hf_url: str):
2323
"""
2424
Parse a Hugging Face dataset file URL (supports 'blob' or 'resolve' forms).
2525
Returns (repo_id, revision, path_in_repo).
26+
Example:
27+
https://huggingface.co/datasets/ORG/REPO/blob/main/annotations_test.json
28+
-> repo_id="ORG/REPO", revision="main", path_in_repo="annotations_test.json"
2629
"""
2730
url = fix_hf_url(hf_url)
2831
parsed = urlparse(url)
2932
parts = parsed.path.strip("/").split("/")
3033

34+
# Remove leading "datasets" if present
3135
if "datasets" in parts:
3236
datasets_idx = parts.index("datasets")
3337
parts = parts[datasets_idx + 1 :]
3438

35-
if len(parts) < 4 or parts[2] != "resolve":
39+
# Expected: ORG / REPO / resolve / REVISION / <path...>
40+
if len(parts) < 5 or parts[2] != "resolve":
3641
raise ValueError(f"URL does not look like a valid HuggingFace dataset file URL: {url}")
3742

3843
repo_id = f"{parts[0]}/{parts[1]}"
@@ -42,53 +47,88 @@ def parse_hf_url(hf_url):
4247
return repo_id, revision, path_in_repo
4348

4449

45-
def get_json_repo_folder(path_in_repo):
46-
"""
47-
Return the folder containing the JSON inside the repo, or '' if at root.
48-
"""
50+
def get_json_repo_folder(path_in_repo: str) -> str:
51+
"""Return the folder containing the JSON inside the repo, or '' if at root."""
4952
folder = os.path.dirname(path_in_repo)
5053
return folder if folder and folder != "." else ""
5154

5255

53-
def extract_video_paths(osl_json):
56+
def parse_types_arg(types_arg: str):
5457
"""
55-
Extract video paths from different OSL / SoccerNetPro JSON schemas.
58+
Parse --types argument.
59+
- "all" means include any input that has a "path".
60+
- Otherwise it's a comma-separated list of input types (e.g. "video,captions,features").
61+
"""
62+
types_arg = (types_arg or "video").strip().lower()
63+
if types_arg in ("all", "*"):
64+
return "all"
65+
return {t.strip() for t in types_arg.split(",") if t.strip()}
66+
67+
68+
def extract_repo_paths_from_json(osl_json: dict, want_types):
69+
"""
70+
Extract file paths from different OSL / SoccerNetPro JSON schemas.
5671
5772
Supported formats:
58-
- videos[].path
59-
- data[].inputs[].path (where type == "video")
73+
- videos[].path (legacy/simple)
74+
- data[].inputs[].path (OSL v2)
75+
where input has fields: {type, path, ...}
76+
77+
want_types:
78+
- "all" -> any input with a "path"
79+
- set(...) -> only inputs whose inp["type"] is in the set
6080
"""
6181
repo_paths = []
6282

63-
# Legacy / simple format
64-
if "videos" in osl_json:
65-
for v in osl_json.get("videos", []):
66-
if "path" in v:
67-
repo_paths.append(v["path"].lstrip("/"))
83+
# Legacy/simple format
84+
if "videos" in osl_json and isinstance(osl_json.get("videos"), list):
85+
# Only include if caller wants videos
86+
if want_types == "all" or ("video" in want_types):
87+
for v in osl_json.get("videos", []):
88+
if isinstance(v, dict) and "path" in v:
89+
repo_paths.append(str(v["path"]).lstrip("/"))
6890

69-
# SoccerNetPro / OSL v2 format
70-
elif "data" in osl_json:
91+
# OSL v2 format
92+
if "data" in osl_json and isinstance(osl_json.get("data"), list):
7193
for item in osl_json.get("data", []):
7294
for inp in item.get("inputs", []):
73-
if inp.get("type") == "video" and "path" in inp:
74-
repo_paths.append(inp["path"].lstrip("/"))
95+
if not isinstance(inp, dict):
96+
continue
97+
p = inp.get("path")
98+
if not p:
99+
continue
100+
inp_type = str(inp.get("type", "")).strip().lower()
101+
102+
if want_types == "all":
103+
repo_paths.append(str(p).lstrip("/"))
104+
else:
105+
if inp_type in want_types:
106+
repo_paths.append(str(p).lstrip("/"))
75107

76108
if not repo_paths:
77-
raise ValueError("No video paths found in the provided OSL JSON.")
109+
if want_types == "all":
110+
raise ValueError("No file paths found in the provided JSON (no inputs with 'path').")
111+
else:
112+
raise ValueError(
113+
f"No matching file paths found for requested types={sorted(list(want_types))}. "
114+
"Check your JSON schema and --types."
115+
)
78116

79117
return repo_paths
80118

81119

82-
def main(osl_json_url, output_dir="downloaded_data", dry_run=False):
120+
def main(osl_json_url: str, output_dir: str = "downloaded_data", dry_run: bool = False, types_arg: str = "video"):
83121
api = HfApi()
122+
want_types = parse_types_arg(types_arg)
84123

85124
# Parse HuggingFace URL
86125
repo_id, revision, path_in_repo = parse_hf_url(osl_json_url)
87126
repo_json_folder = get_json_repo_folder(path_in_repo)
88127

89-
print(f"⬇️ Downloading OSL JSON from {repo_id}@{revision}: {path_in_repo}")
128+
print(f"⬇️ Downloading JSON from {repo_id}@{revision}: {path_in_repo}")
90129
os.makedirs(output_dir, exist_ok=True)
91130

131+
# Download JSON itself
92132
hf_json_path = hf_hub_download(
93133
repo_id=repo_id,
94134
repo_type="dataset",
@@ -97,28 +137,32 @@ def main(osl_json_url, output_dir="downloaded_data", dry_run=False):
97137
local_dir=output_dir,
98138
local_dir_use_symlinks=False,
99139
)
100-
101-
print(f" → Saved as {hf_json_path}")
140+
print(f" → Saved as: {hf_json_path}")
102141

103142
# Load JSON
104-
with open(hf_json_path, "r") as f:
143+
with open(hf_json_path, "r", encoding="utf-8") as f:
105144
osl = json.load(f)
106145

107-
# Extract video paths (schema-aware)
108-
repo_paths = extract_video_paths(osl)
109-
print(f"Found {len(repo_paths)} video files to download.")
110-
111-
def repo_full_path(rel_path):
112-
if repo_json_folder and not rel_path.startswith(repo_json_folder + "/"):
113-
return os.path.join(repo_json_folder, rel_path)
146+
# Extract repo paths (schema-aware)
147+
repo_paths = extract_repo_paths_from_json(osl, want_types)
148+
print(f"Found {len(repo_paths)} referenced files for types={types_arg}.")
149+
150+
# If JSON file lives in a repo subfolder, some inputs may be relative to that folder.
151+
# We keep your original behavior: if path doesn't start with repo_json_folder, prefix it.
152+
def repo_full_path(rel_path: str) -> str:
153+
rel_path = rel_path.lstrip("/")
154+
if repo_json_folder:
155+
prefix = repo_json_folder.rstrip("/") + "/"
156+
if not rel_path.startswith(prefix):
157+
return prefix + rel_path
114158
return rel_path
115159

116-
# Unique, repo-relative paths
117160
allow_patterns = sorted(set(repo_full_path(p) for p in repo_paths))
118161

119162
if dry_run:
120163
print("Running in DRY-RUN mode (no files will be downloaded).")
121164

165+
# Fetch file sizes via repo metadata (best effort)
122166
try:
123167
info_obj = api.repo_info(
124168
repo_id=repo_id,
@@ -152,9 +196,11 @@ def repo_full_path(rel_path):
152196
print(f"Total estimated storage needed: {human_size(total_size)}")
153197

154198
if missing_files:
155-
print(f"WARNING: {len(missing_files)} files not found in repo:")
156-
for f in missing_files:
199+
print(f"WARNING: {len(missing_files)} files not found in repo metadata:")
200+
for f in missing_files[:50]:
157201
print(f" - {f}")
202+
if len(missing_files) > 50:
203+
print(f" ... and {len(missing_files) - 50} more")
158204

159205
else:
160206
print(f"Downloading {len(allow_patterns)} files using snapshot_download...")
@@ -166,26 +212,36 @@ def repo_full_path(rel_path):
166212
allow_patterns=allow_patterns,
167213
max_workers=8,
168214
)
169-
print(f" All requested files downloaded to: {output_dir}")
215+
print(f"✅ Done. All requested files downloaded to: {output_dir}")
170216

171217

172218
if __name__ == "__main__":
173-
parser = argparse.ArgumentParser(description="Download videos referenced in an OSL JSON from HuggingFace.")
219+
parser = argparse.ArgumentParser(
220+
description="Download files referenced in an OSL JSON from Hugging Face (dataset repo)."
221+
)
174222
parser.add_argument(
175223
"--url",
176224
required=True,
177-
help="URL of the OSL JSON file on HuggingFace",
225+
help="URL of the OSL JSON file on Hugging Face (blob/resolve both supported)",
178226
)
179227
parser.add_argument(
180228
"--output-dir",
181229
default="downloaded_data",
182230
help="Directory to store downloaded files",
183231
)
232+
parser.add_argument(
233+
"--types",
234+
default="video",
235+
help=(
236+
"Comma-separated input types to download from item.inputs (e.g. 'video', 'video,captions', "
237+
"'video,captions,features'), or 'all' to download all inputs with a path. Default: video"
238+
),
239+
)
184240
parser.add_argument(
185241
"--dry-run",
186242
action="store_true",
187-
help="List files to download without downloading them",
243+
help="List files to download without downloading them (estimates total size if possible).",
188244
)
189245

190246
args = parser.parse_args()
191-
main(args.url, args.output_dir, dry_run=args.dry_run)
247+
main(args.url, args.output_dir, dry_run=args.dry_run, types_arg=args.types)

0 commit comments

Comments
 (0)