diff --git a/gator/__main__.py b/gator/__main__.py index 94f27f3..54838e8 100644 --- a/gator/__main__.py +++ b/gator/__main__.py @@ -78,6 +78,13 @@ default=0, help="Maximum number of critical messages before failure", ) +@click.option( + "--internal", + type=bool, + default=False, + is_flag=True, + help="Indicates this instance of Gator is running as an internal child job.", +) @click.argument("spec", type=click.Path(exists=True), required=False) def main( ident: str, @@ -94,6 +101,7 @@ def main( limit_warning: Optional[int], limit_error: int, limit_critical: int, + internal: bool, spec: str, ) -> None: # Determine a tracking directory @@ -139,6 +147,7 @@ def main( error=limit_error, critical=limit_critical, ), + internal=internal ) ) if not summary.passed: diff --git a/gator/launch.py b/gator/launch.py index 76fdbc4..3b766f4 100644 --- a/gator/launch.py +++ b/gator/launch.py @@ -20,7 +20,7 @@ import socket from functools import partial from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Dict, Optional, Type, Union, cast from rich.console import Console @@ -32,6 +32,7 @@ from .hub.api import HubAPI from .scheduler import LocalScheduler from .specs import Job, JobArray, JobGroup, Spec +from .specs.common import SpecBase from .tier import Tier from .wrapper import Wrapper @@ -40,7 +41,7 @@ async def launch( ident: Optional[str] = None, hub: Optional[str] = None, parent: Optional[str] = None, - spec: Optional[Union[Spec, Path]] = None, + spec: Optional[Union[SpecBase, Spec, Path]] = None, tracking: Optional[Path] = None, interval: int = 5, quiet: bool = False, @@ -52,6 +53,7 @@ async def launch( sched_opts: Optional[Dict[str, str]] = None, glyph: Optional[str] = None, limits: Optional[MessageLimits] = None, + internal: bool = False, ) -> Summary: # Glyph only used when progress bar visible del glyph @@ -81,51 +83,68 @@ async def launch( ) # Work out where the spec is coming from # - From server (nested call) + parsed_spec: SpecBase if spec is None and client.linked and ident: raw_spec = await client.spec(ident=ident) - spec = Spec.parse_str(raw_spec.get("spec", "")) + parsed_spec = Spec.parse_str(raw_spec.get("spec", "")) # - Passed in directly (when used as a library elif spec is not None and isinstance(spec, (Job, JobArray, JobGroup)): - pass + parsed_spec = cast(SpecBase, spec) # - Passed as a file path elif spec is not None and isinstance(spec, (Path, str)): - spec = Spec.parse(Path(spec)) + parsed_spec = Spec.parse(Path(spec)) # - Unknown else: raise Exception("No specification file provided and no parent server to query") + + # Hint for the type checker and a safety during debugging + assert isinstance(parsed_spec, Job | JobArray | JobGroup), ( + f"Expected specification to be a Job, JobArray or JobGroup, received " + f"{type(parsed_spec).__name__}." + ) + # If an ident has been provided, override whatever the spec gives if ident is not None: - spec.ident = ident + parsed_spec.ident = ident + # Check the spec object - spec.check() - # If a JobArray or JobGroup is provided, launch a tier - if isinstance(spec, (JobArray, JobGroup)): - top = Tier( - spec=spec, + parsed_spec.check() + + # When user launches a single job, wrap it up in a JobArray so we can + # launch it via a common mechanism (which will ensure this job launches via + # the specified scheduler) + if isinstance(parsed_spec, Job) and not internal: + parsed_spec = JobArray(jobs=[parsed_spec]) + + if isinstance(parsed_spec, Job): + # Internal single job - launch via the wrapper on current machine + # as this is the executor instance. I.e. don't use the scheduler + top = Wrapper( + spec=parsed_spec, client=client, logger=logger, tracking=tracking, + interval=interval, quiet=quiet and not all_msg, all_msg=all_msg, heartbeat_cb=heartbeat_cb, - scheduler=scheduler, - sched_opts=sched_opts, limits=limits, ) - # If a Job is provided, launch a wrapper - elif isinstance(spec, Job): - top = Wrapper( - spec=spec, + else: + # Non-internal single job or a multi-task job - launch via the scheduler + top = Tier( + spec=parsed_spec, client=client, logger=logger, tracking=tracking, interval=interval, quiet=quiet and not all_msg, + all_msg=all_msg, + heartbeat_cb=heartbeat_cb, + scheduler=scheduler, + sched_opts=sched_opts, limits=limits, ) - # Unsupported forms - else: - raise Exception(f"Unsupported specification object of type {type(spec).__name__}") # Setup signal handler to capture CTRL+C events def _handler(sig: signal, evt_loop: asyncio.BaseEventLoop, top: Union[Tier, Wrapper]): diff --git a/gator/scheduler/common.py b/gator/scheduler/common.py index c3d378c..fd4dd9e 100644 --- a/gator/scheduler/common.py +++ b/gator/scheduler/common.py @@ -106,7 +106,7 @@ def create_command(self, child: Child, options: Optional[Dict[str, str]] = None) return list( itertools.chain( self.base_command, - ["--id", child.ident, "--tracking", child.tracking.as_posix()], + ["--internal", "--id", child.ident, "--tracking", child.tracking.as_posix()], *(["--sched-arg", f"{k}={v}"] for k, v in full_opts.items()), ) ) diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index 797f946..494c2c4 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -89,6 +89,7 @@ def _create_proc(*_args, **_kwargs): call( f"python3 -m gator --limit-error=0 --limit-critical=0" " --parent test:1234 --interval 7 --scheduler local --all-msg " + "--internal " f"--id T{x} --tracking {(tmp_path / f'T{x}').as_posix()}" " --sched-arg concurrency=1", stdin=subprocess.DEVNULL, @@ -104,3 +105,103 @@ def _create_proc(*_args, **_kwargs): await sched.wait_for_all() # Check all monitors were fired up as_mon.assert_has_calls([call(f"T{x}", y) for x, y in zip(range(10), procs)]) + + async def test_local_scheduler_default_launch(self, mocker, tmp_path): + """Check that launch() without `internal` flag uses Tier/scheduler for a single Job""" + from gator.launch import launch + from gator.specs import JobArray + + # Patch Console to avoid output during test + mocker.patch("gator.launch.Console") + # Mock the Tier class to verify it's instantiated + mk_tier_cls = mocker.patch("gator.launch.Tier") + mk_tier = MagicMock() + mk_tier.launch = AsyncMock() + mk_tier.summarise = AsyncMock() + mk_tier.is_root = True + mk_tier_cls.return_value = mk_tier + # Define a job specification + job = Job("test_scheduler", cwd=tmp_path.as_posix(), command="echo", args=["scheduler"]) + # Call launch without internal flag (defaults to False) + trk_dir = tmp_path / "tracking" + await launch(spec=job, tracking=trk_dir, scheduler=LocalScheduler) + # Verify Tier was instantiated + mk_tier_cls.assert_called_once() + call_kwargs = mk_tier_cls.call_args.kwargs + # The single Job should be wrapped in a JobArray + assert isinstance(call_kwargs["spec"], JobArray) + assert len(call_kwargs["spec"].jobs) == 1 + assert call_kwargs["spec"].jobs[0] is job + assert call_kwargs["tracking"] == trk_dir + assert call_kwargs["scheduler"] is LocalScheduler + # Verify Tier.launch() was called + mk_tier.launch.assert_called_once() + mk_tier.summarise.assert_called_once() + + async def test_local_scheduler_jobarray_internal_launch(self, mocker, tmp_path): + """Check that launch() with JobArray and `internal=True` still uses Tier/scheduler""" + from gator.launch import launch + from gator.specs import JobArray, JobGroup + + # Patch Console to avoid output during test + mocker.patch("gator.launch.Console") + # Mock the Tier class to verify it's instantiated + mk_tier_cls = mocker.patch("gator.launch.Tier") + mk_tier = MagicMock() + mk_tier.launch = AsyncMock() + mk_tier.summarise = AsyncMock() + mk_tier.is_root = True + mk_tier_cls.return_value = mk_tier + # Define a JobArray specification + jobs: list[Job | JobArray | JobGroup] = [ + Job(f"test_job_{i}", cwd=tmp_path.as_posix(), command="echo", args=[f"job{i}"]) + for i in range(3) + ] + job_array = JobArray(jobs=jobs) + # Call launch with internal=True (should still use scheduler for multi-job specs) + trk_dir = tmp_path / "tracking" + await launch(spec=job_array, tracking=trk_dir, scheduler=LocalScheduler, internal=True) + # Verify Tier was instantiated (not Wrapper) + mk_tier_cls.assert_called_once() + call_kwargs = mk_tier_cls.call_args.kwargs + # The JobArray should be passed through unchanged + assert call_kwargs["spec"] is job_array + assert call_kwargs["tracking"] == trk_dir + assert call_kwargs["scheduler"] is LocalScheduler + # Verify Tier.launch() was called + mk_tier.launch.assert_called_once() + mk_tier.summarise.assert_called_once() + + async def test_local_scheduler_jobgroup_internal_launch(self, mocker, tmp_path): + """Check that launch() with JobGroup and `internal=True` still uses Tier/scheduler""" + from gator.launch import launch + from gator.specs import JobArray, JobGroup + + # Patch Console to avoid output during test + mocker.patch("gator.launch.Console") + # Mock the Tier class to verify it's instantiated + mk_tier_cls = mocker.patch("gator.launch.Tier") + mk_tier = MagicMock() + mk_tier.launch = AsyncMock() + mk_tier.summarise = AsyncMock() + mk_tier.is_root = True + mk_tier_cls.return_value = mk_tier + # Define a JobGroup specification + jobs: list[Job | JobArray | JobGroup] = [ + Job(f"test_job_{i}", cwd=tmp_path.as_posix(), command="echo", args=[f"job{i}"]) + for i in range(3) + ] + job_group = JobGroup(jobs=jobs) + # Call launch with internal=True (should still use scheduler for multi-job specs) + trk_dir = tmp_path / "tracking" + await launch(spec=job_group, tracking=trk_dir, scheduler=LocalScheduler, internal=True) + # Verify Tier was instantiated (not Wrapper) + mk_tier_cls.assert_called_once() + call_kwargs = mk_tier_cls.call_args.kwargs + # The JobGroup should be passed through unchanged + assert call_kwargs["spec"] is job_group + assert call_kwargs["tracking"] == trk_dir + assert call_kwargs["scheduler"] is LocalScheduler + # Verify Tier.launch() was called + mk_tier.launch.assert_called_once() + mk_tier.summarise.assert_called_once() diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index f19d95d..ab1b111 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -371,3 +371,30 @@ async def test_wrapper_metric(self, tmp_path, mocker) -> None: await wrp.stop() # Wait for task to complete await t_wrp + + async def test_wrapper_internal_launch(self, tmp_path, mocker) -> None: + """Check that launch() with `internal=True` uses Wrapper for a single Job""" + from gator.launch import launch + + # Patch Console to avoid output during test + mocker.patch("gator.launch.Console") + # Mock the Wrapper class to verify it's instantiated + mk_wrapper_cls = mocker.patch("gator.launch.Wrapper") + mk_wrapper = MagicMock() + mk_wrapper.launch = AsyncMock() + mk_wrapper.summarise = AsyncMock() + mk_wrapper.is_root = True + mk_wrapper_cls.return_value = mk_wrapper + # Define a job specification + job = Job("test_internal", cwd=tmp_path.as_posix(), command="echo", args=["internal"]) + # Call launch with internal=True + trk_dir = tmp_path / "tracking" + await launch(spec=job, tracking=trk_dir, internal=True) + # Verify Wrapper was instantiated with the Job (not wrapped in JobArray) + mk_wrapper_cls.assert_called_once() + call_kwargs = mk_wrapper_cls.call_args.kwargs + assert call_kwargs["spec"] is job + assert call_kwargs["tracking"] == trk_dir + # Verify Wrapper.launch() was called + mk_wrapper.launch.assert_called_once() + mk_wrapper.summarise.assert_called_once()