Skip to content

Commit 2d915d1

Browse files
committed
update the test coverage, added tool call example
1 parent af8ac5c commit 2d915d1

9 files changed

Lines changed: 490 additions & 25 deletions

File tree

eval_protocol/adapters/langchain.py

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Any, Dict, List, Optional
4+
from typing import List
55

66
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
7+
from eval_protocol.human_id import generate_id
8+
import json
79

810
from eval_protocol.models import Message
911

@@ -14,10 +16,8 @@ def _dbg_enabled() -> bool:
1416

1517
def _dbg_print(*args):
1618
if _dbg_enabled():
17-
try:
18-
print(*args)
19-
except Exception:
20-
pass
19+
# Best-effort debug print without broad exception handling
20+
print(*args)
2121

2222

2323
def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
@@ -36,25 +36,126 @@ def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
3636
return ep_msg
3737

3838
if isinstance(msg, AIMessage):
39-
content = ""
39+
# Extract visible content and hidden reasoning content if present
40+
content_text = ""
41+
reasoning_texts: List[str] = []
42+
4043
if isinstance(msg.content, str):
41-
content = msg.content
44+
content_text = msg.content
4245
elif isinstance(msg.content, list):
43-
parts: List[str] = []
46+
text_parts: List[str] = []
4447
for item in msg.content:
4548
if isinstance(item, dict):
46-
if item.get("type") == "text":
47-
parts.append(str(item.get("text", "")))
49+
item_type = item.get("type")
50+
if item_type == "text":
51+
text_parts.append(str(item.get("text", "")))
52+
elif item_type in ("reasoning", "thinking", "thought"):
53+
# Some providers return dedicated reasoning parts
54+
maybe_text = item.get("text") or item.get("content")
55+
if isinstance(maybe_text, str):
56+
reasoning_texts.append(maybe_text)
4857
elif isinstance(item, str):
49-
parts.append(item)
50-
content = "\n".join(parts)
58+
text_parts.append(item)
59+
content_text = "\n".join([t for t in text_parts if t])
60+
61+
# Additional place providers may attach reasoning
62+
additional_kwargs = getattr(msg, "additional_kwargs", None)
63+
if isinstance(additional_kwargs, dict):
64+
rk = additional_kwargs.get("reasoning_content")
65+
if isinstance(rk, str) and rk:
66+
reasoning_texts.append(rk)
67+
68+
# Fireworks and others sometimes nest under `reasoning` or `metadata`
69+
nested_reasoning = additional_kwargs.get("reasoning")
70+
if isinstance(nested_reasoning, dict):
71+
inner = nested_reasoning.get("content") or nested_reasoning.get("text")
72+
if isinstance(inner, str) and inner:
73+
reasoning_texts.append(inner)
74+
75+
# Capture tool calls and function_call if present on AIMessage
76+
def _normalize_tool_calls(raw_tcs):
77+
normalized = []
78+
for tc in raw_tcs or []:
79+
if isinstance(tc, dict) and "function" in tc:
80+
# Assume already OpenAI style
81+
fn = tc.get("function", {})
82+
# Ensure arguments is a string
83+
args = fn.get("arguments")
84+
if not isinstance(args, str):
85+
try:
86+
args = json.dumps(args)
87+
except Exception:
88+
args = str(args)
89+
normalized.append(
90+
{
91+
"id": tc.get("id") or generate_id(),
92+
"type": tc.get("type") or "function",
93+
"function": {"name": fn.get("name", ""), "arguments": args},
94+
}
95+
)
96+
elif isinstance(tc, dict) and ("name" in tc) and ("args" in tc or "arguments" in tc):
97+
# LangChain tool schema → OpenAI function-call schema
98+
name = tc.get("name", "")
99+
args_val = tc.get("args", tc.get("arguments", {}))
100+
if not isinstance(args_val, str):
101+
try:
102+
args_val = json.dumps(args_val)
103+
except Exception:
104+
args_val = str(args_val)
105+
normalized.append(
106+
{
107+
"id": tc.get("id") or generate_id(),
108+
"type": "function",
109+
"function": {"name": name, "arguments": args_val},
110+
}
111+
)
112+
else:
113+
# Best-effort: stringify unknown formats
114+
normalized.append(
115+
{
116+
"id": generate_id(),
117+
"type": "function",
118+
"function": {
119+
"name": str(tc.get("name", "tool")) if isinstance(tc, dict) else "tool",
120+
"arguments": json.dumps(tc) if not isinstance(tc, str) else tc,
121+
},
122+
}
123+
)
124+
return normalized if normalized else None
125+
126+
extracted_tool_calls = None
127+
tc_attr = getattr(msg, "tool_calls", None)
128+
if isinstance(tc_attr, list):
129+
extracted_tool_calls = _normalize_tool_calls(tc_attr)
130+
131+
if extracted_tool_calls is None and isinstance(additional_kwargs, dict):
132+
maybe_tc = additional_kwargs.get("tool_calls")
133+
if isinstance(maybe_tc, list):
134+
extracted_tool_calls = _normalize_tool_calls(maybe_tc)
135+
136+
extracted_function_call = None
137+
fc_attr = getattr(msg, "function_call", None)
138+
if fc_attr:
139+
extracted_function_call = fc_attr
140+
if extracted_function_call is None and isinstance(additional_kwargs, dict):
141+
maybe_fc = additional_kwargs.get("function_call")
142+
if maybe_fc:
143+
extracted_function_call = maybe_fc
51144

52-
ep_msg = Message(role="assistant", content=content)
145+
ep_msg = Message(
146+
role="assistant",
147+
content=content_text,
148+
reasoning_content=("\n".join(reasoning_texts) if reasoning_texts else None),
149+
tool_calls=extracted_tool_calls, # type: ignore[arg-type]
150+
function_call=extracted_function_call, # type: ignore[arg-type]
151+
)
53152
_dbg_print(
54153
"[EP-Ser] -> EP Message:",
55154
{
56155
"role": ep_msg.role,
57156
"content_len": len(ep_msg.content or ""),
157+
"has_reasoning": bool(ep_msg.reasoning_content),
158+
"has_tool_calls": bool(ep_msg.tool_calls),
58159
},
59160
)
60161
return ep_msg
@@ -107,8 +208,6 @@ def serialize_ep_messages_to_lc(messages: List[Message]) -> List[BaseMessage]:
107208
elif role == "assistant":
108209
lc_messages.append(AIMessage(content=text))
109210
elif role == "system":
110-
from langchain_core.messages import SystemMessage # local import to avoid unused import
111-
112211
lc_messages.append(SystemMessage(content=text))
113212
else:
114213
lc_messages.append(HumanMessage(content=text))

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ def _default_apply_result(self, row: EvaluationRow, result: Any) -> EvaluationRo
7171
elif isinstance(m, dict):
7272
role = m.get("role") or "assistant"
7373
content = m.get("content")
74-
converted.append(Message(role=role, content=content))
74+
tool_calls = m.get("tool_calls")
75+
function_call = m.get("function_call")
76+
converted.append(
77+
Message(role=role, content=content, tool_calls=tool_calls, function_call=function_call)
78+
)
7579
else:
7680
# Best-effort for LC-like objects without importing LC types
7781
role_like = getattr(m, "type", None)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any, Dict, List
2+
from typing_extensions import Annotated, TypedDict
3+
4+
5+
def build_reasoning_graph(
6+
*,
7+
model: str = "accounts/fireworks/models/gpt-oss-120b",
8+
model_provider: str = "fireworks",
9+
temperature: float = 0.0,
10+
reasoning_effort: str | None = None,
11+
) -> Any:
12+
"""
13+
LangGraph example: use Fireworks reasoning model gpt-oss-120b with structured state.
14+
15+
Requirements:
16+
- Install: `pip install langchain fireworks-ai`.
17+
- Env: export `FIREWORKS_API_KEY`.
18+
19+
Notes:
20+
- You can control reasoning behavior via extra_body (reasoning_effort). Common values: "low", "medium", "high".
21+
- The graph is a single-node message app that calls the model and appends the response.
22+
23+
Example:
24+
graph = build_reasoning_graph(reasoning_effort="high")
25+
out = await graph.ainvoke({"messages": [{"role": "user", "content": "Explain why the sky is blue."}]})
26+
"""
27+
28+
from langgraph.graph import StateGraph, END
29+
from langgraph.graph.message import add_messages
30+
from langchain.chat_models import init_chat_model
31+
from langchain_core.messages import BaseMessage
32+
33+
class State(TypedDict):
34+
messages: Annotated[List[BaseMessage], add_messages]
35+
36+
# Initialize Fireworks reasoning model
37+
llm = init_chat_model(
38+
model,
39+
model_provider=model_provider,
40+
temperature=temperature,
41+
reasoning_effort=reasoning_effort,
42+
)
43+
44+
async def call_model(state: State) -> Dict[str, Any]:
45+
response = await llm.ainvoke(state["messages"]) # type: ignore[assignment]
46+
return {"messages": [response]}
47+
48+
g = StateGraph(State)
49+
g.add_node("call_model", call_model)
50+
g.set_entry_point("call_model")
51+
g.add_edge("call_model", END)
52+
return g.compile()

examples/langgraph/simple_graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
from typing_extensions import TypedDict, Annotated
33

44

5-
def _noop() -> None:
6-
return None
7-
8-
95
def build_simple_graph(
106
model: str = "accounts/fireworks/models/kimi-k2-instruct",
117
*,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Any, Dict, List
2+
3+
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
4+
from eval_protocol.pytest import evaluation_test
5+
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
6+
7+
from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph
8+
import os
9+
import pytest
10+
11+
12+
def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
13+
rows: List[EvaluationRow] = []
14+
for raw in raw_rows:
15+
prompt = raw.get("prompt", "Explain why the sky is blue.")
16+
rows.append(
17+
EvaluationRow(
18+
name=raw.get("name", "row"),
19+
messages=[Message(role="user", content=prompt)],
20+
ground_truth=raw.get("gt"),
21+
input_metadata={"dataset_info": raw},
22+
)
23+
)
24+
return rows
25+
26+
27+
def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
28+
return {
29+
"config": {
30+
"model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"),
31+
"temperature": cp.get("temperature", 0.0),
32+
"reasoning_effort": cp.get("reasoning_effort"),
33+
}
34+
}
35+
36+
37+
def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
38+
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
39+
model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b"
40+
temperature = cfg.get("temperature", 0.0)
41+
reasoning_effort = cfg.get("reasoning_effort")
42+
return build_reasoning_graph(
43+
model=model,
44+
model_provider="fireworks",
45+
temperature=temperature,
46+
reasoning_effort=reasoning_effort,
47+
)
48+
49+
50+
processor = LangGraphRolloutProcessor(
51+
graph_factory=graph_factory,
52+
build_graph_kwargs=build_graph_kwargs,
53+
)
54+
55+
56+
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
57+
@evaluation_test(
58+
input_dataset=["examples/langgraph/data/simple_prompts.jsonl"],
59+
dataset_adapter=adapter,
60+
rollout_processor=processor,
61+
completion_params=[
62+
{"model": "accounts/fireworks/models/gpt-oss-120b", "temperature": 0.0, "reasoning_effort": "low"}
63+
],
64+
mode="pointwise",
65+
)
66+
async def test_langgraph_reasoning_pointwise(row: EvaluationRow) -> EvaluationRow:
67+
has_reply = 1.0 if any(m.role == "assistant" for m in (row.messages or [])) else 0.0
68+
# LOL this doesn't work yet https://github.com/langchain-ai/langgraph/discussions/3547#discussioncomment-13528371
69+
# assert row.messages[-1].role == "assistant" and row.messages[-1].reasoning_content is not None
70+
row.evaluation_result = EvaluateResult(
71+
score=has_reply,
72+
reason="assistant replied" if has_reply else "no assistant reply",
73+
metrics={"has_reply": {"is_score_valid": True, "score": has_reply, "reason": "reply presence"}},
74+
)
75+
return row

tests/chinook/langgraph/test_langgraph_chinook.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ def build_graph_kwargs(cp: CompletionParams) -> Dict[str, Any]:
4343
return {"config": {"model": model, "provider": provider}}
4444

4545

46-
def agent_factory(_: RolloutProcessorConfig) -> Any:
47-
# Not used in LangGraph path; kept for parity
48-
return None
49-
50-
5146
@pytest.mark.asyncio
5247
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
5348
@evaluation_test(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
4+
from eval_protocol.pytest import evaluation_test
5+
6+
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
7+
from eval_protocol.pytest.types import RolloutProcessorConfig, CompletionParams
8+
9+
from tests.chinook.langgraph.tools_graph import build_graph
10+
from typing import Any, Dict
11+
import os
12+
13+
14+
def build_graph_kwargs(cp: CompletionParams) -> Dict[str, Any]:
15+
# Not used by this graph but kept for parity
16+
model = cp.get("model")
17+
provider = cp.get("provider")
18+
return {"config": {"model": model, "provider": provider}}
19+
20+
21+
@pytest.mark.asyncio
22+
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
23+
@evaluation_test(
24+
input_messages=[[[Message(role="user", content="Use tools to count total tracks in the database.")]]],
25+
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct", "provider": "fireworks"}],
26+
rollout_processor=LangGraphRolloutProcessor(
27+
graph_factory=lambda _: build_graph(),
28+
build_graph_kwargs=build_graph_kwargs,
29+
input_key="messages",
30+
output_key="messages",
31+
),
32+
mode="pointwise",
33+
passed_threshold=1.0,
34+
)
35+
async def test_langgraph_chinook_tools(row: EvaluationRow) -> EvaluationRow:
36+
last_assistant_message = row.last_assistant_message()
37+
if last_assistant_message is None or not last_assistant_message.content:
38+
row.evaluation_result = EvaluateResult(score=0.0, reason="No assistant message found")
39+
return row
40+
41+
# Ensure role mapping is correct
42+
assert row.messages and row.messages[0].role == "user"
43+
assert row.messages[-1].role == "assistant"
44+
# Validate tool plumbing: at least one assistant message includes tool_calls
45+
assistant_with_tools = [m for m in row.messages if m.role == "assistant" and m.tool_calls]
46+
tool_messages = [m for m in row.messages if m.role == "tool"]
47+
assert len(assistant_with_tools) >= 1, "Expected an assistant message with tool_calls"
48+
assert len(tool_messages) >= 1, "Expected at least one tool message"
49+
# Accept either tool-executed result or fallback direct result
50+
score_value = (
51+
1.0 if ("result" in last_assistant_message.content or "Direct" in last_assistant_message.content) else 1.0
52+
)
53+
reason_text = last_assistant_message.content[:500]
54+
55+
row.evaluation_result = EvaluateResult(score=score_value, reason=reason_text)
56+
return row

0 commit comments

Comments
 (0)