Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .coverage
Binary file not shown.
272 changes: 272 additions & 0 deletions tests/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,275 @@ def test_partition_key_single_large_group(
# rows_unique should reduce to 1 row
assert len(result) == 1
assert result["patient_id"][0] == "P001"


class TestChunkValidationResultStr:
"""Tests for ChunkValidationResult.__str__() edge cases."""

def test_str_valid_with_warnings(self) -> None:
"""Test __str__ for valid result with warnings."""
result = ChunkValidationResult(
is_valid=True,
warnings=["Unknown operation 'custom_op' - assuming chunkable"],
)
s = str(result)
assert "compatible" in s
assert "Warnings:" in s
assert "custom_op" in s

def test_str_group_dependent_ops(self) -> None:
"""Test __str__ for result with group-dependent operations."""
result = ChunkValidationResult(
is_valid=True,
group_dependent_ops=[("rows_unique", ["patient_id"])],
)
s = str(result)
assert "compatible" in s
assert "Group-dependent operations" in s
assert "rows_unique" in s
assert "patient_id" in s


class TestChunkedProtocolEdgeCases:
"""Tests for ChunkedProtocol edge cases."""

def test_output_hash_empty(self) -> None:
"""Test output_hash returns empty string when no chunks."""
protocol = ChunkedProtocol()
assert protocol.output_hash() == ""

def test_repr(self) -> None:
"""Test __repr__ output."""
protocol = ChunkedProtocol()
protocol.add_chunk(
ChunkInfo(
chunk_index=0,
input_rows=100,
output_rows=90,
input_hash="abc",
output_hash="def",
elapsed_seconds=0.1,
)
)
r = repr(protocol)
assert "ChunkedProtocol" in r
assert "1 chunks" in r
assert "100 rows" in r

def test_len(self) -> None:
"""Test __len__ returns number of chunks."""
protocol = ChunkedProtocol()
assert len(protocol) == 0
protocol.add_chunk(
ChunkInfo(
chunk_index=0,
input_rows=100,
output_rows=90,
input_hash="abc",
output_hash="def",
elapsed_seconds=0.1,
)
)
assert len(protocol) == 1

def test_summary_with_metadata(self) -> None:
"""Test summary includes metadata."""
protocol = ChunkedProtocol()
protocol.set_source(path="test.parquet", partition_key=["id"], chunk_size=1000)
protocol.set_metadata(author="test_user", project="test_project")
protocol.add_chunk(
ChunkInfo(
chunk_index=0,
input_rows=100,
output_rows=90,
input_hash="abc",
output_hash="def",
elapsed_seconds=0.1,
)
)
summary = protocol.summary()
assert "author: test_user" in summary
assert "project: test_project" in summary

def test_summary_with_row_change(self) -> None:
"""Test summary shows row change when rows differ."""
protocol = ChunkedProtocol()
protocol.set_source(path="test.parquet", partition_key=None, chunk_size=1000)
protocol.add_chunk(
ChunkInfo(
chunk_index=0,
input_rows=100,
output_rows=80,
input_hash="abc",
output_hash="def",
elapsed_seconds=0.1,
)
)
summary = protocol.summary()
assert "Row change:" in summary
assert "-20" in summary

def test_print(self, capsys: pytest.CaptureFixture[str]) -> None:
"""Test print method."""
protocol = ChunkedProtocol()
protocol.set_source(path="test.parquet", partition_key=None, chunk_size=1000)
protocol.print()
captured = capsys.readouterr()
assert "CHUNKED PROCESSING PROTOCOL" in captured.out

def test_from_json_string(self) -> None:
"""Test from_json with JSON string."""
original = ChunkedProtocol()
original.set_source(path="test.parquet", partition_key=["id"], chunk_size=1000)
json_str = original.to_json()
restored = ChunkedProtocol.from_json(json_str)
assert restored._source_path == "test.parquet"

def test_chunks_property(self) -> None:
"""Test chunks property returns list of ChunkInfo."""
protocol = ChunkedProtocol()
chunk = ChunkInfo(
chunk_index=0,
input_rows=100,
output_rows=90,
input_hash="abc",
output_hash="def",
elapsed_seconds=0.1,
)
protocol.add_chunk(chunk)
chunks = protocol.chunks
assert len(chunks) == 1
assert chunks[0].input_rows == 100


class TestValidateChunkedPipelineUnknownOp:
"""Tests for unknown operations in chunked validation."""

def test_unknown_operation_warning(self) -> None:
"""Test that unknown operations produce a warning."""

# Create a fake operation that's not in the registry
class FakeOp:
@staticmethod
def _custom_unknown_op(df: pl.DataFrame) -> pl.DataFrame:
return df

fake_ops = [(FakeOp._custom_unknown_op, {})]
result = validate_chunked_pipeline(fake_ops, partition_key=None)
assert result.is_valid # Unknown ops don't block
assert len(result.warnings) == 1
assert "custom_unknown_op" in result.warnings[0]


class TestPartitionChunkIteratorEdgeCases:
"""Tests for partition chunk iterator edge cases."""

@pytest.fixture
def empty_parquet_with_partition(self, tmp_path: Path) -> Path:
"""Create an empty Parquet file with partition column."""
df = pl.DataFrame(
{
"patient_id": pl.Series([], dtype=pl.Utf8),
"value": pl.Series([], dtype=pl.Float64),
}
)
path = tmp_path / "empty_partitioned.parquet"
df.write_parquet(path)
return path

def test_partition_iterator_empty_file(
self, empty_parquet_with_partition: Path
) -> None:
"""Test partition chunk iterator with empty file returns early."""
plan = TransformPlan().rows_unique(columns=["patient_id"])
result, protocol = plan.process_chunked(
source=empty_parquet_with_partition,
partition_key="patient_id",
chunk_size=10,
)
assert len(result) == 0
assert protocol.num_chunks == 0


class TestValidateChunkedGroupColsNormalization:
"""Tests for group_cols string normalization in chunked validation."""

def test_group_cols_as_string_normalized_to_list(self) -> None:
"""Test that group_cols as string is normalized to list."""
# When group_by is specified as a string (not a list), it should be normalized
plan = TransformPlan().math_cumsum("value", new_column="cumsum", group_by="id")
result = validate_chunked_pipeline(plan._operations, partition_key="id")
assert result.is_valid
# Check that group_dependent_ops has the operation with group_cols as list
assert len(result.group_dependent_ops) == 1
op_name, group_cols = result.group_dependent_ops[0]
assert op_name == "math_cumsum"
assert group_cols == ["id"]

def test_group_cols_string_directly_in_params(self) -> None:
"""Test group_cols normalization when param is a string directly."""

# Directly create operation tuple with string group_cols to test normalization
def _mock_group_op(df: pl.DataFrame, columns: str) -> pl.DataFrame:
_ = columns # Unused but required in signature
return df

# Register as GROUP_DEPENDENT with 'columns' param
from transformplan.chunking import (
OPERATION_CHUNK_REGISTRY,
ChunkMode,
OperationMeta,
)

# Temporarily add a mock operation with string columns
OPERATION_CHUNK_REGISTRY["mock_group_op"] = OperationMeta(
ChunkMode.GROUP_DEPENDENT, group_param="columns"
)
try:
# Create fake operation with string columns (not list)
fake_ops = [(_mock_group_op, {"columns": "patient_id"})]
result = validate_chunked_pipeline(fake_ops, partition_key="patient_id")
assert result.is_valid
# The group_cols should be normalized to a list
op_name, group_cols = result.group_dependent_ops[0]
assert op_name == "mock_group_op"
assert group_cols == ["patient_id"]
finally:
# Cleanup
del OPERATION_CHUNK_REGISTRY["mock_group_op"]


class TestProcessChunkedPendingRows:
"""Tests for pending rows handling in partition chunk iterator."""

@pytest.fixture
def multi_group_parquet(self, tmp_path: Path) -> Path:
"""Create a Parquet file with multiple groups for testing pending rows."""
# Create data where groups don't evenly divide into chunks
# Group P001 has 5 rows, P002 has 3 rows, P003 has 2 rows
df = pl.DataFrame(
{
"patient_id": ["P001"] * 5 + ["P002"] * 3 + ["P003"] * 2,
"value": list(range(10)),
}
)
path = tmp_path / "multi_group.parquet"
df.write_parquet(path)
return path

def test_process_chunked_with_pending_rows(
self, multi_group_parquet: Path
) -> None:
"""Test chunked processing handles pending rows correctly."""
# Use a chunk size that will cause groups to be split
plan = TransformPlan().rows_unique(columns=["patient_id"])
result, protocol = plan.process_chunked(
source=multi_group_parquet,
partition_key="patient_id",
chunk_size=4, # Will force partial groups
)
# Should have 3 unique patients
assert len(result) == 3
assert set(result["patient_id"].to_list()) == {"P001", "P002", "P003"}
# Should have processed at least 1 chunk
assert protocol.num_chunks >= 1
Loading
Loading