Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 239 additions & 67 deletions EXTENDING_ACE.md

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,30 @@ uv run python -m eval.finance.run \
| `--no_ground_truth` | Don't use ground truth in reflection | False |
| `--use_bulletpoint_analyzer` | Enable bulletpoint analyzer for playbook deduplication and merging | False |
| `--bulletpoint_analyzer_threshold` | Similarity threshold for bulletpoint analyzer (0-1) | 0.9 |
| `--task_prompts_dir` | Path to task-specific prompts directory (see [Custom Prompts](#custom-task-specific-prompts)) | None |

</details>

### Custom Task-Specific Prompts

ACE supports task-specific prompts that override the default prompts. This allows you to customize the Generator, Reflector, and Curator behavior for different domains without modifying the core ACE code.

```bash
# Use custom prompts for a specific task
uv run python -m eval.finance.run \
--task_name finer \
--mode offline \
--save_path results \
--task_prompts_dir ./eval/finance/prompts
```

To create custom prompts, create a `prompts/` directory under your task folder with any of these files:
- `generator.py` - Define `GENERATOR_PROMPT`
- `reflector.py` - Define `REFLECTOR_PROMPT` and/or `REFLECTOR_PROMPT_NO_GT`
- `curator.py` - Define `CURATOR_PROMPT` and/or `CURATOR_PROMPT_NO_GT`

Only the prompts you define will override the defaults; missing prompts fall back to the built-in defaults.

## 📈 Results and Outputs

Using offline training as an example, after training, ACE generates:
Expand Down
3 changes: 2 additions & 1 deletion ace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@

from .ace import ACE
from .core import Generator, Reflector, Curator, BulletpointAnalyzer
from .prompts import PromptConfig, load_prompts

__all__ = ['ACE', 'Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer']
__all__ = ['ACE', 'Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'PromptConfig', 'load_prompts']

__version__ = "1.0.0"
30 changes: 24 additions & 6 deletions ace/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, List, Tuple, Optional, Any

from .core import Generator, Reflector, Curator, BulletpointAnalyzer
from .prompts import PromptConfig, load_prompts
from playbook_utils import *
from logger import *
from utils import *
Expand All @@ -39,11 +40,12 @@ def __init__(
max_tokens: int = 4096,
initial_playbook: Optional[str] = None,
use_bulletpoint_analyzer: bool = False,
bulletpoint_analyzer_threshold: float = 0.90
bulletpoint_analyzer_threshold: float = 0.90,
prompt_config: Optional[PromptConfig] = None
):
"""
Initialize the ACE system.

Args:
api_provider: API provider for LLM calls
generator_model: Model name for generator
Expand All @@ -53,14 +55,30 @@ def __init__(
initial_playbook: Initial playbook content (optional)
use_bulletpoint_analyzer: Whether to use bulletpoint analyzer for deduplication
bulletpoint_analyzer_threshold: Similarity threshold for bulletpoint analyzer (0-1)
prompt_config: PromptConfig with custom prompts (optional, uses defaults if None)
"""
# Load default prompts if none provided
if prompt_config is None:
prompt_config = load_prompts()

# Initialize API clients
generator_client, reflector_client, curator_client = initialize_clients(api_provider)

# Initialize the three agents
self.generator = Generator(generator_client, api_provider, generator_model, max_tokens)
self.reflector = Reflector(reflector_client, api_provider, reflector_model, max_tokens)
self.curator = Curator(curator_client, api_provider, curator_model, max_tokens)
# Initialize the three agents with prompts from config
self.generator = Generator(
generator_client, api_provider, generator_model, max_tokens,
generator_prompt=prompt_config.generator_prompt
)
self.reflector = Reflector(
reflector_client, api_provider, reflector_model, max_tokens,
reflector_prompt=prompt_config.reflector_prompt,
reflector_prompt_no_gt=prompt_config.reflector_prompt_no_gt
)
self.curator = Curator(
curator_client, api_provider, curator_model, max_tokens,
curator_prompt=prompt_config.curator_prompt,
curator_prompt_no_gt=prompt_config.curator_prompt_no_gt
)

# Initialize bulletpoint analyzer if requested and available
self.use_bulletpoint_analyzer = use_bulletpoint_analyzer
Expand Down
16 changes: 11 additions & 5 deletions ace/core/curator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,27 @@ class Curator:
Curator agent that manages the playbook by adding, updating,
merging, and deleting bullets based on reflection feedback.
"""

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096):

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096,
curator_prompt: Optional[str] = None,
curator_prompt_no_gt: Optional[str] = None):
"""
Initialize the Curator agent.

Args:
api_client: OpenAI client for LLM calls
api_provider: API provider for LLM calls
model: Model name to use for curation
max_tokens: Maximum tokens for curation
curator_prompt: Custom curator prompt (optional, uses default if None)
curator_prompt_no_gt: Custom curator prompt without ground truth (optional)
"""
self.api_client = api_client
self.api_provider = api_provider
self.model = model
self.max_tokens = max_tokens
self.curator_prompt = curator_prompt or CURATOR_PROMPT
self.curator_prompt_no_gt = curator_prompt_no_gt or CURATOR_PROMPT_NO_GT

def curate(
self,
Expand Down Expand Up @@ -72,7 +78,7 @@ def curate(

# Select the appropriate prompt
if use_ground_truth:
prompt = CURATOR_PROMPT.format(
prompt = self.curator_prompt.format(
current_step=current_step,
total_samples=total_samples,
token_budget=token_budget,
Expand All @@ -82,7 +88,7 @@ def curate(
question_context=question_context
)
else:
prompt = CURATOR_PROMPT_NO_GT.format(
prompt = self.curator_prompt_no_gt.format(
current_step=current_step,
total_samples=total_samples,
token_budget=token_budget,
Expand Down
16 changes: 12 additions & 4 deletions ace/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@ class Generator:
Generator agent that produces answers to questions using knowledge
from a playbook and previous reflections.
"""

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096):

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096,
generator_prompt: Optional[str] = None):
"""
Initialize the Generator agent.

Args:
api_client: OpenAI client for LLM calls
api_provider: API provider for LLM calls
model: Model name to use for generation
max_tokens: Maximum tokens for generation
generator_prompt: Custom generator prompt (optional, uses default if None)
"""
self.api_client = api_client
self.api_provider = api_provider
self.model = model
self.max_tokens = max_tokens
self.generator_prompt = generator_prompt or GENERATOR_PROMPT

def generate(
self,
Expand Down Expand Up @@ -56,7 +59,12 @@ def generate(
Tuple of (full_response, bullet_ids_used, call_info)
"""
# Format the prompt
prompt = GENERATOR_PROMPT.format(playbook, reflection, question, context)
prompt = self.generator_prompt.format(
playbook=playbook,
reflection=reflection,
question=question,
context=context
)

response, call_info = timed_llm_call(
self.api_client,
Expand Down
38 changes: 22 additions & 16 deletions ace/core/reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@ class Reflector:
Reflector agent that analyzes the generator's reasoning and tags
bullets as helpful, harmful, or neutral.
"""

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096):

def __init__(self, api_client, api_provider, model: str, max_tokens: int = 4096,
reflector_prompt: Optional[str] = None,
reflector_prompt_no_gt: Optional[str] = None):
"""
Initialize the Reflector agent.

Args:
api_client: OpenAI client for LLM calls
api_provider: API provider for LLM calls
model: Model name to use for reflection
max_tokens: Maximum tokens for reflection
reflector_prompt: Custom reflector prompt (optional, uses default if None)
reflector_prompt_no_gt: Custom reflector prompt without ground truth (optional)
"""
self.api_client = api_client
self.api_provider = api_provider
self.model = model
self.max_tokens = max_tokens
self.reflector_prompt = reflector_prompt or REFLECTOR_PROMPT
self.reflector_prompt_no_gt = reflector_prompt_no_gt or REFLECTOR_PROMPT_NO_GT

def reflect(
self,
Expand Down Expand Up @@ -63,21 +69,21 @@ def reflect(
"""
# Select the appropriate prompt
if use_ground_truth and ground_truth:
prompt = REFLECTOR_PROMPT.format(
question,
reasoning_trace,
predicted_answer,
ground_truth,
environment_feedback,
bullets_used
prompt = self.reflector_prompt.format(
question=question,
reasoning_trace=reasoning_trace,
predicted_answer=predicted_answer,
ground_truth=ground_truth,
environment_feedback=environment_feedback,
bullets_used=bullets_used
)
else:
prompt = REFLECTOR_PROMPT_NO_GT.format(
question,
reasoning_trace,
predicted_answer,
environment_feedback,
bullets_used
prompt = self.reflector_prompt_no_gt.format(
question=question,
reasoning_trace=reasoning_trace,
predicted_answer=predicted_answer,
environment_feedback=environment_feedback,
bullets_used=bullets_used
)

response, call_info = timed_llm_call(
Expand Down
10 changes: 8 additions & 2 deletions ace/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
from .generator import *
from .reflector import *
from .curator import *
from .config import PromptConfig
from .loader import load_prompts

__all__ = [
# Generator prompts
'GENERATOR_PROMPT',

# Reflector prompts
'REFLECTOR_PROMPT',
'REFLECTOR_PROMPT_NO_GT',

# Curator prompts
'CURATOR_PROMPT',
'CURATOR_PROMPT_NO_GT',

# Prompt configuration
'PromptConfig',
'load_prompts',
]
16 changes: 16 additions & 0 deletions ace/prompts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
PromptConfig dataclass for ACE system.
Holds all agent prompts in a single configuration object.
"""

from dataclasses import dataclass


@dataclass
class PromptConfig:
"""Configuration dataclass holding all agent prompts."""
generator_prompt: str
reflector_prompt: str
reflector_prompt_no_gt: str
curator_prompt: str
curator_prompt_no_gt: str
8 changes: 4 additions & 4 deletions ace/prompts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@


**Playbook:**
{}
{playbook}

**Reflection:**
{}
{reflection}

**Question:**
{}
{question}

**Context:**
{}
{context}

**Answer in this exact JSON format:**
{{
Expand Down
Loading