From 85e63a3e08b8f11c5e661e17f3fbc9e529193e95 Mon Sep 17 00:00:00 2001 From: Luke Schmitt Date: Wed, 15 Apr 2026 21:50:30 -0500 Subject: [PATCH] Auto-detect dataset types --- README.md | 2 +- src/trossen_cloud_cli/cli.py | 2 +- src/trossen_cloud_cli/commands/datasets.py | 31 +++-- src/trossen_cloud_cli/validators/__init__.py | 41 ++++++- tests/test_validators.py | 121 ++++++++++++++++++- 5 files changed, 186 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 292d656..1eb81d8 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ The token is stored securely in your OS keyring. ```bash # Upload a local dataset -trc dataset upload ./my-data --name my-dataset --type lerobot_v3 +trc dataset upload ./my-data --name my-dataset # Download a dataset trc dataset download ./output diff --git a/src/trossen_cloud_cli/cli.py b/src/trossen_cloud_cli/cli.py index 8b93732..19a7430 100644 --- a/src/trossen_cloud_cli/cli.py +++ b/src/trossen_cloud_cli/cli.py @@ -52,7 +52,7 @@ def main_callback( [bold]Datasets:[/bold] - trc dataset upload ./my-data --name my-dataset --type lerobot_v3 + trc dataset upload ./my-data --name my-dataset trc dataset import-hf org/dataset-name --name my-dataset trc dataset download ./output trc dataset list --mine diff --git a/src/trossen_cloud_cli/commands/datasets.py b/src/trossen_cloud_cli/commands/datasets.py index 462a9ce..228dac4 100644 --- a/src/trossen_cloud_cli/commands/datasets.py +++ b/src/trossen_cloud_cli/commands/datasets.py @@ -17,11 +17,24 @@ from ..output import console, print_error, print_info, print_success, print_warning from ..types import DatasetType, PrivacyLevel from ..upload import UploadError, create_and_upload_dataset -from ..validators import validate_dataset +from ..validators import detect_dataset_type, validate_dataset app = typer.Typer(help="Manage datasets") +def _resolve_dataset_type(path: Path, dataset_type: DatasetType | None) -> DatasetType: + """Auto-detect dataset type if not provided, or exit with an error.""" + if dataset_type is not None: + return dataset_type + detected = detect_dataset_type(path) + if detected is None: + valid = ", ".join(dt.value for dt in DatasetType) + print_error(f"Could not detect dataset type. Use --type to specify ({valid}).") + raise typer.Exit(1) + print_info(f"Detected dataset type: {detected.value}") + return detected + + def is_user_name_format(identifier: str) -> bool: """ Check if identifier is in / format. @@ -56,9 +69,9 @@ def upload( typer.Option("--name", "-n", help="Dataset name"), ], dataset_type: Annotated[ - DatasetType, - typer.Option("--type", "-t", help="Dataset type"), - ], + DatasetType | None, + typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"), + ] = None, privacy: Annotated[ PrivacyLevel, typer.Option("--privacy", "-p", help="Privacy level"), @@ -77,6 +90,8 @@ def upload( """ require_auth() + dataset_type = _resolve_dataset_type(path, dataset_type) + # Parse metadata if provided metadata_dict = None if metadata: @@ -150,9 +165,9 @@ def import_hf( typer.Option("--name", "-n", help="Dataset name (defaults to HF repo name)"), ] = None, dataset_type: Annotated[ - DatasetType, - typer.Option("--type", "-t", help="Dataset type"), - ] = DatasetType.LEROBOT_V3, + DatasetType | None, + typer.Option("--type", "-t", help="Dataset type (auto-detected if omitted)"), + ] = None, privacy: Annotated[ PrivacyLevel, typer.Option("--privacy", "-p", help="Privacy level"), @@ -222,6 +237,8 @@ def import_hf( print_success(f"Downloaded to {local_path}") + dataset_type = _resolve_dataset_type(local_path, dataset_type) + # Validate dataset before upload validation_warnings = validate_dataset(local_path, dataset_type) if validation_warnings: diff --git a/src/trossen_cloud_cli/validators/__init__.py b/src/trossen_cloud_cli/validators/__init__.py index 5a9aa8c..df5d8d0 100644 --- a/src/trossen_cloud_cli/validators/__init__.py +++ b/src/trossen_cloud_cli/validators/__init__.py @@ -1,5 +1,6 @@ -"""Dataset validators for pre-upload structural checks.""" +"""Dataset validators and type detection for pre-upload structural checks.""" +import os from pathlib import Path from ..types import DatasetType @@ -7,6 +8,44 @@ from .mcap import validate_mcap +def _has_visible_mcap(root: Path) -> bool: + """True if ``root`` contains a non-hidden ``.mcap`` file outside any hidden directory. + + Hidden subdirectories (e.g. ``.git``, ``.cache``) are pruned during traversal + rather than walked-then-filtered, so detection stays fast on trees that + contain large hidden directories. + """ + for _dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + if any(name.endswith(".mcap") and not name.startswith(".") for name in filenames): + return True + return False + + +def detect_dataset_type(path: Path) -> DatasetType | None: + """ + Detect the dataset type from its contents. + + Returns the detected DatasetType, or None if the type cannot be determined. + Hidden filenames (those starting with ``.``) are ignored. The treatment of + parent directories depends on the input shape, mirroring ``collect_files``: + + * **Directory input:** hidden subdirectories (e.g. ``.git``, ``.cache``) + are skipped during traversal. + * **Single-file input** (e.g. passing ``some/.cache/foo.mcap`` directly): + only the filename is checked — parent directory names don't matter, + since ``collect_files`` would still upload that file. + """ + if path.is_file() and path.suffix == ".mcap" and not path.name.startswith("."): + return DatasetType.TROSSENMCAP + if path.is_dir(): + if (path / "meta" / "info.json").is_file(): + return DatasetType.LEROBOT_V3 + if _has_visible_mcap(path): + return DatasetType.TROSSENMCAP + return None + + def validate_dataset(path: Path, dataset_type: DatasetType) -> list[str]: """ Validate a dataset directory against its type-specific spec. diff --git a/tests/test_validators.py b/tests/test_validators.py index 7d16ed2..947f0a4 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -9,7 +9,7 @@ from trossen_cloud_cli.cli import app from trossen_cloud_cli.types import DatasetType -from trossen_cloud_cli.validators import validate_dataset +from trossen_cloud_cli.validators import detect_dataset_type, validate_dataset from trossen_cloud_cli.validators.lerobot import validate_lerobot from trossen_cloud_cli.validators.mcap import MCAP_MAGIC, validate_mcap @@ -161,6 +161,56 @@ def test_dispatches_to_mcap(self, tmp_path): assert warnings == [] +# ── Detection tests ───────────────────────────────────────────────────────── + + +class TestDetectDatasetType: + def test_detects_lerobot_from_meta_info(self, tmp_path): + ds = _make_valid_lerobot(tmp_path) + assert detect_dataset_type(ds) == DatasetType.LEROBOT_V3 + + def test_detects_mcap_from_directory(self, tmp_path): + ds = _make_valid_mcap_dataset(tmp_path) + assert detect_dataset_type(ds) == DatasetType.TROSSENMCAP + + def test_detects_mcap_from_single_file(self, tmp_path): + f = tmp_path / "episode_000000.mcap" + _make_valid_mcap_file(f) + assert detect_dataset_type(f) == DatasetType.TROSSENMCAP + + def test_returns_none_for_empty_directory(self, tmp_path): + assert detect_dataset_type(tmp_path) is None + + def test_returns_none_for_nonexistent_path(self, tmp_path): + assert detect_dataset_type(tmp_path / "nope") is None + + def test_lerobot_takes_priority_over_mcap(self, tmp_path): + """If both meta/info.json and .mcap files exist, detect lerobot_v3.""" + ds = _make_valid_lerobot(tmp_path) + _make_valid_mcap_file(ds / "episode_000000.mcap") + assert detect_dataset_type(ds) == DatasetType.LEROBOT_V3 + + def test_ignores_mcap_in_hidden_directories(self, tmp_path): + """Files under hidden dirs (.git, .cache, ...) must not trigger detection, + since upload skips them. Otherwise the user would be told the dataset is + TROSSENMCAP based on files that won't actually be uploaded. + """ + ds = tmp_path / "dataset" + ds.mkdir() + hidden = ds / ".cache" + hidden.mkdir() + _make_valid_mcap_file(hidden / "episode_000000.mcap") + assert detect_dataset_type(ds) is None + + def test_returns_none_for_hidden_single_mcap_file(self, tmp_path): + """A directly-passed hidden .mcap (e.g. .foo.mcap) must not be detected, + since collect_files would skip it and upload would then fail with + 'No files found to upload'.""" + f = tmp_path / ".hidden.mcap" + _make_valid_mcap_file(f) + assert detect_dataset_type(f) is None + + # ── LeRobot v3 validator tests ─────────────────────────────────────────────── @@ -595,6 +645,44 @@ def test_upload_force_skips_confirmation(self, tmp_path): assert result.exit_code == 0 upload_mock.assert_called_once() + def test_upload_auto_detects_type(self, tmp_path): + """Without --type, the CLI auto-detects the dataset type from contents.""" + ds = _make_valid_mcap_dataset(tmp_path) + upload_result = {"id": "ds-123", "name": "test"} + with ( + patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN), + patch( + "trossen_cloud_cli.commands.datasets.validate_dataset", + return_value=[], + ), + patch( + "trossen_cloud_cli.commands.datasets.create_and_upload_dataset", + return_value=upload_result, + ) as upload_mock, + ): + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test"], + ) + assert result.exit_code == 0 + assert "Detected dataset type: trossenmcap" in result.output + upload_mock.assert_called_once() + assert upload_mock.call_args.kwargs["dataset_type"] == "trossenmcap" + + def test_upload_auto_detect_fails_for_unrecognizable_dir(self, tmp_path): + """Auto-detection fails with a clear error when the directory has files + but none match a known dataset type.""" + ds = tmp_path / "unrecognizable" + ds.mkdir() + (ds / "random.txt").write_text("hello") + with patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN): + result = runner.invoke( + app, + ["dataset", "upload", str(ds), "--name", "test"], + ) + assert result.exit_code == 1 + assert "could not detect" in result.output.lower() + def test_upload_no_force_prompts_and_aborts(self, tmp_path): """Without --force, validation warnings trigger a prompt; 'n' aborts.""" ds = _make_valid_mcap_dataset(tmp_path) @@ -615,3 +703,34 @@ def test_upload_no_force_prompts_and_aborts(self, tmp_path): ) assert result.exit_code == 0 upload_mock.assert_not_called() + + def test_import_hf_auto_detects_type(self, tmp_path): + """import-hf auto-detects type from downloaded content when --type is omitted.""" + # _make_valid_lerobot creates a "dataset" subdir, and snapshot_download + # returns the path to the downloaded content, so we use that subdir. + download_dir = _make_valid_lerobot(tmp_path) + + upload_result = {"id": "ds-456", "name": "my-dataset"} + with ( + patch("trossen_cloud_cli.auth.get_token", return_value=MOCK_TOKEN), + patch( + "huggingface_hub.snapshot_download", + return_value=str(download_dir), + ), + patch( + "trossen_cloud_cli.commands.datasets.validate_dataset", + return_value=[], + ), + patch( + "trossen_cloud_cli.commands.datasets.create_and_upload_dataset", + return_value=upload_result, + ) as upload_mock, + ): + result = runner.invoke( + app, + ["dataset", "import-hf", "org/my-dataset", "--name", "my-dataset", "--force"], + ) + assert result.exit_code == 0 + assert "lerobot_v3" in result.output + upload_mock.assert_called_once() + assert upload_mock.call_args.kwargs["dataset_type"] == "lerobot_v3"