Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/clabe/xml_rpc/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def wait_for_result(self, job_id: str, timeout: Optional[float] = None) -> JobRe
if result.status == JobStatus.DONE:
return result

# In monitor mode, reset timer if job is still running
if self.settings.monitor and self.is_running(job_id):
# In monitor mode, a RUNNING response proves the job is alive
if self.settings.monitor and result.status == JobStatus.RUNNING:
logger.debug("Job %s is still running; resetting timeout timer", job_id)
start_time = time.time()

Expand Down
7 changes: 3 additions & 4 deletions src/clabe/xml_rpc/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ async def _wait_for_result_async(self, job_id: str):
raise TimeoutError(f"Job {job_id} did not complete within {timeout} seconds")

await asyncio.sleep(poll_interval)
# this is synchronous but should be fast
result = self.client.get_result(job_id)
result = await asyncio.to_thread(self.client.get_result, job_id)
if result.status == JobStatus.DONE:
return result

# In monitor mode, reset timer if job is still running
if self.monitor and self.client.is_running(job_id):
# In monitor mode, a RUNNING response proves the job is alive
if self.monitor and result.status == JobStatus.RUNNING:
logger.debug("Job %s is still running; resetting timeout timer", job_id)
start_time = time.time()
83 changes: 65 additions & 18 deletions src/clabe/xml_rpc/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import secrets
import socket
import subprocess
import threading
import time
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from functools import wraps
from pathlib import Path
from socketserver import ThreadingMixIn
from typing import ClassVar, Optional
from xmlrpc.server import SimpleXMLRPCServer

Expand Down Expand Up @@ -60,18 +63,29 @@ def get_local_ip():
return s.getsockname()[0]


_FINISHED_JOB_TTL = 300 # seconds before unclaimed finished jobs are pruned


class _ThreadedXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
"""Thread-per-connection XML-RPC server."""

daemon_threads = True


class XmlRpcServer:
"""XML-RPC server for remote command execution and file transfer."""

def __init__(self, settings: XmlRpcServerSettings):
self.settings = settings
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
self.jobs: dict[str, Future] = {}
self._jobs_lock = threading.Lock()
self._job_done_at: dict[str, float] = {}

# Ensure file transfer directory exists
os.makedirs(settings.file_transfer_dir, exist_ok=True)

server = SimpleXMLRPCServer((str(settings.address), settings.port), allow_none=True)
server = _ThreadedXMLRPCServer((str(settings.address), settings.port), allow_none=True)
server.register_function(self.require_auth(self.submit_command), "run")
server.register_function(self.require_auth(self.get_result), "result")
server.register_function(self.require_auth(self.list_jobs), "jobs")
Expand All @@ -88,6 +102,7 @@ def __init__(self, settings: XmlRpcServerSettings):
logger.info("File transfer directory: %s", settings.file_transfer_dir.resolve())
logger.info("Use the token above to authenticate requests")

threading.Thread(target=self._cleanup_finished_jobs, daemon=True).start()
self.server = server

def authenticate(self, token: str) -> bool:
Expand All @@ -112,6 +127,32 @@ def wrapper(token, *args, **kwargs):

return wrapper

def _cleanup_finished_jobs(self):
"""Background thread: prune unclaimed finished jobs older than _FINISHED_JOB_TTL."""
while True:
time.sleep(60)
self._prune_stale_jobs()

def _prune_stale_jobs(self) -> list[str]:
"""Prune unclaimed finished jobs older than _FINISHED_JOB_TTL. Returns pruned job IDs."""
cutoff = time.time() - _FINISHED_JOB_TTL
now = time.time()
with self._jobs_lock:
for jid, fut in self.jobs.items():
if fut.done() and jid not in self._job_done_at:
self._job_done_at[jid] = now
stale = [
jid
for jid, t in self._job_done_at.items()
if t < cutoff and jid in self.jobs and self.jobs[jid].done()
]
for jid in stale:
del self.jobs[jid]
del self._job_done_at[jid]
if stale:
logger.info("Pruned %s unclaimed finished jobs", len(stale))
return stale

@staticmethod
def _normalize_returncode(returncode: int | None) -> int | None:
"""Convert a process exit code to a signed 32-bit integer.
Expand Down Expand Up @@ -151,7 +192,8 @@ def submit_command(self, cmd_args):
job_id = str(uuid.uuid4())
logger.debug("Submitting job %s to executor", job_id)
future = self.executor.submit(self._run_command_sync, cmd_args)
self.jobs[job_id] = future
with self._jobs_lock:
self.jobs[job_id] = future
logger.info("Submitted job %s: %s", job_id, cmd_args)
logger.debug("Active jobs: %s", len(self.jobs))
response = JobSubmissionResponse(success=True, job_id=job_id)
Expand All @@ -160,37 +202,42 @@ def submit_command(self, cmd_args):
def get_result(self, job_id):
"""Fetch the result of a finished command"""
logger.debug("Fetching result for job %s", job_id)
if job_id not in self.jobs:
logger.debug("Job %s not found", job_id)
return JobStatusResponse(
success=False, error="Invalid job_id", job_id=job_id, status=JobStatus.ERROR
).model_dump(mode="json")
with self._jobs_lock:
if job_id not in self.jobs:
logger.debug("Job %s not found", job_id)
return JobStatusResponse(
success=False, error="Invalid job_id", job_id=job_id, status=JobStatus.ERROR
).model_dump(mode="json")

future = self.jobs[job_id]
if not future.done():
logger.debug("Job %s still running", job_id)
return JobStatusResponse(success=True, job_id=job_id, status=JobStatus.RUNNING).model_dump(mode="json")

future = self.jobs[job_id]
if not future.done():
logger.debug("Job %s still running", job_id)
return JobStatusResponse(success=True, job_id=job_id, status=JobStatus.RUNNING).model_dump(mode="json")
del self.jobs[job_id]
self._job_done_at.pop(job_id, None)

result = future.result()
logger.debug("Job %s completed, cleaning up", job_id)
del self.jobs[job_id] # cleanup finished job
return JobStatusResponse(success=True, job_id=job_id, status=JobStatus.DONE, result=result).model_dump(
mode="json"
)

def is_running(self, job_id):
"""Check if a job is still running"""
if job_id not in self.jobs:
logger.debug("Job %s not found when checking status", job_id)
return False
is_running = not self.jobs[job_id].done()
with self._jobs_lock:
if job_id not in self.jobs:
logger.debug("Job %s not found when checking status", job_id)
return False
is_running = not self.jobs[job_id].done()
logger.debug("Job %s running status: %s", job_id, is_running)
return is_running

def list_jobs(self):
"""List all running jobs"""
running_jobs = [jid for jid, fut in self.jobs.items() if not fut.done()]
finished_jobs = [jid for jid, fut in self.jobs.items() if fut.done()]
with self._jobs_lock:
running_jobs = [job_id for job_id, fut in self.jobs.items() if not fut.done()]
finished_jobs = [job_id for job_id, fut in self.jobs.items() if fut.done()]
logger.debug("Listing jobs: %s running, %s finished", len(running_jobs), len(finished_jobs))
return JobListResponse(success=True, running=running_jobs, finished=finished_jobs).model_dump(mode="json")

Expand Down
6 changes: 2 additions & 4 deletions tests/apps/test_rpc_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,9 @@ async def test_run_async_no_job_id(self, executor):
@pytest.mark.asyncio
async def test_run_async_with_monitor_enabled(self, mock_client):
"""Test asynchronous command execution with monitor mode enabled (default)."""
# Create executor with monitor=True (default)
executor = XmlRpcExecutor(mock_client, monitor=True)

submission_response = JobSubmissionResponse(success=True, job_id="monitor-job")
# Simulate job running initially, then completing
running_result = JobResult(
job_id="monitor-job", status=JobStatus.RUNNING, stdout=None, stderr=None, returncode=None, error=None
)
Expand All @@ -128,9 +126,7 @@ async def test_run_async_with_monitor_enabled(self, mock_client):
)

executor.client.submit_command.return_value = submission_response
# First call returns running, second call returns done
executor.client.get_result.side_effect = [running_result, done_result]
executor.client.is_running.return_value = True

cmd = Command(cmd="long_async_command", output_parser=identity_parser)
result = await executor.run_async(cmd)
Expand All @@ -140,3 +136,5 @@ async def test_run_async_with_monitor_enabled(self, mock_client):
assert result.exit_code == 0

executor.client.submit_command.assert_called_once_with("long_async_command")
executor.client.get_result.assert_called_with("monitor-job")
executor.client.is_running.assert_not_called()
4 changes: 2 additions & 2 deletions tests/xml-rpc/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_wait_for_result_with_monitor_default(self, rpc_client: XmlRpcClient):
submission = rpc_client.submit_command([sys.executable, "-c", "import time; time.sleep(0.5); print('done')"])

# The job takes 0.5s but timeout is 0.3s - with default monitor=True,
# the timeout resets each time is_running returns True
# the timeout resets each poll cycle that returns RUNNING status
result = rpc_client.wait_for_result(submission.job_id, timeout=0.3)

assert result.status == JobStatus.DONE
Expand All @@ -305,7 +305,7 @@ def test_wait_for_result_monitor_resets_timeout(self, rpc_client: XmlRpcClient):
[sys.executable, "-c", "import time; time.sleep(0.8); print('finished')"]
)

# With default monitor=True, the timeout resets each poll cycle while job is running
# With default monitor=True, the timeout resets each poll cycle while job is RUNNING
result = rpc_client.wait_for_result(submission.job_id, timeout=0.4)

assert result.status == JobStatus.DONE
Expand Down
90 changes: 90 additions & 0 deletions tests/xml-rpc/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,96 @@ def test_delete_all_files(self, rpc_client):
assert filename not in filenames


class TestJobCleanup:
"""Test background cleanup of unclaimed finished jobs."""

def test_stale_finished_job_is_pruned(self, rpc_server):
"""Jobs completed longer ago than TTL are removed by _prune_stale_jobs."""
import sys

from clabe.xml_rpc._server import _FINISHED_JOB_TTL

server, settings = rpc_server
client = ServerProxy(f"http://{settings.address}:{settings.port}")
token = settings.token.get_secret_value()

result = client.run(token, [sys.executable, "-c", "print('cleanup test')"])
job_id = result["job_id"]

# Wait for the job to finish
for _ in range(20):
time.sleep(0.1)
with server._jobs_lock:
if job_id in server.jobs and server.jobs[job_id].done():
break
else:
pytest.fail("Job did not complete in time")

# Backdate _job_done_at to simulate TTL expiry
with server._jobs_lock:
server._job_done_at[job_id] = time.time() - (_FINISHED_JOB_TTL + 10)

pruned = server._prune_stale_jobs()
assert job_id in pruned

with server._jobs_lock:
assert job_id not in server.jobs

# get_result should now return error
get_result = client.result(token, job_id)
assert get_result["success"] is False
assert get_result["status"] == "error"

def test_recently_finished_job_is_not_pruned(self, rpc_server):
"""Jobs completed within TTL are left in place by _prune_stale_jobs."""
import sys

server, settings = rpc_server
client = ServerProxy(f"http://{settings.address}:{settings.port}")
token = settings.token.get_secret_value()

result = client.run(token, [sys.executable, "-c", "print('keep me')"])
job_id = result["job_id"]

for _ in range(20):
time.sleep(0.1)
with server._jobs_lock:
if job_id in server.jobs and server.jobs[job_id].done():
break
else:
pytest.fail("Job did not complete in time")

pruned = server._prune_stale_jobs()
assert job_id not in pruned

# Result should still be retrievable
get_result = client.result(token, job_id)
assert get_result["status"] == "done"

def test_running_job_is_never_pruned(self, rpc_server):
"""Jobs still running are never touched by _prune_stale_jobs."""
import sys

from clabe.xml_rpc._server import _FINISHED_JOB_TTL

server, settings = rpc_server
client = ServerProxy(f"http://{settings.address}:{settings.port}")
token = settings.token.get_secret_value()

result = client.run(token, [sys.executable, "-c", "import time; time.sleep(3)"])
job_id = result["job_id"]

# Force an old _job_done_at entry even though the job is still running
with server._jobs_lock:
server._job_done_at[job_id] = time.time() - (_FINISHED_JOB_TTL + 10)

pruned = server._prune_stale_jobs()
assert job_id not in pruned

with server._jobs_lock:
assert job_id in server.jobs


class TestHelperFunctions:
"""Test helper functions."""

Expand Down
Loading
Loading