-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtypes.py
More file actions
103 lines (84 loc) · 3.03 KB
/
types.py
File metadata and controls
103 lines (84 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from mcp.client.session import ClientSession
class TerminationReason(str, Enum):
"""Enum representing different reasons why a trajectory terminated.
MAX_STEPS: Trajectory ends because we hit the step limit
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
USER_STOP: Trajectory ends because the simulated user signals to stop
ERROR: Trajectory ends because of an error
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
"""
MAX_STEPS = "max_steps"
CONTROL_PLANE_SIGNAL = "control_plane_signal"
USER_STOP = "user_stop"
ERROR = "error"
STOP = "stop"
LENGTH = "length"
TOOL_CALLS = "tool_calls"
@classmethod
def from_str(cls, value: str) -> "TerminationReason":
if value == "stop":
return cls.STOP
elif value == "length":
return cls.LENGTH
elif value == "max_steps":
return cls.MAX_STEPS
elif value == "control_plane_signal":
return cls.CONTROL_PLANE_SIGNAL
elif value == "user_stop":
return cls.USER_STOP
elif value == "error":
return cls.ERROR
elif value == "tool_calls":
return cls.TOOL_CALLS
else:
raise ValueError(f"Invalid termination reason: {value}")
@dataclass
class MCPToolCall:
"""Represents a tool call to be executed via MCP."""
tool_name: str
arguments: Dict[str, Any]
tool_call_id: Optional[str] = None
@dataclass
class DatasetRow:
"""Represents a row from the dataset JSONL."""
id: str
seed: int
system_prompt: str
user_prompt_template: str
environment_context: Dict[str, Any]
user_simulation: Optional[Dict[str, Any]] = None
@dataclass
class MCPSession:
"""Represents a single MCP session with an environment."""
session_id: str
base_url: str
seed: Optional[int]
model_id: str
dataset_row: Optional[DatasetRow] = None
terminated: bool = False
last_observation: Any = None
# Persistent MCP connection components
_exit_stack: Optional[AsyncExitStack] = None
_mcp_session: Optional[ClientSession] = None
@dataclass
class Trajectory:
"""Represents a complete rollout trajectory."""
session: MCPSession
observations: List[Any]
actions: List[str]
rewards: List[float]
terminated: bool
total_reward: float
steps: int
duration: float
control_plane_steps: List[Dict[str, Any]]
control_plane_summary: Dict[str, Any]
termination_reason: str
conversation_history: List[Dict[str, Any]]
usage: Dict[str, int] = field(default_factory=dict)