-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsession.py
More file actions
216 lines (160 loc) · 6.56 KB
/
session.py
File metadata and controls
216 lines (160 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Session Management - Handle conversation lifecycle and sandbox persistence."""
import asyncio
from types import TracebackType
import structlog
from backend.src.config.core import CoreConfig
from .mcp_registry import MCPRegistry
from .sandbox import PTCSandbox
logger = structlog.get_logger(__name__)
class Session:
"""Represents a conversation session with a persistent sandbox."""
def __init__(self, conversation_id: str, config: CoreConfig) -> None:
"""Initialize session.
Args:
conversation_id: Unique conversation identifier
config: Application configuration
"""
self.conversation_id = conversation_id
self.config = config
self.sandbox: PTCSandbox | None = None
self.mcp_registry: MCPRegistry | None = None
self._initialized = False
logger.info("Created session", conversation_id=conversation_id)
async def initialize(self, sandbox_id: str | None = None) -> None:
"""Initialize the session (connect MCP servers and setup sandbox).
Args:
sandbox_id: Optional existing sandbox ID to reconnect to instead of creating new
"""
if self._initialized:
logger.warning("Session already initialized", conversation_id=self.conversation_id)
return
logger.info(
"Initializing session",
conversation_id=self.conversation_id,
reconnecting=sandbox_id is not None,
)
# Initialize MCP registry
self.mcp_registry = MCPRegistry(self.config)
if sandbox_id:
# RECONNECT MODE: Run MCP connections and sandbox start in parallel
# Create sandbox instance without mcp_registry
self.sandbox = await asyncio.to_thread(
PTCSandbox, self.config, None
)
# Run both operations in parallel
await asyncio.gather(
self.mcp_registry.connect_all(),
self.sandbox.reconnect(sandbox_id),
)
self.sandbox.mcp_registry = self.mcp_registry
logger.info(
"Reconnected to existing sandbox",
conversation_id=self.conversation_id,
sandbox_id=sandbox_id,
)
else:
# NEW SANDBOX MODE: Run workspace setup and MCP connect concurrently
self.sandbox = await asyncio.to_thread(
PTCSandbox, self.config, None
)
snapshot_name, _ = await asyncio.gather(
self.sandbox.setup_sandbox_workspace(),
self.mcp_registry.connect_all(),
)
self.sandbox.mcp_registry = self.mcp_registry
await self.sandbox.setup_tools_and_mcp(snapshot_name)
self._initialized = True
logger.info("Session initialized", conversation_id=self.conversation_id)
async def get_sandbox(self) -> PTCSandbox | None:
"""Get the sandbox for this session (initializes if needed).
Returns:
PTCSandbox instance
"""
if not self._initialized:
await self.initialize()
return self.sandbox
async def cleanup(self) -> None:
"""Clean up session resources."""
logger.info("Cleaning up session", conversation_id=self.conversation_id)
if self.sandbox:
await self.sandbox.cleanup()
self.sandbox = None
if self.mcp_registry:
await self.mcp_registry.disconnect_all()
self.mcp_registry = None
self._initialized = False
logger.info("Session cleaned up", conversation_id=self.conversation_id)
async def stop(self) -> None:
"""Stop sandbox for session persistence.
This is used when persist_session is enabled - stops the sandbox
so it can be restarted quickly on the next session, rather than
deleting it entirely.
"""
logger.info("Stopping session for persistence", conversation_id=self.conversation_id)
if self.sandbox:
await self.sandbox.stop_sandbox()
if self.mcp_registry:
await self.mcp_registry.disconnect_all()
logger.info("Session stopped", conversation_id=self.conversation_id)
async def __aenter__(self) -> "Session":
"""Async context manager entry."""
await self.initialize()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.cleanup()
class SessionManager:
"""Manages multiple conversation sessions."""
_sessions: dict[str, Session] = {}
@classmethod
def get_session(cls, conversation_id: str, config: CoreConfig) -> Session:
"""Get or create a session for a conversation.
Args:
conversation_id: Unique conversation identifier
config: Application configuration
Returns:
Session instance
"""
if conversation_id not in cls._sessions:
logger.info("Creating new session", conversation_id=conversation_id)
cls._sessions[conversation_id] = Session(conversation_id, config)
else:
logger.debug("Returning existing session", conversation_id=conversation_id)
return cls._sessions[conversation_id]
@classmethod
async def cleanup_session(cls, conversation_id: str) -> None:
"""Clean up a specific session.
Args:
conversation_id: Conversation identifier
"""
if conversation_id in cls._sessions:
session = cls._sessions[conversation_id]
await session.cleanup()
del cls._sessions[conversation_id]
logger.info("Session removed", conversation_id=conversation_id)
@classmethod
async def cleanup_all(cls) -> None:
"""Clean up all active sessions."""
logger.info("Cleaning up all sessions", count=len(cls._sessions))
for conversation_id in list(cls._sessions.keys()):
await cls.cleanup_session(conversation_id)
logger.info("All sessions cleaned up")
@classmethod
def get_active_sessions(cls) -> list[str]:
"""Get list of active session IDs.
Returns:
List of conversation IDs
"""
return list(cls._sessions.keys())
@classmethod
def get_session_count(cls) -> int:
"""Get count of active sessions.
Returns:
Number of active sessions
"""
return len(cls._sessions)