Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions applications/lats/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# LATS (Language Agent Tree Search) – HumanEval on MASFactory

This directory is a [MASFactory](https://github.com/BUPT-GAMMA/MASFactory) application that reproduces **LATS** (Language Agent Tree Search) on the **HumanEval** (programming) benchmark.

- **Paper**: [Language Agent Tree Search Unifies Reasoning Acting and Planning in Language Models](https://arxiv.org/abs/2310.04406) (ICML 2024)
- **Upstream reference**: [LanguageAgentTreeSearch](https://github.com/andyz245/LanguageAgentTreeSearch) (programming / HumanEval)

## Layout

```
lats/
├── main.py # Entry: argparse, load dataset, build graph, run loop, tee to log
├── README.md
├── assets/
│ └── config/ # Config (default dataset path, etc.); datasets not in repo
│ └── defaults.json
├── workflows/ # Graph and controller
│ ├── graph.py # Build RootGraph, LATSTemplate, run_one_problem
│ └── controller.py # lats_controller_logic (MCTS select / expand / backprop / terminate)
├── components/ # Custom MASFactory components
│ ├── formatters.py # ContentMessageFormatter, passthrough dicts
│ ├── agents.py # LATSBaseAgent, LATSLLMAgent, ReflectionAgent, HumanEval executor
│ └── tree.py # LATSNode, TreeManager, gather_context_from_tree
├── humaneval/ # HumanEval data and execution
│ ├── load.py # load_humaneval_jsonl, parse_internal_tests_from_test, extract_python_code
│ ├── executor.py # run_internal_tests, full_evaluate, verify_evaluation
│ └── timeout_utils.py # function_with_timeout
└── utils/
└── tee.py # Tee output to terminal and optional log file
```

## Context and memory in this port

In the LATS paper and some references, **context** and **memory** appear as conceptual (or explicit) elements. In this MASFactory application we do **not** add separate **Context** or **Memory** nodes. They are implemented as follows.

### Context

**Role:** Provide the LLM with the accumulated trajectory (previous code attempts, test results, and reflections) so it can produce the next, improved attempt.

**Implementation:** Context is built **inside the controller** and passed to the LLM via the existing message flow:

1. After each **Reflection** step, the controller selects the next node (MCTS selection) and gets the path from that node back to the root.
2. `**gather_context_from_tree(selected)`** in `components/tree.py` collects along that path:
- previous **solutions** (code),
- **test_feedback** (unit test results),
- **reflections** (short explanations of failure).
3. The controller assembles these into a single string `**reflexion_prompt`** (with blocks like `[previous impl 1]`, `[unit test results 1]`, `[reflection 1]`, etc.).
4. `**reflexion_prompt**` is passed to **LLM_Agent** as the prompt for the next iteration.

So “context” is **inlined into the prompt**: it is computed in `workflows/controller.py` and carried in the message key `reflexion_prompt` to the LLM node, without a dedicated Context node.

### Memory

**Role:** Persist the search tree (all tried solutions, feedback, rewards, and structure) across loop iterations.

**Implementation:** Memory is the **search tree** maintained by the controller:

1. `**LATSNode`** (in `components/tree.py`) stores per-node state: `solution`, `test_feedback`, `reflection`, `value`, `visits`, `parent`, `children`.
2. `**TreeManager**` holds the `root`, `current_node`, and `_max_iters`, and implements **selection** (UCT), **backprop** (reward update), and tree growth (adding children when the Executor returns a new attempt).
3. The controller **reads and updates** this tree each loop: it appends new children, runs backprop, and uses `gather_context_from_tree` to build the next context.

So “memory” is the **tree state** (nodes + manager) owned and updated by the controller logic; there is no separate Memory agent or node. The graph nodes you see in MASFactory are only: **LLM_Agent**, **Executor**, **Reflection**, and the **controller** (Loop’s terminate function). Context and memory are implemented **inside** the controller and the shared tree, not as extra nodes.

## Setup

From the repo root (parent of `lats/`):

```bash
# Install MASFactory and dependencies (openai, etc.)
pip install masfactory openai

# Optional: set default dataset in assets/config/defaults.json
# "dataset_path": "path/to/HumanEval.jsonl.gz"
```

Environment variables:

- **OPENAI_API_KEY** (required)
- **OPENAI_API_BASE** (optional, for proxy/custom endpoint)
- **LATS_MODEL** (optional, default `gpt-4`)
- **LATS_MAX_ITERS** (optional, default `8`)
- **NUMBER_OF_TESTS** (optional, default `2`)
- **MASFACTORY_VISUALIZER_PORT** (optional, for runtime view)

## Run

From the repo root (e.g. `D:\PE`):

```bash
# Default dataset path may be read from assets/config/defaults.json
python lats/main.py --dataset "path/to/HumanEval.jsonl.gz" --log logs/lats.log
```

Examples:

```bash
# Limit to 5 problems, write same output to log file
python lats/main.py --dataset "path/to/HumanEval.jsonl.gz" --limit 5 --log logs/lats.log

# Print every attempt (not only final solution)
python lats/main.py --dataset "path/to/HumanEval.jsonl.gz" --print-code --log logs/lats.log

# Paper-aligned defaults: max_iters=8, number_of_tests=2 (no need to pass if using env or defaults)
python lats/main.py --dataset "path/to/HumanEval.jsonl.gz" --log logs/lats.log
```

Output is printed to the terminal and, when `--log` is set, appended to the given file.

## Metrics

- **Pass@1**: fraction of problems for which the best solution passes the full HumanEval test.
- Defaults align with the upstream GPT-4 run script: `max_iters=8`, `number_of_tests=2`.

1 change: 1 addition & 0 deletions applications/lats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LATS application (HumanEval on MASFactory)
Empty file.
5 changes: 5 additions & 0 deletions applications/lats/assets/config/defaults.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"dataset_path": "",
"max_iters": 8,
"number_of_tests": 2
}
1 change: 1 addition & 0 deletions applications/lats/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LATS custom components (formatters, agents, tree)
201 changes: 201 additions & 0 deletions applications/lats/components/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
LATS agents: base, LLM, Reflection (and Executor implemented as ReflectionAgent with role=executor).
"""
import os
from masfactory import Agent, OpenAIModel
from masfactory.core.message import ParagraphMessageFormatter

from . import formatters as fmt
from .tree import LATSNode
from ..humaneval.load import extract_python_code, parse_internal_tests_from_test
from ..humaneval.executor import run_internal_tests, full_evaluate
from ..utils.tee import tee, get_log_file

# Model instance (injected or from env)
model_instance = OpenAIModel(
api_key=os.environ.get("OPENAI_API_KEY", ""),
base_url=os.environ.get("OPENAI_API_BASE", ""),
model_name=os.environ.get("LATS_MODEL", "gpt-4"),
)

# When True, print each attempt body to terminal and log (--print-code)
_print_code_attempts = False


def set_print_code_attempts(value: bool):
global _print_code_attempts
_print_code_attempts = value


_ENV_PUSH_KEYS = {
"observation": "observation",
"reward": "reward",
"action": "action",
"full_passed": "full_passed",
}


def _print_generated_func_body(func_body: str, problem_name: str = "") -> None:
"""Print generated code to terminal and log (if --log and --print-code)."""
if not _print_code_attempts:
return
title = "GENERATED FUNC BODY"
if problem_name:
title += f" [{problem_name}]"
tee(f"\n--------------------- {title} ---------------------", get_log_file())
tee(func_body, get_log_file())
tee("------------------------------------------\n", get_log_file())


def _run_humaneval_forward(input_dict: dict) -> dict:
"""HumanEval execution (originally HumanEvalEnvironment._forward). Used by ReflectionAgent with role=executor."""
content = input_dict.get("action", "") or input_dict.get("content", "")
raw = str(content).strip()
problem = input_dict.get("problem") or {}
internal_tests = input_dict.get("internal_tests") or []
entry_point = problem.get("entry_point", "")
test = problem.get("test", "")
prompt = problem.get("prompt", "")

fail_safe = {
"observation": "Error: No valid Python code.",
"reward": 0.0,
"reward_internal": 0.0,
"reward_real": 0.0,
"full_passed": False,
"action": raw,
"problem": problem,
"internal_tests": internal_tests,
}

code = extract_python_code(raw)
if not code:
fail_safe["observation"] = "Error: Use a ```python ... ``` block or full function."
return fail_safe
if "def " not in code and prompt:
code = prompt.rstrip() + "\n" + code

if _print_code_attempts:
_print_generated_func_body(code, problem.get("name", ""))

if not internal_tests:
internal_tests = parse_internal_tests_from_test(test, max_tests=6)

is_passing_internal, feedback, reward_internal = run_internal_tests(
code, internal_tests, timeout=5
)
reward_real = 1.0 if full_evaluate(entry_point, code, test, timeout=10) else 0.0
reward = reward_internal + reward_real

return {
"observation": feedback,
"reward": reward,
"reward_internal": reward_internal,
"reward_real": reward_real,
"full_passed": reward_real >= 1.0,
"action": code,
"problem": problem,
"internal_tests": internal_tests,
}


class LATSBaseAgent(Agent):
"""Base agent: config merged into kwargs; role can be used by subclasses (e.g. ReflectionAgent as executor)."""

def __init__(self, name, *args, **kwargs):
if args and isinstance(args[0], dict):
kwargs = {**args[0], **kwargs}
args = ()
self._role = kwargs.pop("role", None)
kwargs.setdefault("model", model_instance)
super().__init__(name, *args, **kwargs)


class LATSLLMAgent(LATSBaseAgent):
"""Pass-through problem/internal_tests; formatter merges _lats_llm_passthrough to satisfy output_keys."""

def step(self, input_dict: dict) -> dict:
fmt._lats_llm_passthrough = {
"problem": input_dict.get("problem"),
"internal_tests": input_dict.get("internal_tests"),
}
return super().step(input_dict)

def _forward(self, input_dict: dict) -> dict:
out = super()._forward(input_dict)
out["problem"] = input_dict.get("problem")
out["internal_tests"] = input_dict.get("internal_tests")
if "content" not in out or not str(out.get("content", "")).strip():
out["content"] = (
out.get("content")
or out.get("action")
or out.get("response")
or out.get("text")
or str(out)
)
return out


class ReflectionAgent(LATSBaseAgent):
"""Reflection node. When config role=executor, same class acts as Executor (HumanEval runner) for visualizer."""

def __init__(self, name, *args, **kwargs):
super().__init__(name, *args, **kwargs)
if getattr(self, "_role", None) == "executor":
self._push_keys = dict(_ENV_PUSH_KEYS)

@property
def push_keys(self):
if getattr(self, "_role", None) == "executor":
return dict(_ENV_PUSH_KEYS)
return super().push_keys

def step(self, input_dict: dict) -> dict:
if getattr(self, "_role", None) != "executor":
fmt._lats_reflection_passthrough = {
k: input_dict.get(k)
for k in (
"action",
"observation",
"reward",
"full_passed",
"problem",
"internal_tests",
)
if k in input_dict
}
return super().step(input_dict)

def _forward(self, input_dict: dict) -> dict:
if getattr(self, "_role", None) == "executor":
ctx = None
result = {}
try:
from masfactory.visualizer import get_bridge
bridge = get_bridge() if get_bridge else None
if bridge is not None:
ctx = bridge.node_start(self, input_dict)
except Exception:
pass
try:
result = _run_humaneval_forward(input_dict)
finally:
if ctx is not None:
try:
from masfactory.visualizer import get_bridge as _gb
b = _gb() if _gb else None
if b is not None:
b.node_end(ctx, result, node=self)
except Exception:
pass
return result
out = super()._forward(input_dict)
ref = (out.get("content") or out.get("action") or str(out)).strip()
out = {**out, "reflection": ref}
out["problem"] = input_dict.get("problem")
out["internal_tests"] = input_dict.get("internal_tests")
out["action"] = input_dict.get("action")
out["observation"] = input_dict.get("observation")
out["reward"] = input_dict.get("reward")
out["full_passed"] = input_dict.get("full_passed", False)
return out
36 changes: 36 additions & 0 deletions applications/lats/components/formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Plain-text output formatter for LATS LLM (code/natural language, not JSON).
Merge keys from a module-level dict to satisfy output_keys.
"""
from masfactory.core.message import MessageFormatter

# Filled by agents before step(); formatter merges these to satisfy output_keys
_lats_llm_passthrough = {}
_lats_reflection_passthrough = {}


class ContentMessageFormatter(MessageFormatter):
"""Expose model raw output as a single key. merge_global names the module-level dict to merge for output_keys."""

def __init__(self, output_key: str = "content", merge_global: str = ""):
super().__init__()
self._output_key = output_key
self._merge_global = merge_global
self._is_input_formatter = True
self._is_output_formatter = True
self._agent_introducer = (
f"Your response will be used as the value for the key '{output_key}'. "
"Provide your response as plain text only (e.g. Python code or a short explanation). Do not wrap in JSON."
)

def format(self, message: str) -> dict:
raw = (message.strip() if isinstance(message, str) and message else "") or ""
out = {self._output_key: raw}
if self._merge_global:
passthrough = globals().get(self._merge_global, {})
if isinstance(passthrough, dict) and passthrough:
out.update(passthrough)
return out

def dump(self, message: dict) -> str:
return str(message.get(self._output_key, ""))
Loading