Skip to content

Commit db40469

Browse files
committed
fix: treat ArgoCD Code message as reconnect signal, not shell exit code
1 parent f29df94 commit db40469

2 files changed

Lines changed: 400 additions & 38 deletions

File tree

centml/cli/shell.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _resolve_pod(cclient, deployment_id, pod_name=None):
197197
return running_pods[0]
198198

199199

200-
async def _forward_io(ws, screen, stream):
200+
async def _forward_io(ws, screen, stream, shutdown):
201201
"""Bidirectional forwarding between local stdin/stdout and WebSocket.
202202
203203
Output flows through a pyte terminal emulator so that cursor
@@ -208,16 +208,21 @@ async def _forward_io(ws, screen, stream):
208208
ws: WebSocket connection.
209209
screen: pyte.Screen instance sized to the local terminal.
210210
stream: pyte.Stream attached to *screen*.
211+
shutdown: asyncio.Event set by signal handlers to request exit.
211212
212-
Returns the remote exit code.
213+
Returns:
214+
Tuple of (exit_code, should_reconnect). ``should_reconnect`` is True
215+
when the server sent a ``{"Code": ...}`` reconnect signal (ArgoCD
216+
token refresh), False on normal exit or connection close.
213217
"""
214218
loop = asyncio.get_running_loop()
215219
exit_code = 0
220+
should_reconnect = False
216221
stdin_fd = sys.stdin.fileno()
217222
stdin_closed = asyncio.Event()
218223

219224
async def _read_ws():
220-
nonlocal exit_code
225+
nonlocal exit_code, should_reconnect
221226
try:
222227
async for raw_msg in ws:
223228
msg = json.loads(raw_msg)
@@ -229,8 +234,11 @@ async def _read_ws():
229234
elif msg.get("error"):
230235
stream.feed(f"Error: {msg['error']}\r\n")
231236
_render_dirty(screen, sys.stdout.buffer)
237+
# ArgoCD sends {"Code": ...} as a reconnect signal (token
238+
# refresh), not a shell exit code. Mirror ArgoCD UI behavior:
239+
# disconnect and reconnect with a fresh token.
232240
if "Code" in msg:
233-
exit_code = msg["Code"]
241+
should_reconnect = True
234242
return
235243
except websockets.ConnectionClosed:
236244
# Backend proxy may not send a clean close frame when
@@ -249,7 +257,7 @@ def _on_stdin_ready():
249257

250258
loop.add_reader(stdin_fd, _on_stdin_ready)
251259
try:
252-
while not stdin_closed.is_set():
260+
while not stdin_closed.is_set() and not shutdown.is_set():
253261
try:
254262
data = await asyncio.wait_for(read_queue.get(), timeout=0.5)
255263
except asyncio.TimeoutError:
@@ -270,7 +278,15 @@ def _on_stdin_ready():
270278
finally:
271279
loop.remove_reader(stdin_fd)
272280

273-
tasks = [asyncio.create_task(_read_ws()), asyncio.create_task(_read_stdin())]
281+
async def _watch_shutdown():
282+
while not shutdown.is_set():
283+
await asyncio.sleep(0.2)
284+
285+
tasks = [
286+
asyncio.create_task(_read_ws()),
287+
asyncio.create_task(_read_stdin()),
288+
asyncio.create_task(_watch_shutdown()),
289+
]
274290
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
275291
for t in pending:
276292
t.cancel()
@@ -282,13 +298,19 @@ def _on_stdin_ready():
282298
for t in done:
283299
if t.exception() is not None:
284300
raise t.exception()
285-
return exit_code
301+
return (exit_code, should_reconnect)
286302

287303

288-
async def _interactive_session(ws_url, token):
304+
async def _interactive_session(ws_url, get_token_fn):
289305
"""Run an interactive terminal session over WebSocket.
290306
291-
Enters raw mode, forwards I/O bidirectionally, and restores terminal on exit.
307+
Enters raw mode, forwards I/O bidirectionally, and restores terminal on
308+
exit. Reconnects automatically when the server sends a ``{"Code": ...}``
309+
token-refresh signal (matching ArgoCD UI behavior).
310+
311+
Args:
312+
ws_url: WebSocket URL for the terminal endpoint.
313+
get_token_fn: Callable that returns a fresh bearer token string.
292314
"""
293315
fd = sys.stdin.fileno()
294316
old_settings = termios.tcgetattr(fd)
@@ -303,32 +325,51 @@ async def _interactive_session(ws_url, token):
303325
sys.stdout.buffer.write(b"\033[?1049h\033[2J\033[H")
304326
sys.stdout.buffer.flush()
305327

306-
headers = {"Authorization": f"Bearer {token}"}
307-
async with websockets.connect(
308-
ws_url, additional_headers=headers, close_timeout=2
309-
) as ws:
310-
await ws.send(
311-
json.dumps({"operation": "resize", "rows": rows, "cols": cols})
312-
)
328+
loop = asyncio.get_running_loop()
313329

314-
loop = asyncio.get_running_loop()
330+
shutdown = asyncio.Event()
331+
loop.add_signal_handler(signal.SIGTERM, shutdown.set)
332+
loop.add_signal_handler(signal.SIGHUP, shutdown.set)
315333

316-
def _send_resize():
317-
c, r = shutil.get_terminal_size()
318-
screen.resize(r, c)
319-
screen.dirty.update(range(r))
334+
# _ws_ref holds the current websocket so SIGWINCH can reach it.
335+
_ws_ref = [None]
336+
337+
def _send_resize():
338+
c, r = shutil.get_terminal_size()
339+
screen.resize(r, c)
340+
screen.dirty.update(range(r))
341+
if _ws_ref[0] is not None:
320342
asyncio.ensure_future(
321-
ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))
343+
_ws_ref[0].send(
344+
json.dumps({"operation": "resize", "rows": r, "cols": c})
345+
)
322346
)
323347

324-
loop.add_signal_handler(signal.SIGWINCH, _send_resize)
348+
loop.add_signal_handler(signal.SIGWINCH, _send_resize)
325349

326-
try:
327-
exit_code = await _forward_io(ws, screen, stream)
328-
finally:
329-
loop.remove_signal_handler(signal.SIGWINCH)
330-
331-
return exit_code
350+
try:
351+
while True:
352+
token = get_token_fn()
353+
headers = {"Authorization": f"Bearer {token}"}
354+
async with websockets.connect(
355+
ws_url, additional_headers=headers, close_timeout=2
356+
) as ws:
357+
_ws_ref[0] = ws
358+
await ws.send(
359+
json.dumps(
360+
{"operation": "resize", "rows": rows, "cols": cols}
361+
)
362+
)
363+
exit_code, should_reconnect = await _forward_io(
364+
ws, screen, stream, shutdown
365+
)
366+
_ws_ref[0] = None
367+
if not should_reconnect or shutdown.is_set():
368+
return exit_code
369+
finally:
370+
loop.remove_signal_handler(signal.SIGWINCH)
371+
loop.remove_signal_handler(signal.SIGTERM)
372+
loop.remove_signal_handler(signal.SIGHUP)
332373
finally:
333374
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
334375
# Leave alternate screen buffer, restore cursor and attributes.
@@ -422,9 +463,10 @@ async def _exec_session(ws_url, token, command):
422463
elif msg.get("error"):
423464
sys.stderr.write(f"Error: {msg['error']}\n")
424465
return 1
425-
if is_done or "Code" in msg:
426-
if "Code" in msg:
427-
exit_code = msg["Code"]
466+
if "Code" in msg and not is_done:
467+
sys.stderr.write("Connection interrupted, please retry.\n")
468+
return 1
469+
if is_done:
428470
break
429471
except websockets.ConnectionClosed:
430472
pass
@@ -454,8 +496,7 @@ def shell(deployment_id, pod, shell_type):
454496
ws_url = _build_ws_url(
455497
settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type
456498
)
457-
token = auth.get_centml_token()
458-
exit_code = asyncio.run(_interactive_session(ws_url, token))
499+
exit_code = asyncio.run(_interactive_session(ws_url, auth.get_centml_token))
459500
sys.exit(exit_code)
460501

461502

0 commit comments

Comments
 (0)