diff --git a/code-interpreter/app/services/executor_kubernetes.py b/code-interpreter/app/services/executor_kubernetes.py index 387d6c9..65aa6d4 100644 --- a/code-interpreter/app/services/executor_kubernetes.py +++ b/code-interpreter/app/services/executor_kubernetes.py @@ -6,8 +6,9 @@ import tarfile import time import uuid -from collections.abc import Sequence -from contextlib import suppress +from collections.abc import Generator, Sequence +from contextlib import contextmanager, suppress +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -37,6 +38,32 @@ logger = logging.getLogger(__name__) +def _parse_exit_code(error: str) -> int | None: + """Parse the exit code from a Kubernetes exec error channel message.""" + try: + error_dict = eval(error) # noqa: S307 + if isinstance(error_dict, dict) and "status" in error_dict: + if error_dict["status"] == "Success": + return 0 + details = error_dict.get("details", {}) + if isinstance(details, dict) and "exitCode" in details: + return int(details["exitCode"]) + return 1 + except Exception as e: + logger.warning(f"Error occurred when parsing exit code: {e}") + return None + return None + + +@dataclass +class _KubeExecContext: + """Holds the live pod and exec stream for the duration of an execution.""" + + pod_name: str + exec_resp: Any # kubernetes WSClient + start: float + + class KubernetesExecutor(BaseExecutor): def __init__(self) -> None: try: @@ -181,6 +208,147 @@ def _create_tar_archive( return tar_buffer.getvalue() + def _wait_for_pod_ready(self, pod_name: str, timeout_sec: int = 30) -> None: + """Wait for a pod to reach Running state.""" + logger.info(f"Waiting for pod {pod_name} to be ready") + for _ in range(timeout_sec * 10): + pod = self.v1.read_namespaced_pod(pod_name, self.namespace) + if pod.status.phase == "Running": + logger.info(f"Pod {pod_name} is running") + return + time.sleep(0.1) + raise RuntimeError(f"Pod {pod_name} did not become ready in {timeout_sec} seconds") + + def _upload_tar_to_pod(self, pod_name: str, tar_archive: bytes) -> None: + """Upload and extract a tar archive into the pod's workspace.""" + logger.info(f"Uploading tar archive ({len(tar_archive)} bytes) to pod {pod_name}") + exec_command = ["tar", "-x", "-C", "/workspace"] + resp = stream.stream( + self.v1.connect_get_namespaced_pod_exec, + pod_name, + self.namespace, + command=exec_command, + stderr=True, + stdin=True, + stdout=True, + tty=False, + _preload_content=False, + ) + + resp.write_stdin(tar_archive) + resp.write_stdin(b"") + + tar_stderr = b"" + tar_exit_code: int | None = None + + while resp.is_open(): + resp.update(timeout=1) + if resp.peek_stdout(): + stdout_chunk: str = resp.read_stdout() + logger.debug(f"Tar stdout: {stdout_chunk}") + if resp.peek_stderr(): + stderr_chunk: str = resp.read_stderr() + tar_stderr += stderr_chunk.encode("utf-8") + logger.warning(f"Tar stderr: {stderr_chunk}") + + error: str = resp.read_channel(ws_client.ERROR_CHANNEL) + if error: + logger.debug(f"Tar command error channel: {error}") + tar_exit_code = _parse_exit_code(error) + break + + resp.close() + logger.info(f"Tar extraction completed with exit code: {tar_exit_code}") + + if tar_exit_code is None: + raise RuntimeError("Tar extraction command did not complete") + if tar_exit_code != 0: + raise RuntimeError( + f"Tar extraction failed with exit code {tar_exit_code}. " + f"stderr: {tar_stderr.decode('utf-8', errors='replace')}" + ) + + def _kill_python_process(self, pod_name: str) -> None: + """Kill the Python process running in the pod.""" + with suppress(Exception): + stream.stream( + self.v1.connect_get_namespaced_pod_exec, + pod_name, + self.namespace, + command=["pkill", "-9", "python"], + stderr=False, + stdin=False, + stdout=False, + tty=False, + ) + + @contextmanager + def _run_in_pod( + self, + *, + code: str, + cpu_time_limit_sec: int | None, + memory_limit_mb: int | None, + files: Sequence[tuple[str, bytes]] | None, + last_line_interactive: bool, + ) -> Generator[_KubeExecContext, None, None]: + """Create a pod, stage files, open Python exec stream, and clean up. + + Yields a _KubeExecContext whose exec_resp is ready for stdin/stdout I/O. + The pod is deleted in the finally block regardless of how the caller exits. + """ + pod_name = f"code-exec-{uuid.uuid4().hex}" + logger.info(f"Starting execution in pod {pod_name}") + logger.debug( + f"Code to execute: {code[:100]}..." if len(code) > 100 else f"Code to execute: {code}" + ) + + pod_manifest = self._create_pod_manifest( + pod_name=pod_name, + memory_limit_mb=memory_limit_mb, + cpu_time_limit_sec=cpu_time_limit_sec, + ) + + try: + logger.info(f"Creating pod {pod_name} in namespace {self.namespace}") + self.v1.create_namespaced_pod( + namespace=self.namespace, + body=pod_manifest, + ) + + self._wait_for_pod_ready(pod_name) + + tar_archive = self._create_tar_archive(code, files, last_line_interactive) + self._upload_tar_to_pod(pod_name, tar_archive) + + logger.info(f"Executing Python code in pod {pod_name}") + start = time.perf_counter() + exec_command = ["python", "/workspace/__main__.py"] + + exec_resp = stream.stream( + self.v1.connect_get_namespaced_pod_exec, + pod_name, + self.namespace, + command=exec_command, + stderr=True, + stdin=True, + stdout=True, + tty=False, + _preload_content=False, + ) + + yield _KubeExecContext( + pod_name=pod_name, + exec_resp=exec_resp, + start=start, + ) + except Exception as e: + logger.error(f"Error during execution in pod {pod_name}: {e}", exc_info=True) + raise + finally: + logger.info(f"Cleaning up pod {pod_name}") + self._cleanup_pod(pod_name) + def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ...]: """Extract files from the pod workspace after execution using tar. @@ -300,125 +468,16 @@ def execute_python( last_line_interactive: If True, the last line will print its value to stdout if it's a bare expression (only the last line is affected). """ - pod_name = f"code-exec-{uuid.uuid4().hex}" - logger.info(f"Starting execution in pod {pod_name}") - logger.debug( - f"Code to execute: {code[:100]}..." if len(code) > 100 else f"Code to execute: {code}" - ) - - pod_manifest = self._create_pod_manifest( - pod_name=pod_name, - memory_limit_mb=memory_limit_mb, + with self._run_in_pod( + code=code, cpu_time_limit_sec=cpu_time_limit_sec, - ) - - try: - logger.info(f"Creating pod {pod_name} in namespace {self.namespace}") - self.v1.create_namespaced_pod( - namespace=self.namespace, - body=pod_manifest, - ) - - logger.info(f"Waiting for pod {pod_name} to be ready") - max_wait = 30 - for _ in range(max_wait * 10): - pod = self.v1.read_namespaced_pod(pod_name, self.namespace) - if pod.status.phase == "Running": - break - time.sleep(0.1) - else: - raise RuntimeError(f"Pod {pod_name} did not become ready in {max_wait} seconds") - - logger.info(f"Pod {pod_name} is running, creating tar archive") - tar_archive = self._create_tar_archive(code, files, last_line_interactive) - logger.debug(f"Tar archive size: {len(tar_archive)} bytes") - - logger.info(f"Executing tar extraction in pod {pod_name}") - exec_command = ["tar", "-x", "-C", "/workspace"] - resp = stream.stream( - self.v1.connect_get_namespaced_pod_exec, - pod_name, - self.namespace, - command=exec_command, - stderr=True, - stdin=True, - stdout=True, - tty=False, - _preload_content=False, - ) - - # Write tar archive to stdin as raw bytes - logger.debug("Writing tar archive to stdin") - resp.write_stdin(tar_archive) - # Signal end of input - logger.debug("Closing stdin") - resp.write_stdin(b"") - - # Wait for tar extraction to complete by reading until the stream closes - logger.debug("Waiting for tar extraction to complete") - tar_stderr = b"" - tar_stdout = b"" - tar_exit_code: int | None = None - - while resp.is_open(): - resp.update(timeout=1) - if resp.peek_stdout(): - chunk = resp.read_stdout().encode("utf-8") - tar_stdout += chunk - logger.debug(f"Tar stdout: {chunk}") - if resp.peek_stderr(): - chunk = resp.read_stderr().encode("utf-8") - tar_stderr += chunk - logger.warning(f"Tar stderr: {chunk}") - - # Check for command completion - error = resp.read_channel(ws_client.ERROR_CHANNEL) - if error: - logger.debug(f"Tar command error channel: {error}") - try: - error_dict = eval(error) # noqa: S307 - if isinstance(error_dict, dict) and "status" in error_dict: - if error_dict["status"] == "Success": - tar_exit_code = 0 - elif "details" in error_dict and "exitCode" in error_dict["details"]: - tar_exit_code = error_dict["details"]["exitCode"] - else: - tar_exit_code = 1 - except Exception as e: # noqa: S110 - logger.error(f"Failed to parse error channel: {e}") - break - - resp.close() - logger.info(f"Tar extraction completed with exit code: {tar_exit_code}") - - # Check if tar extraction failed - if tar_exit_code is None: - raise RuntimeError("Tar extraction command did not complete") - if tar_exit_code != 0: - raise RuntimeError( - f"Tar extraction failed with exit code {tar_exit_code}. " - f"stderr: {tar_stderr.decode('utf-8', errors='replace')}" - ) - - logger.info(f"Executing Python code in pod {pod_name}") - start = time.perf_counter() - exec_command = ["python", "/workspace/__main__.py"] - - exec_resp = stream.stream( - self.v1.connect_get_namespaced_pod_exec, - pod_name, - self.namespace, - command=exec_command, - stderr=True, - stdin=True, - stdout=True, - tty=False, - _preload_content=False, - ) - + memory_limit_mb=memory_limit_mb, + files=files, + last_line_interactive=last_line_interactive, + ) as ctx: if stdin: logger.debug("Writing stdin to Python process") - exec_resp.write_stdin(stdin) + ctx.exec_resp.write_stdin(stdin) stdout_data = b"" stderr_data = b"" @@ -428,72 +487,40 @@ def execute_python( timeout_sec = timeout_ms / 1000.0 end_time = time.time() + timeout_sec - while exec_resp.is_open(): + while ctx.exec_resp.is_open(): remaining = end_time - time.time() if remaining <= 0: timed_out = True break - exec_resp.update(timeout=min(remaining, 1)) + ctx.exec_resp.update(timeout=min(remaining, 1)) - if exec_resp.peek_stdout(): - stdout_data += exec_resp.read_stdout().encode("utf-8") + if ctx.exec_resp.peek_stdout(): + stdout_data += ctx.exec_resp.read_stdout().encode("utf-8") - if exec_resp.peek_stderr(): - stderr_data += exec_resp.read_stderr().encode("utf-8") + if ctx.exec_resp.peek_stderr(): + stderr_data += ctx.exec_resp.read_stderr().encode("utf-8") - error = exec_resp.read_channel(ws_client.ERROR_CHANNEL) + error = ctx.exec_resp.read_channel(ws_client.ERROR_CHANNEL) if error: - try: - error_dict = eval(error) # noqa: S307 - if isinstance(error_dict, dict) and "status" in error_dict: - status = error_dict["status"] - if status == "Success": - exit_code = 0 - elif ( - "reason" in error_dict and error_dict["reason"] == "NonZeroExitCode" - ): - if "details" in error_dict and "exitCode" in error_dict["details"]: - exit_code = error_dict["details"]["exitCode"] - else: - exit_code = 1 - break - except Exception: # noqa: S110 - pass - - exec_resp.close() + exit_code = _parse_exit_code(error) + break + + ctx.exec_resp.close() if timed_out: - exec_command = ["pkill", "-9", "python"] - with suppress(Exception): - stream.stream( - self.v1.connect_get_namespaced_pod_exec, - pod_name, - self.namespace, - command=exec_command, - stderr=False, - stdin=False, - stdout=False, - tty=False, - ) + self._kill_python_process(ctx.pod_name) logger.info( f"Python execution completed. Exit code: {exit_code}, Timed out: {timed_out}" ) logger.debug(f"stdout length: {len(stdout_data)}, stderr length: {len(stderr_data)}") - logger.info(f"Extracting workspace snapshot from pod {pod_name}") - workspace_snapshot = self._extract_workspace_snapshot(pod_name) + logger.info(f"Extracting workspace snapshot from pod {ctx.pod_name}") + workspace_snapshot = self._extract_workspace_snapshot(ctx.pod_name) logger.debug(f"Workspace snapshot has {len(workspace_snapshot)} entries") - except Exception as e: - logger.error(f"Error during execution in pod {pod_name}: {e}", exc_info=True) - raise - finally: - logger.info(f"Cleaning up pod {pod_name}") - self._cleanup_pod(pod_name) - - duration_ms = int((time.perf_counter() - start) * 1000) + duration_ms = int((time.perf_counter() - ctx.start) * 1000) stdout = self.truncate_output(stdout_data, max_output_bytes) stderr = self.truncate_output(stderr_data, max_output_bytes)