diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index c4a6bf9c..02834218 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -10,7 +10,7 @@ import polars as pl from duckdb import DuckDBPyConnection -from agentune.analyze.join.base import JoinStrategy +from agentune.analyze.join.base import JoinStrategy, TablesWithJoinStrategies from agentune.core.database import DuckdbTable from agentune.core.dataset import Dataset from agentune.core.schema import Schema @@ -65,3 +65,55 @@ async def aformat_batch(self, input: Dataset, conn: DuckDBPyConnection) -> pl.Se ... + +@attrs.define +class TableFormatter(ABC, UseTypeTag): + """Abstract base class for formatting database table. + + Formats information about table (schemas and sample data) to string representations. + """ + + @abstractmethod + def format_table( + self, + sample_data: Dataset, + ) -> str: + """Format schema and sample data for a single table. + + Args: + sample_data: Dataset containing the sample data to format + Returns: + String representation of the table with its schema and sample data + """ + ... + + +@attrs.define +class TablesFormatter(ABC, UseTypeTag): + """Abstract base class for formatting database tables. + + Formats information about available database tables (schemas and sample data) to string representations. + """ + + primary_table_name: str = 'primary_table' + + @abstractmethod + def format_all_tables( + self, + input: Dataset, + tables: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + random_seed: int | None = None, + ) -> str: + """Format the primary dataset and all auxiliary tables with their schemas and sample. + + Args: + input: Input dataset (primary table) + tables: Available tables with their join strategies + conn: Database connection to query sample data + random_seed: Random seed for sampling (if applicable) + + Returns: + String representation of all tables with their schemas and sample data + """ + ... diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py new file mode 100644 index 00000000..f8ee2ff6 --- /dev/null +++ b/agentune/core/formatter/tables.py @@ -0,0 +1,199 @@ +"""Schema formatting for LLM prompts.""" + +from typing import override + +import attrs +import polars as pl +from duckdb import DuckDBPyConnection + +from agentune.analyze.join.base import TablesWithJoinStrategies +from agentune.core import types +from agentune.core.dataset import Dataset +from agentune.core.formatter.base import TableFormatter, TablesFormatter +from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler +from agentune.core.sampler.table_samples import HeadTableSampler +from agentune.core.schema import Schema + + +@attrs.frozen +class MarkdownTableFormatter(TableFormatter): + """Markdown table formatter. + + Formats a single table with its schema and sample data using markdown headers. + The schema is displayed as a bulleted list of columns with their DuckDB types, + and the sample data is formatted as CSV for readability. + + Args: + markdown_level: The markdown header level to use for sections (default: 3). + max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). + + Example: + ### Schema: + - id: INTEGER + - name: VARCHAR + - status: ENUM('active', 'inactive') + - score: DOUBLE + + ### Sample Data: + id,name,status,score + 1,Alice,active,95.5 + 2,Bob,inactive,82.3 + """ + markdown_level: int = 3 + max_str: int = 100 + + def _format_schema(self, schema: Schema) -> str: + """Format schema to human-readable string.""" + lines = [] + for field in schema.cols: + # Convert Dtype to simple string representation using duckdb_type + dtype_str = repr(field.dtype.duckdb_type) + lines.append(f'- {field.name}: {dtype_str}') + + return '\n'.join(lines) + + def _format_sample_data(self, dataset: Dataset) -> str: + """Format sample data rows as table using CSV format.""" + # Only truncate string columns using Polars + select_exprs = [] + for field in dataset.schema.cols: + col_name = field.name + # Check if column is a string or json type + if field.dtype in (types.string, types.json_dtype): + # Truncate long strings + select_exprs.append( + pl.when(pl.col(col_name).str.len_chars() > self.max_str) + .then(pl.col(col_name).str.slice(0, self.max_str) + '...') + .otherwise(pl.col(col_name)) + .alias(col_name) + ) + else: + select_exprs.append(pl.col(col_name)) + + truncated_data = dataset.data.select(select_exprs) + return truncated_data.write_csv() + + @override + def format_table( + self, + sample_data: Dataset, + ) -> str: + """Format schema and sample data for a single table. + + Includes markdown headers at the specified level. Formats the schema as a list + of columns with their DuckDB types, and formats the sample data as CSV. + + Args: + sample_data: Dataset containing the sample data to format + + Returns: + String representation of the table with its schema and sample data + """ + markdown_header = '#' * self.markdown_level + # Schema + out = [f'{markdown_header} Schema:'] + out.append(self._format_schema(sample_data.schema)) + out.append('') + + # Sample data + out.append(f'{markdown_header} Sample Data:') + out.append(self._format_sample_data(sample_data)) + out.append('') + return '\n'.join(out) + + +@attrs.frozen +class MarkdownTablesFormatter(TablesFormatter): + """Markdown tables formatter. + + Formats all available tables (primary and secondary) with their schemas and sample data + in markdown format. Each table includes a header with its name, followed by its schema + and sample data sections. + + Args: + markdown_level: The markdown header level to use for table sections (default: 2). + num_samples: Number of sample rows to retrieve for each table (default: 5). + max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). + table_formatter: TableFormatter to use for formatting individual tables. + Defaults to MarkdownTableFormatter with markdown_level + 1. + primary_dataset_sampler: DataSampler to use for sampling the primary dataset. + Defaults to RandomSampler for representative sampling. + tables_sampler: TableSampler to use for sampling the secondary tables. + Defaults to HeadTableSampler for consistent sampling. + + Example: + ## Primary Table: users + + ### Schema: + - id: INTEGER + - name: VARCHAR + - tier: ENUM('bronze', 'silver', 'gold') + + ### Sample Data: + id,name,tier + 1,Alice,gold + 2,Bob,silver + + ## Table: orders + + ### Schema: + - order_id: INTEGER + - user_id: INTEGER + - amount: DOUBLE + + ### Sample Data: + order_id,user_id,amount + 101,1,49.99 + 102,2,125.50 + """ + markdown_level: int = 2 + num_samples: int = 5 + max_str: int = 100 + table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1, max_str=self.max_str), takes_self=True)) + primary_dataset_sampler: DataSampler = RandomSampler() + tables_sampler: TableSampler = HeadTableSampler() + + @override + def format_all_tables( + self, + input: Dataset, + tables: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + random_seed: int | None = None, + ) -> str: + """Format all available tables with their schemas and sample data for LLM prompts. + + Formats the primary table followed by all secondary tables. Each table is formatted + with a header at the specified markdown level, followed by its schema and sample data + using the configured table formatter. + + Args: + input: Primary input dataset to format. + tables: Secondary tables with their join strategies to format. + conn: Database connection for querying sample data from secondary tables. + random_seed: Optional random seed for reproducible sampling. + + Returns: + Markdown-formatted string containing all tables with their schemas and sample data. + """ + sections = [] + markdown_header = '#' * self.markdown_level + # get sample data for primary table + sample_data = self.primary_dataset_sampler.sample(input, self.num_samples, random_seed=random_seed) + # Format primary table + sections.append(f'{markdown_header} Primary Table: {self.primary_table_name}\n') + sections.append(self.table_formatter.format_table(sample_data)) + # Format secondary tables + for table_with_strategies in tables: + # get sample data for the table + sample_data = self.tables_sampler.sample( + table_with_strategies.table.name, + conn, + self.num_samples, + random_seed=random_seed, + ) + # Format table section + sections.append(f'{markdown_header} Table: {table_with_strategies.table.name.name}\n') + sections.append(self.table_formatter.format_table(sample_data)) + + return '\n'.join(sections) diff --git a/tests/agentune/core/formatter/__init__.py b/tests/agentune/core/formatter/__init__.py new file mode 100644 index 00000000..33efb0e2 --- /dev/null +++ b/tests/agentune/core/formatter/__init__.py @@ -0,0 +1 @@ +"""Tests for formatter module.""" diff --git a/tests/agentune/core/formatter/test_tables.py b/tests/agentune/core/formatter/test_tables.py new file mode 100644 index 00000000..b428f36d --- /dev/null +++ b/tests/agentune/core/formatter/test_tables.py @@ -0,0 +1,116 @@ +"""Tests for table formatting.""" + +import polars as pl + +from agentune.core.dataset import Dataset +from agentune.core.formatter.tables import MarkdownTableFormatter +from agentune.core.schema import Field, Schema +from agentune.core.types import float64, int32, string + + +class TestMarkdownTableFormatter: + """Test MarkdownTableFormatter functionality.""" + + def test_string_truncation(self) -> None: + """Test that long string values are truncated with '...'.""" + formatter = MarkdownTableFormatter(max_str=20) + + # Create test data with a long string + data = pl.DataFrame({ + 'id': [1, 2, 3], + 'short': ['abc', 'def', 'ghi'], + 'long': [ + 'this is a very long string that should be truncated', + 'another extremely long text value here', + 'short' + ], + }) + + schema = Schema(( + Field('id', int32), + Field('short', string), + Field('long', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Verify truncated strings in correct column + assert lines[1].split(',')[2] == 'this is a very long ...' + assert lines[2].split(',')[2] == 'another extremely lo...' + assert lines[3].split(',')[2] == 'short' + + # Verify short strings preserved in correct column + assert lines[1].split(',')[1] == 'abc' + + def test_non_string_columns_preserved(self) -> None: + """Test that non-string columns are not truncated.""" + formatter = MarkdownTableFormatter(max_str=10) + + # Create test data with various types + data = pl.DataFrame({ + 'id': [1, 2, 3], + 'value': [123456789012345, 987654321098765, 111222333444555], + 'price': [99.99999999, 12345.6789, 0.123456789], + 'name': ['this is a very long name', 'short', 'another long name here'], + }) + + schema = Schema(( + Field('id', int32), + Field('value', int32), + Field('price', float64), + Field('name', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Verify columns: id preserved, value preserved, price preserved, name truncated + assert lines[1].split(',')[0] == '1' + assert lines[1].split(',')[1] == '123456789012345' + assert lines[1].split(',')[3] == 'this is a ...' + assert lines[2].split(',')[3] == 'short' + + def test_custom_max_str_length(self) -> None: + """Test that max_str parameter controls truncation length.""" + data = pl.DataFrame({ + 'text': ['12345678901234567890', 'short'], + }) + + schema = Schema((Field('text', string),)) + dataset = Dataset(schema=schema, data=data) + + # Test with max_str=10 + result_10 = MarkdownTableFormatter(max_str=10).format_table(dataset) + assert result_10.split('### Sample Data:')[1].strip().split('\n')[1] == '1234567890...' + + # Test with max_str=100 (no truncation) + result_100 = MarkdownTableFormatter(max_str=100).format_table(dataset) + assert result_100.split('### Sample Data:')[1].strip().split('\n')[1] == '12345678901234567890' + + def test_empty_dataset(self) -> None: + """Test formatting an empty dataset.""" + formatter = MarkdownTableFormatter(max_str=20) + + data = pl.DataFrame({ + 'id': pl.Series([], dtype=pl.Int32), + 'name': pl.Series([], dtype=pl.Utf8), + }) + + schema = Schema(( + Field('id', int32), + Field('name', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + csv_lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Should have only header, no data rows + assert len(csv_lines) == 1 + assert csv_lines[0] == 'id,name'