From 7559216353b1f483bc19362799b29c3fdc6b5894 Mon Sep 17 00:00:00 2001 From: Luke Schmitt Date: Thu, 7 May 2026 22:58:32 -0500 Subject: [PATCH] Add dataset validate command --- src/trossen_cloud_cli/cli.py | 1 + src/trossen_cloud_cli/commands/datasets.py | 71 +++++++++++++++------- tests/test_validators.py | 49 +++++++++++++++ 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/src/trossen_cloud_cli/cli.py b/src/trossen_cloud_cli/cli.py index 19a7430..b5e7977 100644 --- a/src/trossen_cloud_cli/cli.py +++ b/src/trossen_cloud_cli/cli.py @@ -52,6 +52,7 @@ def main_callback( [bold]Datasets:[/bold] + trc dataset validate ./my-data trc dataset upload ./my-data --name my-dataset trc dataset import-hf org/dataset-name --name my-dataset trc dataset download ./output diff --git a/src/trossen_cloud_cli/commands/datasets.py b/src/trossen_cloud_cli/commands/datasets.py index 7e14727..248df6d 100644 --- a/src/trossen_cloud_cli/commands/datasets.py +++ b/src/trossen_cloud_cli/commands/datasets.py @@ -54,6 +54,24 @@ def _parse_dataset_type(value: str | None) -> DatasetType | None: ) from None +def _print_validation_warnings(warnings: list[str]) -> None: + """Print a heading and a list of dataset validation warnings to the console.""" + console.print(f"\n[warning]Found {len(warnings)} validation warning(s):[/warning]") + for w in warnings: + print_warning(w) + + +def _validate_or_confirm(path: Path, dataset_type: DatasetType, force: bool) -> None: + """Run validation; if warnings, print them and prompt to continue (unless ``force``).""" + warnings = validate_dataset(path, dataset_type) + if not warnings: + return + _print_validation_warnings(warnings) + console.print() + if not force and not typer.confirm("Continue with upload?"): + raise typer.Exit(0) + + def _resolve_dataset_type(path: Path, dataset_type: DatasetType | None) -> DatasetType: """Auto-detect dataset type if not provided, or exit with an error.""" if dataset_type is not None: @@ -87,6 +105,35 @@ async def resolve_dataset_identifier(client: ApiClient, identifier: str) -> dict return await client.get(f"/datasets/{identifier}") +@app.command("validate") +def validate( + path: Annotated[ + Path, + typer.Argument( + help="Path to the dataset directory or file to validate", + exists=True, + resolve_path=True, + ), + ], + dataset_type_str: _DatasetTypeOption = None, +) -> None: + """ + Validate a local dataset against its type-specific spec without uploading. + + Exit codes: 0 on success, 1 on validation warnings or an undetectable type, + 2 on invalid ``--type`` (raised by Typer). + """ + dataset_type = _resolve_dataset_type(path, _parse_dataset_type(dataset_type_str)) + + warnings = validate_dataset(path, dataset_type) + if not warnings: + print_success(f"Dataset is valid ({dataset_type.value})") + return + + _print_validation_warnings(warnings) + raise typer.Exit(1) + + @app.command("upload") def upload( path: Annotated[ @@ -131,17 +178,7 @@ def upload( print_error("Invalid JSON metadata") raise typer.Exit(1) - # Validate dataset before upload - validation_warnings = validate_dataset(path, dataset_type) - if validation_warnings: - console.print( - f"\n[warning]Found {len(validation_warnings)} validation warning(s):[/warning]" - ) - for w in validation_warnings: - print_warning(w) - console.print() - if not force and not typer.confirm("Continue with upload?"): - raise typer.Exit(0) + _validate_or_confirm(path, dataset_type, force) try: dataset = asyncio.run( @@ -267,17 +304,7 @@ def import_hf( dataset_type = _resolve_dataset_type(local_path, parsed_type) - # Validate dataset before upload - validation_warnings = validate_dataset(local_path, dataset_type) - if validation_warnings: - console.print( - f"\n[warning]Found {len(validation_warnings)} validation warning(s):[/warning]" - ) - for w in validation_warnings: - print_warning(w) - console.print() - if not force and not typer.confirm("Continue with upload?"): - raise typer.Exit(0) + _validate_or_confirm(local_path, dataset_type, force) # Upload to Trossen Cloud dataset = asyncio.run( diff --git a/tests/test_validators.py b/tests/test_validators.py index 5619227..7f1535c 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -856,3 +856,52 @@ def test_import_hf_invalid_type_fails_before_download(self, tmp_path): assert result.exit_code != 0 assert "invalid" in result.output.lower() m.assert_not_called() + + +# ── Validate command tests ────────────────────────────────────────────────── + + +class TestValidateCommand: + def test_valid_dataset_exits_zero(self, tmp_path): + ds = _make_valid_lerobot(tmp_path) + result = runner.invoke(app, ["dataset", "validate", str(ds)]) + assert result.exit_code == 0 + assert "valid" in result.output.lower() + assert "lerobot_v3" in result.output + + def test_warnings_exit_one(self, tmp_path): + ds = tmp_path / "broken" + ds.mkdir() + (ds / "meta").mkdir() + (ds / "meta" / "info.json").write_text("{}") + result = runner.invoke(app, ["dataset", "validate", str(ds)]) + assert result.exit_code == 1 + assert "warning" in result.output.lower() + + def test_undetectable_type_errors(self, tmp_path): + empty = tmp_path / "empty" + empty.mkdir() + result = runner.invoke(app, ["dataset", "validate", str(empty)]) + assert result.exit_code == 1 + assert "detect" in result.output.lower() + + def test_explicit_type_alias(self, tmp_path): + ds = _make_valid_mcap_dataset(tmp_path) + result = runner.invoke(app, ["dataset", "validate", str(ds), "--type", "mcap"]) + assert result.exit_code == 0 + assert "trossenmcap" in result.output + + def test_does_not_require_auth(self, tmp_path): + """validate is purely local and must not require an API token.""" + ds = _make_valid_lerobot(tmp_path) + with patch("trossen_cloud_cli.auth.get_token", return_value=None): + result = runner.invoke(app, ["dataset", "validate", str(ds)]) + assert result.exit_code == 0 + + def test_validates_single_mcap_file(self, tmp_path): + """A single .mcap file is a valid input — validate_mcap handles file paths.""" + f = tmp_path / "episode_000000.mcap" + _make_valid_mcap_file(f) + result = runner.invoke(app, ["dataset", "validate", str(f)]) + assert result.exit_code == 0 + assert "trossenmcap" in result.output