Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions src/trossen_cloud_cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,47 @@
app = typer.Typer(help="Manage datasets")


_TYPE_ALIASES = {"lerobot": "lerobot_v3", "mcap": "trossenmcap"}


def _valid_type_names() -> str:
"""Comma-separated list of accepted --type values (canonical names + aliases)."""
return ", ".join([*DatasetType, *_TYPE_ALIASES])


_DatasetTypeOption = Annotated[
str | None,
typer.Option(
"--type",
"-t",
help=f"Dataset type ({_valid_type_names()}). Auto-detected if omitted.",
),
Comment thread
lukeschmitt-tr marked this conversation as resolved.
]


def _parse_dataset_type(value: str | None) -> DatasetType | None:
"""Parse a --type string into a DatasetType, resolving aliases (case-insensitive)."""
if value is None:
return None
lower = value.lower()
resolved = _TYPE_ALIASES.get(lower, lower)
try:
return DatasetType(resolved)
except ValueError:
raise typer.BadParameter(
f"Invalid dataset type '{value}'. Valid: {_valid_type_names()}"
) from None


def _resolve_dataset_type(path: Path, dataset_type: DatasetType | None) -> DatasetType:
"""Auto-detect dataset type if not provided, or exit with an error."""
if dataset_type is not None:
return dataset_type
detected = detect_dataset_type(path)
if detected is None:
valid = ", ".join(dt.value for dt in DatasetType)
print_error(f"Could not detect dataset type. Use --type to specify ({valid}).")
print_error(
f"Could not detect dataset type. Use --type to specify ({_valid_type_names()})."
)
raise typer.Exit(1)
print_info(f"Detected dataset type: {detected.value}")
return detected
Expand Down Expand Up @@ -68,10 +101,7 @@ def upload(
str,
typer.Option("--name", "-n", help="Dataset name"),
],
dataset_type: Annotated[
DatasetType | None,
typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"),
] = None,
dataset_type_str: _DatasetTypeOption = None,
privacy: Annotated[
PrivacyLevel,
typer.Option("--privacy", "-p", help="Privacy level"),
Expand All @@ -88,9 +118,9 @@ def upload(
"""
Upload a dataset to Trossen Cloud.
"""
parsed_type = _parse_dataset_type(dataset_type_str)
require_auth()

dataset_type = _resolve_dataset_type(path, dataset_type)
dataset_type = _resolve_dataset_type(path, parsed_type)

# Parse metadata if provided
metadata_dict = None
Expand Down Expand Up @@ -164,10 +194,7 @@ def import_hf(
str | None,
typer.Option("--name", "-n", help="Dataset name (defaults to HF repo name)"),
] = None,
dataset_type: Annotated[
DatasetType | None,
typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"),
] = None,
dataset_type_str: _DatasetTypeOption = None,
privacy: Annotated[
PrivacyLevel,
typer.Option("--privacy", "-p", help="Privacy level"),
Expand Down Expand Up @@ -197,6 +224,7 @@ def import_hf(
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError

parsed_type = _parse_dataset_type(dataset_type_str)
require_auth()

repo_id = _parse_hf_repo_id(repo)
Expand Down Expand Up @@ -237,7 +265,7 @@ def import_hf(

print_success(f"Downloaded to {local_path}")

dataset_type = _resolve_dataset_type(local_path, dataset_type)
dataset_type = _resolve_dataset_type(local_path, parsed_type)

Comment on lines 266 to 269
# Validate dataset before upload
validation_warnings = validate_dataset(local_path, dataset_type)
Expand Down
123 changes: 123 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import struct
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -733,3 +734,125 @@ def test_import_hf_auto_detects_type(self, tmp_path):
assert "lerobot_v3" in result.output
upload_mock.assert_called_once()
assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3"


# ── Type alias / parsing tests ──────────────────────────────────────────────


@contextmanager
def _mock_upload():
"""Patch auth + validation + upload for a successful dataset upload invocation."""
with (
patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN),
patch("trossen_cloud_cli.commands.datasets.validate_dataset", return_value=[]),
patch(
"trossen_cloud_cli.commands.datasets.create_and_upload_dataset",
return_value={"id": "ds-123", "name": "test"},
) as upload_mock,
):
yield upload_mock


class TestTypeAliases:
def test_upload_accepts_alias_lerobot(self, tmp_path):
"""--type lerobot is accepted as an alias for lerobot_v3."""
ds = _make_valid_lerobot(tmp_path)
with _mock_upload() as upload_mock:
result = runner.invoke(
app,
["dataset", "upload", str(ds), "--name", "test", "--type", "lerobot"],
)
assert result.exit_code == 0
upload_mock.assert_called_once()
assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3"

def test_upload_accepts_alias_mcap(self, tmp_path):
"""--type mcap is accepted as an alias for trossenmcap."""
ds = _make_valid_mcap_dataset(tmp_path)
with _mock_upload() as upload_mock:
result = runner.invoke(
app,
["dataset", "upload", str(ds), "--name", "test", "--type", "mcap"],
)
assert result.exit_code == 0
upload_mock.assert_called_once()
assert upload_mock.call_args.kwargs["dataset_type"] == "trossenmcap"

def test_upload_rejects_invalid_type(self, tmp_path):
"""--type with an invalid value gives a clear error."""
ds = _make_valid_mcap_dataset(tmp_path)
with patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN):
result = runner.invoke(
app,
["dataset", "upload", str(ds), "--name", "test", "--type", "bogus"],
)
assert result.exit_code != 0
assert "invalid" in result.output.lower()

def test_upload_alias_is_case_insensitive(self, tmp_path):
"""Alias resolution lowercases the input — `LeRobot`, `MCAP`, etc. all work."""
ds = _make_valid_lerobot(tmp_path)
with _mock_upload() as upload_mock:
result = runner.invoke(
app,
["dataset", "upload", str(ds), "--name", "test", "--type", "LeRobot"],
)
assert result.exit_code == 0
assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3"

def test_upload_invalid_type_fails_before_auth(self, tmp_path):
"""Invalid --type must fail fast — before require_auth() runs."""
ds = _make_valid_mcap_dataset(tmp_path)
require_auth_mock = patch("trossen_cloud_cli.commands.datasets.require_auth")
with require_auth_mock as m:
result = runner.invoke(
app,
["dataset", "upload", str(ds), "--name", "test", "--type", "bogus"],
)
assert result.exit_code != 0
assert "invalid" in result.output.lower()
m.assert_not_called()

def test_import_hf_accepts_alias(self, tmp_path):
"""import-hf resolves --type aliases like dataset upload does."""
download_dir = _make_valid_lerobot(tmp_path)
upload_result = {"id": "ds-789", "name": "my-dataset"}
with (
patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN),
patch("huggingface_hub.snapshot_download", return_value=str(download_dir)),
patch("trossen_cloud_cli.commands.datasets.validate_dataset", return_value=[]),
patch(
"trossen_cloud_cli.commands.datasets.create_and_upload_dataset",
return_value=upload_result,
) as upload_mock,
):
result = runner.invoke(
app,
[
"dataset",
"import-hf",
"org/my-dataset",
"--name",
"my-dataset",
"--type",
"LeRobot",
"--force",
],
)
assert result.exit_code == 0
assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3"

def test_import_hf_invalid_type_fails_before_download(self, tmp_path):
"""Invalid --type on import-hf must fail before snapshot_download runs."""
snapshot_mock = patch("huggingface_hub.snapshot_download")
with (
patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN),
snapshot_mock as m,
):
result = runner.invoke(
app,
["dataset", "import-hf", "org/my-dataset", "--type", "bogus"],
)
assert result.exit_code != 0
assert "invalid" in result.output.lower()
m.assert_not_called()