-
Notifications
You must be signed in to change notification settings - Fork 3
feat: Implement SchemaFormatter #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
leonidb
merged 14 commits into
main
from
feat/431-Create-SchemaFormatter-and-SimpleSchemaFormatter
Jan 14, 2026
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5e3682b
feat: Implement SchemaFormatter, SimpleSchemaFormatter and HeadSample…
yotam319-sparkbeyond 8bdaf7a
remove unnecessary tests
yotam319-sparkbeyond ff399b4
feat: Refactor schema serialization methods and remove obsolete tests
yotam319-sparkbeyond 91a6165
feat: Rename serialization methods to formatting for clarity
yotam319-sparkbeyond 24aa366
rename to TablesFormatter and SimpleTablesFormatter
yotam319-sparkbeyond dc0df16
mypy fix
yotam319-sparkbeyond 898aa5d
feat: Create TableFormatter and MarkdownTableFormatter for improved t…
yotam319-sparkbeyond 1a32253
add example to docstring
yotam319-sparkbeyond 1f1d72f
feat: Enhance MarkdownTableFormatter to truncate long string values u…
yotam319-sparkbeyond 1116742
ruff
yotam319-sparkbeyond 0c6add4
Correct string length check
yotam319-sparkbeyond 27ae8ee
update string type check
yotam319-sparkbeyond 2fb61a2
update
yotam319-sparkbeyond 021de57
ruff fix
yotam319-sparkbeyond File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| """Schema formatting for LLM prompts.""" | ||
|
|
||
| from typing import override | ||
|
|
||
| import attrs | ||
| import polars as pl | ||
| from duckdb import DuckDBPyConnection | ||
|
|
||
| from agentune.analyze.join.base import TablesWithJoinStrategies | ||
| from agentune.core import types | ||
| from agentune.core.dataset import Dataset | ||
| from agentune.core.formatter.base import TableFormatter, TablesFormatter | ||
| from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler | ||
| from agentune.core.sampler.table_samples import HeadTableSampler | ||
| from agentune.core.schema import Schema | ||
|
|
||
|
|
||
| @attrs.frozen | ||
| class MarkdownTableFormatter(TableFormatter): | ||
| """Markdown table formatter. | ||
|
|
||
| Formats a single table with its schema and sample data using markdown headers. | ||
| The schema is displayed as a bulleted list of columns with their DuckDB types, | ||
| and the sample data is formatted as CSV for readability. | ||
|
|
||
| Args: | ||
| markdown_level: The markdown header level to use for sections (default: 3). | ||
| max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). | ||
|
|
||
| Example: | ||
| ### Schema: | ||
| - id: INTEGER | ||
| - name: VARCHAR | ||
| - status: ENUM('active', 'inactive') | ||
| - score: DOUBLE | ||
|
|
||
| ### Sample Data: | ||
| id,name,status,score | ||
| 1,Alice,active,95.5 | ||
| 2,Bob,inactive,82.3 | ||
| """ | ||
| markdown_level: int = 3 | ||
| max_str: int = 100 | ||
|
|
||
| def _format_schema(self, schema: Schema) -> str: | ||
| """Format schema to human-readable string.""" | ||
| lines = [] | ||
| for field in schema.cols: | ||
| # Convert Dtype to simple string representation using duckdb_type | ||
| dtype_str = repr(field.dtype.duckdb_type) | ||
| lines.append(f'- {field.name}: {dtype_str}') | ||
|
|
||
| return '\n'.join(lines) | ||
|
|
||
| def _format_sample_data(self, dataset: Dataset) -> str: | ||
| """Format sample data rows as table using CSV format.""" | ||
| # Only truncate string columns using Polars | ||
| select_exprs = [] | ||
| for field in dataset.schema.cols: | ||
| col_name = field.name | ||
| # Check if column is a string or json type | ||
| if field.dtype in (types.string, types.json_dtype): | ||
| # Truncate long strings | ||
| select_exprs.append( | ||
| pl.when(pl.col(col_name).str.len_chars() > self.max_str) | ||
| .then(pl.col(col_name).str.slice(0, self.max_str) + '...') | ||
| .otherwise(pl.col(col_name)) | ||
| .alias(col_name) | ||
| ) | ||
| else: | ||
| select_exprs.append(pl.col(col_name)) | ||
|
|
||
| truncated_data = dataset.data.select(select_exprs) | ||
| return truncated_data.write_csv() | ||
|
|
||
| @override | ||
| def format_table( | ||
| self, | ||
| sample_data: Dataset, | ||
| ) -> str: | ||
| """Format schema and sample data for a single table. | ||
|
|
||
| Includes markdown headers at the specified level. Formats the schema as a list | ||
| of columns with their DuckDB types, and formats the sample data as CSV. | ||
|
|
||
| Args: | ||
| sample_data: Dataset containing the sample data to format | ||
|
|
||
| Returns: | ||
| String representation of the table with its schema and sample data | ||
| """ | ||
| markdown_header = '#' * self.markdown_level | ||
| # Schema | ||
| out = [f'{markdown_header} Schema:'] | ||
| out.append(self._format_schema(sample_data.schema)) | ||
| out.append('') | ||
|
|
||
| # Sample data | ||
| out.append(f'{markdown_header} Sample Data:') | ||
| out.append(self._format_sample_data(sample_data)) | ||
| out.append('') | ||
| return '\n'.join(out) | ||
|
|
||
|
|
||
| @attrs.frozen | ||
| class MarkdownTablesFormatter(TablesFormatter): | ||
| """Markdown tables formatter. | ||
|
|
||
| Formats all available tables (primary and secondary) with their schemas and sample data | ||
| in markdown format. Each table includes a header with its name, followed by its schema | ||
| and sample data sections. | ||
|
|
||
| Args: | ||
| markdown_level: The markdown header level to use for table sections (default: 2). | ||
| num_samples: Number of sample rows to retrieve for each table (default: 5). | ||
| max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). | ||
| table_formatter: TableFormatter to use for formatting individual tables. | ||
| Defaults to MarkdownTableFormatter with markdown_level + 1. | ||
| primary_dataset_sampler: DataSampler to use for sampling the primary dataset. | ||
| Defaults to RandomSampler for representative sampling. | ||
| tables_sampler: TableSampler to use for sampling the secondary tables. | ||
| Defaults to HeadTableSampler for consistent sampling. | ||
|
|
||
| Example: | ||
| ## Primary Table: users | ||
|
|
||
| ### Schema: | ||
| - id: INTEGER | ||
| - name: VARCHAR | ||
| - tier: ENUM('bronze', 'silver', 'gold') | ||
|
|
||
| ### Sample Data: | ||
| id,name,tier | ||
| 1,Alice,gold | ||
| 2,Bob,silver | ||
|
|
||
| ## Table: orders | ||
|
|
||
| ### Schema: | ||
| - order_id: INTEGER | ||
| - user_id: INTEGER | ||
| - amount: DOUBLE | ||
|
|
||
| ### Sample Data: | ||
| order_id,user_id,amount | ||
| 101,1,49.99 | ||
| 102,2,125.50 | ||
| """ | ||
| markdown_level: int = 2 | ||
| num_samples: int = 5 | ||
| max_str: int = 100 | ||
| table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1, max_str=self.max_str), takes_self=True)) | ||
| primary_dataset_sampler: DataSampler = RandomSampler() | ||
| tables_sampler: TableSampler = HeadTableSampler() | ||
|
|
||
| @override | ||
| def format_all_tables( | ||
| self, | ||
| input: Dataset, | ||
| tables: TablesWithJoinStrategies, | ||
| conn: DuckDBPyConnection, | ||
| random_seed: int | None = None, | ||
| ) -> str: | ||
| """Format all available tables with their schemas and sample data for LLM prompts. | ||
|
|
||
| Formats the primary table followed by all secondary tables. Each table is formatted | ||
| with a header at the specified markdown level, followed by its schema and sample data | ||
| using the configured table formatter. | ||
|
|
||
| Args: | ||
| input: Primary input dataset to format. | ||
| tables: Secondary tables with their join strategies to format. | ||
| conn: Database connection for querying sample data from secondary tables. | ||
| random_seed: Optional random seed for reproducible sampling. | ||
|
|
||
| Returns: | ||
| Markdown-formatted string containing all tables with their schemas and sample data. | ||
| """ | ||
| sections = [] | ||
| markdown_header = '#' * self.markdown_level | ||
| # get sample data for primary table | ||
| sample_data = self.primary_dataset_sampler.sample(input, self.num_samples, random_seed=random_seed) | ||
| # Format primary table | ||
| sections.append(f'{markdown_header} Primary Table: {self.primary_table_name}\n') | ||
| sections.append(self.table_formatter.format_table(sample_data)) | ||
| # Format secondary tables | ||
| for table_with_strategies in tables: | ||
| # get sample data for the table | ||
| sample_data = self.tables_sampler.sample( | ||
| table_with_strategies.table.name, | ||
| conn, | ||
| self.num_samples, | ||
| random_seed=random_seed, | ||
| ) | ||
| # Format table section | ||
| sections.append(f'{markdown_header} Table: {table_with_strategies.table.name.name}\n') | ||
| sections.append(self.table_formatter.format_table(sample_data)) | ||
|
|
||
| return '\n'.join(sections) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Tests for formatter module.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| """Tests for table formatting.""" | ||
|
|
||
| import polars as pl | ||
|
|
||
| from agentune.core.dataset import Dataset | ||
| from agentune.core.formatter.tables import MarkdownTableFormatter | ||
| from agentune.core.schema import Field, Schema | ||
| from agentune.core.types import float64, int32, string | ||
|
|
||
|
|
||
| class TestMarkdownTableFormatter: | ||
| """Test MarkdownTableFormatter functionality.""" | ||
|
|
||
| def test_string_truncation(self) -> None: | ||
| """Test that long string values are truncated with '...'.""" | ||
| formatter = MarkdownTableFormatter(max_str=20) | ||
|
|
||
| # Create test data with a long string | ||
| data = pl.DataFrame({ | ||
| 'id': [1, 2, 3], | ||
| 'short': ['abc', 'def', 'ghi'], | ||
| 'long': [ | ||
| 'this is a very long string that should be truncated', | ||
| 'another extremely long text value here', | ||
| 'short' | ||
| ], | ||
| }) | ||
|
|
||
| schema = Schema(( | ||
| Field('id', int32), | ||
| Field('short', string), | ||
| Field('long', string), | ||
| )) | ||
|
|
||
| dataset = Dataset(schema=schema, data=data) | ||
|
|
||
| result = formatter.format_table(dataset) | ||
| lines = result.split('### Sample Data:')[1].strip().split('\n') | ||
|
|
||
| # Verify truncated strings in correct column | ||
| assert lines[1].split(',')[2] == 'this is a very long ...' | ||
| assert lines[2].split(',')[2] == 'another extremely lo...' | ||
| assert lines[3].split(',')[2] == 'short' | ||
|
|
||
| # Verify short strings preserved in correct column | ||
| assert lines[1].split(',')[1] == 'abc' | ||
|
|
||
| def test_non_string_columns_preserved(self) -> None: | ||
| """Test that non-string columns are not truncated.""" | ||
| formatter = MarkdownTableFormatter(max_str=10) | ||
|
|
||
| # Create test data with various types | ||
| data = pl.DataFrame({ | ||
| 'id': [1, 2, 3], | ||
| 'value': [123456789012345, 987654321098765, 111222333444555], | ||
| 'price': [99.99999999, 12345.6789, 0.123456789], | ||
| 'name': ['this is a very long name', 'short', 'another long name here'], | ||
| }) | ||
|
|
||
| schema = Schema(( | ||
| Field('id', int32), | ||
| Field('value', int32), | ||
| Field('price', float64), | ||
| Field('name', string), | ||
| )) | ||
|
|
||
| dataset = Dataset(schema=schema, data=data) | ||
|
|
||
| result = formatter.format_table(dataset) | ||
| lines = result.split('### Sample Data:')[1].strip().split('\n') | ||
|
|
||
| # Verify columns: id preserved, value preserved, price preserved, name truncated | ||
| assert lines[1].split(',')[0] == '1' | ||
| assert lines[1].split(',')[1] == '123456789012345' | ||
| assert lines[1].split(',')[3] == 'this is a ...' | ||
| assert lines[2].split(',')[3] == 'short' | ||
|
|
||
| def test_custom_max_str_length(self) -> None: | ||
| """Test that max_str parameter controls truncation length.""" | ||
| data = pl.DataFrame({ | ||
| 'text': ['12345678901234567890', 'short'], | ||
| }) | ||
|
|
||
| schema = Schema((Field('text', string),)) | ||
| dataset = Dataset(schema=schema, data=data) | ||
|
|
||
| # Test with max_str=10 | ||
| result_10 = MarkdownTableFormatter(max_str=10).format_table(dataset) | ||
| assert result_10.split('### Sample Data:')[1].strip().split('\n')[1] == '1234567890...' | ||
|
|
||
| # Test with max_str=100 (no truncation) | ||
| result_100 = MarkdownTableFormatter(max_str=100).format_table(dataset) | ||
| assert result_100.split('### Sample Data:')[1].strip().split('\n')[1] == '12345678901234567890' | ||
|
|
||
| def test_empty_dataset(self) -> None: | ||
| """Test formatting an empty dataset.""" | ||
| formatter = MarkdownTableFormatter(max_str=20) | ||
|
|
||
| data = pl.DataFrame({ | ||
| 'id': pl.Series([], dtype=pl.Int32), | ||
| 'name': pl.Series([], dtype=pl.Utf8), | ||
| }) | ||
|
|
||
| schema = Schema(( | ||
| Field('id', int32), | ||
| Field('name', string), | ||
| )) | ||
|
|
||
| dataset = Dataset(schema=schema, data=data) | ||
|
|
||
| result = formatter.format_table(dataset) | ||
| csv_lines = result.split('### Sample Data:')[1].strip().split('\n') | ||
|
|
||
| # Should have only header, no data rows | ||
| assert len(csv_lines) == 1 | ||
| assert csv_lines[0] == 'id,name' |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.