diff --git a/.coverage b/.coverage index 75369ae..0e7c2ec 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 06420b2..b2c26d3 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -589,3 +589,275 @@ def test_partition_key_single_large_group( # rows_unique should reduce to 1 row assert len(result) == 1 assert result["patient_id"][0] == "P001" + + +class TestChunkValidationResultStr: + """Tests for ChunkValidationResult.__str__() edge cases.""" + + def test_str_valid_with_warnings(self) -> None: + """Test __str__ for valid result with warnings.""" + result = ChunkValidationResult( + is_valid=True, + warnings=["Unknown operation 'custom_op' - assuming chunkable"], + ) + s = str(result) + assert "compatible" in s + assert "Warnings:" in s + assert "custom_op" in s + + def test_str_group_dependent_ops(self) -> None: + """Test __str__ for result with group-dependent operations.""" + result = ChunkValidationResult( + is_valid=True, + group_dependent_ops=[("rows_unique", ["patient_id"])], + ) + s = str(result) + assert "compatible" in s + assert "Group-dependent operations" in s + assert "rows_unique" in s + assert "patient_id" in s + + +class TestChunkedProtocolEdgeCases: + """Tests for ChunkedProtocol edge cases.""" + + def test_output_hash_empty(self) -> None: + """Test output_hash returns empty string when no chunks.""" + protocol = ChunkedProtocol() + assert protocol.output_hash() == "" + + def test_repr(self) -> None: + """Test __repr__ output.""" + protocol = ChunkedProtocol() + protocol.add_chunk( + ChunkInfo( + chunk_index=0, + input_rows=100, + output_rows=90, + input_hash="abc", + output_hash="def", + elapsed_seconds=0.1, + ) + ) + r = repr(protocol) + assert "ChunkedProtocol" in r + assert "1 chunks" in r + assert "100 rows" in r + + def test_len(self) -> None: + """Test __len__ returns number of chunks.""" + protocol = ChunkedProtocol() + assert len(protocol) == 0 + protocol.add_chunk( + ChunkInfo( + chunk_index=0, + input_rows=100, + output_rows=90, + input_hash="abc", + output_hash="def", + elapsed_seconds=0.1, + ) + ) + assert len(protocol) == 1 + + def test_summary_with_metadata(self) -> None: + """Test summary includes metadata.""" + protocol = ChunkedProtocol() + protocol.set_source(path="test.parquet", partition_key=["id"], chunk_size=1000) + protocol.set_metadata(author="test_user", project="test_project") + protocol.add_chunk( + ChunkInfo( + chunk_index=0, + input_rows=100, + output_rows=90, + input_hash="abc", + output_hash="def", + elapsed_seconds=0.1, + ) + ) + summary = protocol.summary() + assert "author: test_user" in summary + assert "project: test_project" in summary + + def test_summary_with_row_change(self) -> None: + """Test summary shows row change when rows differ.""" + protocol = ChunkedProtocol() + protocol.set_source(path="test.parquet", partition_key=None, chunk_size=1000) + protocol.add_chunk( + ChunkInfo( + chunk_index=0, + input_rows=100, + output_rows=80, + input_hash="abc", + output_hash="def", + elapsed_seconds=0.1, + ) + ) + summary = protocol.summary() + assert "Row change:" in summary + assert "-20" in summary + + def test_print(self, capsys: pytest.CaptureFixture[str]) -> None: + """Test print method.""" + protocol = ChunkedProtocol() + protocol.set_source(path="test.parquet", partition_key=None, chunk_size=1000) + protocol.print() + captured = capsys.readouterr() + assert "CHUNKED PROCESSING PROTOCOL" in captured.out + + def test_from_json_string(self) -> None: + """Test from_json with JSON string.""" + original = ChunkedProtocol() + original.set_source(path="test.parquet", partition_key=["id"], chunk_size=1000) + json_str = original.to_json() + restored = ChunkedProtocol.from_json(json_str) + assert restored._source_path == "test.parquet" + + def test_chunks_property(self) -> None: + """Test chunks property returns list of ChunkInfo.""" + protocol = ChunkedProtocol() + chunk = ChunkInfo( + chunk_index=0, + input_rows=100, + output_rows=90, + input_hash="abc", + output_hash="def", + elapsed_seconds=0.1, + ) + protocol.add_chunk(chunk) + chunks = protocol.chunks + assert len(chunks) == 1 + assert chunks[0].input_rows == 100 + + +class TestValidateChunkedPipelineUnknownOp: + """Tests for unknown operations in chunked validation.""" + + def test_unknown_operation_warning(self) -> None: + """Test that unknown operations produce a warning.""" + + # Create a fake operation that's not in the registry + class FakeOp: + @staticmethod + def _custom_unknown_op(df: pl.DataFrame) -> pl.DataFrame: + return df + + fake_ops = [(FakeOp._custom_unknown_op, {})] + result = validate_chunked_pipeline(fake_ops, partition_key=None) + assert result.is_valid # Unknown ops don't block + assert len(result.warnings) == 1 + assert "custom_unknown_op" in result.warnings[0] + + +class TestPartitionChunkIteratorEdgeCases: + """Tests for partition chunk iterator edge cases.""" + + @pytest.fixture + def empty_parquet_with_partition(self, tmp_path: Path) -> Path: + """Create an empty Parquet file with partition column.""" + df = pl.DataFrame( + { + "patient_id": pl.Series([], dtype=pl.Utf8), + "value": pl.Series([], dtype=pl.Float64), + } + ) + path = tmp_path / "empty_partitioned.parquet" + df.write_parquet(path) + return path + + def test_partition_iterator_empty_file( + self, empty_parquet_with_partition: Path + ) -> None: + """Test partition chunk iterator with empty file returns early.""" + plan = TransformPlan().rows_unique(columns=["patient_id"]) + result, protocol = plan.process_chunked( + source=empty_parquet_with_partition, + partition_key="patient_id", + chunk_size=10, + ) + assert len(result) == 0 + assert protocol.num_chunks == 0 + + +class TestValidateChunkedGroupColsNormalization: + """Tests for group_cols string normalization in chunked validation.""" + + def test_group_cols_as_string_normalized_to_list(self) -> None: + """Test that group_cols as string is normalized to list.""" + # When group_by is specified as a string (not a list), it should be normalized + plan = TransformPlan().math_cumsum("value", new_column="cumsum", group_by="id") + result = validate_chunked_pipeline(plan._operations, partition_key="id") + assert result.is_valid + # Check that group_dependent_ops has the operation with group_cols as list + assert len(result.group_dependent_ops) == 1 + op_name, group_cols = result.group_dependent_ops[0] + assert op_name == "math_cumsum" + assert group_cols == ["id"] + + def test_group_cols_string_directly_in_params(self) -> None: + """Test group_cols normalization when param is a string directly.""" + + # Directly create operation tuple with string group_cols to test normalization + def _mock_group_op(df: pl.DataFrame, columns: str) -> pl.DataFrame: + _ = columns # Unused but required in signature + return df + + # Register as GROUP_DEPENDENT with 'columns' param + from transformplan.chunking import ( + OPERATION_CHUNK_REGISTRY, + ChunkMode, + OperationMeta, + ) + + # Temporarily add a mock operation with string columns + OPERATION_CHUNK_REGISTRY["mock_group_op"] = OperationMeta( + ChunkMode.GROUP_DEPENDENT, group_param="columns" + ) + try: + # Create fake operation with string columns (not list) + fake_ops = [(_mock_group_op, {"columns": "patient_id"})] + result = validate_chunked_pipeline(fake_ops, partition_key="patient_id") + assert result.is_valid + # The group_cols should be normalized to a list + op_name, group_cols = result.group_dependent_ops[0] + assert op_name == "mock_group_op" + assert group_cols == ["patient_id"] + finally: + # Cleanup + del OPERATION_CHUNK_REGISTRY["mock_group_op"] + + +class TestProcessChunkedPendingRows: + """Tests for pending rows handling in partition chunk iterator.""" + + @pytest.fixture + def multi_group_parquet(self, tmp_path: Path) -> Path: + """Create a Parquet file with multiple groups for testing pending rows.""" + # Create data where groups don't evenly divide into chunks + # Group P001 has 5 rows, P002 has 3 rows, P003 has 2 rows + df = pl.DataFrame( + { + "patient_id": ["P001"] * 5 + ["P002"] * 3 + ["P003"] * 2, + "value": list(range(10)), + } + ) + path = tmp_path / "multi_group.parquet" + df.write_parquet(path) + return path + + def test_process_chunked_with_pending_rows( + self, multi_group_parquet: Path + ) -> None: + """Test chunked processing handles pending rows correctly.""" + # Use a chunk size that will cause groups to be split + plan = TransformPlan().rows_unique(columns=["patient_id"]) + result, protocol = plan.process_chunked( + source=multi_group_parquet, + partition_key="patient_id", + chunk_size=4, # Will force partial groups + ) + # Should have 3 unique patients + assert len(result) == 3 + assert set(result["patient_id"].to_list()) == {"P001", "P002", "P003"} + # Should have processed at least 1 chunk + assert protocol.num_chunks >= 1 diff --git a/tests/test_core.py b/tests/test_core.py index 3597f98..a9488e3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -252,3 +252,152 @@ def test_from_dict_unknown_operation(self) -> None: } with pytest.raises(ValueError, match="Unknown operation"): TransformPlan.from_dict(d) + + +class TestFormatParamsAsPython: + """Tests for _format_params_as_python() method.""" + + def test_format_params_none_value_skipped(self) -> None: + """Test that None values are skipped in params formatting.""" + plan = TransformPlan() + # col_add with no expr (None) should skip the None value + plan.col_add("new_col", value="test") + code = plan.to_python() + # The output should not have expr=None in it + assert "expr=None" not in code + assert 'new_column="new_col"' in code + + def test_format_params_string_value(self) -> None: + """Test that string values are formatted with quotes.""" + plan = TransformPlan().col_rename("old", "new") + code = plan.to_python() + assert 'column="old"' in code + assert 'new_name="new"' in code + + def test_format_params_bool_value(self) -> None: + """Test that boolean values are formatted correctly.""" + plan = TransformPlan().rows_sort("age", descending=True) + code = plan.to_python() + assert "descending=True" in code + + def test_format_params_int_value(self) -> None: + """Test that int values are formatted correctly.""" + plan = TransformPlan().math_add("value", 10) + code = plan.to_python() + assert "value=10" in code + + def test_format_params_float_value(self) -> None: + """Test that float values are formatted correctly.""" + plan = TransformPlan().math_multiply("price", 1.5) + code = plan.to_python() + assert "value=1.5" in code + + def test_format_params_list_value(self) -> None: + """Test that list values are formatted correctly.""" + plan = TransformPlan().col_select(["a", "b", "c"]) + code = plan.to_python() + assert "['a', 'b', 'c']" in code + + def test_format_params_dict_value(self) -> None: + """Test that dict values are formatted correctly.""" + plan = TransformPlan().map_values("status", {"A": "Active", "B": "Blocked"}) + code = plan.to_python() + assert "mapping=" in code + assert "'A'" in code + + +class TestFormatFilterAsPython: + """Tests for _format_filter_as_python() method.""" + + def test_format_filter_and(self) -> None: + """Test Python code generation for And filter.""" + plan = TransformPlan().rows_filter((Col("age") >= 30) & (Col("id") > 1)) + code = plan.to_python() + assert "&" in code + assert 'Col("age")' in code or "Col(" in code + + def test_format_filter_or(self) -> None: + """Test Python code generation for Or filter.""" + plan = TransformPlan().rows_filter((Col("age") >= 30) | (Col("id") == 1)) + code = plan.to_python() + assert "|" in code + + def test_format_filter_not(self) -> None: + """Test Python code generation for Not filter.""" + plan = TransformPlan().rows_filter(~(Col("active") == True)) # noqa: E712 + code = plan.to_python() + assert "~" in code + + def test_format_filter_is_in(self) -> None: + """Test Python code generation for IsIn filter.""" + plan = TransformPlan().rows_filter(Col("name").is_in(["Alice", "Bob"])) + code = plan.to_python() + assert ".is_in(" in code + assert "['Alice', 'Bob']" in code + + def test_format_filter_is_null(self) -> None: + """Test Python code generation for IsNull filter.""" + plan = TransformPlan().rows_filter(Col("name").is_null()) + code = plan.to_python() + assert ".is_null()" in code + + def test_format_filter_is_not_null(self) -> None: + """Test Python code generation for IsNotNull filter.""" + plan = TransformPlan().rows_filter(Col("name").is_not_null()) + code = plan.to_python() + assert ".is_not_null()" in code + + def test_format_filter_between(self) -> None: + """Test Python code generation for Between filter.""" + plan = TransformPlan().rows_filter(Col("age").between(25, 40)) + code = plan.to_python() + assert ".between(" in code + assert "25" in code + assert "40" in code + + def test_format_filter_str_contains(self) -> None: + """Test Python code generation for StrContains filter.""" + df = pl.DataFrame({"email": ["test@example.com"]}) + plan = TransformPlan().rows_filter(Col("email").str_contains("@example")) + plan.validate(df) # Just validate to ensure it's valid + code = plan.to_python() + assert ".str_contains(" in code + assert "@example" in code + + def test_format_filter_str_starts_with(self) -> None: + """Test Python code generation for StrStartsWith filter.""" + df = pl.DataFrame({"code": ["PRD-001"]}) + plan = TransformPlan().rows_filter(Col("code").str_starts_with("PRD")) + plan.validate(df) + code = plan.to_python() + assert ".str_starts_with(" in code + assert "'PRD'" in code + + def test_format_filter_str_ends_with(self) -> None: + """Test Python code generation for StrEndsWith filter.""" + df = pl.DataFrame({"file": ["data.csv"]}) + plan = TransformPlan().rows_filter(Col("file").str_ends_with(".csv")) + plan.validate(df) + code = plan.to_python() + assert ".str_ends_with(" in code + assert "'.csv'" in code + + def test_format_filter_unknown_fallback(self) -> None: + """Test Python code generation for unknown filter type (fallback).""" + plan = TransformPlan() + # Call the private method directly with an unknown filter type + result = plan._format_filter_as_python({"type": "unknown_filter", "column": "x"}) + assert "Filter.from_dict(" in result + + +class TestFormatParamsEdgeCases: + """Tests for _format_params_as_python edge cases.""" + + def test_format_params_custom_type(self) -> None: + """Test formatting with custom/unusual type (else branch).""" + plan = TransformPlan() + # Create a tuple value which falls into the else branch + # Use a frozenset or any non-standard type + result = plan._format_params_as_python({"custom": frozenset([1, 2, 3])}) + assert "custom=" in result + assert "frozenset" in result diff --git a/tests/test_filters.py b/tests/test_filters.py index 0f965e1..9b7eb15 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -568,3 +568,37 @@ def test_filter_roundtrip_produces_same_results(self) -> None: result2 = df.filter(restored.to_expr()) assert result1.equals(result2) + + +class TestFilterSerializationMissing: + """Tests for missing filter serialization coverage.""" + + def test_lt_to_dict(self) -> None: + """Test Lt.to_dict() serialization.""" + f = Lt("age", 65) + d = f.to_dict() + assert d == {"type": "lt", "column": "age", "value": 65} + + def test_is_not_null_to_dict(self) -> None: + """Test IsNotNull.to_dict() serialization.""" + f = IsNotNull("name") + d = f.to_dict() + assert d == {"type": "is_not_null", "column": "name"} + + def test_str_starts_with_roundtrip(self) -> None: + """Test StrStartsWith round-trip via Filter.from_dict().""" + original = StrStartsWith("code", "PRD-") + d = original.to_dict() + restored = Filter.from_dict(d) + assert isinstance(restored, StrStartsWith) + assert restored.column == "code" + assert restored.prefix == "PRD-" + + def test_str_ends_with_roundtrip(self) -> None: + """Test StrEndsWith round-trip via Filter.from_dict().""" + original = StrEndsWith("file", ".csv") + d = original.to_dict() + restored = Filter.from_dict(d) + assert isinstance(restored, StrEndsWith) + assert restored.column == "file" + assert restored.suffix == ".csv" diff --git a/tests/test_map.py b/tests/test_map.py index f8b32e9..c1aae23 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -278,3 +278,62 @@ def test_map_values_type_coercion(self) -> None: ) result, _ = plan.process(df) assert result["code"].dtype == pl.Utf8 + + +class TestMapDiscretizeRightFalse: + """Tests for map_discretize with right=False.""" + + def test_map_discretize_right_false(self) -> None: + """Test discretization with right=False (left-closed intervals).""" + df = pl.DataFrame({"value": [0, 50, 100]}) + plan = TransformPlan().map_discretize( + "value", bins=[50], labels=["Low", "High"], new_column="category", right=False + ) + result, _ = plan.process(df) + # With right=False: [left, right) + # 0 -> [-inf, 50) -> Low + # 50 -> [50, inf) -> High (50 is at the boundary, goes to High) + # 100 -> [50, inf) -> High + assert result["category"][0] == "Low" + assert result["category"][1] == "High" + assert result["category"][2] == "High" + + def test_map_discretize_right_false_auto_labels(self) -> None: + """Test discretization with right=False and auto-generated labels.""" + df = pl.DataFrame({"value": [10, 50, 90]}) + plan = TransformPlan().map_discretize( + "value", bins=[30, 70], new_column="bucket", right=False + ) + result, _ = plan.process(df) + assert "bucket" in result.columns + # Auto-labels should be like "[-inf, 30)", "[30, 70)", "[70, inf)" + # 10 -> "[-inf, 30)" + # 50 -> "[30, 70)" + # 90 -> "[70, inf)" + labels = result["bucket"].to_list() + assert "[-inf, 30)" in labels[0] + assert "[30, 70)" in labels[1] or "[30.0, 70.0)" in labels[1] + assert "[70" in labels[2] + + def test_map_discretize_right_false_multiple_bins(self) -> None: + """Test discretization with right=False and multiple bins.""" + df = pl.DataFrame({"score": [0, 60, 70, 80, 100]}) + plan = TransformPlan().map_discretize( + "score", + bins=[60, 70, 80, 90], + labels=["F", "D", "C", "B", "A"], + new_column="grade", + right=False, + ) + result, _ = plan.process(df) + # With right=False (left-closed): + # 0 -> [-inf, 60) -> F + # 60 -> [60, 70) -> D + # 70 -> [70, 80) -> C + # 80 -> [80, 90) -> B + # 100 -> [90, inf) -> A + assert result["grade"][0] == "F" + assert result["grade"][1] == "D" + assert result["grade"][2] == "C" + assert result["grade"][3] == "B" + assert result["grade"][4] == "A" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a525be5..c7d18c5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -298,3 +298,190 @@ def test_step_output_hash(self, basic_df: pl.DataFrame) -> None: for step in d["steps"]: assert "output_hash" in step assert len(step["output_hash"]) == 16 + + +class TestProtocolOutputHashNoSteps: + """Tests for Protocol.output_hash when no steps.""" + + def test_output_hash_returns_input_hash_when_no_steps( + self, basic_df: pl.DataFrame + ) -> None: + """Test that output_hash returns input_hash when no steps.""" + protocol = Protocol() + input_hash = frame_hash(basic_df) + protocol.set_input(input_hash, basic_df.shape) + # No steps added + assert protocol.output_hash == input_hash + + +class TestProtocolSummaryMetadata: + """Tests for Protocol.summary() with metadata.""" + + def test_summary_includes_metadata(self, basic_df: pl.DataFrame) -> None: + """Test that protocol summary includes metadata.""" + plan = TransformPlan().col_drop("age") + _, protocol = plan.process(basic_df) + protocol.set_metadata(author="test_user", project="test_project") + summary = protocol.summary() + assert "author: test_user" in summary + assert "project: test_project" in summary + + +class TestProtocolFormatFilter: + """Tests for Protocol._format_filter() method.""" + + def test_format_filter_and(self) -> None: + """Test _format_filter for And filter.""" + from transformplan import Col + + df = pl.DataFrame({"age": [25, 35, 45], "id": [1, 2, 3]}) + plan = TransformPlan().rows_filter((Col("age") >= 30) & (Col("id") > 1)) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + # Should contain the & operator formatting + assert "&" in summary + + def test_format_filter_or(self) -> None: + """Test _format_filter for Or filter.""" + from transformplan import Col + + df = pl.DataFrame({"age": [25, 35, 45], "id": [1, 2, 3]}) + plan = TransformPlan().rows_filter((Col("age") >= 30) | (Col("id") == 1)) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + # Should contain the | operator formatting + assert "|" in summary + + def test_format_filter_not(self) -> None: + """Test _format_filter for Not filter.""" + from transformplan import Col + + df = pl.DataFrame({"active": [True, False, True]}) + plan = TransformPlan().rows_filter(~(Col("active") == True)) # noqa: E712 + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + # Should contain the ~ operator formatting + assert "~" in summary + + def test_format_filter_is_in_short(self) -> None: + """Test _format_filter for IsIn with short list.""" + from transformplan import Col + + df = pl.DataFrame({"name": ["Alice", "Bob", "Charlie"]}) + plan = TransformPlan().rows_filter(Col("name").is_in(["Alice", "Bob"])) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + # Should show the values + assert "name in" in summary + + def test_format_filter_is_in_long(self) -> None: + """Test _format_filter for IsIn with long list (truncated).""" + from transformplan import Col + + df = pl.DataFrame({"id": [1, 2, 3, 4, 5, 6]}) + plan = TransformPlan().rows_filter(Col("id").is_in([1, 2, 3, 4, 5, 6])) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + # Should truncate long lists + assert "items" in summary + + def test_format_filter_is_null(self, df_with_nulls: pl.DataFrame) -> None: + """Test _format_filter for IsNull.""" + from transformplan import Col + + plan = TransformPlan().rows_filter(Col("name").is_null()) + _, protocol = plan.process(df_with_nulls) + summary = protocol.summary(show_params=True) + assert "is null" in summary + + def test_format_filter_is_not_null(self, df_with_nulls: pl.DataFrame) -> None: + """Test _format_filter for IsNotNull.""" + from transformplan import Col + + plan = TransformPlan().rows_filter(Col("name").is_not_null()) + _, protocol = plan.process(df_with_nulls) + summary = protocol.summary(show_params=True) + assert "is not null" in summary + + def test_format_filter_between(self) -> None: + """Test _format_filter for Between.""" + from transformplan import Col + + df = pl.DataFrame({"age": [25, 35, 45]}) + plan = TransformPlan().rows_filter(Col("age").between(25, 40)) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + assert "between" in summary + + def test_format_filter_str_contains(self) -> None: + """Test _format_filter for StrContains.""" + from transformplan import Col + + df = pl.DataFrame({"email": ["alice@example.com", "bob@test.com"]}) + plan = TransformPlan().rows_filter(Col("email").str_contains("@example")) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + assert "contains" in summary + + def test_format_filter_str_starts_with(self) -> None: + """Test _format_filter for StrStartsWith.""" + from transformplan import Col + + df = pl.DataFrame({"code": ["PRD-001", "TST-002"]}) + plan = TransformPlan().rows_filter(Col("code").str_starts_with("PRD")) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + assert "starts_with" in summary + + def test_format_filter_str_ends_with(self) -> None: + """Test _format_filter for StrEndsWith.""" + from transformplan import Col + + df = pl.DataFrame({"file": ["data.csv", "report.pdf"]}) + plan = TransformPlan().rows_filter(Col("file").str_ends_with(".csv")) + _, protocol = plan.process(df) + summary = protocol.summary(show_params=True) + assert "ends_with" in summary + + def test_format_filter_unknown_type(self) -> None: + """Test _format_filter for unknown filter type (fallback).""" + protocol = Protocol() + # Manually test _format_filter with unknown type + result = protocol._format_filter({"type": "unknown_filter_type"}) + assert "" in result + + +class TestProtocolFormatParams: + """Tests for Protocol._format_params() edge cases.""" + + def test_format_params_long_list(self) -> None: + """Test _format_params truncates long lists.""" + protocol = Protocol() + params = {"values": [1, 2, 3, 4, 5, 6, 7]} + result = protocol._format_params(params) + # Should truncate to show first 2 items and count + assert "items" in result + assert "1, 2" in result + + def test_format_params_long_result(self) -> None: + """Test _format_params truncates overall result.""" + protocol = Protocol() + # Create params that would produce a very long string + params = { + "column1": "some_value", + "column2": "another_value", + "column3": "yet_another", + "column4": "and_more", + "column5": "even_more", + } + result = protocol._format_params(params, max_length=30) + assert len(result) <= 30 + assert result.endswith("...") + + def test_format_params_dict_without_type(self) -> None: + """Test _format_params with dict that has no 'type' key.""" + protocol = Protocol() + params = {"mapping": {"A": "Active", "B": "Blocked"}} + result = protocol._format_params(params) + # Should show {...} for dict without type + assert "{...}" in result diff --git a/tests/test_validation.py b/tests/test_validation.py index 56a3100..6e96ac9 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -352,3 +352,313 @@ def test_validate_filter_string_operation(self, basic_df: pl.DataFrame) -> None: result = plan.validate(basic_df) assert not result.is_valid assert "cannot use string filter" in str(result.errors[0]) + + +class TestDryRunResultEdgeCases: + """Tests for DryRunResult edge cases.""" + + def test_output_schema_empty_pipeline(self, basic_df: pl.DataFrame) -> None: + """Test output_schema for empty pipeline returns input schema.""" + plan = TransformPlan() + result = plan.dry_run(basic_df) + assert len(result.output_schema) == len(basic_df.columns) + + def test_summary_validation_failed(self, basic_df: pl.DataFrame) -> None: + """Test summary shows validation failed message.""" + plan = TransformPlan().col_drop("nonexistent") + result = plan.dry_run(basic_df) + summary = result.summary() + assert "FAILED" in summary + assert "1 errors" in summary or "1 error" in summary + + def test_summary_with_show_schema(self, basic_df: pl.DataFrame) -> None: + """Test summary with show_schema=True.""" + plan = TransformPlan().col_drop("age") + result = plan.dry_run(basic_df) + summary = result.summary(show_schema=True) + # Should show column types + assert "Int64" in summary or "Int" in summary + + def test_summary_with_modified_columns(self, basic_df: pl.DataFrame) -> None: + """Test summary shows modified columns correctly.""" + plan = TransformPlan().col_cast("age", pl.Float64) + result = plan.dry_run(basic_df) + summary = result.summary() + # Should show modified indicator ~ + assert "~" in summary or "age" in summary + + def test_summary_with_error_marker(self, basic_df: pl.DataFrame) -> None: + """Test summary shows error marker for failed steps.""" + plan = TransformPlan().col_drop("nonexistent") + result = plan.dry_run(basic_df) + summary = result.summary() + # Should show error marker + assert "✗" in summary + + def test_repr_invalid(self, basic_df: pl.DataFrame) -> None: + """Test repr for invalid result.""" + plan = TransformPlan().col_drop("nonexistent") + result = plan.dry_run(basic_df) + r = repr(result) + assert "invalid" in r + assert "1 errors" in r or "error" in r + + +class TestValidationEdgeCases: + """Tests for validation edge cases.""" + + def test_validate_col_drop_null_with_columns( + self, df_with_nulls: pl.DataFrame + ) -> None: + """Test col_drop_null validation with specified columns.""" + plan = TransformPlan().col_drop_null(columns=["name", "age"]) + result = plan.validate(df_with_nulls) + assert result.is_valid + + def test_validate_col_drop_null_missing_columns( + self, df_with_nulls: pl.DataFrame + ) -> None: + """Test col_drop_null validation with missing columns.""" + plan = TransformPlan().col_drop_null(columns=["nonexistent"]) + result = plan.validate(df_with_nulls) + assert not result.is_valid + + def test_validate_col_add_with_expr(self, basic_df: pl.DataFrame) -> None: + """Test col_add validation with source expression.""" + plan = TransformPlan().col_add("name_copy", expr="name") + result = plan.validate(basic_df) + assert result.is_valid + + def test_validate_col_add_duplicate_column(self, basic_df: pl.DataFrame) -> None: + """Test col_add validation when column already exists.""" + plan = TransformPlan().col_add("name", value="test") + result = plan.validate(basic_df) + assert not result.is_valid + assert "already exists" in str(result.errors[0]) + + def test_validate_col_add_uuid_duplicate(self, basic_df: pl.DataFrame) -> None: + """Test col_add_uuid validation when column already exists.""" + plan = TransformPlan().col_add_uuid("name") + result = plan.validate(basic_df) + assert not result.is_valid + assert "already exists" in str(result.errors[0]) + + def test_validate_col_hash_missing_columns(self, basic_df: pl.DataFrame) -> None: + """Test col_hash validation with missing columns.""" + plan = TransformPlan().col_hash(["name", "nonexistent"], "hash_col") + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_col_hash_duplicate_new_column( + self, basic_df: pl.DataFrame + ) -> None: + """Test col_hash validation when new column already exists.""" + plan = TransformPlan().col_hash(["name", "age"], "name") + result = plan.validate(basic_df) + assert not result.is_valid + assert "already exists" in str(result.errors[0]) + + def test_validate_col_coalesce_missing(self, basic_df: pl.DataFrame) -> None: + """Test col_coalesce validation with missing columns.""" + plan = TransformPlan().col_coalesce(["name", "nonexistent"], "result") + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_math_cumsum_missing_group_by( + self, basic_df: pl.DataFrame + ) -> None: + """Test math_cumsum validation with missing group_by columns.""" + plan = TransformPlan().math_cumsum( + "salary", new_column="cumsum", group_by=["nonexistent"] + ) + result = plan.validate(basic_df) + assert not result.is_valid + assert "Group-by columns do not exist" in str(result.errors[0]) + + def test_validate_math_rank_missing_group_by(self, basic_df: pl.DataFrame) -> None: + """Test math_rank validation with missing group_by columns.""" + plan = TransformPlan().math_rank( + "salary", new_column="rank", group_by=["nonexistent"] + ) + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_str_split_duplicate_columns( + self, string_df: pl.DataFrame + ) -> None: + """Test str_split validation when new columns already exist.""" + plan = TransformPlan().str_split("text", " ", new_columns=["text"]) + result = plan.validate(string_df) + assert not result.is_valid + assert "already exists" in str(result.errors[0]) + + def test_validate_rows_drop_nulls_missing_columns( + self, basic_df: pl.DataFrame + ) -> None: + """Test rows_drop_nulls validation with missing columns.""" + plan = TransformPlan().rows_drop_nulls(columns=["nonexistent"]) + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_rows_unique_missing_columns( + self, basic_df: pl.DataFrame + ) -> None: + """Test rows_unique validation with missing columns.""" + plan = TransformPlan().rows_unique(columns=["nonexistent"]) + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_rows_drop_filter(self, basic_df: pl.DataFrame) -> None: + """Test rows_drop validation with filter.""" + plan = TransformPlan().rows_drop(Col("nonexistent") > 10) + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_rows_flag_filter(self, basic_df: pl.DataFrame) -> None: + """Test rows_flag validation with filter.""" + plan = TransformPlan().rows_flag(Col("nonexistent") > 10, "flag") + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_rows_flag_duplicate_column(self, basic_df: pl.DataFrame) -> None: + """Test rows_flag validation when new column already exists.""" + plan = TransformPlan().rows_flag(Col("age") > 30, "name") + result = plan.validate(basic_df) + assert not result.is_valid + assert "already exists" in str(result.errors[0]) + + def test_validate_rows_sort_missing_columns(self, basic_df: pl.DataFrame) -> None: + """Test rows_sort validation with missing columns.""" + plan = TransformPlan().rows_sort("nonexistent") + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_rows_deduplicate_missing_sort( + self, duplicates_df: pl.DataFrame + ) -> None: + """Test rows_deduplicate validation with missing sort column.""" + plan = TransformPlan().rows_deduplicate(["id"], "nonexistent") + result = plan.validate(duplicates_df) + assert not result.is_valid + assert "does not exist" in str(result.errors[0]) + + def test_validate_rows_deduplicate_missing_columns( + self, duplicates_df: pl.DataFrame + ) -> None: + """Test rows_deduplicate validation with missing columns.""" + plan = TransformPlan().rows_deduplicate(["nonexistent"], "timestamp") + result = plan.validate(duplicates_df) + assert not result.is_valid + + def test_validate_rows_explode_non_list(self, basic_df: pl.DataFrame) -> None: + """Test rows_explode validation on non-list column.""" + plan = TransformPlan().rows_explode("name") + result = plan.validate(basic_df) + assert not result.is_valid + assert "expected List" in str(result.errors[0]) + + def test_validate_rows_melt_missing_id(self, wide_df: pl.DataFrame) -> None: + """Test rows_melt validation with missing id columns.""" + plan = TransformPlan().rows_melt( + id_columns=["nonexistent"], value_columns=["q1"] + ) + result = plan.validate(wide_df) + assert not result.is_valid + + def test_validate_rows_melt_missing_value(self, wide_df: pl.DataFrame) -> None: + """Test rows_melt validation with missing value columns.""" + plan = TransformPlan().rows_melt( + id_columns=["id"], value_columns=["nonexistent"] + ) + result = plan.validate(wide_df) + assert not result.is_valid + + def test_validate_rows_pivot_missing_index(self, long_df: pl.DataFrame) -> None: + """Test rows_pivot validation with missing index columns.""" + plan = TransformPlan().rows_pivot( + index=["nonexistent"], columns="quarter", values="value" + ) + result = plan.validate(long_df) + assert not result.is_valid + + def test_validate_rows_pivot_missing_columns(self, long_df: pl.DataFrame) -> None: + """Test rows_pivot validation with missing pivot column.""" + plan = TransformPlan().rows_pivot( + index=["id"], columns="nonexistent", values="value" + ) + result = plan.validate(long_df) + assert not result.is_valid + + def test_validate_rows_pivot_missing_values(self, long_df: pl.DataFrame) -> None: + """Test rows_pivot validation with missing values column.""" + plan = TransformPlan().rows_pivot( + index=["id"], columns="quarter", values="nonexistent" + ) + result = plan.validate(long_df) + assert not result.is_valid + + def test_validate_filter_not_operator(self, basic_df: pl.DataFrame) -> None: + """Test filter validation with Not operator.""" + plan = TransformPlan().rows_filter(~(Col("nonexistent") == 1)) + result = plan.validate(basic_df) + assert not result.is_valid + + def test_validate_col_select_missing_columns(self, basic_df: pl.DataFrame) -> None: + """Test col_select validation with missing columns.""" + plan = TransformPlan().col_select(["name", "nonexistent"]) + result = plan.validate(basic_df) + assert not result.is_valid + assert "Columns do not exist" in str(result.errors[0]) + + def test_validate_col_add_missing_expr(self, basic_df: pl.DataFrame) -> None: + """Test col_add validation with missing source expression.""" + plan = TransformPlan().col_add("new_col", expr="nonexistent") + result = plan.validate(basic_df) + assert not result.is_valid + assert "does not exist" in str(result.errors[0]) + + def test_input_schema_property(self, basic_df: pl.DataFrame) -> None: + """Test DryRunResult.input_schema property.""" + plan = TransformPlan().col_drop("age") + result = plan.dry_run(basic_df) + schema = result.input_schema + assert "age" in schema + assert schema["age"] == pl.Int64() + + +class TestDryRunSummaryFormatting: + """Tests for DryRunResult.summary() formatting edge cases.""" + + def test_summary_columns_added(self, basic_df: pl.DataFrame) -> None: + """Test summary shows columns added with + prefix.""" + plan = TransformPlan().col_add("new_col", value="test") + result = plan.dry_run(basic_df) + summary = result.summary() + assert "+['" in summary or "+" in summary + + def test_summary_filter_param(self, basic_df: pl.DataFrame) -> None: + """Test summary shows filter params as .""" + plan = TransformPlan().rows_filter(Col("age") > 30) + result = plan.dry_run(basic_df) + summary = result.summary(show_params=True) + assert "" in summary + + def test_summary_long_list_param(self) -> None: + """Test summary truncates long list params.""" + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [10, 11, 12], "e": [13, 14, 15]}) + plan = TransformPlan().col_select(["a", "b", "c", "d", "e"]) + result = plan.dry_run(df) + summary = result.summary(show_params=True) + # List of 5 items should be truncated + assert "items" in summary + + def test_summary_long_string_param(self) -> None: + """Test summary truncates long string params.""" + df = pl.DataFrame({"text": ["hello world"]}) + # Create a plan with a long string parameter + long_pattern = "this_is_a_very_long_pattern_that_should_be_truncated" + plan = TransformPlan().str_replace("text", long_pattern, "short") + result = plan.dry_run(df) + summary = result.summary(show_params=True) + # Long string should be truncated + assert "..." in summary or long_pattern[:17] in summary