Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion agentune/core/formatter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import polars as pl
from duckdb import DuckDBPyConnection

from agentune.analyze.join.base import JoinStrategy
from agentune.analyze.join.base import JoinStrategy, TablesWithJoinStrategies
from agentune.core.database import DuckdbTable
from agentune.core.dataset import Dataset
from agentune.core.schema import Schema
Expand Down Expand Up @@ -65,3 +65,55 @@ async def aformat_batch(self, input: Dataset, conn: DuckDBPyConnection) -> pl.Se
...



@attrs.define
class TableFormatter(ABC, UseTypeTag):
"""Abstract base class for formatting database table.

Formats information about table (schemas and sample data) to string representations.
"""

@abstractmethod
def format_table(
self,
sample_data: Dataset,
) -> str:
"""Format schema and sample data for a single table.

Args:
sample_data: Dataset containing the sample data to format
Returns:
String representation of the table with its schema and sample data
"""
...


@attrs.define
class TablesFormatter(ABC, UseTypeTag):
"""Abstract base class for formatting database tables.

Formats information about available database tables (schemas and sample data) to string representations.
"""

primary_table_name: str = 'primary_table'

@abstractmethod
def format_all_tables(
self,
input: Dataset,
tables: TablesWithJoinStrategies,
conn: DuckDBPyConnection,
random_seed: int | None = None,
) -> str:
"""Format the primary dataset and all auxiliary tables with their schemas and sample.

Args:
input: Input dataset (primary table)
tables: Available tables with their join strategies
conn: Database connection to query sample data
random_seed: Random seed for sampling (if applicable)

Returns:
String representation of all tables with their schemas and sample data
"""
...
199 changes: 199 additions & 0 deletions agentune/core/formatter/tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Schema formatting for LLM prompts."""

from typing import override

import attrs
import polars as pl
from duckdb import DuckDBPyConnection

from agentune.analyze.join.base import TablesWithJoinStrategies
from agentune.core import types
from agentune.core.dataset import Dataset
from agentune.core.formatter.base import TableFormatter, TablesFormatter
from agentune.core.sampler.base import DataSampler, RandomSampler, TableSampler
from agentune.core.sampler.table_samples import HeadTableSampler
from agentune.core.schema import Schema


@attrs.frozen
class MarkdownTableFormatter(TableFormatter):
"""Markdown table formatter.

Formats a single table with its schema and sample data using markdown headers.
The schema is displayed as a bulleted list of columns with their DuckDB types,
and the sample data is formatted as CSV for readability.

Args:
markdown_level: The markdown header level to use for sections (default: 3).
max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100).

Example:
### Schema:
- id: INTEGER
- name: VARCHAR
- status: ENUM('active', 'inactive')
- score: DOUBLE

### Sample Data:
id,name,status,score
1,Alice,active,95.5
2,Bob,inactive,82.3
"""
markdown_level: int = 3
max_str: int = 100

def _format_schema(self, schema: Schema) -> str:
"""Format schema to human-readable string."""
lines = []
for field in schema.cols:
# Convert Dtype to simple string representation using duckdb_type
dtype_str = repr(field.dtype.duckdb_type)
lines.append(f'- {field.name}: {dtype_str}')

return '\n'.join(lines)

def _format_sample_data(self, dataset: Dataset) -> str:
"""Format sample data rows as table using CSV format."""
# Only truncate string columns using Polars
select_exprs = []
for field in dataset.schema.cols:
col_name = field.name
# Check if column is a string or json type
if field.dtype in (types.string, types.json_dtype):
# Truncate long strings
select_exprs.append(
pl.when(pl.col(col_name).str.len_chars() > self.max_str)
.then(pl.col(col_name).str.slice(0, self.max_str) + '...')
.otherwise(pl.col(col_name))
.alias(col_name)
)
else:
select_exprs.append(pl.col(col_name))

truncated_data = dataset.data.select(select_exprs)
return truncated_data.write_csv()

@override
def format_table(
self,
sample_data: Dataset,
) -> str:
"""Format schema and sample data for a single table.

Includes markdown headers at the specified level. Formats the schema as a list
of columns with their DuckDB types, and formats the sample data as CSV.

Args:
sample_data: Dataset containing the sample data to format

Returns:
String representation of the table with its schema and sample data
"""
markdown_header = '#' * self.markdown_level
# Schema
out = [f'{markdown_header} Schema:']
out.append(self._format_schema(sample_data.schema))
out.append('')

# Sample data
out.append(f'{markdown_header} Sample Data:')
out.append(self._format_sample_data(sample_data))
out.append('')
return '\n'.join(out)


@attrs.frozen
class MarkdownTablesFormatter(TablesFormatter):
"""Markdown tables formatter.

Formats all available tables (primary and secondary) with their schemas and sample data
in markdown format. Each table includes a header with its name, followed by its schema
and sample data sections.

Args:
markdown_level: The markdown header level to use for table sections (default: 2).
num_samples: Number of sample rows to retrieve for each table (default: 5).
max_str: Maximum string length for cell values. Longer values are truncated with '...' (default: 100).
table_formatter: TableFormatter to use for formatting individual tables.
Defaults to MarkdownTableFormatter with markdown_level + 1.
primary_dataset_sampler: DataSampler to use for sampling the primary dataset.
Defaults to RandomSampler for representative sampling.
tables_sampler: TableSampler to use for sampling the secondary tables.
Defaults to HeadTableSampler for consistent sampling.

Example:
## Primary Table: users

### Schema:
- id: INTEGER
- name: VARCHAR
- tier: ENUM('bronze', 'silver', 'gold')

### Sample Data:
id,name,tier
1,Alice,gold
2,Bob,silver

## Table: orders

### Schema:
- order_id: INTEGER
- user_id: INTEGER
- amount: DOUBLE

### Sample Data:
order_id,user_id,amount
101,1,49.99
102,2,125.50
"""
markdown_level: int = 2
num_samples: int = 5
max_str: int = 100
table_formatter: TableFormatter = attrs.field(default=attrs.Factory(lambda self: MarkdownTableFormatter(markdown_level=self.markdown_level + 1, max_str=self.max_str), takes_self=True))
primary_dataset_sampler: DataSampler = RandomSampler()
tables_sampler: TableSampler = HeadTableSampler()

@override
def format_all_tables(
self,
input: Dataset,
tables: TablesWithJoinStrategies,
conn: DuckDBPyConnection,
random_seed: int | None = None,
) -> str:
"""Format all available tables with their schemas and sample data for LLM prompts.

Formats the primary table followed by all secondary tables. Each table is formatted
with a header at the specified markdown level, followed by its schema and sample data
using the configured table formatter.

Args:
input: Primary input dataset to format.
tables: Secondary tables with their join strategies to format.
conn: Database connection for querying sample data from secondary tables.
random_seed: Optional random seed for reproducible sampling.

Returns:
Markdown-formatted string containing all tables with their schemas and sample data.
"""
sections = []
markdown_header = '#' * self.markdown_level
# get sample data for primary table
sample_data = self.primary_dataset_sampler.sample(input, self.num_samples, random_seed=random_seed)
# Format primary table
sections.append(f'{markdown_header} Primary Table: {self.primary_table_name}\n')
sections.append(self.table_formatter.format_table(sample_data))
# Format secondary tables
for table_with_strategies in tables:
# get sample data for the table
sample_data = self.tables_sampler.sample(
table_with_strategies.table.name,
conn,
self.num_samples,
random_seed=random_seed,
)
# Format table section
sections.append(f'{markdown_header} Table: {table_with_strategies.table.name.name}\n')
sections.append(self.table_formatter.format_table(sample_data))

return '\n'.join(sections)
1 change: 1 addition & 0 deletions tests/agentune/core/formatter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for formatter module."""
116 changes: 116 additions & 0 deletions tests/agentune/core/formatter/test_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Tests for table formatting."""

import polars as pl

from agentune.core.dataset import Dataset
from agentune.core.formatter.tables import MarkdownTableFormatter
from agentune.core.schema import Field, Schema
from agentune.core.types import float64, int32, string


class TestMarkdownTableFormatter:
"""Test MarkdownTableFormatter functionality."""

def test_string_truncation(self) -> None:
"""Test that long string values are truncated with '...'."""
formatter = MarkdownTableFormatter(max_str=20)

# Create test data with a long string
data = pl.DataFrame({
'id': [1, 2, 3],
'short': ['abc', 'def', 'ghi'],
'long': [
'this is a very long string that should be truncated',
'another extremely long text value here',
'short'
],
})

schema = Schema((
Field('id', int32),
Field('short', string),
Field('long', string),
))

dataset = Dataset(schema=schema, data=data)

result = formatter.format_table(dataset)
lines = result.split('### Sample Data:')[1].strip().split('\n')

# Verify truncated strings in correct column
assert lines[1].split(',')[2] == 'this is a very long ...'
assert lines[2].split(',')[2] == 'another extremely lo...'
assert lines[3].split(',')[2] == 'short'

# Verify short strings preserved in correct column
assert lines[1].split(',')[1] == 'abc'

def test_non_string_columns_preserved(self) -> None:
"""Test that non-string columns are not truncated."""
formatter = MarkdownTableFormatter(max_str=10)

# Create test data with various types
data = pl.DataFrame({
'id': [1, 2, 3],
'value': [123456789012345, 987654321098765, 111222333444555],
'price': [99.99999999, 12345.6789, 0.123456789],
'name': ['this is a very long name', 'short', 'another long name here'],
})

schema = Schema((
Field('id', int32),
Field('value', int32),
Field('price', float64),
Field('name', string),
))

dataset = Dataset(schema=schema, data=data)

result = formatter.format_table(dataset)
lines = result.split('### Sample Data:')[1].strip().split('\n')

# Verify columns: id preserved, value preserved, price preserved, name truncated
assert lines[1].split(',')[0] == '1'
assert lines[1].split(',')[1] == '123456789012345'
assert lines[1].split(',')[3] == 'this is a ...'
assert lines[2].split(',')[3] == 'short'

def test_custom_max_str_length(self) -> None:
"""Test that max_str parameter controls truncation length."""
data = pl.DataFrame({
'text': ['12345678901234567890', 'short'],
})

schema = Schema((Field('text', string),))
dataset = Dataset(schema=schema, data=data)

# Test with max_str=10
result_10 = MarkdownTableFormatter(max_str=10).format_table(dataset)
assert result_10.split('### Sample Data:')[1].strip().split('\n')[1] == '1234567890...'

# Test with max_str=100 (no truncation)
result_100 = MarkdownTableFormatter(max_str=100).format_table(dataset)
assert result_100.split('### Sample Data:')[1].strip().split('\n')[1] == '12345678901234567890'

def test_empty_dataset(self) -> None:
"""Test formatting an empty dataset."""
formatter = MarkdownTableFormatter(max_str=20)

data = pl.DataFrame({
'id': pl.Series([], dtype=pl.Int32),
'name': pl.Series([], dtype=pl.Utf8),
})

schema = Schema((
Field('id', int32),
Field('name', string),
))

dataset = Dataset(schema=schema, data=data)

result = formatter.format_table(dataset)
csv_lines = result.split('### Sample Data:')[1].strip().split('\n')

# Should have only header, no data rows
assert len(csv_lines) == 1
assert csv_lines[0] == 'id,name'