diff --git a/src/trossen_cloud_cli/commands/datasets.py b/src/trossen_cloud_cli/commands/datasets.py index 228dac4..7e14727 100644 --- a/src/trossen_cloud_cli/commands/datasets.py +++ b/src/trossen_cloud_cli/commands/datasets.py @@ -22,14 +22,47 @@ app = typer.Typer(help="Manage datasets") +_TYPE_ALIASES = {"lerobot": "lerobot_v3", "mcap": "trossenmcap"} + + +def _valid_type_names() -> str: + """Comma-separated list of accepted --type values (canonical names + aliases).""" + return ", ".join([*DatasetType, *_TYPE_ALIASES]) + + +_DatasetTypeOption = Annotated[ + str | None, + typer.Option( + "--type", + "-t", + help=f"Dataset type ({_valid_type_names()}). Auto-detected if omitted.", + ), +] + + +def _parse_dataset_type(value: str | None) -> DatasetType | None: + """Parse a --type string into a DatasetType, resolving aliases (case-insensitive).""" + if value is None: + return None + lower = value.lower() + resolved = _TYPE_ALIASES.get(lower, lower) + try: + return DatasetType(resolved) + except ValueError: + raise typer.BadParameter( + f"Invalid dataset type '{value}'. Valid: {_valid_type_names()}" + ) from None + + def _resolve_dataset_type(path: Path, dataset_type: DatasetType | None) -> DatasetType: """Auto-detect dataset type if not provided, or exit with an error.""" if dataset_type is not None: return dataset_type detected = detect_dataset_type(path) if detected is None: - valid = ", ".join(dt.value for dt in DatasetType) - print_error(f"Could not detect dataset type. Use --type to specify ({valid}).") + print_error( + f"Could not detect dataset type. Use --type to specify ({_valid_type_names()})." + ) raise typer.Exit(1) print_info(f"Detected dataset type: {detected.value}") return detected @@ -68,10 +101,7 @@ def upload( str, typer.Option("--name", "-n", help="Dataset name"), ], - dataset_type: Annotated[ - DatasetType | None, - typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"), - ] = None, + dataset_type_str: _DatasetTypeOption = None, privacy: Annotated[ PrivacyLevel, typer.Option("--privacy", "-p", help="Privacy level"), @@ -88,9 +118,9 @@ def upload( """ Upload a dataset to Trossen Cloud. """ + parsed_type = _parse_dataset_type(dataset_type_str) require_auth() - - dataset_type = _resolve_dataset_type(path, dataset_type) + dataset_type = _resolve_dataset_type(path, parsed_type) # Parse metadata if provided metadata_dict = None @@ -164,10 +194,7 @@ def import_hf( str | None, typer.Option("--name", "-n", help="Dataset name (defaults to HF repo name)"), ] = None, - dataset_type: Annotated[ - DatasetType | None, - typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"), - ] = None, + dataset_type_str: _DatasetTypeOption = None, privacy: Annotated[ PrivacyLevel, typer.Option("--privacy", "-p", help="Privacy level"), @@ -197,6 +224,7 @@ def import_hf( from huggingface_hub import snapshot_download from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError + parsed_type = _parse_dataset_type(dataset_type_str) require_auth() repo_id = _parse_hf_repo_id(repo) @@ -237,7 +265,7 @@ def import_hf( print_success(f"Downloaded to {local_path}") - dataset_type = _resolve_dataset_type(local_path, dataset_type) + dataset_type = _resolve_dataset_type(local_path, parsed_type) # Validate dataset before upload validation_warnings = validate_dataset(local_path, dataset_type) diff --git a/tests/test_validators.py b/tests/test_validators.py index 8f3a42f..5619227 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,6 +2,7 @@ import json import struct +from contextlib import contextmanager from pathlib import Path from unittest.mock import patch @@ -733,3 +734,125 @@ def test_import_hf_auto_detects_type(self, tmp_path): assert "lerobot_v3" in result.output upload_mock.assert_called_once() assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3" + + +# ── Type alias / parsing tests ────────────────────────────────────────────── + + +@contextmanager +def _mock_upload(): + """Patch auth + validation + upload for a successful dataset upload invocation.""" + with ( + patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN), + patch("trossen_cloud_cli.commands.datasets.validate_dataset", return_value=[]), + patch( + "trossen_cloud_cli.commands.datasets.create_and_upload_dataset", + return_value={"id": "ds-123", "name": "test"}, + ) as upload_mock, + ): + yield upload_mock + + +class TestTypeAliases: + def test_upload_accepts_alias_lerobot(self, tmp_path): + """--type lerobot is accepted as an alias for lerobot_v3.""" + ds = _make_valid_lerobot(tmp_path) + with _mock_upload() as upload_mock: + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test", "--type", "lerobot"], + ) + assert result.exit_code == 0 + upload_mock.assert_called_once() + assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3" + + def test_upload_accepts_alias_mcap(self, tmp_path): + """--type mcap is accepted as an alias for trossenmcap.""" + ds = _make_valid_mcap_dataset(tmp_path) + with _mock_upload() as upload_mock: + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test", "--type", "mcap"], + ) + assert result.exit_code == 0 + upload_mock.assert_called_once() + assert upload_mock.call_args.kwargs["dataset_type"] == "trossenmcap" + + def test_upload_rejects_invalid_type(self, tmp_path): + """--type with an invalid value gives a clear error.""" + ds = _make_valid_mcap_dataset(tmp_path) + with patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN): + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test", "--type", "bogus"], + ) + assert result.exit_code != 0 + assert "invalid" in result.output.lower() + + def test_upload_alias_is_case_insensitive(self, tmp_path): + """Alias resolution lowercases the input — `LeRobot`, `MCAP`, etc. all work.""" + ds = _make_valid_lerobot(tmp_path) + with _mock_upload() as upload_mock: + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test", "--type", "LeRobot"], + ) + assert result.exit_code == 0 + assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3" + + def test_upload_invalid_type_fails_before_auth(self, tmp_path): + """Invalid --type must fail fast — before require_auth() runs.""" + ds = _make_valid_mcap_dataset(tmp_path) + require_auth_mock = patch("trossen_cloud_cli.commands.datasets.require_auth") + with require_auth_mock as m: + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test", "--type", "bogus"], + ) + assert result.exit_code != 0 + assert "invalid" in result.output.lower() + m.assert_not_called() + + def test_import_hf_accepts_alias(self, tmp_path): + """import-hf resolves --type aliases like dataset upload does.""" + download_dir = _make_valid_lerobot(tmp_path) + upload_result = {"id": "ds-789", "name": "my-dataset"} + with ( + patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN), + patch("huggingface_hub.snapshot_download", return_value=str(download_dir)), + patch("trossen_cloud_cli.commands.datasets.validate_dataset", return_value=[]), + patch( + "trossen_cloud_cli.commands.datasets.create_and_upload_dataset", + return_value=upload_result, + ) as upload_mock, + ): + result = runner.invoke( + app, + [ + "dataset", + "import-hf", + "org/my-dataset", + "--name", + "my-dataset", + "--type", + "LeRobot", + "--force", + ], + ) + assert result.exit_code == 0 + assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3" + + def test_import_hf_invalid_type_fails_before_download(self, tmp_path): + """Invalid --type on import-hf must fail before snapshot_download runs.""" + snapshot_mock = patch("huggingface_hub.snapshot_download") + with ( + patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN), + snapshot_mock as m, + ): + result = runner.invoke( + app, + ["dataset", "import-hf", "org/my-dataset", "--type", "bogus"], + ) + assert result.exit_code != 0 + assert "invalid" in result.output.lower() + m.assert_not_called()