Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/trossen_cloud_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def main_callback(

[bold]Datasets:[/bold]

trc dataset validate ./my-data
trc dataset upload ./my-data --name my-dataset
trc dataset import-hf org/dataset-name --name my-dataset
trc dataset download <dataset-id> ./output
Expand Down
71 changes: 49 additions & 22 deletions src/trossen_cloud_cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def _parse_dataset_type(value: str | None) -> DatasetType | None:
) from None


def _print_validation_warnings(warnings: list[str]) -> None:
"""Print a heading and a list of dataset validation warnings to the console."""
console.print(f"\n[warning]Found {len(warnings)} validation warning(s):[/warning]")
for w in warnings:
print_warning(w)


def _validate_or_confirm(path: Path, dataset_type: DatasetType, force: bool) -> None:
"""Run validation; if warnings, print them and prompt to continue (unless ``force``)."""
warnings = validate_dataset(path, dataset_type)
if not warnings:
return
_print_validation_warnings(warnings)
console.print()
if not force and not typer.confirm("Continue with upload?"):
raise typer.Exit(0)


def _resolve_dataset_type(path: Path, dataset_type: DatasetType | None) -> DatasetType:
"""Auto-detect dataset type if not provided, or exit with an error."""
if dataset_type is not None:
Expand Down Expand Up @@ -87,6 +105,35 @@ async def resolve_dataset_identifier(client: ApiClient, identifier: str) -> dict
return await client.get(f"/datasets/{identifier}")


@app.command("validate")
def validate(
path: Annotated[
Path,
typer.Argument(
help="Path to the dataset directory or file to validate",
exists=True,
resolve_path=True,
),
],
dataset_type_str: _DatasetTypeOption = None,
) -> None:
"""
Validate a local dataset against its type-specific spec without uploading.

Exit codes: 0 on success, 1 on validation warnings or an undetectable type,
2 on invalid ``--type`` (raised by Typer).
"""
dataset_type = _resolve_dataset_type(path, _parse_dataset_type(dataset_type_str))

warnings = validate_dataset(path, dataset_type)
if not warnings:
print_success(f"Dataset is valid ({dataset_type.value})")
return

_print_validation_warnings(warnings)
raise typer.Exit(1)


@app.command("upload")
def upload(
path: Annotated[
Expand Down Expand Up @@ -131,17 +178,7 @@ def upload(
print_error("Invalid JSON metadata")
raise typer.Exit(1)

# Validate dataset before upload
validation_warnings = validate_dataset(path, dataset_type)
if validation_warnings:
console.print(
f"\n[warning]Found {len(validation_warnings)} validation warning(s):[/warning]"
)
for w in validation_warnings:
print_warning(w)
console.print()
if not force and not typer.confirm("Continue with upload?"):
raise typer.Exit(0)
_validate_or_confirm(path, dataset_type, force)

try:
dataset = asyncio.run(
Expand Down Expand Up @@ -267,17 +304,7 @@ def import_hf(

dataset_type = _resolve_dataset_type(local_path, parsed_type)

# Validate dataset before upload
validation_warnings = validate_dataset(local_path, dataset_type)
if validation_warnings:
console.print(
f"\n[warning]Found {len(validation_warnings)} validation warning(s):[/warning]"
)
for w in validation_warnings:
print_warning(w)
console.print()
if not force and not typer.confirm("Continue with upload?"):
raise typer.Exit(0)
_validate_or_confirm(local_path, dataset_type, force)

# Upload to Trossen Cloud
dataset = asyncio.run(
Expand Down
49 changes: 49 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,52 @@ def test_import_hf_invalid_type_fails_before_download(self, tmp_path):
assert result.exit_code != 0
assert "invalid" in result.output.lower()
m.assert_not_called()


# ── Validate command tests ──────────────────────────────────────────────────


class TestValidateCommand:
def test_valid_dataset_exits_zero(self, tmp_path):
ds = _make_valid_lerobot(tmp_path)
result = runner.invoke(app, ["dataset", "validate", str(ds)])
assert result.exit_code == 0
assert "valid" in result.output.lower()
assert "lerobot_v3" in result.output

def test_warnings_exit_one(self, tmp_path):
ds = tmp_path / "broken"
ds.mkdir()
(ds / "meta").mkdir()
(ds / "meta" / "info.json").write_text("{}")
result = runner.invoke(app, ["dataset", "validate", str(ds)])
assert result.exit_code == 1
assert "warning" in result.output.lower()

def test_undetectable_type_errors(self, tmp_path):
empty = tmp_path / "empty"
empty.mkdir()
result = runner.invoke(app, ["dataset", "validate", str(empty)])
assert result.exit_code == 1
assert "detect" in result.output.lower()

def test_explicit_type_alias(self, tmp_path):
ds = _make_valid_mcap_dataset(tmp_path)
result = runner.invoke(app, ["dataset", "validate", str(ds), "--type", "mcap"])
assert result.exit_code == 0
assert "trossenmcap" in result.output

def test_does_not_require_auth(self, tmp_path):
"""validate is purely local and must not require an API token."""
ds = _make_valid_lerobot(tmp_path)
with patch("trossen_cloud_cli.auth.get_token", return_value=None):
result = runner.invoke(app, ["dataset", "validate", str(ds)])
assert result.exit_code == 0

def test_validates_single_mcap_file(self, tmp_path):
"""A single .mcap file is a valid input — validate_mcap handles file paths."""
f = tmp_path / "episode_000000.mcap"
_make_valid_mcap_file(f)
result = runner.invoke(app, ["dataset", "validate", str(f)])
assert result.exit_code == 0
assert "trossenmcap" in result.output