From 5e3682b154355cd331a1db5dba6f8a681c0c3a1b Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Thu, 8 Jan 2026 17:40:15 +0200 Subject: [PATCH 01/14] feat: Implement SchemaFormatter, SimpleSchemaFormatter and HeadSampler for schema formatting and sampling --- agentune/core/formatter/__init__.py | 10 + agentune/core/formatter/base.py | 44 ++- agentune/core/formatter/schema.py | 76 +++++ agentune/core/sampler/base.py | 24 ++ tests/agentune/core/formatter/__init__.py | 1 + tests/agentune/core/formatter/test_schema.py | 309 +++++++++++++++++++ 6 files changed, 463 insertions(+), 1 deletion(-) create mode 100644 agentune/core/formatter/schema.py create mode 100644 tests/agentune/core/formatter/__init__.py create mode 100644 tests/agentune/core/formatter/test_schema.py diff --git a/agentune/core/formatter/__init__.py b/agentune/core/formatter/__init__.py index e69de29b..79792dcc 100644 --- a/agentune/core/formatter/__init__.py +++ b/agentune/core/formatter/__init__.py @@ -0,0 +1,10 @@ +"""Data formatting module.""" + +from agentune.core.formatter.base import DataFormatter, SchemaFormatter +from agentune.core.formatter.schema import SimpleSchemaFormatter + +__all__ = [ + 'DataFormatter', + 'SchemaFormatter', + 'SimpleSchemaFormatter', +] diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index c4a6bf9c..e6aeacdf 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -10,7 +10,7 @@ import polars as pl from duckdb import DuckDBPyConnection -from agentune.analyze.join.base import JoinStrategy +from agentune.analyze.join.base import JoinStrategy, TablesWithJoinStrategies from agentune.core.database import DuckdbTable from agentune.core.dataset import Dataset from agentune.core.schema import Schema @@ -65,3 +65,45 @@ async def aformat_batch(self, input: Dataset, conn: DuckDBPyConnection) -> pl.Se ... +@attrs.define +class SchemaFormatter(ABC, UseTypeTag): + """Abstract base class for schema formatting strategies. + + Formats information about available database tables (schemas and sample data) for use in LLM prompts. + """ + + primary_table_name: str = 'primary_table' + + def _serialize_schema(self, schema: Schema) -> str: + """Serialize schema to human-readable string for LLM prompts.""" + lines = [] + for field in schema.cols: + # Convert Dtype to simple string representation + dtype_str = repr(field.dtype.polars_type) + lines.append(f'- {field.name}: {dtype_str}') + + return '\n'.join(lines) + + def _format_sample_data(self, dataset: Dataset) -> str: + """Format sample data rows as table for LLM prompts.""" + return dataset.data.write_csv() + + @abstractmethod + def format_all_tables( + self, + input: Dataset, + tables: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + random_seed: int | None = None, + ) -> str: + """Format the primary dataset and all auxiliary tables with their schemas and sample data for LLM prompts. + + Args: + input: Input dataset (primary table) + tables: Available tables with their join strategies + conn: Database connection to query sample data + + Returns: + String representation of all tables with their schemas and sample data + """ + ... diff --git a/agentune/core/formatter/schema.py b/agentune/core/formatter/schema.py new file mode 100644 index 00000000..2835ab80 --- /dev/null +++ b/agentune/core/formatter/schema.py @@ -0,0 +1,76 @@ +"""Schema formatting for LLM prompts.""" + +from typing import override + +import attrs +from duckdb import DuckDBPyConnection + +from agentune.analyze.join.base import TablesWithJoinStrategies +from agentune.core.dataset import Dataset +from agentune.core.formatter.base import SchemaFormatter +from agentune.core.sampler.base import DataSampler, HeadSampler + + +@attrs.frozen +class SimpleSchemaFormatter(SchemaFormatter): + """Simple schema formatter for LLM prompts. + + Formats all available tables with their schemas and sample data. + Each table is formatted with: + - Table name + - Schema (list of columns with types) + - Sample data (CSV format) + """ + num_samples: int = 5 + sampler: DataSampler = HeadSampler() + + @override + def format_all_tables( + self, + input: Dataset, + tables: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + random_seed: int | None = None, + ) -> str: + """Format all available tables with their schemas and sample data for LLM prompts. + + Args: + input: Input dataset (primary table) + tables: Available tables with their join strategies + conn: Database connection to query sample data + + Returns: + String representation of all tables with their schemas and sample data + """ + sections = [] + + # Format primary table + sections.append('## Primary Table: ' + self.primary_table_name) + sections.append('') + sections.append('### Schema:') + sections.append(self._serialize_schema(input.schema)) + sections.append('') + sections.append(f'### Sample Data ({self.num_samples} rows):') + sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) + sections.append(self._format_sample_data(sample_data)) + sections.append('') + + # Format secondary tables + for table_with_strategies in tables: + table = table_with_strategies.table + sections.append(f'## Table: {table.name.name}') + sections.append('') + + # Schema + sections.append('### Schema:') + sections.append(self._serialize_schema(table.schema)) + sections.append('') + + # Sample data + sections.append(f'### Sample Data ({self.num_samples} rows):') + dataset = table.as_source().to_dataset(conn) + sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) + sections.append(self._format_sample_data(sample_data)) + sections.append('') + + return '\n'.join(sections) diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 1812dc8b..32f8e5e0 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -61,3 +61,27 @@ def sample(self, dataset: Dataset, sample_size: int, random_seed: int | None = N schema=dataset.schema, data=sampled_df, ) + +@attrs.define +class HeadSampler(DataSampler): + """Simple head sampling. + + This sampler selects the first `sample_size` rows from the dataset. + """ + + @override + def sample(self, dataset: Dataset, sample_size: int, random_seed: int | None = None) -> Dataset: + """Sample data using head sampling.""" + self._validate_inputs(dataset, sample_size) + + # If sample size >= dataset size, return the entire dataset + if sample_size >= dataset.data.height: + return dataset + + # Head sample + sampled_df = dataset.data.head(sample_size) + + return Dataset( + schema=dataset.schema, + data=sampled_df, + ) diff --git a/tests/agentune/core/formatter/__init__.py b/tests/agentune/core/formatter/__init__.py new file mode 100644 index 00000000..85f6b103 --- /dev/null +++ b/tests/agentune/core/formatter/__init__.py @@ -0,0 +1 @@ +"""Tests for formatter modules.""" diff --git a/tests/agentune/core/formatter/test_schema.py b/tests/agentune/core/formatter/test_schema.py new file mode 100644 index 00000000..b9981941 --- /dev/null +++ b/tests/agentune/core/formatter/test_schema.py @@ -0,0 +1,309 @@ +"""Tests for schema formatters.""" + +import polars as pl +import pytest +from duckdb import DuckDBPyConnection + +from agentune.analyze.join.base import ( + TablesWithJoinStrategies, + TableWithJoinStrategies, +) +from agentune.analyze.join.lookup import LookupJoinStrategy +from agentune.core import types +from agentune.core.database import DuckdbName, DuckdbTable +from agentune.core.dataset import Dataset +from agentune.core.formatter.schema import SimpleSchemaFormatter +from agentune.core.sampler.base import RandomSampler +from agentune.core.schema import Field, Schema + + +@pytest.fixture +def primary_dataset() -> Dataset: + """Create a primary dataset for testing.""" + data = pl.DataFrame( + { + 'id': [1, 2, 3, 4, 5], + 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], + 'age': [25, 30, 35, 28, 32], + } + ) + return Dataset.from_polars(data) + + +@pytest.fixture +def secondary_table(conn: DuckDBPyConnection) -> DuckdbTable: + """Create a secondary table in the database.""" + table_name = DuckdbName('orders', 'memory', 'main') + schema = Schema( + cols=( + Field('order_id', types.int32), + Field('customer_id', types.int32), + Field('amount', types.float64), + ) + ) + + table = DuckdbTable(name=table_name, schema=schema) + table.create(conn) + + # Insert sample data + conn.execute( + f'INSERT INTO {table_name} VALUES (101, 1, 100.5), (102, 2, 250.0), (103, 1, 75.25)' + ) + + return table + + +@pytest.fixture +def tables_with_strategies(secondary_table: DuckdbTable) -> TablesWithJoinStrategies: + """Create tables with join strategies for testing.""" + join_strategy: LookupJoinStrategy[int] = LookupJoinStrategy( + name='mock_join', + table=secondary_table, + key_col=Field('customer_id', types.int32), + value_cols=(Field('order_id', types.int32), Field('amount', types.float64)), + ) + table_with_strategies = TableWithJoinStrategies.from_list([join_strategy]) + return TablesWithJoinStrategies.from_list([table_with_strategies]) + + +class TestSimpleSchemaFormatter: + """Tests for SimpleSchemaFormatter.""" + + def test_format_all_tables_basic( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test basic formatting of all tables.""" + formatter = SimpleSchemaFormatter() + result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) + + # Check that the result contains expected sections + assert '## Primary Table: primary_table' in result + assert '### Schema:' in result + assert '### Sample Data' in result + + # Check primary table columns are present + assert '- id:' in result + assert '- name:' in result + assert '- age:' in result + + # Check secondary table is present + assert '## Table: orders' in result + assert '- order_id:' in result + assert '- customer_id:' in result + assert '- amount:' in result + + def test_format_all_tables_custom_primary_name( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with custom primary table name.""" + formatter = SimpleSchemaFormatter(primary_table_name='users') + result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) + + assert '## Primary Table: users' in result + assert '## Primary Table: primary_table' not in result + + def test_format_all_tables_custom_num_samples( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with custom number of samples.""" + formatter = SimpleSchemaFormatter(num_samples=3) + result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) + + # Check that sample size is mentioned + assert '### Sample Data (3 rows):' in result + + # Count the number of data rows in the primary table sample + # The CSV output should have a header and 3 data rows + primary_section = result.split('## Table: orders')[0] + csv_lines = primary_section.split('### Sample Data (3 rows):')[1].strip().split('\n') + # Filter out empty lines + csv_lines = [line for line in csv_lines if line.strip()] + # Should have header + 3 data rows + assert len(csv_lines) == 4 + + def test_format_all_tables_with_random_sampler( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with random sampler.""" + formatter = SimpleSchemaFormatter( + num_samples=3, + sampler=RandomSampler(), + ) + result = formatter.format_all_tables( + primary_dataset, tables_with_strategies, conn, random_seed=42 + ) + + # Should still produce valid output + assert '## Primary Table:' in result + assert '### Sample Data (3 rows):' in result + + def test_format_all_tables_no_secondary_tables( + self, + primary_dataset: Dataset, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with only primary table.""" + empty_tables = TablesWithJoinStrategies.from_list([]) + formatter = SimpleSchemaFormatter() + result = formatter.format_all_tables(primary_dataset, empty_tables, conn) + + # Should have primary table + assert '## Primary Table: primary_table' in result + assert '- id:' in result + + # Should not have secondary table markers + assert result.count('## Table:') == 0 + + def test_serialize_schema(self, primary_dataset: Dataset) -> None: + """Test schema serialization.""" + formatter = SimpleSchemaFormatter() + schema_str = formatter._serialize_schema(primary_dataset.schema) + + # Check that all columns are present + assert '- id:' in schema_str + assert '- name:' in schema_str + assert '- age:' in schema_str + + # Check format + lines = schema_str.strip().split('\n') + assert len(lines) == 3 # Three columns + for line in lines: + assert line.startswith('- ') + assert ': ' in line + + def test_format_sample_data(self, primary_dataset: Dataset) -> None: + """Test sample data formatting.""" + formatter = SimpleSchemaFormatter() + sample_data = formatter.sampler.sample(primary_dataset, 2) + formatted = formatter._format_sample_data(sample_data) + + # Should be CSV format + assert 'id,name,age' in formatted or 'id,' in formatted + lines = formatted.strip().split('\n') + # Header + 2 data rows + assert len(lines) == 3 + + def test_format_all_tables_large_dataset( + self, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with dataset larger than sample size.""" + # Create a larger dataset + large_data = pl.DataFrame( + { + 'id': range(100), + 'value': [f'value_{i}' for i in range(100)], + } + ) + large_dataset = Dataset.from_polars(large_data) + + formatter = SimpleSchemaFormatter(num_samples=5) + result = formatter.format_all_tables(large_dataset, tables_with_strategies, conn) + + # Should only sample 5 rows + assert '### Sample Data (5 rows):' in result + + # The CSV output should have limited rows + primary_section = result.split('## Table:')[0] + assert 'value_99' not in primary_section # Should not contain the last row + + def test_format_all_tables_multiple_secondary_tables( + self, + primary_dataset: Dataset, + conn: DuckDBPyConnection, + ) -> None: + """Test formatting with multiple secondary tables.""" + # Create first secondary table + table1_name = DuckdbName('orders', 'memory', 'main') + table1_schema = Schema(cols=(Field('order_id', types.int32),)) + table1 = DuckdbTable(name=table1_name, schema=table1_schema) + table1.create(conn) + conn.execute(f'INSERT INTO {table1_name} VALUES (1), (2)') + + # Create second secondary table + table2_name = DuckdbName('products', 'memory', 'main') + table2_schema = Schema(cols=(Field('product_id', types.int32),)) + table2 = DuckdbTable(name=table2_name, schema=table2_schema) + table2.create(conn) + conn.execute(f'INSERT INTO {table2_name} VALUES (10), (20)') + + # Create join strategies + strategy1: LookupJoinStrategy[int] = LookupJoinStrategy( + name='join1', + table=table1, + key_col=Field('order_id', types.int32), + value_cols=(), + ) + strategy2: LookupJoinStrategy[int] = LookupJoinStrategy( + name='join2', + table=table2, + key_col=Field('product_id', types.int32), + value_cols=(), + ) + + tables = TablesWithJoinStrategies.from_list([ + TableWithJoinStrategies.from_list([strategy1]), + TableWithJoinStrategies.from_list([strategy2]), + ]) + + formatter = SimpleSchemaFormatter() + result = formatter.format_all_tables(primary_dataset, tables, conn) + + # Check both secondary tables are present + assert '## Table: orders' in result + assert '## Table: products' in result + assert '- order_id:' in result + assert '- product_id:' in result + + def test_format_preserves_order( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test that formatting preserves expected section order.""" + formatter = SimpleSchemaFormatter() + result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) + + # Find positions of key sections + primary_pos = result.find('## Primary Table:') + primary_schema_pos = result.find('### Schema:', primary_pos) + primary_sample_pos = result.find('### Sample Data', primary_schema_pos) + secondary_pos = result.find('## Table:', primary_sample_pos) + + # Verify order + assert primary_pos < primary_schema_pos < primary_sample_pos < secondary_pos + + def test_format_output_example( + self, + primary_dataset: Dataset, + tables_with_strategies: TablesWithJoinStrategies, + conn: DuckDBPyConnection, + ) -> None: + """Test to display the actual formatted output for visual inspection.""" + formatter = SimpleSchemaFormatter(num_samples=3) + result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) + + # Print the actual output for inspection + print('\n' + '=' * 80) + print('FORMATTED OUTPUT:') + print('=' * 80) + print(result) + print('=' * 80) + + # Basic assertions to ensure test passes + assert len(result) > 0 + assert '## Primary Table:' in result From 8bdaf7a18dd5613d1dae8116711086a6d547e80f Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Thu, 8 Jan 2026 17:51:14 +0200 Subject: [PATCH 02/14] remove unnecessary tests --- agentune/core/formatter/__init__.py | 10 - tests/agentune/core/formatter/test_schema.py | 193 ------------------- 2 files changed, 203 deletions(-) diff --git a/agentune/core/formatter/__init__.py b/agentune/core/formatter/__init__.py index 79792dcc..e69de29b 100644 --- a/agentune/core/formatter/__init__.py +++ b/agentune/core/formatter/__init__.py @@ -1,10 +0,0 @@ -"""Data formatting module.""" - -from agentune.core.formatter.base import DataFormatter, SchemaFormatter -from agentune.core.formatter.schema import SimpleSchemaFormatter - -__all__ = [ - 'DataFormatter', - 'SchemaFormatter', - 'SimpleSchemaFormatter', -] diff --git a/tests/agentune/core/formatter/test_schema.py b/tests/agentune/core/formatter/test_schema.py index b9981941..239f1eec 100644 --- a/tests/agentune/core/formatter/test_schema.py +++ b/tests/agentune/core/formatter/test_schema.py @@ -13,7 +13,6 @@ from agentune.core.database import DuckdbName, DuckdbTable from agentune.core.dataset import Dataset from agentune.core.formatter.schema import SimpleSchemaFormatter -from agentune.core.sampler.base import RandomSampler from agentune.core.schema import Field, Schema @@ -95,198 +94,6 @@ def test_format_all_tables_basic( assert '- customer_id:' in result assert '- amount:' in result - def test_format_all_tables_custom_primary_name( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with custom primary table name.""" - formatter = SimpleSchemaFormatter(primary_table_name='users') - result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) - - assert '## Primary Table: users' in result - assert '## Primary Table: primary_table' not in result - - def test_format_all_tables_custom_num_samples( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with custom number of samples.""" - formatter = SimpleSchemaFormatter(num_samples=3) - result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) - - # Check that sample size is mentioned - assert '### Sample Data (3 rows):' in result - - # Count the number of data rows in the primary table sample - # The CSV output should have a header and 3 data rows - primary_section = result.split('## Table: orders')[0] - csv_lines = primary_section.split('### Sample Data (3 rows):')[1].strip().split('\n') - # Filter out empty lines - csv_lines = [line for line in csv_lines if line.strip()] - # Should have header + 3 data rows - assert len(csv_lines) == 4 - - def test_format_all_tables_with_random_sampler( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with random sampler.""" - formatter = SimpleSchemaFormatter( - num_samples=3, - sampler=RandomSampler(), - ) - result = formatter.format_all_tables( - primary_dataset, tables_with_strategies, conn, random_seed=42 - ) - - # Should still produce valid output - assert '## Primary Table:' in result - assert '### Sample Data (3 rows):' in result - - def test_format_all_tables_no_secondary_tables( - self, - primary_dataset: Dataset, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with only primary table.""" - empty_tables = TablesWithJoinStrategies.from_list([]) - formatter = SimpleSchemaFormatter() - result = formatter.format_all_tables(primary_dataset, empty_tables, conn) - - # Should have primary table - assert '## Primary Table: primary_table' in result - assert '- id:' in result - - # Should not have secondary table markers - assert result.count('## Table:') == 0 - - def test_serialize_schema(self, primary_dataset: Dataset) -> None: - """Test schema serialization.""" - formatter = SimpleSchemaFormatter() - schema_str = formatter._serialize_schema(primary_dataset.schema) - - # Check that all columns are present - assert '- id:' in schema_str - assert '- name:' in schema_str - assert '- age:' in schema_str - - # Check format - lines = schema_str.strip().split('\n') - assert len(lines) == 3 # Three columns - for line in lines: - assert line.startswith('- ') - assert ': ' in line - - def test_format_sample_data(self, primary_dataset: Dataset) -> None: - """Test sample data formatting.""" - formatter = SimpleSchemaFormatter() - sample_data = formatter.sampler.sample(primary_dataset, 2) - formatted = formatter._format_sample_data(sample_data) - - # Should be CSV format - assert 'id,name,age' in formatted or 'id,' in formatted - lines = formatted.strip().split('\n') - # Header + 2 data rows - assert len(lines) == 3 - - def test_format_all_tables_large_dataset( - self, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with dataset larger than sample size.""" - # Create a larger dataset - large_data = pl.DataFrame( - { - 'id': range(100), - 'value': [f'value_{i}' for i in range(100)], - } - ) - large_dataset = Dataset.from_polars(large_data) - - formatter = SimpleSchemaFormatter(num_samples=5) - result = formatter.format_all_tables(large_dataset, tables_with_strategies, conn) - - # Should only sample 5 rows - assert '### Sample Data (5 rows):' in result - - # The CSV output should have limited rows - primary_section = result.split('## Table:')[0] - assert 'value_99' not in primary_section # Should not contain the last row - - def test_format_all_tables_multiple_secondary_tables( - self, - primary_dataset: Dataset, - conn: DuckDBPyConnection, - ) -> None: - """Test formatting with multiple secondary tables.""" - # Create first secondary table - table1_name = DuckdbName('orders', 'memory', 'main') - table1_schema = Schema(cols=(Field('order_id', types.int32),)) - table1 = DuckdbTable(name=table1_name, schema=table1_schema) - table1.create(conn) - conn.execute(f'INSERT INTO {table1_name} VALUES (1), (2)') - - # Create second secondary table - table2_name = DuckdbName('products', 'memory', 'main') - table2_schema = Schema(cols=(Field('product_id', types.int32),)) - table2 = DuckdbTable(name=table2_name, schema=table2_schema) - table2.create(conn) - conn.execute(f'INSERT INTO {table2_name} VALUES (10), (20)') - - # Create join strategies - strategy1: LookupJoinStrategy[int] = LookupJoinStrategy( - name='join1', - table=table1, - key_col=Field('order_id', types.int32), - value_cols=(), - ) - strategy2: LookupJoinStrategy[int] = LookupJoinStrategy( - name='join2', - table=table2, - key_col=Field('product_id', types.int32), - value_cols=(), - ) - - tables = TablesWithJoinStrategies.from_list([ - TableWithJoinStrategies.from_list([strategy1]), - TableWithJoinStrategies.from_list([strategy2]), - ]) - - formatter = SimpleSchemaFormatter() - result = formatter.format_all_tables(primary_dataset, tables, conn) - - # Check both secondary tables are present - assert '## Table: orders' in result - assert '## Table: products' in result - assert '- order_id:' in result - assert '- product_id:' in result - - def test_format_preserves_order( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test that formatting preserves expected section order.""" - formatter = SimpleSchemaFormatter() - result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) - - # Find positions of key sections - primary_pos = result.find('## Primary Table:') - primary_schema_pos = result.find('### Schema:', primary_pos) - primary_sample_pos = result.find('### Sample Data', primary_schema_pos) - secondary_pos = result.find('## Table:', primary_sample_pos) - - # Verify order - assert primary_pos < primary_schema_pos < primary_sample_pos < secondary_pos - def test_format_output_example( self, primary_dataset: Dataset, From ff399b4e834a5e47de70d4c1004231a500232017 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Sun, 11 Jan 2026 11:19:09 +0200 Subject: [PATCH 03/14] feat: Refactor schema serialization methods and remove obsolete tests --- agentune/core/formatter/base.py | 2 +- agentune/core/formatter/schema.py | 45 +++---- tests/agentune/core/formatter/__init__.py | 1 - tests/agentune/core/formatter/test_schema.py | 116 ------------------- 4 files changed, 24 insertions(+), 140 deletions(-) delete mode 100644 tests/agentune/core/formatter/__init__.py delete mode 100644 tests/agentune/core/formatter/test_schema.py diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index e6aeacdf..77f28e89 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -74,7 +74,7 @@ class SchemaFormatter(ABC, UseTypeTag): primary_table_name: str = 'primary_table' - def _serialize_schema(self, schema: Schema) -> str: + def _format_schema(self, schema: Schema) -> str: """Serialize schema to human-readable string for LLM prompts.""" lines = [] for field in schema.cols: diff --git a/agentune/core/formatter/schema.py b/agentune/core/formatter/schema.py index 2835ab80..eba03353 100644 --- a/agentune/core/formatter/schema.py +++ b/agentune/core/formatter/schema.py @@ -9,6 +9,7 @@ from agentune.core.dataset import Dataset from agentune.core.formatter.base import SchemaFormatter from agentune.core.sampler.base import DataSampler, HeadSampler +from agentune.core.schema import Schema @attrs.frozen @@ -24,6 +25,19 @@ class SimpleSchemaFormatter(SchemaFormatter): num_samples: int = 5 sampler: DataSampler = HeadSampler() + def _serialize_schema_and_samples(self, schema: Schema, sample_data: Dataset) -> str: + """Serialize schema and sample data for a single table.""" + # Schema + out = ['### Schema:'] + out.append(self._format_schema(schema)) + out.append('') + + # Sample data + out.append(f'### Sample Data ({self.num_samples} rows):') + out.append(self._format_sample_data(sample_data)) + out.append('') + return '\n'.join(out) + @override def format_all_tables( self, @@ -44,33 +58,20 @@ def format_all_tables( """ sections = [] - # Format primary table - sections.append('## Primary Table: ' + self.primary_table_name) - sections.append('') - sections.append('### Schema:') - sections.append(self._serialize_schema(input.schema)) - sections.append('') - sections.append(f'### Sample Data ({self.num_samples} rows):') + # get sample data for primary table sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) - sections.append(self._format_sample_data(sample_data)) - sections.append('') - + # Format primary table + sections.append(f'## Primary Table: {self.primary_table_name}\n') + sections.append(self._serialize_schema_and_samples(input.schema, sample_data)) + # Format secondary tables for table_with_strategies in tables: + # get sample data for the table table = table_with_strategies.table - sections.append(f'## Table: {table.name.name}') - sections.append('') - - # Schema - sections.append('### Schema:') - sections.append(self._serialize_schema(table.schema)) - sections.append('') - - # Sample data - sections.append(f'### Sample Data ({self.num_samples} rows):') dataset = table.as_source().to_dataset(conn) sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) - sections.append(self._format_sample_data(sample_data)) - sections.append('') + # Format table section + sections.append(f'## Table: {table.name.name}\n') + sections.append(self._serialize_schema_and_samples(dataset.schema, sample_data)) return '\n'.join(sections) diff --git a/tests/agentune/core/formatter/__init__.py b/tests/agentune/core/formatter/__init__.py deleted file mode 100644 index 85f6b103..00000000 --- a/tests/agentune/core/formatter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for formatter modules.""" diff --git a/tests/agentune/core/formatter/test_schema.py b/tests/agentune/core/formatter/test_schema.py deleted file mode 100644 index 239f1eec..00000000 --- a/tests/agentune/core/formatter/test_schema.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Tests for schema formatters.""" - -import polars as pl -import pytest -from duckdb import DuckDBPyConnection - -from agentune.analyze.join.base import ( - TablesWithJoinStrategies, - TableWithJoinStrategies, -) -from agentune.analyze.join.lookup import LookupJoinStrategy -from agentune.core import types -from agentune.core.database import DuckdbName, DuckdbTable -from agentune.core.dataset import Dataset -from agentune.core.formatter.schema import SimpleSchemaFormatter -from agentune.core.schema import Field, Schema - - -@pytest.fixture -def primary_dataset() -> Dataset: - """Create a primary dataset for testing.""" - data = pl.DataFrame( - { - 'id': [1, 2, 3, 4, 5], - 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], - 'age': [25, 30, 35, 28, 32], - } - ) - return Dataset.from_polars(data) - - -@pytest.fixture -def secondary_table(conn: DuckDBPyConnection) -> DuckdbTable: - """Create a secondary table in the database.""" - table_name = DuckdbName('orders', 'memory', 'main') - schema = Schema( - cols=( - Field('order_id', types.int32), - Field('customer_id', types.int32), - Field('amount', types.float64), - ) - ) - - table = DuckdbTable(name=table_name, schema=schema) - table.create(conn) - - # Insert sample data - conn.execute( - f'INSERT INTO {table_name} VALUES (101, 1, 100.5), (102, 2, 250.0), (103, 1, 75.25)' - ) - - return table - - -@pytest.fixture -def tables_with_strategies(secondary_table: DuckdbTable) -> TablesWithJoinStrategies: - """Create tables with join strategies for testing.""" - join_strategy: LookupJoinStrategy[int] = LookupJoinStrategy( - name='mock_join', - table=secondary_table, - key_col=Field('customer_id', types.int32), - value_cols=(Field('order_id', types.int32), Field('amount', types.float64)), - ) - table_with_strategies = TableWithJoinStrategies.from_list([join_strategy]) - return TablesWithJoinStrategies.from_list([table_with_strategies]) - - -class TestSimpleSchemaFormatter: - """Tests for SimpleSchemaFormatter.""" - - def test_format_all_tables_basic( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test basic formatting of all tables.""" - formatter = SimpleSchemaFormatter() - result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) - - # Check that the result contains expected sections - assert '## Primary Table: primary_table' in result - assert '### Schema:' in result - assert '### Sample Data' in result - - # Check primary table columns are present - assert '- id:' in result - assert '- name:' in result - assert '- age:' in result - - # Check secondary table is present - assert '## Table: orders' in result - assert '- order_id:' in result - assert '- customer_id:' in result - assert '- amount:' in result - - def test_format_output_example( - self, - primary_dataset: Dataset, - tables_with_strategies: TablesWithJoinStrategies, - conn: DuckDBPyConnection, - ) -> None: - """Test to display the actual formatted output for visual inspection.""" - formatter = SimpleSchemaFormatter(num_samples=3) - result = formatter.format_all_tables(primary_dataset, tables_with_strategies, conn) - - # Print the actual output for inspection - print('\n' + '=' * 80) - print('FORMATTED OUTPUT:') - print('=' * 80) - print(result) - print('=' * 80) - - # Basic assertions to ensure test passes - assert len(result) > 0 - assert '## Primary Table:' in result From 91a6165d6013c7af33a1d28f30eacb3c2442e03a Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Sun, 11 Jan 2026 11:23:43 +0200 Subject: [PATCH 04/14] feat: Rename serialization methods to formatting for clarity --- agentune/core/formatter/base.py | 2 +- agentune/core/formatter/schema.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index 77f28e89..a0fde202 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -75,7 +75,7 @@ class SchemaFormatter(ABC, UseTypeTag): primary_table_name: str = 'primary_table' def _format_schema(self, schema: Schema) -> str: - """Serialize schema to human-readable string for LLM prompts.""" + """Format schema to human-readable string for LLM prompts.""" lines = [] for field in schema.cols: # Convert Dtype to simple string representation diff --git a/agentune/core/formatter/schema.py b/agentune/core/formatter/schema.py index eba03353..1f1d6d3c 100644 --- a/agentune/core/formatter/schema.py +++ b/agentune/core/formatter/schema.py @@ -25,8 +25,8 @@ class SimpleSchemaFormatter(SchemaFormatter): num_samples: int = 5 sampler: DataSampler = HeadSampler() - def _serialize_schema_and_samples(self, schema: Schema, sample_data: Dataset) -> str: - """Serialize schema and sample data for a single table.""" + def _format_schema_and_samples(self, schema: Schema, sample_data: Dataset) -> str: + """Format schema and sample data for a single table.""" # Schema out = ['### Schema:'] out.append(self._format_schema(schema)) @@ -62,7 +62,7 @@ def format_all_tables( sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) # Format primary table sections.append(f'## Primary Table: {self.primary_table_name}\n') - sections.append(self._serialize_schema_and_samples(input.schema, sample_data)) + sections.append(self._format_schema_and_samples(input.schema, sample_data)) # Format secondary tables for table_with_strategies in tables: @@ -72,6 +72,6 @@ def format_all_tables( sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) # Format table section sections.append(f'## Table: {table.name.name}\n') - sections.append(self._serialize_schema_and_samples(dataset.schema, sample_data)) + sections.append(self._format_schema_and_samples(dataset.schema, sample_data)) return '\n'.join(sections) From 24aa3660caefd091478bdad8a2ebe8fa49e40ce9 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Sun, 11 Jan 2026 12:00:29 +0200 Subject: [PATCH 05/14] rename to TablesFormatter and SimpleTablesFormatter --- agentune/core/formatter/base.py | 12 ++++++++++-- agentune/core/formatter/{schema.py => tables.py} | 14 +++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) rename agentune/core/formatter/{schema.py => tables.py} (83%) diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index a0fde202..8eac897d 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -66,8 +66,8 @@ async def aformat_batch(self, input: Dataset, conn: DuckDBPyConnection) -> pl.Se @attrs.define -class SchemaFormatter(ABC, UseTypeTag): - """Abstract base class for schema formatting strategies. +class TablesFormatter(ABC, UseTypeTag): + """Abstract base class for formatting database tables. Formats information about available database tables (schemas and sample data) for use in LLM prompts. """ @@ -88,6 +88,14 @@ def _format_sample_data(self, dataset: Dataset) -> str: """Format sample data rows as table for LLM prompts.""" return dataset.data.write_csv() + @abstractmethod + def format_table( + self, + sample_data: Dataset, + ) -> str: + """Format schema and sample data for a single table.""" + ... + @abstractmethod def format_all_tables( self, diff --git a/agentune/core/formatter/schema.py b/agentune/core/formatter/tables.py similarity index 83% rename from agentune/core/formatter/schema.py rename to agentune/core/formatter/tables.py index 1f1d6d3c..bc4ef99f 100644 --- a/agentune/core/formatter/schema.py +++ b/agentune/core/formatter/tables.py @@ -7,14 +7,14 @@ from agentune.analyze.join.base import TablesWithJoinStrategies from agentune.core.dataset import Dataset -from agentune.core.formatter.base import SchemaFormatter +from agentune.core.formatter.base import TablesFormatter from agentune.core.sampler.base import DataSampler, HeadSampler from agentune.core.schema import Schema @attrs.frozen -class SimpleSchemaFormatter(SchemaFormatter): - """Simple schema formatter for LLM prompts. +class SimpleTablesFormatter(TablesFormatter): + """Simple tables formatter for LLM prompts. Formats all available tables with their schemas and sample data. Each table is formatted with: @@ -25,11 +25,11 @@ class SimpleSchemaFormatter(SchemaFormatter): num_samples: int = 5 sampler: DataSampler = HeadSampler() - def _format_schema_and_samples(self, schema: Schema, sample_data: Dataset) -> str: + def format_table(self, sample_data: Dataset) -> str: """Format schema and sample data for a single table.""" # Schema out = ['### Schema:'] - out.append(self._format_schema(schema)) + out.append(self._format_schema(sample_data.schema)) out.append('') # Sample data @@ -62,7 +62,7 @@ def format_all_tables( sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) # Format primary table sections.append(f'## Primary Table: {self.primary_table_name}\n') - sections.append(self._format_schema_and_samples(input.schema, sample_data)) + sections.append(self._format_table(input.schema, sample_data)) # Format secondary tables for table_with_strategies in tables: @@ -72,6 +72,6 @@ def format_all_tables( sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) # Format table section sections.append(f'## Table: {table.name.name}\n') - sections.append(self._format_schema_and_samples(dataset.schema, sample_data)) + sections.append(self._format_table(dataset.schema, sample_data)) return '\n'.join(sections) From dc0df16636de3a2a2a38cc5c15d2755a2ab43627 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Sun, 11 Jan 2026 12:02:36 +0200 Subject: [PATCH 06/14] mypy fix --- agentune/core/formatter/tables.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index bc4ef99f..073122ef 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -9,7 +9,6 @@ from agentune.core.dataset import Dataset from agentune.core.formatter.base import TablesFormatter from agentune.core.sampler.base import DataSampler, HeadSampler -from agentune.core.schema import Schema @attrs.frozen @@ -62,7 +61,7 @@ def format_all_tables( sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) # Format primary table sections.append(f'## Primary Table: {self.primary_table_name}\n') - sections.append(self._format_table(input.schema, sample_data)) + sections.append(self.format_table(sample_data)) # Format secondary tables for table_with_strategies in tables: @@ -72,6 +71,6 @@ def format_all_tables( sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) # Format table section sections.append(f'## Table: {table.name.name}\n') - sections.append(self._format_table(dataset.schema, sample_data)) + sections.append(self.format_table(sample_data)) return '\n'.join(sections) From 898aa5dc50bfa0560d4072d30b0ade98c18f1624 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 14:44:47 +0200 Subject: [PATCH 07/14] feat: Create TableFormatter and MarkdownTableFormatter for improved table formatting --- agentune/core/formatter/base.py | 44 +++++------ agentune/core/formatter/tables.py | 121 +++++++++++++++++++++++------- agentune/core/sampler/base.py | 24 ------ 3 files changed, 115 insertions(+), 74 deletions(-) diff --git a/agentune/core/formatter/base.py b/agentune/core/formatter/base.py index 8eac897d..02834218 100644 --- a/agentune/core/formatter/base.py +++ b/agentune/core/formatter/base.py @@ -65,36 +65,37 @@ async def aformat_batch(self, input: Dataset, conn: DuckDBPyConnection) -> pl.Se ... + @attrs.define -class TablesFormatter(ABC, UseTypeTag): - """Abstract base class for formatting database tables. +class TableFormatter(ABC, UseTypeTag): + """Abstract base class for formatting database table. - Formats information about available database tables (schemas and sample data) for use in LLM prompts. + Formats information about table (schemas and sample data) to string representations. """ - primary_table_name: str = 'primary_table' - - def _format_schema(self, schema: Schema) -> str: - """Format schema to human-readable string for LLM prompts.""" - lines = [] - for field in schema.cols: - # Convert Dtype to simple string representation - dtype_str = repr(field.dtype.polars_type) - lines.append(f'- {field.name}: {dtype_str}') - - return '\n'.join(lines) - - def _format_sample_data(self, dataset: Dataset) -> str: - """Format sample data rows as table for LLM prompts.""" - return dataset.data.write_csv() - @abstractmethod def format_table( self, sample_data: Dataset, ) -> str: - """Format schema and sample data for a single table.""" + """Format schema and sample data for a single table. + + Args: + sample_data: Dataset containing the sample data to format + Returns: + String representation of the table with its schema and sample data + """ ... + + +@attrs.define +class TablesFormatter(ABC, UseTypeTag): + """Abstract base class for formatting database tables. + + Formats information about available database tables (schemas and sample data) to string representations. + """ + + primary_table_name: str = 'primary_table' @abstractmethod def format_all_tables( @@ -104,12 +105,13 @@ def format_all_tables( conn: DuckDBPyConnection, random_seed: int | None = None, ) -> str: - """Format the primary dataset and all auxiliary tables with their schemas and sample data for LLM prompts. + """Format the primary dataset and all auxiliary tables with their schemas and sample. Args: input: Input dataset (primary table) tables: Available tables with their join strategies conn: Database connection to query sample data + random_seed: Random seed for sampling (if applicable) Returns: String representation of all tables with their schemas and sample data diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 073122ef..21600d05 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -7,35 +7,91 @@ from agentune.analyze.join.base import TablesWithJoinStrategies from agentune.core.dataset import Dataset -from agentune.core.formatter.base import TablesFormatter -from agentune.core.sampler.base import DataSampler, HeadSampler +from agentune.core.formatter.base import TableFormatter, TablesFormatter +from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler +from agentune.core.sampler.table_samples import HeadTableSampler +from agentune.core.schema import Schema @attrs.frozen -class SimpleTablesFormatter(TablesFormatter): - """Simple tables formatter for LLM prompts. +class MarkdownTableFormatter(TableFormatter): + """Markdown table formatter. - Formats all available tables with their schemas and sample data. - Each table is formatted with: - - Table name - - Schema (list of columns with types) - - Sample data (CSV format) + Formats a single table with its schema and sample data using markdown headers. + The schema is displayed as a bulleted list of columns with their DuckDB types, + and the sample data is formatted as CSV for readability. + + Args: + markdown_level: The markdown header level to use for sections (default: 3). """ - num_samples: int = 5 - sampler: DataSampler = HeadSampler() + markdown_level: int = 3 + + def _format_schema(self, schema: Schema) -> str: + """Format schema to human-readable string.""" + lines = [] + for field in schema.cols: + # Convert Dtype to simple string representation using duckdb_type + dtype_str = repr(field.dtype.duckdb_type) + lines.append(f'- {field.name}: {dtype_str}') + + return '\n'.join(lines) + + def _format_sample_data(self, dataset: Dataset) -> str: + """Format sample data rows as table using CSV format.""" + return dataset.data.write_csv() + + @override + def format_table( + self, + sample_data: Dataset, + ) -> str: + """Format schema and sample data for a single table. + + Includes markdown headers at the specified level. Formats the schema as a list + of columns with their DuckDB types, and formats the sample data as CSV. + + Args: + sample_data: Dataset containing the sample data to format - def format_table(self, sample_data: Dataset) -> str: - """Format schema and sample data for a single table.""" + Returns: + String representation of the table with its schema and sample data + """ + markdown_header = '#' * self.markdown_level # Schema - out = ['### Schema:'] + out = [f'{markdown_header} Schema:'] out.append(self._format_schema(sample_data.schema)) out.append('') # Sample data - out.append(f'### Sample Data ({self.num_samples} rows):') + out.append(f'{markdown_header} Sample Data:') out.append(self._format_sample_data(sample_data)) out.append('') return '\n'.join(out) + + +@attrs.frozen +class MarkdownTablesFormatter(TablesFormatter): + """Markdown tables formatter. + + Formats all available tables (primary and secondary) with their schemas and sample data + in markdown format. Each table includes a header with its name, followed by its schema + and sample data sections. + + Args: + markdown_level: The markdown header level to use for table sections (default: 2). + num_samples: Number of sample rows to retrieve for each table (default: 5). + table_formatter: TableFormatter to use for formatting individual tables. + Defaults to MarkdownTableFormatter with markdown_level + 1. + primary_dataset_sampler: DataSampler to use for sampling the primary dataset. + Defaults to RandomSampler for representative sampling. + tables_sampler: TableSampler to use for sampling the secondary tables. + Defaults to HeadTableSampler for consistent sampling. + """ + markdown_level: int = 2 + num_samples: int = 5 + table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1), takes_self=True)) + primary_dataset_sampler: DataSampler = RandomSampler() + tables_sampler: TableSampler = HeadTableSampler() @override def format_all_tables( @@ -46,31 +102,38 @@ def format_all_tables( random_seed: int | None = None, ) -> str: """Format all available tables with their schemas and sample data for LLM prompts. + + Formats the primary table followed by all secondary tables. Each table is formatted + with a header at the specified markdown level, followed by its schema and sample data + using the configured table formatter. Args: - input: Input dataset (primary table) - tables: Available tables with their join strategies - conn: Database connection to query sample data + input: Primary input dataset to format. + tables: Secondary tables with their join strategies to format. + conn: Database connection for querying sample data from secondary tables. + random_seed: Optional random seed for reproducible sampling. Returns: - String representation of all tables with their schemas and sample data + Markdown-formatted string containing all tables with their schemas and sample data. """ sections = [] - + markdown_header = '#' * self.markdown_level # get sample data for primary table - sample_data = self.sampler.sample(input, self.num_samples, random_seed=random_seed) + sample_data = self.primary_dataset_sampler.sample(input, self.num_samples, random_seed=random_seed) # Format primary table - sections.append(f'## Primary Table: {self.primary_table_name}\n') - sections.append(self.format_table(sample_data)) - + sections.append(f'{markdown_header} Primary Table: {self.primary_table_name}\n') + sections.append(self.table_formatter.format_table(sample_data)) # Format secondary tables for table_with_strategies in tables: # get sample data for the table - table = table_with_strategies.table - dataset = table.as_source().to_dataset(conn) - sample_data = self.sampler.sample(dataset, self.num_samples, random_seed=random_seed) + sample_data = self.tables_sampler.sample( + table_with_strategies, + conn, + self.num_samples, + random_seed=random_seed, + ) # Format table section - sections.append(f'## Table: {table.name.name}\n') - sections.append(self.format_table(sample_data)) + sections.append(f'{markdown_header} Table: {table_with_strategies.table.name.name}\n') + sections.append(self.table_formatter.format_table(sample_data)) return '\n'.join(sections) diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 32f8e5e0..1812dc8b 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -61,27 +61,3 @@ def sample(self, dataset: Dataset, sample_size: int, random_seed: int | None = N schema=dataset.schema, data=sampled_df, ) - -@attrs.define -class HeadSampler(DataSampler): - """Simple head sampling. - - This sampler selects the first `sample_size` rows from the dataset. - """ - - @override - def sample(self, dataset: Dataset, sample_size: int, random_seed: int | None = None) -> Dataset: - """Sample data using head sampling.""" - self._validate_inputs(dataset, sample_size) - - # If sample size >= dataset size, return the entire dataset - if sample_size >= dataset.data.height: - return dataset - - # Head sample - sampled_df = dataset.data.head(sample_size) - - return Dataset( - schema=dataset.schema, - data=sampled_df, - ) From 1a32253411d58b90c361fefcaa37ab7dcf76bdfb Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 15:50:13 +0200 Subject: [PATCH 08/14] add example to docstring --- agentune/core/formatter/tables.py | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 21600d05..9d3ae63a 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -23,6 +23,18 @@ class MarkdownTableFormatter(TableFormatter): Args: markdown_level: The markdown header level to use for sections (default: 3). + + Example: + ### Schema: + - id: INTEGER + - name: VARCHAR + - status: ENUM('active', 'inactive') + - score: DOUBLE + + ### Sample Data: + id,name,status,score + 1,Alice,active,95.5 + 2,Bob,inactive,82.3 """ markdown_level: int = 3 @@ -86,6 +98,31 @@ class MarkdownTablesFormatter(TablesFormatter): Defaults to RandomSampler for representative sampling. tables_sampler: TableSampler to use for sampling the secondary tables. Defaults to HeadTableSampler for consistent sampling. + + Example: + ## Primary Table: users + + ### Schema: + - id: INTEGER + - name: VARCHAR + - tier: ENUM('bronze', 'silver', 'gold') + + ### Sample Data: + id,name,tier + 1,Alice,gold + 2,Bob,silver + + ## Table: orders + + ### Schema: + - order_id: INTEGER + - user_id: INTEGER + - amount: DOUBLE + + ### Sample Data: + order_id,user_id,amount + 101,1,49.99 + 102,2,125.50 """ markdown_level: int = 2 num_samples: int = 5 From 1f1d72fe28c70c8a8f19b59e1473d0e6c11d6700 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 17:04:03 +0200 Subject: [PATCH 09/14] feat: Enhance MarkdownTableFormatter to truncate long string values using Polars --- agentune/core/formatter/tables.py | 26 ++++- tests/agentune/core/formatter/__init__.py | 1 + tests/agentune/core/formatter/test_tables.py | 116 +++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 tests/agentune/core/formatter/__init__.py create mode 100644 tests/agentune/core/formatter/test_tables.py diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 9d3ae63a..21314aa1 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -11,6 +11,7 @@ from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler from agentune.core.sampler.table_samples import HeadTableSampler from agentune.core.schema import Schema +import polars as pl @attrs.frozen @@ -23,6 +24,7 @@ class MarkdownTableFormatter(TableFormatter): Args: markdown_level: The markdown header level to use for sections (default: 3). + max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). Example: ### Schema: @@ -37,6 +39,7 @@ class MarkdownTableFormatter(TableFormatter): 2,Bob,inactive,82.3 """ markdown_level: int = 3 + max_str: int = 100 def _format_schema(self, schema: Schema) -> str: """Format schema to human-readable string.""" @@ -50,7 +53,24 @@ def _format_schema(self, schema: Schema) -> str: def _format_sample_data(self, dataset: Dataset) -> str: """Format sample data rows as table using CSV format.""" - return dataset.data.write_csv() + # Only truncate string columns using Polars + select_exprs = [] + for field in dataset.schema.cols: + col_name = field.name + # Check if column is a string type + if field.dtype.polars_type in (pl.String, pl.Utf8): + # Truncate long strings + select_exprs.append( + pl.when(pl.col(col_name).str.len_bytes() > self.max_str) + .then(pl.col(col_name).str.slice(0, self.max_str) + '...') + .otherwise(pl.col(col_name)) + .alias(col_name) + ) + else: + select_exprs.append(pl.col(col_name)) + + truncated_data = dataset.data.select(select_exprs) + return truncated_data.write_csv() @override def format_table( @@ -92,6 +112,7 @@ class MarkdownTablesFormatter(TablesFormatter): Args: markdown_level: The markdown header level to use for table sections (default: 2). num_samples: Number of sample rows to retrieve for each table (default: 5). + max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100). table_formatter: TableFormatter to use for formatting individual tables. Defaults to MarkdownTableFormatter with markdown_level + 1. primary_dataset_sampler: DataSampler to use for sampling the primary dataset. @@ -126,7 +147,8 @@ class MarkdownTablesFormatter(TablesFormatter): """ markdown_level: int = 2 num_samples: int = 5 - table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1), takes_self=True)) + max_str: int = 100 + table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1, max_str=self.max_str), takes_self=True)) primary_dataset_sampler: DataSampler = RandomSampler() tables_sampler: TableSampler = HeadTableSampler() diff --git a/tests/agentune/core/formatter/__init__.py b/tests/agentune/core/formatter/__init__.py new file mode 100644 index 00000000..33efb0e2 --- /dev/null +++ b/tests/agentune/core/formatter/__init__.py @@ -0,0 +1 @@ +"""Tests for formatter module.""" diff --git a/tests/agentune/core/formatter/test_tables.py b/tests/agentune/core/formatter/test_tables.py new file mode 100644 index 00000000..b428f36d --- /dev/null +++ b/tests/agentune/core/formatter/test_tables.py @@ -0,0 +1,116 @@ +"""Tests for table formatting.""" + +import polars as pl + +from agentune.core.dataset import Dataset +from agentune.core.formatter.tables import MarkdownTableFormatter +from agentune.core.schema import Field, Schema +from agentune.core.types import float64, int32, string + + +class TestMarkdownTableFormatter: + """Test MarkdownTableFormatter functionality.""" + + def test_string_truncation(self) -> None: + """Test that long string values are truncated with '...'.""" + formatter = MarkdownTableFormatter(max_str=20) + + # Create test data with a long string + data = pl.DataFrame({ + 'id': [1, 2, 3], + 'short': ['abc', 'def', 'ghi'], + 'long': [ + 'this is a very long string that should be truncated', + 'another extremely long text value here', + 'short' + ], + }) + + schema = Schema(( + Field('id', int32), + Field('short', string), + Field('long', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Verify truncated strings in correct column + assert lines[1].split(',')[2] == 'this is a very long ...' + assert lines[2].split(',')[2] == 'another extremely lo...' + assert lines[3].split(',')[2] == 'short' + + # Verify short strings preserved in correct column + assert lines[1].split(',')[1] == 'abc' + + def test_non_string_columns_preserved(self) -> None: + """Test that non-string columns are not truncated.""" + formatter = MarkdownTableFormatter(max_str=10) + + # Create test data with various types + data = pl.DataFrame({ + 'id': [1, 2, 3], + 'value': [123456789012345, 987654321098765, 111222333444555], + 'price': [99.99999999, 12345.6789, 0.123456789], + 'name': ['this is a very long name', 'short', 'another long name here'], + }) + + schema = Schema(( + Field('id', int32), + Field('value', int32), + Field('price', float64), + Field('name', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Verify columns: id preserved, value preserved, price preserved, name truncated + assert lines[1].split(',')[0] == '1' + assert lines[1].split(',')[1] == '123456789012345' + assert lines[1].split(',')[3] == 'this is a ...' + assert lines[2].split(',')[3] == 'short' + + def test_custom_max_str_length(self) -> None: + """Test that max_str parameter controls truncation length.""" + data = pl.DataFrame({ + 'text': ['12345678901234567890', 'short'], + }) + + schema = Schema((Field('text', string),)) + dataset = Dataset(schema=schema, data=data) + + # Test with max_str=10 + result_10 = MarkdownTableFormatter(max_str=10).format_table(dataset) + assert result_10.split('### Sample Data:')[1].strip().split('\n')[1] == '1234567890...' + + # Test with max_str=100 (no truncation) + result_100 = MarkdownTableFormatter(max_str=100).format_table(dataset) + assert result_100.split('### Sample Data:')[1].strip().split('\n')[1] == '12345678901234567890' + + def test_empty_dataset(self) -> None: + """Test formatting an empty dataset.""" + formatter = MarkdownTableFormatter(max_str=20) + + data = pl.DataFrame({ + 'id': pl.Series([], dtype=pl.Int32), + 'name': pl.Series([], dtype=pl.Utf8), + }) + + schema = Schema(( + Field('id', int32), + Field('name', string), + )) + + dataset = Dataset(schema=schema, data=data) + + result = formatter.format_table(dataset) + csv_lines = result.split('### Sample Data:')[1].strip().split('\n') + + # Should have only header, no data rows + assert len(csv_lines) == 1 + assert csv_lines[0] == 'id,name' From 1116742fdd0d095a8ae2fd0c6cf05e166dfe32f7 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 17:05:08 +0200 Subject: [PATCH 10/14] ruff --- agentune/core/formatter/tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 21314aa1..2b88751d 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -3,6 +3,7 @@ from typing import override import attrs +import polars as pl from duckdb import DuckDBPyConnection from agentune.analyze.join.base import TablesWithJoinStrategies @@ -11,7 +12,6 @@ from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler from agentune.core.sampler.table_samples import HeadTableSampler from agentune.core.schema import Schema -import polars as pl @attrs.frozen From 0c6add4eb5ad91335eb300d22641c7b9b74a7e9a Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Tue, 13 Jan 2026 11:04:13 +0200 Subject: [PATCH 11/14] Correct string length check --- agentune/core/formatter/tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 2b88751d..cdc2d851 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -61,7 +61,7 @@ def _format_sample_data(self, dataset: Dataset) -> str: if field.dtype.polars_type in (pl.String, pl.Utf8): # Truncate long strings select_exprs.append( - pl.when(pl.col(col_name).str.len_bytes() > self.max_str) + pl.when(pl.col(col_name).str.len_chars() > self.max_str) .then(pl.col(col_name).str.slice(0, self.max_str) + '...') .otherwise(pl.col(col_name)) .alias(col_name) From 27ae8ee5b2a210ad1e206dce06dc8921705d2e89 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Tue, 13 Jan 2026 19:51:44 +0200 Subject: [PATCH 12/14] update string type check --- agentune/core/formatter/tables.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index cdc2d851..0fa12d1d 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -12,6 +12,7 @@ from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler from agentune.core.sampler.table_samples import HeadTableSampler from agentune.core.schema import Schema +import agentune.core.types as types @attrs.frozen @@ -57,8 +58,8 @@ def _format_sample_data(self, dataset: Dataset) -> str: select_exprs = [] for field in dataset.schema.cols: col_name = field.name - # Check if column is a string type - if field.dtype.polars_type in (pl.String, pl.Utf8): + # Check if column is a string or json type + if field.dtype in (types.string, types.json_dtype): # Truncate long strings select_exprs.append( pl.when(pl.col(col_name).str.len_chars() > self.max_str) From 2fb61a2afda931a9e8817be8875237c65902db3f Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Wed, 14 Jan 2026 11:47:38 +0200 Subject: [PATCH 13/14] update --- agentune/core/formatter/tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 0fa12d1d..225521b5 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -187,7 +187,7 @@ def format_all_tables( for table_with_strategies in tables: # get sample data for the table sample_data = self.tables_sampler.sample( - table_with_strategies, + table_with_strategies.table.name, conn, self.num_samples, random_seed=random_seed, From 021de57a524e0d482e8839153f854f82da882b43 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Wed, 14 Jan 2026 15:07:07 +0200 Subject: [PATCH 14/14] ruff fix --- agentune/core/formatter/tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentune/core/formatter/tables.py b/agentune/core/formatter/tables.py index 225521b5..f8ee2ff6 100644 --- a/agentune/core/formatter/tables.py +++ b/agentune/core/formatter/tables.py @@ -7,12 +7,12 @@ from duckdb import DuckDBPyConnection from agentune.analyze.join.base import TablesWithJoinStrategies +from agentune.core import types from agentune.core.dataset import Dataset from agentune.core.formatter.base import TableFormatter, TablesFormatter from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler from agentune.core.sampler.table_samples import HeadTableSampler from agentune.core.schema import Schema -import agentune.core.types as types @attrs.frozen