diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index af5d30f..4c6ff79 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -23,9 +23,9 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 15 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: 3.11 - name: Install Poetry @@ -43,9 +43,9 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 15 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: 3.11 - name: Install Poetry @@ -78,12 +78,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install Poetry diff --git a/examples/job_logging.yaml b/examples/job_logging.yaml new file mode 100644 index 0000000..765f7e9 --- /dev/null +++ b/examples/job_logging.yaml @@ -0,0 +1,11 @@ +!Job + ident: test_job + env: + test_key_a: hey + test_key_b: you + command: python3 + args: + - examples/log_test.py + resources: + - !Cores [1] + - !Memory [30, MB] diff --git a/examples/log_test.py b/examples/log_test.py new file mode 100644 index 0000000..19a5b46 --- /dev/null +++ b/examples/log_test.py @@ -0,0 +1,25 @@ +import logging +from time import sleep + +from gator.adapters.logging import GatorHandler + +if __name__ == "__main__": + logging.basicConfig( + level="NOTSET", + format="%(message)s", + datefmt="[%X]", + handlers=[GatorHandler()], + ) + log = logging.getLogger("log_test") + log.debug("A debug message!") + sleep(1) + log.getChild("a.b.c.d").info("Hello world!") + sleep(1) + for idx in range(30): + log.getChild("b").info(f"Pass {idx}") + sleep(0.2) + log.getChild("c").warning("A warning message!") + sleep(1) + log.error("An error message!") + sleep(1) + log.info("DONE") diff --git a/gator/__main__.py b/gator/__main__.py index ad66867..94f27f3 100644 --- a/gator/__main__.py +++ b/gator/__main__.py @@ -25,7 +25,8 @@ from . import launch, launch_progress from .common.logger import MessageLimits -from .scheduler import LocalScheduler +from .common.ws_wrapper import WebsocketWrapper +from .scheduler import LocalScheduler, SlurmScheduler from .specs import Spec from .specs.common import SpecError @@ -54,7 +55,7 @@ @click.option( "--scheduler", default="local", - type=click.Choice(("local",), case_sensitive=False), + type=click.Choice(("local", "slurm"), case_sensitive=False), help="Select the scheduler to use for launching jobs", show_default=True, ) @@ -101,7 +102,10 @@ def main( ) tracking.mkdir(parents=True, exist_ok=True) # Select the right scheduler - sched = {"local": LocalScheduler}.get(scheduler.lower()) + sched = { + "local": LocalScheduler, + "slurm": SlurmScheduler, + }.get(scheduler.lower()) # Break apart scheduler options as '=' sched_opts = {} for arg in sched_arg: @@ -115,6 +119,7 @@ def main( key, val = arg.split("=") sched_opts[key.strip()] = val.strip() # Launch with optional progress tracking + exit_code = 0 try: summary = asyncio.run( (launch_progress if progress else launch).launch( @@ -136,6 +141,8 @@ def main( ), ) ) + if not summary.passed: + exit_code = 1 except SpecError as e: console_file = (Path(tracking) / "error.log").open("a") if parent else None con = Console(file=console_file) @@ -146,18 +153,19 @@ def main( if hasattr(e.obj, "jobs"): e.obj.jobs = ["..."] con.log(Spec.dump([e.obj])) - sys.exit(1) + exit_code = 1 except BaseException: console_file = (Path(tracking) / "error.log").open("a") if parent else None con = Console(file=console_file) con.log(traceback.format_exc()) if verbose: con.print_exception(show_locals=True) - sys.exit(1) - - if summary.passed: - sys.exit(0) - sys.exit(1) + exit_code = 1 + finally: + # Stop active websocket wrappers (may be left over if an exception occurs) + asyncio.run(WebsocketWrapper.stop_all()) + # Forward an exception code + sys.exit(exit_code) if __name__ == "__main__": diff --git a/gator/adapters/__init__.py b/gator/adapters/__init__.py new file mode 100644 index 0000000..0bbd577 --- /dev/null +++ b/gator/adapters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/gator/adapters/logging.py b/gator/adapters/logging.py new file mode 100644 index 0000000..ad28d74 --- /dev/null +++ b/gator/adapters/logging.py @@ -0,0 +1,45 @@ +# Copyright 2023, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime + +from .parent import Parent + + +class GatorHandler(logging.Handler): + """ + Custom handler for Python logging to redirect messages via Gator's logging + API such that severities are correctly recorded. + + :param ws_address: Optional websocket address for the parent tier, otherwise + it will be read from the GATOR_PARENT environment variable + """ + + def __init__(self, ws_address: str | None = None): + super().__init__() + self._parent = Parent(ws_address) + self._do_log("INFO", "Log fowarding via GatorHandler", "root") + + def _do_log(self, severity: str, message: str, hierarchy: str): + self._parent.post( + "log", + timestamp=datetime.now().timestamp(), + hierarchy=hierarchy, + severity=severity, + message=message, + ) + + def emit(self, record: logging.LogRecord): + self._do_log(record.levelname, record.getMessage(), record.name) diff --git a/gator/adapters/parent.py b/gator/adapters/parent.py new file mode 100644 index 0000000..60f7eda --- /dev/null +++ b/gator/adapters/parent.py @@ -0,0 +1,115 @@ +# Copyright 2023, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import json +import logging +import os +import sys +from queue import SimpleQueue +from threading import Event, Thread + +from websockets.exceptions import ConnectionClosed +from websockets.sync.client import connect + + +class TeardownMarker: + pass + + +class Parent: + """ + Thread based wrapper around the Gator websocket interface + + :param ws_address: Optional websocket address for the parent tier, otherwise + it will be read from the GATOR_PARENT environment variable + """ + + def __init__(self, ws_address: str | None = None): + self._ws_address = ws_address or Parent.get_parent_address() + assert self._ws_address, ( + "Websocket address for parent process is not set and could not be " + "determined from the environment" + ) + self._rx_q = SimpleQueue[dict[str, str]] + self._tx_q = SimpleQueue[TeardownMarker | dict[str, str]]() + self._teardown_evt = Event() + self._ws_thread = Thread(target=self._manage_ws, daemon=True) + self._ws_thread.start() + atexit.register(self._teardown_at_exit) + + @staticmethod + def get_parent_address() -> str | None: + return os.environ.get("GATOR_PARENT", None) + + def post(self, action, **payload): + self._tx_q.put( + { + "action": action, + "posted": True, + "payload": payload, + } + ) + + def receive(self) -> dict[str, str]: + return self._rx_q.get() + + def _manage_ws(self): + idx = 0 + + def _receiver(ws, rx_q: SimpleQueue[dict[str, str]]): + try: + for packet in ws: + rx_q.put(json.loads(packet)) + except ConnectionClosed: + pass + + rx_thread = None + try: + with connect( + f"ws://{self._ws_address}", + logger=(logger := logging.getLogger("gator_ws")), + ) as ws: + # Disable log propagation to avoid recursive forwarding + logger.propagate = False + # Setup a receiving thread + rx_thread = Thread(target=_receiver, daemon=True, args=(ws, self._rx_q)) + rx_thread.start() + # Transmit until a teardown is inserted + while True: + packet = self._tx_q.get() + # Check if the process wants us to teardown + if isinstance(packet, TeardownMarker): + break + # Otherwise log the message + ws.send(json.dumps(packet)) + idx += 1 + except ConnectionClosed: + pass + # Wait for the receiver thread to end + rx_thread.join() + # Set the teardown event to signify a clean exit + self._teardown_evt.set() + + def _teardown_at_exit(self): + self._teardown() + + def _teardown(self): + self._tx_q.put(TeardownMarker()) + if not self._teardown_evt.wait(timeout=10): + print( + "Gator timed out waiting for the websocket thread to teardown, " + "some packets may have been missed!", + file=sys.stderr, + ) diff --git a/gator/adapters/pstats.py b/gator/adapters/pstats.py new file mode 100644 index 0000000..30d20f3 --- /dev/null +++ b/gator/adapters/pstats.py @@ -0,0 +1,42 @@ +# Copyright 2023, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from .parent import Parent + + +class ProcessStats: + """ + Custom process statistics gathering for operations that Gator cannot normally + track, for example launched Docker containers. + + :param ws_address: Optional websocket address for the parent tier, otherwise + it will be read from the GATOR_PARENT environment variable + """ + + def __init__(self, ws_address: str | None = None): + super().__init__() + self._parent = Parent(ws_address) + + def record(self, cpu_perc: float, memory: float): + self._parent.post( + "extra_usage", + timestamp=datetime.now().timestamp(), + cpu_perc=cpu_perc, + memory=memory, + ) + + def teardown(self): + self._parent._teardown() diff --git a/gator/babysitter.py b/gator/babysitter.py index 5bcea5b..b9638eb 100644 --- a/gator/babysitter.py +++ b/gator/babysitter.py @@ -23,7 +23,7 @@ import sys from pathlib import Path -with Path(f"log_{socket.gethostname()}_{os.getpid()}.log").open("w", encoding="utf-8") as fh: +with Path(f"log_{socket.getfqdn()}_{os.getpid()}.log").open("w", encoding="utf-8") as fh: fh.write(f"Starting process with arguments: {sys.argv[1:]}\n") fh.flush() proc = subprocess.Popen( diff --git a/gator/common/db_client.py b/gator/common/db_client.py index d8d0fab..7c223d5 100644 --- a/gator/common/db_client.py +++ b/gator/common/db_client.py @@ -171,6 +171,7 @@ async def get_messages(self, after: int = 0, limit: int = 10) -> ApiMessagesResp uid=cast(int, x.db_uid), severity=int(x.severity), message=x.message, + hierarchy=x.hierarchy, timestamp=int(x.timestamp.timestamp()), ) for x in msgs diff --git a/gator/common/layer.py b/gator/common/layer.py index f4621e1..613ca39 100644 --- a/gator/common/layer.py +++ b/gator/common/layer.py @@ -79,6 +79,10 @@ class MetricResponseError(TypedDict): MetricResponse = Union[MetricResponseSuccess, MetricResponseError] +class UsageResponse(TypedDict): + result: Literal["success"] + + class BaseDatabase(Database): async def push_metric(self, metric: Metric): pass @@ -308,7 +312,7 @@ async def setup(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: self.path = result["path"] # Otherwise, register with the parent else: - self.__hub_uid = await HubAPI.register( + self.__hub_uid, hub_url = await HubAPI.register( ident=self.ident, url=server_address, layer=type(self).__name__.lower(), @@ -317,7 +321,7 @@ async def setup(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: if self.__hub_uid is not None: self.uidx = self.root = int(self.__hub_uid) self.path = [] - await self.logger.info(f"Registered with hub with ID {self.__hub_uid}") + await self.logger.info(f"Registered with hub with ID {self.__hub_uid}: {hub_url}") else: self.uidx = self.root = 0 self.path = [] @@ -443,6 +447,7 @@ async def get_messages( uid=x.db_uid, severity=int(x.severity), message=x.message, + hierarchy=x.hierarchy, timestamp=int(x.timestamp.timestamp()), ) for x in msgs diff --git a/gator/common/logger.py b/gator/common/logger.py index bb2274b..3464a13 100644 --- a/gator/common/logger.py +++ b/gator/common/logger.py @@ -41,6 +41,8 @@ class MessageLimits: class Logger: + HIER_WIDTH: typing.ClassVar[int] = 23 + HIER_BALANCE: typing.ClassVar[int] = (HIER_WIDTH - 3) // 2 FORMAT: typing.ClassVar[Dict[LogSeverity, Tuple[str, str]]] = { LogSeverity.DEBUG: ("[bold cyan]", "[/bold cyan]"), LogSeverity.INFO: ("[bold]", "[/bold]"), @@ -119,6 +121,7 @@ async def log( self, severity: LogSeverity, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, @@ -129,6 +132,7 @@ async def log( :param severity: Severity level of the logged message :param message: Text of the message being logged + :param hierarchy: Optional logging hierarchy :param forward: Whether to forward the message onto the parent layer, if this is not provided then it will default to the logger's forward parameter (set during construction) @@ -152,79 +156,102 @@ async def log( if forward and self.ws_cli.linked and severity >= self.verbosity: await self.ws_cli.log( timestamp=int(timestamp.timestamp()), + hierarchy=hierarchy, severity=severity.name, message=message, posted=True, ) + # Generate a truncated version of the hierarchy + short_hier = f"{hierarchy:{Logger.HIER_WIDTH}s}" + if len(short_hier) > Logger.HIER_WIDTH: + short_hier = "...".join( + short_hier[: Logger.HIER_BALANCE], + short_hier[-Logger.HIER_BALANCE :], + ) # If a console is attached, log locally if self.__console and severity >= self.verbosity: prefix, suffix = self.FORMAT.get(severity, ("[bold]", "[/bold]")) - self.__console.log(f"{prefix}[{severity.name:<7s}]{suffix} {escape(message)}") + self.__console.log( + f"{prefix}{severity.name:<7s}{suffix} {escape(short_hier)} " f"{escape(message)}" + ) # Normally don't capture forwarded messages if not forwarded or self.capture_all: # Record to the database if self.__database is not None: await self.__database.push_logentry( - LogEntry(severity=severity, message=message, timestamp=timestamp) + LogEntry( + hierarchy=hierarchy, + severity=severity, + message=message, + timestamp=timestamp, + ) ) # Tee to file if configured - if not forwarded and self.__log_fh is not None: + if self.__log_fh is not None: date = datetime.now().strftime(r"%H:%M:%S") - self.__log_fh.write(f"[{date}] [{severity.name:<7s}] {message}\n") + self.__log_fh.write(f"[{date}] {severity.name:<7s} {short_hier} {message}\n") async def debug( self, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, ) -> None: - await self.log(LogSeverity.DEBUG, message, forward, timestamp, forwarded) + await self.log(LogSeverity.DEBUG, message, hierarchy, forward, timestamp, forwarded) async def info( self, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, ) -> None: - await self.log(LogSeverity.INFO, message, forward, timestamp, forwarded) + await self.log(LogSeverity.INFO, message, hierarchy, forward, timestamp, forwarded) async def warning( self, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, ) -> None: - await self.log(LogSeverity.WARNING, message, forward, timestamp, forwarded) + await self.log(LogSeverity.WARNING, message, hierarchy, forward, timestamp, forwarded) async def error( self, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, ) -> None: - await self.log(LogSeverity.ERROR, message, forward, timestamp, forwarded) + await self.log(LogSeverity.ERROR, message, hierarchy, forward, timestamp, forwarded) async def critical( self, message: str, + hierarchy: str = "root", forward: Optional[bool] = None, timestamp: Optional[datetime] = None, forwarded: bool = False, ) -> None: - await self.log(LogSeverity.CRITICAL, message, forward, timestamp, forwarded) + await self.log(LogSeverity.CRITICAL, message, hierarchy, forward, timestamp, forwarded) @click.command() +@click.option("-H", "--hierarchy", type=str, default="root", help="Log hierarchy") @click.option("-s", "--severity", type=str, default="INFO", help="Severity level") @click.argument("message") -def logger(severity, message): +def logger(hierarchy, severity, message): asyncio.run( Logger(verbosity=LogSeverity.DEBUG).log( - severity=getattr(LogSeverity, severity.upper()), message=message + severity=getattr(LogSeverity, severity.upper()), + message=message, + hierarchy=hierarchy, ) ) diff --git a/gator/common/types.py b/gator/common/types.py index 3d84db0..4459d05 100644 --- a/gator/common/types.py +++ b/gator/common/types.py @@ -73,6 +73,7 @@ class LogEntry(Base): severity: LogSeverity = LogSeverity.INFO message: str = "" + hierarchy: str = "" timestamp: datetime = dataclasses.field(default_factory=datetime.now) diff --git a/gator/common/ws_client.py b/gator/common/ws_client.py index c2f42f1..34a04f7 100644 --- a/gator/common/ws_client.py +++ b/gator/common/ws_client.py @@ -49,10 +49,13 @@ def _teardown() -> None: # For chaining return self + async def stop_ws(self): + await self.stop() + await super().stop_ws() + async def stop(self) -> None: if self.ws is not None: await self.ws.close() - await self.stop_monitor() self.ws = None async def __aenter__(self): diff --git a/gator/common/ws_server.py b/gator/common/ws_server.py index 2e30748..e51880b 100644 --- a/gator/common/ws_server.py +++ b/gator/common/ws_server.py @@ -58,7 +58,7 @@ async def get_address(self) -> str: # Attempt to get the hostname (fully qualified) hostname = socket.getfqdn() if not hostname: - raise Exception("Blank hostname returned from socket.gethostname()") + raise Exception("Blank hostname returned from socket.getfqdn()") # Get all known IP addresses for this host (note this can raise an # exception if the host is unresolvable) _, _, ipaddrs = socket.gethostbyname_ex(hostname) @@ -83,6 +83,7 @@ async def handle_log( timestamp: Optional[str] = None, severity: str = "INFO", message: str = "N/A", + hierarchy: str = "root", **_kwargs, ) -> None: """ @@ -97,7 +98,13 @@ async def handle_log( timestamp = datetime.fromtimestamp(int(timestamp)) severity = getattr(LogSeverity, severity.strip().upper(), LogSeverity.INFO) # Log the message - await self.logger.log(severity, message.strip(), timestamp=timestamp, forwarded=True) + await self.logger.log( + severity, + message.strip(), + hierarchy, + timestamp=timestamp, + forwarded=True, + ) # ========================================================================== # Server diff --git a/gator/common/ws_wrapper.py b/gator/common/ws_wrapper.py index f1253fd..f60a510 100644 --- a/gator/common/ws_wrapper.py +++ b/gator/common/ws_wrapper.py @@ -17,9 +17,10 @@ import dataclasses import itertools import json -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union import websockets +import websockets.exceptions from .ws_router import WebsocketRouter @@ -36,6 +37,8 @@ class WebsocketWrapperError(Exception): class WebsocketWrapper(WebsocketRouter): + WS_WRAPPERS: ClassVar[list["WebsocketWrapper"]] = [] + def __init__(self, ws: Optional[websockets.WebSocketClientProtocol] = None) -> None: super().__init__() self.ws = ws @@ -44,6 +47,7 @@ def __init__(self, ws: Optional[websockets.WebSocketClientProtocol] = None) -> N self.__next_request_id = itertools.count() self.__request_lock = asyncio.Lock() self.__pending = {} + WebsocketWrapper.WS_WRAPPERS.append(self) @property def linked(self): @@ -63,6 +67,14 @@ async def stop_monitor(self) -> None: await self.__monitor_task self.__monitor_task = None + async def stop_ws(self) -> None: + await self.stop_monitor() + + @classmethod + async def stop_all(cls) -> None: + for ws in cls.WS_WRAPPERS: + await ws.stop_ws() + async def send(self, data: Union[str, dict]) -> None: await self.ws.send(data if isinstance(data, str) else json.dumps(data)) @@ -90,6 +102,8 @@ async def monitor(self) -> None: raise WebsocketWrapperError(f"Failed to decode message: {raw}") from e except asyncio.CancelledError: pass + except websockets.exceptions.ConnectionClosedError: + print("WEBSOCKET CLOSED UNEXPECTEDLY") def __getattr__(self, key: str) -> Any: # Attempt to resolve diff --git a/gator/hub/api.py b/gator/hub/api.py index 8187894..233c262 100644 --- a/gator/hub/api.py +++ b/gator/hub/api.py @@ -24,9 +24,10 @@ class _HubAPI(HTTPAPI): COMPLETE = "job/{job_id}/complete" HEARTBEAT = "job/{job_id}/heartbeat" - async def register(self, ident: str, url: str, layer: str, owner: str) -> str: + async def register(self, ident: str, url: str, layer: str, owner: str) -> tuple[str, str]: response = await self.post(self.REGISTER, ident=ident, url=url, layer=layer, owner=owner) - return response.get("uid", None) + uid = response.get("uid", None) + return uid, f"http://{self.url}/?path={uid}" async def complete(self, uid: str, db_file: str, result: JobResult) -> None: await self.post(self.COMPLETE.format(job_id=uid), db_file=db_file, result=int(result)) diff --git a/gator/launch.py b/gator/launch.py index 8501494..76fdbc4 100644 --- a/gator/launch.py +++ b/gator/launch.py @@ -14,7 +14,10 @@ import asyncio import math +import os +import platform import signal +import socket from functools import partial from pathlib import Path from typing import Dict, Optional, Type, Union @@ -70,6 +73,12 @@ async def launch( forward=all_msg, ) logger.set_console(console) + # Log the machine's details + uname = platform.uname() + await logger.info( + f"Running on {socket.getfqdn()} as PID {os.getpid()} under {Path.cwd()} " + f"(architecture: {uname.processor}, OS: {uname.system} {uname.release})" + ) # Work out where the spec is coming from # - From server (nested call) if spec is None and client.linked and ident: diff --git a/gator/scheduler/__init__.py b/gator/scheduler/__init__.py index 78d01b1..03ef370 100644 --- a/gator/scheduler/__init__.py +++ b/gator/scheduler/__init__.py @@ -14,5 +14,6 @@ from .common import SchedulerError from .local import LocalScheduler +from .slurm import SlurmScheduler -assert all((LocalScheduler, SchedulerError)) +assert all((LocalScheduler, SchedulerError, SlurmScheduler)) diff --git a/gator/scheduler/common.py b/gator/scheduler/common.py index 144161d..c3d378c 100644 --- a/gator/scheduler/common.py +++ b/gator/scheduler/common.py @@ -15,6 +15,7 @@ import abc import functools import itertools +from pathlib import Path from typing import Any, Dict, List, Optional, Type from ..common.child import Child @@ -30,6 +31,7 @@ class BaseScheduler: def __init__( self, + tracking: Path, parent: str, interval: int = 5, quiet: bool = True, @@ -37,6 +39,7 @@ def __init__( options: Optional[Dict[str, str]] = None, limits: Optional[MessageLimits] = None, ) -> None: + self.tracking = tracking self.parent = parent self.interval = interval self.quiet = quiet @@ -88,7 +91,7 @@ def base_command(self) -> List[str]: ] return cmd - def create_command(self, child: Child, options: Optional[Dict[str, str]] = None) -> str: + def create_command(self, child: Child, options: Optional[Dict[str, str]] = None) -> List[str]: """ Build a command for launching a job on the compute infrastructure using details from the child object. @@ -100,7 +103,7 @@ def create_command(self, child: Child, options: Optional[Dict[str, str]] = None) full_opts = self.options.copy() full_opts.update(options or {}) - return " ".join( + return list( itertools.chain( self.base_command, ["--id", child.ident, "--tracking", child.tracking.as_posix()], diff --git a/gator/scheduler/local.py b/gator/scheduler/local.py index e22595f..7dfee01 100644 --- a/gator/scheduler/local.py +++ b/gator/scheduler/local.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from pathlib import Path from typing import Dict, List, Optional import websockets.exceptions @@ -28,6 +29,7 @@ class LocalScheduler(BaseScheduler): def __init__( self, + tracking: Path, parent: str, interval: int = 5, quiet: bool = True, @@ -35,7 +37,7 @@ def __init__( options: Optional[Dict[str, str]] = None, limits: Optional[MessageLimits] = None, ) -> None: - super().__init__(parent, interval, quiet, logger, options, limits) + super().__init__(tracking, parent, interval, quiet, logger, options, limits) self.launch_task = None self.update_lock = asyncio.Lock() self.launched_processes = {} @@ -102,7 +104,7 @@ async def _inner(): # Launch jobs self.slots[task.ident] = granted self.launched_processes[task.ident] = await asyncio.create_subprocess_shell( - self.create_command(task, {"concurrency": granted}), + " ".join(self.create_command(task, {"concurrency": granted})), stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT, diff --git a/gator/scheduler/slurm.py b/gator/scheduler/slurm.py new file mode 100644 index 0000000..42d0017 --- /dev/null +++ b/gator/scheduler/slurm.py @@ -0,0 +1,251 @@ +# Copyright 2023, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import getpass +import os +import subprocess +from datetime import datetime, timedelta +from enum import IntEnum +from pathlib import Path +from typing import ClassVar, Dict, List, Optional + +import aiohttp + +from ..common.child import Child +from ..common.logger import Logger, MessageLimits +from ..specs.jobs import Job +from .common import BaseScheduler, SchedulerError + + +class SlurmErrorCodes(IntEnum): + """Enumerates common Slurm error codes""" + + INVALID_TRES_SPEC: int = 2115 + """Invalid Trackable RESource (TRES) specification""" + SLURMDB_CONN_FAIL: int = 7000 + """Unable to connect to database (slurmdb connection failure)""" + + +class SlurmScheduler(BaseScheduler): + """Executes tasks via a Slurm cluster""" + + RETRY_ON_ERROR: ClassVar[set[int]] = { + SlurmErrorCodes.SLURMDB_CONN_FAIL, + } + + def __init__( + self, + tracking: Path, + parent: str, + interval: int = 5, + quiet: bool = True, + logger: Optional[Logger] = None, + options: Optional[Dict[str, str]] = None, + limits: Optional[MessageLimits] = None, + ) -> None: + super().__init__(tracking, parent, interval, quiet, logger, options, limits) + self._username: str = getpass.getuser() + self._api_root: str = self.get_option("api_root", "http://127.0.0.1:6820/") + self._api_version: str | None = None + self._token: str | None = None + self._expiry: datetime | None = None + self._interval: int = int(self.get_option("jwt_interval", 60)) + self._queue: str = self.get_option("queue", "generalq") + self._job_ids: list[int] = [] + self._stdout_dirx: Path = self.tracking / "slurm" + self._stdout_dirx.mkdir(exist_ok=True, parents=True) + + @property + def expired(self) -> bool: + return (self._expiry is None) or (self._expiry >= datetime.now()) + + @property + def token(self) -> str: + if self.expired: + result = subprocess.run( + [ + "scontrol", + "token", + f"lifespan={int(self._interval*1.1)}", + f"username={self._username}", + ], + capture_output=True, + timeout=5, + check=True, + ) + stdout = result.stdout.decode("utf-8").strip() + if not stdout.startswith("SLURM_JWT="): + raise SchedulerError(f"Failed to extract Slurm JWT from STDOUT: {stdout}") + self._token = stdout.split("SLURM_JWT=")[1].strip() + self._expiry = datetime.now() + timedelta(seconds=self._interval) + return self._token + + def clear_token(self): + self._token = None + self._expiry = None + + def get_session(self) -> aiohttp.ClientSession: + return aiohttp.ClientSession( + base_url=self._api_root + (f"/slurm/{self._api_version}/" if self._api_version else ""), + headers={ + "X-SLURM-USER-NAME": self._username, + "X-SLURM-USER-TOKEN": self.token, + }, + ) + + async def _retry_post( + self, + route: str, + payload: dict[str, str], + retries: int = 3, + backoff: float = 1.0, + ) -> dict[str, str]: + for idx in range(retries): + async with self.get_session() as session: + async with session.post(route, json=payload) as resp: + data = await resp.json() + err_nums = [x.get("error_number", None) for x in data.get("errors", [])] + # Check for a known error + if set(err_nums).intersection(self.RETRY_ON_ERROR): + # Log what happened + await self.logger.debug( + f"Slurm API error on attempt {idx}/{retries}, retrying " + f"in {backoff} seconds (with forced token refresh)" + ) + # Force a token expiry + self.clear_token() + # Wait a little + await asyncio.sleep(backoff) + # Retry + continue + # If no known error, return the data + return data + else: + raise SchedulerError(f"Post request to {route} failed {retries} times: {data}") + + async def _retry_get( + self, + route: str, + retries: int = 3, + backoff: float = 1.0, + ) -> dict[str, str]: + for idx in range(retries): + async with self.get_session() as session: + async with session.get(route) as resp: + data = await resp.json() + err_nums = [x.get("error_number", None) for x in data.get("errors", [])] + # Check for a known error + if set(err_nums).intersection(self.RETRY_ON_ERROR): + # Log what happened + await self.logger.debug( + f"Slurm API error on attempt {idx}/{retries}, retrying " + f"in {backoff} seconds (with forced token refresh)" + ) + # Force a token expiry + self.clear_token() + # Wait a little + await asyncio.sleep(backoff) + # Retry + continue + # If no known error, return the data + return data + else: + raise SchedulerError(f"Post request to {route} failed {retries} times: {data}") + + async def launch(self, tasks: List[Child]) -> None: + # Figure out the active API version of Slurm REST interface + if not self._api_version: + async with self.get_session() as session: + async with session.get("openapi/v3") as resp: + data = await resp.json() + slurm_roots = [x for x in data["paths"] if x.startswith("/slurm/")] + self._api_version = Path(slurm_roots[0]).parts[2] + await self.logger.info(f"Slurm scheduler using REST API version {self._api_version}") + + # Ping to check connection/authentication to Slurm + data = await self._retry_get("ping") + ping = data["pings"][0]["latency"] + await self.logger.debug(f"Slurm REST latency {ping}") + + # For each task... + for task in tasks: + # Figure out the requested resources + tres_per_job = [] + if isinstance(task.spec, Job): + tres_per_job += [ + f"cpu={int(task.spec.requested_cores)}", + f"mem={int(task.spec.requested_memory)}", + *[f"license/{k}={v}" for k, v in task.spec.requested_licenses.items()], + *[f"gres/{k}={v}" for k, v in task.spec.requested_features.items()], + ] + + # Submit the payload to Slurm + stdout = self._stdout_dirx / f"{task.ident}.log" + data = await self._retry_post( + "job/submit", + { + "job": { + "name": task.ident, + "script": "\n".join( + [ + "#!/bin/bash", + " ".join(self.create_command(task)), + "", + ] + ), + "tres_per_job": ",".join(tres_per_job), + "partition": self._queue, + "current_working_directory": Path.cwd().as_posix(), + "user_id": str(os.getuid()), + "group_id": str(os.getgid()), + "environment": [f"{k}={v}" for k, v in os.environ.items()], + "standard_output": stdout.as_posix(), + "standard_error": stdout.as_posix(), + } + }, + ) + + # Check for an invalid request + err_codes = { + x.get("error_number", 0) + for x in data.get("errors", []) + if (x.get("error_number", 0) != 0) + } + if err_codes.intersection({SlurmErrorCodes.INVALID_TRES_SPEC}): + raise SchedulerError( + f"Gator generated an unsupported resource request to Slurm " + f"({data['errors'][0]['error']}): {tres_per_job}" + ) + elif len(err_codes) > 0: + raise SchedulerError( + "Gator received unexpected error(s) when submitting a job " + "to Slurm: " + + ", ".join(f"{x['error']} ({x['error_number']})" for x in data["errors"]) + ) + + # Track the job ID + self._job_ids.append(job_id := data["result"]["job_id"]) + await self.logger.debug(f"Scheduled Slurm job {job_id}") + + async def wait_for_all(self): + for job_id in self._job_ids: + while True: + states = [] + data = await self._retry_get(f"job/{job_id}") + for job in data["jobs"]: + states += job["job_state"] + if len([x for x in states if x.lower() in ("pending", "running")]) == 0: + break + await asyncio.sleep(5) diff --git a/gator/specs/__init__.py b/gator/specs/__init__.py index 2462a9e..29c94e7 100644 --- a/gator/specs/__init__.py +++ b/gator/specs/__init__.py @@ -18,9 +18,9 @@ from .common import Dumper, Loader, SpecBase from .jobs import Job, JobArray, JobGroup -from .resource import Cores, License, Memory +from .resource import Cores, Feature, License, Memory -assert all((Job, JobArray, JobGroup, Cores, License, Memory)) +assert all((Job, JobArray, JobGroup, Cores, License, Memory, Feature)) class Spec: diff --git a/gator/specs/jobs.py b/gator/specs/jobs.py index 0d0be88..de2709f 100644 --- a/gator/specs/jobs.py +++ b/gator/specs/jobs.py @@ -18,7 +18,7 @@ from typing import Dict, List, Optional, Union from .common import SpecBase, SpecError -from .resource import Cores, License, Memory +from .resource import Cores, Feature, License, Memory @dataclass @@ -26,11 +26,12 @@ class Job(SpecBase): yaml_tag = "!Job" ident: Optional[str] = None + extend_env: bool = True env: Optional[Dict[str, str]] = field(default_factory=dict) cwd: Optional[str] = None command: Optional[str] = None args: Optional[List[str]] = field(default_factory=list) - resources: Optional[List[Union[Cores, License, Memory]]] = field(default_factory=list) + resources: Optional[List[Union[Cores, License, Memory, Feature]]] = field(default_factory=list) on_done: Optional[List[str]] = field(default_factory=list) on_fail: Optional[List[str]] = field(default_factory=list) on_pass: Optional[List[str]] = field(default_factory=list) @@ -61,9 +62,16 @@ def requested_licenses(self) -> Dict[str, int]: """Return a summary of all of the licenses requested""" return {x.name: x.count for x in self.resources if isinstance(x, License)} + @functools.cached_property + def requested_features(self) -> Dict[str, int]: + """Return a summary of all of the features requested""" + return {x.name: x.count for x in self.resources if isinstance(x, Feature)} + def check(self) -> None: if self.ident is not None and not isinstance(self.ident, str): raise SpecError(self, "ident", "ident must be a string") + if not isinstance(self.extend_env, bool): + raise SpecError(self, "extend_env", "Environment extend must be boolean") if not isinstance(self.env, dict): raise SpecError(self, "env", "Environment must be a dictionary") if set(map(type, self.env.keys())).difference({str}): @@ -80,12 +88,14 @@ def check(self) -> None: raise SpecError(self, "args", "Arguments must be strings or integers") if not isinstance(self.resources, list): raise SpecError(self, "resources", "Resources must be a list") - if set(map(type, self.resources)).difference({Cores, Memory, License}): + if set(map(type, self.resources)).difference({Cores, Memory, License, Feature}): raise SpecError( self, "resources", - "Resources must be !Cores, !Memory, or !License", + "Resources must be !Cores, !Memory, !License, or !Feature", ) + for resource in self.resources: + resource.check() type_count = Counter(type(x) for x in self.resources) if type_count[Cores] > 1: raise SpecError(self, "resources", "More than one !Cores resource request") @@ -100,6 +110,15 @@ def check(self) -> None: "resources", f"More than one entry for license '{name}'", ) + # NOTE: Any number of features may be specified + feat_name_count = Counter(x.name for x in self.resources if isinstance(x, Feature)) + for name, count in feat_name_count.items(): + if count > 1: + raise SpecError( + self, + "resources", + f"More than one entry for feature '{name}'", + ) for condition in ("on_done", "on_fail", "on_pass"): value = getattr(self, condition) if not isinstance(value, list): @@ -115,6 +134,7 @@ class JobArray(SpecBase): ident: Optional[str] = None repeats: Optional[int] = 1 jobs: Optional[List[Union[Job, "JobArray", "JobGroup"]]] = field(default_factory=list) + extend_env: bool = True env: Optional[Dict[str, str]] = field(default_factory=dict) cwd: Optional[str] = None on_fail: Optional[List[str]] = field(default_factory=list) @@ -152,6 +172,8 @@ def check(self) -> None: "jobs", f"Duplicated keys for jobs: {', '.join(duplicated)}", ) + if not isinstance(self.extend_env, bool): + raise SpecError(self, "extend_env", "Environment extend must be boolean") if not isinstance(self.env, dict): raise SpecError(self, "env", "Environment must be a dictionary") if set(map(type, self.env.keys())).difference({str}): @@ -177,6 +199,7 @@ class JobGroup(SpecBase): ident: Optional[str] = None jobs: Optional[List[Union[Job, "JobArray", "JobGroup"]]] = field(default_factory=list) + extend_env: bool = True env: Optional[Dict[str, str]] = field(default_factory=dict) cwd: Optional[str] = None on_fail: Optional[List[str]] = field(default_factory=list) @@ -212,6 +235,8 @@ def check(self) -> None: "jobs", f"Duplicated keys for jobs: {', '.join(duplicated)}", ) + if not isinstance(self.extend_env, bool): + raise SpecError(self, "extend_env", "Environment extend must be boolean") if not isinstance(self.env, dict): raise SpecError(self, "env", "Environment must be a dictionary") if set(map(type, self.env.keys())).difference({str}): diff --git a/gator/specs/resource.py b/gator/specs/resource.py index 72d4c92..7ea1cb2 100644 --- a/gator/specs/resource.py +++ b/gator/specs/resource.py @@ -12,16 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +import platform +from dataclasses import dataclass, field from .common import SpecBase, SpecError +ARCH_ALIASES = { + # x86 + "x86": "x86_64", + "x86_64": "x86_64", + "amd64": "x86_64", + # Arm + "arm": "aarch64", + "arm64": "aarch64", + "aarch64": "aarch64", + # RISC-V + "riscv": "riscv64", + "riscv64": "riscv64", +} + @dataclass class Cores(SpecBase): + """ + Specifies the count and optionally the architecture of the CPU cores to + execute on + """ + yaml_tag = "!Cores" count: int + arch: str | None = field(default_factory=lambda: ARCH_ALIASES[platform.uname().machine]) def check(self) -> None: if not isinstance(self.count, int): @@ -30,13 +51,24 @@ def check(self) -> None: # NOTE: Zero is valid - if a job doesn't consume much resource then # it may be desirable to run it without blocking others raise SpecError(self, "count", "Count must be zero or greater") + if self.arch is not None: + if not isinstance(self.arch, str): + raise SpecError(self, "arch", "Architecture must be a string") + self.arch = self.arch.lower().strip() + if self.arch not in ARCH_ALIASES: + raise SpecError( + self, "arch", f"Architecture must be one of {', '.join(ARCH_ALIASES)}" + ) + self.arch = ARCH_ALIASES[self.arch] @dataclass class Memory(SpecBase): + """Specifies the quantity of memory (RAM) required for the job to execute""" + yaml_tag = "!Memory" - size: int + size: int | float unit: str = "MB" @property @@ -45,8 +77,8 @@ def in_megabytes(self) -> int: return self.size * mapping def check(self) -> None: - if not isinstance(self.size, int): - raise SpecError(self, "size", "Size must be an integer") + if not isinstance(self.size, (int, float)): + raise SpecError(self, "size", "Size must be an int or float") if self.size < 0: # NOTE: Zero is valid - if a job doesn't consume much resource then # it may be desirable to run it without blocking others @@ -59,6 +91,11 @@ def check(self) -> None: @dataclass class License(SpecBase): + """ + Specifies a floating license required for a job to run, if the license is + node-locked then a !Feature should be used instead. + """ + yaml_tag = "!License" name: str @@ -73,3 +110,26 @@ def check(self) -> None: # NOTE: Zero is valid - if a job doesn't consume much resource then # it may be desirable to run it without blocking others raise SpecError(self, "count", "Count must be zero or greater") + + +@dataclass +class Feature(SpecBase): + """ + Specifies a feature of a machine required for a job to run, this can be used + for describing node-locked licenses or accelerators. + """ + + yaml_tag = "!Feature" + + name: str + count: int = 1 + + def check(self) -> None: + if not isinstance(self.name, str): + raise SpecError(self, "name", "Name must be a string") + if not isinstance(self.count, int): + raise SpecError(self, "count", "Count must be an integer") + if self.count < 0: + # NOTE: Zero is valid - if a job doesn't consume much resource then + # it may be desirable to run it without blocking others + raise SpecError(self, "count", "Count must be zero or greater") diff --git a/gator/tier.py b/gator/tier.py index e8cd038..b8144b9 100644 --- a/gator/tier.py +++ b/gator/tier.py @@ -13,8 +13,9 @@ # limitations under the License. import asyncio +import os from collections import defaultdict -from copy import copy, deepcopy +from copy import deepcopy from datetime import datetime from typing import Dict, List, Optional, Type @@ -85,6 +86,7 @@ async def launch(self, *args, **kwargs) -> Summary: # Create a scheduler try: self.scheduler = self.sched_cls( + tracking=self.tracking, parent=await self.server.get_address(), quiet=not self.all_msg, logger=self.logger, @@ -470,9 +472,12 @@ async def __launch(self): Logger.error(f"Unexpected job object type {type(job).__name__}") continue # Propagate environment variables from parent to child - merged = copy(self.spec.env) - merged.update(job.env) - job.env = merged + env = {} + if self.spec.extend_env: + env.update(os.environ) + env.update(self.spec.env) + env.update(job.env) + job.env = env # Propagate working directory from parent to child job.cwd = job.cwd or self.spec.cwd # Vary behaviour depending if this a job array or not @@ -483,7 +488,7 @@ async def __launch(self): child_dir = base_trk_dir if is_jarr: job_cp = deepcopy(job) - job_cp.env["GATOR_ARRAY_INDEX"] = idx_jarr + job_cp.env["GATOR_ARRAY_INDEX"] = str(idx_jarr) child_id += f"_{idx_jarr}" child_dir = base_trk_dir / str(idx_jarr) else: diff --git a/gator/wrapper.py b/gator/wrapper.py index 137a774..cea0c51 100644 --- a/gator/wrapper.py +++ b/gator/wrapper.py @@ -17,16 +17,14 @@ import shlex import socket import subprocess -from collections import defaultdict from datetime import datetime from pathlib import Path import expandvars -import plotly.graph_objects as pg import psutil from tabulate import tabulate -from .common.layer import BaseLayer, MetricResponse +from .common.layer import BaseLayer, MetricResponse, UsageResponse from .common.summary import Summary from .common.types import Attribute, JobResult, LogSeverity, ProcStat @@ -34,17 +32,16 @@ class Wrapper(BaseLayer): """Wraps a single process and tracks logging & process statistics""" - def __init__(self, *args, plotting: bool = False, summary: bool = False, **kwargs) -> None: + def __init__(self, *args, summary: bool = False, **kwargs) -> None: """ Initialise the wrapper, launch it and monitor it until completion. - :param plotting: Plot the resource usage once the job completes. :param summary: Display a tabulated summary of resource usage """ super().__init__(*args, **kwargs) - self.plotting = plotting self.summary = summary self.proc = None + self.extra_usage = None # Capture forwarded messages from the wrapped job if self.logger: self.logger.capture_all = True @@ -53,6 +50,7 @@ async def launch(self, *args, **kwargs) -> None: await self.setup(*args, **kwargs) # Register endpoint for metrics self.server.add_route("metric", self.__handle_metric) + self.server.add_route("extra_usage", self.__handle_extra_usage) # Register additional data types await self.db.register(Attribute) await self.db.register(ProcStat) @@ -105,6 +103,21 @@ async def __handle_metric(self, name: str, value: int, **_) -> MetricResponse: # Return success return {"result": "success"} + async def __handle_extra_usage( + self, timestamp: int, cpu_perc: float, memory: float, **_ + ) -> UsageResponse: + """ + Handle additional resource usage information being reported from a child. + + Example: { "timestamp": 12345678, "cpu_perc": 0.4, "memory": 1234.2 } + """ + self.extra_usage = (timestamp, cpu_perc, memory) + await self.logger.debug( + f"Process reported extra usage - CPU: {cpu_perc:.01f}%, Memory: {memory:.01f} MB" + ) + # Return success + return {"result": "success"} + async def __monitor_stdio( self, proc: asyncio.subprocess.Process, @@ -122,7 +135,7 @@ async def _monitor(pipe, severity): log_fh.write(line) clean = line.rstrip() if len(clean) > 0: - await self.logger.log(severity, clean) + await self.logger.log(severity, clean, "stdio") t_stdout = asyncio.create_task(_monitor(stdout, LogSeverity.INFO)) t_stderr = asyncio.create_task(_monitor(stderr, LogSeverity.ERROR)) @@ -149,7 +162,6 @@ async def __monitor_usage( try: # Capture statistics with ps.oneshot(): - await self.logger.debug(f"Capturing statistics for {proc.pid}") nproc = 1 cpu_perc = ps.cpu_percent() mem_stat = ps.memory_info() @@ -167,26 +179,41 @@ async def __monitor_usage( vms += c_mem_stat.vms # if io_count is not None: # io_count += ps.io_counters() if hasattr(ps, "io_counters") else None + # Convert RSS and VMS into MB + rss_mb = rss / (1024 * 1024) + vms_mb = vms / (1024 * 1024) + # Take account of 'extra' usage reported by the process + if self.extra_usage is not None: + _ts, ex_cpu_perc, ex_memory = self.extra_usage + cpu_perc += ex_cpu_perc + rss_mb += ex_memory + await self.logger.debug( + f"Resource usage of {proc.pid} - CPU: {cpu_perc:.01f}%, " + f"Memory: {rss_mb:.01f} MB" + ) # Push statistics to the database await self.db.push_procstat( ProcStat( timestamp=datetime.now(), nproc=nproc, cpu=cpu_perc, - mem=rss, - vmem=vms, + mem=rss_mb, + vmem=vms_mb, ) ) # Check if exceeding the limits - now_exceeding = (cpu_cores > 0 and cpu_perc > (100 * cpu_cores)) or ( - memory_mb > 0 and (rss / 1e6) > memory_mb + now_exceeding = any( + ( + (cpu_cores > 0 and cpu_perc > (100 * cpu_cores)), + (memory_mb > 0 and rss_mb > memory_mb), + ) ) if now_exceeding and not exceeding: await self.logger.warning( f"Job has exceed it's requested resources of " f"{cpu_cores} CPU cores and {memory_mb} MB of RAM - " - f"current usage is {cpu_perc / 100:.01f} CPU cores and " - f"{rss / 1E6:0.1f} MB of RAM" + f"current usage is {cpu_perc:.01f}% CPU and " + f"{rss_mb:0.1f} MB of RAM" ) exceeding = now_exceeding except psutil.NoSuchProcess: @@ -205,7 +232,13 @@ async def __launch(self) -> None: Launch the process and pipe STDIN, STDOUT, and STDERR with line buffering """ # Overlay any custom variables on the environment - env = {str(k): str(v) for k, v in (self.spec.env or os.environ).items()} + env = {} + if self.spec.extend_env: + env.update(os.environ) + env.update(self.spec.env) + if "PYTHONPATH" in env: + env["PYTHONPATH"] += ":" + env["PYTHONPATH"] = env.get("PYTHONPATH", "") + Path(__file__).parent.parent.as_posix() env["GATOR_PARENT"] = await self.server.get_address() env["PYTHONUNBUFFERED"] = "1" # Determine the working directory @@ -236,7 +269,7 @@ async def __launch(self) -> None: # Setup initial attributes await self.db.push_attribute(Attribute(name="cmd", value=full_cmd)) await self.db.push_attribute(Attribute(name="cwd", value=working_dir.as_posix())) - await self.db.push_attribute(Attribute(name="host", value=socket.gethostname())) + await self.db.push_attribute(Attribute(name="host", value=socket.getfqdn())) await self.db.push_attribute(Attribute(name="req_cores", value=str(cpu_cores))) await self.db.push_attribute(Attribute(name="req_memory", value=str(memory_mb))) await self.db.push_attribute( @@ -293,21 +326,6 @@ async def __report(self) -> None: pid = await self.db.get_attribute(name="pid") started_at = datetime.fromtimestamp(self.started) stopped_at = datetime.fromtimestamp(self.stopped) - # If plotting enabled, draw the plot - if self.plotting: - dates = [] - series = defaultdict(list) - for entry in data: - dates.append(entry.timestamp) - series["Processes"].append(entry.nproc) - series["CPU %"].append(entry.cpu) - series["Memory (MB)"].append(entry.mem / (1024**3)) - series["VMemory (MB)"].append(entry.vmem / (1024**3)) - fig = pg.Figure() - for key, vals in series.items(): - fig.add_trace(pg.Scatter(x=dates, y=vals, mode="lines", name=key)) - fig.update_layout(title=f"Resource Usage for {pid[0].value}", xaxis_title="Time") - fig.write_image(self.plotting.as_posix(), format="png") # Summarise process usage if self.summary: max_nproc = max(x.nproc for x in data) if data else 0 diff --git a/mkdocs.yml b/mkdocs.yml index 28bfe80..122f5e6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,4 +1,4 @@ -site_name: Gator Aid +site_name: Gator repo_name: intuity/gator repo_url: https://github.com/intuity/Gator theme: diff --git a/pyproject.toml b/pyproject.toml index 95eb053..1e1089c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gator-eda" -version = "0.1" +version = "1.0" description = "Hierarchical job execution and logging" authors = ["Peter Birch "] license = "Apache-2.0" @@ -8,37 +8,42 @@ readme = "README.md" packages = [{ include = "gator", from = "." }] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.11" click = "^8.1.3" -plotly = "^5.14.1" psutil = "^5.9.4" -kaleido = "0.2.1" -requests = "^2.28.2" rich = "^13.3.4" tabulate = "^0.9.0" pyyaml = "^6.0" -uwsgi = "^2.0.21" expandvars = "^0.9.0" websockets = "^11.0.2" aiosqlite = "^0.19.0" +aiohttp = "^3.12.13" + +[tool.poetry.group.hub] +optional = true + +[tool.poetry.group.hub.dependencies] +uwsgi = "^2.0.21" quart = "^0.18.4" # This is a secondary dependency but needs to be specified. https://stackoverflow.com/a/77214086 Werkzeug = "2.3.7" piccolo = { extras = ["orjson", "postgres", "uvloop"], version = "^0.111.1" } -aiohttp = "^3.8.4" + +[tool.poetry.group.dev] +optional = true [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" pytest-cov = "^4.0.0" pytest-mock = "^3.10.0" -mkdocs = "^1.4.2" -mkdocs-material = "^9.1.6" +mkdocs = "^1.6.1" +mkdocs-material = "^9.6.7" poethepoet = "^0.19.0" pytest-asyncio = "^0.21.0" ruff = "^0.6.8" -# mkdocstrings and griffe pinned at versions that work together -mkdocstrings = { extras = ["python"], version = "0.21.2" } -griffe = "0.25.5" +mkdocstrings = { extras = ["python"], version = "^0.28.2" } +griffe = "^1.6.0" +pre-commit = "^4.2.0" [tool.poetry.scripts] gator = "gator.__main__:main" diff --git a/tests/common/test_logger.py b/tests/common/test_logger.py index bf9b3ba..f2975b3 100644 --- a/tests/common/test_logger.py +++ b/tests/common/test_logger.py @@ -49,6 +49,9 @@ def logger_linked(logger_local) -> Logger: class TestLogger: + # Root hierarchy string + ROOT_STR = f"{'root':{Logger.HIER_WIDTH}s}" + @pytest.mark.asyncio async def test_unlinked(self, logger): """Local logging goes to the console""" @@ -75,32 +78,36 @@ async def test_local(self, logger_local): # Raw await logger.log(LogSeverity.INFO, "Testing info") assert not logger.ws_cli.log.called - logger._Logger__console.log.assert_called_with("[bold][INFO ][/bold] Testing info") + logger._Logger__console.log.assert_called_with( + "[bold]INFO [/bold] " + TestLogger.ROOT_STR + " Testing info" + ) logger._Logger__console.log.reset_mock() # Debug await logger.debug("Testing debug") assert not logger.ws_cli.log.called logger._Logger__console.log.assert_called_with( - "[bold cyan][DEBUG ][/bold cyan] Testing debug" + "[bold cyan]DEBUG [/bold cyan] " + TestLogger.ROOT_STR + " Testing debug" ) logger._Logger__console.log.reset_mock() # Info await logger.info("Testing info") assert not logger.ws_cli.log.called - logger._Logger__console.log.assert_called_with("[bold][INFO ][/bold] Testing info") + logger._Logger__console.log.assert_called_with( + "[bold]INFO [/bold] " + TestLogger.ROOT_STR + " Testing info" + ) logger._Logger__console.log.reset_mock() # Warning await logger.warning("Testing warning") assert not logger.ws_cli.log.called logger._Logger__console.log.assert_called_with( - "[bold yellow][WARNING][/bold yellow] Testing warning" + "[bold yellow]WARNING[/bold yellow] " + TestLogger.ROOT_STR + " Testing warning" ) logger._Logger__console.log.reset_mock() # Error await logger.error("Testing error") assert not logger.ws_cli.log.called logger._Logger__console.log.assert_called_with( - "[bold red][ERROR ][/bold red] Testing error" + "[bold red]ERROR [/bold red] " + TestLogger.ROOT_STR + " Testing error" ) logger._Logger__console.log.reset_mock() @@ -112,42 +119,48 @@ async def test_linked(self, logger_linked): # Raw await logger.log(LogSeverity.INFO, "Testing info") logger.ws_cli.log.assert_called_with( - timestamp=1234, severity="INFO", message="Testing info", posted=True + timestamp=1234, hierarchy="root", severity="INFO", message="Testing info", posted=True + ) + logger._Logger__console.log.assert_called_with( + "[bold]INFO [/bold] " + TestLogger.ROOT_STR + " Testing info" ) - logger._Logger__console.log.assert_called_with("[bold][INFO ][/bold] Testing info") logger.ws_cli.log.reset_mock() logger._Logger__console.log.reset_mock() # Debug await logger.debug("Testing debug") logger.ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="DEBUG", message="Testing debug", posted=True, ) logger._Logger__console.log.assert_called_with( - "[bold cyan][DEBUG ][/bold cyan] Testing debug" + "[bold cyan]DEBUG [/bold cyan] " + TestLogger.ROOT_STR + " Testing debug" ) logger.ws_cli.log.reset_mock() logger._Logger__console.log.reset_mock() # Info await logger.info("Testing info") logger.ws_cli.log.assert_called_with( - timestamp=1234, severity="INFO", message="Testing info", posted=True + timestamp=1234, hierarchy="root", severity="INFO", message="Testing info", posted=True + ) + logger._Logger__console.log.assert_called_with( + "[bold]INFO [/bold] " + TestLogger.ROOT_STR + " Testing info" ) - logger._Logger__console.log.assert_called_with("[bold][INFO ][/bold] Testing info") logger.ws_cli.log.reset_mock() logger._Logger__console.log.reset_mock() # Warning await logger.warning("Testing warning") logger.ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="WARNING", message="Testing warning", posted=True, ) logger._Logger__console.log.assert_called_with( - "[bold yellow][WARNING][/bold yellow] Testing warning" + "[bold yellow]WARNING[/bold yellow] " + TestLogger.ROOT_STR + " Testing warning" ) logger.ws_cli.log.reset_mock() logger._Logger__console.log.reset_mock() @@ -155,12 +168,13 @@ async def test_linked(self, logger_linked): await logger.error("Testing error") logger.ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="ERROR", message="Testing error", posted=True, ) logger._Logger__console.log.assert_called_with( - "[bold red][ERROR ][/bold red] Testing error" + "[bold red]ERROR [/bold red] " + TestLogger.ROOT_STR + " Testing error" ) logger.ws_cli.log.reset_mock() logger._Logger__console.log.reset_mock() @@ -177,6 +191,7 @@ def test_cli(self, mocker): runner.invoke(gator.common.logger.logger, ["This is a test"]) ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="INFO", message="This is a test", posted=True, @@ -189,6 +204,7 @@ def test_cli(self, mocker): ) ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="DEBUG", message="This is a debug test", posted=True, @@ -201,6 +217,7 @@ def test_cli(self, mocker): ) ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="INFO", message="This is an info test", posted=True, @@ -213,6 +230,7 @@ def test_cli(self, mocker): ) ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="WARNING", message="This is a warning test", posted=True, @@ -225,6 +243,7 @@ def test_cli(self, mocker): ) ws_cli.log.assert_called_with( timestamp=1234, + hierarchy="root", severity="ERROR", message="This is an error test", posted=True, diff --git a/tests/specs/test_job.py b/tests/specs/test_job.py index fd94ac4..5e66b7d 100644 --- a/tests/specs/test_job.py +++ b/tests/specs/test_job.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import platform + import pytest from gator.specs import Spec from gator.specs.common import SpecError from gator.specs.jobs import Job -from gator.specs.resource import Cores, License, Memory +from gator.specs.resource import ARCH_ALIASES, Cores, Feature, License, Memory def test_spec_job_positional(): """A job should preserve all positional arguments provided to it""" job = Job( "id_123", + True, {"key_a": 2345, "key_b": False}, "/path/to/working/dir", "echo", @@ -96,13 +99,16 @@ def test_spec_job_parse(tmp_path): " ident: id_123\n" " env:\n" " key_a: 2345\n" - " key_b: False\n" + " key_b: hello\n" " cwd: /path/to/working/dir\n" + " extend_env: true\n" " command: echo\n" " args:\n" " - String to print\n" " resources:\n" - " - !Cores [3]\n" + " - !Cores\n" + " arch: x86\n" + " count: 3\n" " - !License [A, 2]\n" " - !Memory [1, GB]\n" " on_done:\n" @@ -113,13 +119,15 @@ def test_spec_job_parse(tmp_path): " - job_2\n" ) job = Spec.parse(spec_file) + job.check() assert isinstance(job, Job) assert job.ident == "id_123" - assert job.env == {"key_a": 2345, "key_b": False} + assert job.env == {"key_a": 2345, "key_b": "hello"} assert job.cwd == "/path/to/working/dir" assert job.command == "echo" assert job.args == ["String to print"] assert isinstance(job.resources[0], Cores) + assert job.resources[0].arch == "x86_64" assert job.resources[0].count == 3 assert job.requested_cores == 3 assert isinstance(job.resources[1], License) @@ -204,6 +212,7 @@ def test_spec_job_dump(): "env:\n" " key_a: 2345\n" " key_b: false\n" + "extend_env: true\n" "ident: id_123\n" "on_done:\n" "- job_0\n" @@ -213,6 +222,7 @@ def test_spec_job_dump(): "- job_2\n" "resources:\n" "- !Cores\n" + " arch: " + ARCH_ALIASES[platform.uname().machine] + "\n" " count: 3\n" "- !License\n" " count: 2\n" @@ -284,7 +294,7 @@ def test_spec_job_bad_fields(): # Check bad resources (non-YAML tags) with pytest.raises(SpecError) as exc: Job(resources=["hello", 2]).check() - assert str(exc.value) == "Resources must be !Cores, !Memory, or !License" + assert str(exc.value) == "Resources must be !Cores, !Memory, !License, or !Feature" assert exc.value.field == "resources" # Check duplicate entries for !Cores with pytest.raises(SpecError) as exc: @@ -301,6 +311,11 @@ def test_spec_job_bad_fields(): Job(resources=[Cores(2), License("A"), License("B"), License("B")]).check() assert str(exc.value) == "More than one entry for license 'B'" assert exc.value.field == "resources" + # Check duplicate entries of a particular feature + with pytest.raises(SpecError) as exc: + Job(resources=[Cores(2), Feature("A"), Feature("B"), Feature("B")]).check() + assert str(exc.value) == "More than one entry for feature 'B'" + assert exc.value.field == "resources" # Check on done/fail/pass for field in ("on_done", "on_fail", "on_pass"): # Check non-list diff --git a/tests/specs/test_job_array.py b/tests/specs/test_job_array.py index 69abe18..a075a96 100644 --- a/tests/specs/test_job_array.py +++ b/tests/specs/test_job_array.py @@ -165,6 +165,7 @@ def test_spec_job_array_dump(): "!JobArray\n" "cwd: null\n" "env: {}\n" + "extend_env: true\n" "ident: arr_123\n" "jobs:\n" "- !Job\n" @@ -175,6 +176,7 @@ def test_spec_job_array_dump(): " env:\n" " key_a: 2345\n" " key_b: false\n" + " extend_env: true\n" " ident: id_123\n" " on_done: []\n" " on_fail: []\n" diff --git a/tests/specs/test_job_group.py b/tests/specs/test_job_group.py index f9e74f0..b031a33 100644 --- a/tests/specs/test_job_group.py +++ b/tests/specs/test_job_group.py @@ -157,6 +157,7 @@ def test_spec_job_group_dump(): "!JobGroup\n" "cwd: null\n" "env: {}\n" + "extend_env: true\n" "ident: grp_123\n" "jobs:\n" "- !Job\n" @@ -167,6 +168,7 @@ def test_spec_job_group_dump(): " env:\n" " key_a: 2345\n" " key_b: false\n" + " extend_env: true\n" " ident: id_123\n" " on_done: []\n" " on_fail: []\n" diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index 8b06b68..797f946 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -40,7 +40,13 @@ async def setup_teardown(self, mocker) -> None: async def test_local_scheduling(self, mocker, tmp_path): """Launch a number of tasks""" # Create an scheduler - sched = LocalScheduler(parent="test:1234", interval=7, quiet=False, logger=self.logger) + sched = LocalScheduler( + tracking=tmp_path / "tracking", + parent="test:1234", + interval=7, + quiet=False, + logger=self.logger, + ) assert sched.parent == "test:1234" assert sched.interval == 7 assert sched.quiet is False diff --git a/tests/test_tier.py b/tests/test_tier.py index cae538b..b3e1bbf 100644 --- a/tests/test_tier.py +++ b/tests/test_tier.py @@ -339,13 +339,13 @@ async def test_tier_get_tree(self, tmp_path) -> None: script_a = tmp_path / "a.sh" script_b = tmp_path / "b.sh" script_c = tmp_path / "c.sh" - script_a.write_text(f"touch {touch_a.as_posix()}\nsleep 30\n") - script_b.write_text(f"touch {touch_b.as_posix()}\nsleep 30\n") - script_c.write_text(f"touch {touch_c.as_posix()}\nsleep 30\n") + script_a.write_text(f"touch {touch_a.as_posix()}\nsleep 5\n") + script_b.write_text(f"touch {touch_b.as_posix()}\nsleep 5\n") + script_c.write_text(f"touch {touch_c.as_posix()}\nsleep 5\n") # Define job specification - job_a = Job("a", command="sh", args=[script_a.as_posix()]) - job_b = Job("b", command="sh", args=[script_b.as_posix()]) - job_c = Job("c", command="sh", args=[script_c.as_posix()]) + job_a = Job("a", command="bash", args=[script_a.as_posix()]) + job_b = Job("b", command="bash", args=[script_b.as_posix()]) + job_c = Job("c", command="bash", args=[script_c.as_posix()]) grp_low = JobGroup("low", jobs=[job_a]) grp_mid = JobGroup("mid", jobs=[job_b, grp_low]) grp_top = JobGroup("top", jobs=[job_c, grp_mid]) @@ -382,5 +382,7 @@ async def test_tier_get_tree(self, tmp_path) -> None: } # Stop the jobs await tier.stop() + # Disconnect the websocket client + await ws_cli.stop() # Wait for the jobs to stop await t_launch diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 3644109..f19d95d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -83,7 +83,6 @@ async def test_wrapper_basic(self, tmp_path) -> None: assert not wrp.quiet assert not wrp.all_msg assert wrp.heartbeat_cb is None - assert not wrp.plotting assert not wrp.summary assert wrp.proc is None assert not wrp.complete @@ -116,7 +115,7 @@ async def test_wrapper_basic(self, tmp_path) -> None: ("started", None), ("cmd", "echo hi"), ("cwd", tmp_path.as_posix()), - ("host", socket.gethostname()), + ("host", socket.getfqdn()), ("req_cores", "2"), ("req_memory", "1500.0"), ("req_licenses", "A=1,B=3"), @@ -259,35 +258,6 @@ async def test_wrapper_terminate(self, tmp_path) -> None: assert wrp.terminated assert wrp.code == 255 - async def test_wrapper_plotting(self, tmp_path) -> None: - """Check a plot is drawn if requested""" - # Mock datetime to always return one value - self.mk_wrp_dt.now.side_effect = None - self.mk_wrp_dt.now.return_value = datetime.fromtimestamp(12345) - # Define a job specification - job = Job("test", cwd=tmp_path.as_posix(), command="echo", args=["hi"]) - # Mock procstats returned by DB - self.mk_db.get_procstat.return_value = [ - ProcStat(db_uid=0, nproc=1, cpu=0.1, mem=11 * (1024**3)) - ] * 5 - # Create a wrapper - trk_dir = tmp_path / "tracking" - plt_path = tmp_path / "plot.png" - wrp = Wrapper( - spec=job, - client=self.client, - tracking=trk_dir, - logger=self.logger, - interval=1, - plotting=plt_path, - ) - # Check no plot exists - assert not plt_path.exists() - # Run the job - await wrp.launch() - # Check plot has been written out - assert plt_path.exists() - async def test_wrapper_summary(self, tmp_path, mocker) -> None: """Check that a process summary table is produced""" # Patch tabulate and print