Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ requires-python = ">=3.10"
"vulture==2.14",
"pytest~=8.4",
"pytest-cov~=7.0",
"pytest-asyncio~=1.2",
"import-linter~=2.5",
"pytest-deadfixtures~=2.2",
"taplo~=0.9.3",
Expand Down
75 changes: 32 additions & 43 deletions src/cloudai/_core/base_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,10 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import time
from abc import ABC, abstractmethod
from asyncio import Task
from pathlib import Path
from typing import Dict, List

Expand Down Expand Up @@ -71,33 +70,33 @@ def __init__(self, mode: str, system: System, test_scenario: TestScenario, outpu
logging.debug(f"{self.__class__.__name__} initialized")
self.shutting_down = False

async def shutdown(self):
def shutdown(self):
"""Gracefully shut down the runner, terminating all outstanding jobs."""
self.shutting_down = True
logging.info("Terminating all jobs...")
for job in self.jobs:
logging.info(f"Terminating job {job.id} for test {job.test_run.name}")
self.system.kill(job)
logging.info("All jobs have been killed.")
logging.info("Waiting for all jobs to be killed.")

async def run(self):
"""Asynchronously run the test scenario."""
def run(self):
"""Run the test scenario."""
if self.shutting_down:
return

total_tests = len(self.test_scenario.test_runs)
dependency_free_trs = self.find_dependency_free_tests()
for tr in dependency_free_trs:
await self.submit_test(tr)
self.submit_test(tr)

logging.debug(f"Total tests: {total_tests}, dependency free tests: {[tr.name for tr in dependency_free_trs]}")
while self.jobs:
await self.check_start_post_init_dependencies()
await self.monitor_jobs()
self.check_start_post_init_dependencies()
self.monitor_jobs()
logging.debug(f"sleeping for {self.monitor_interval} seconds")
await asyncio.sleep(self.monitor_interval)
time.sleep(self.monitor_interval)

async def submit_test(self, tr: TestRun):
def submit_test(self, tr: TestRun):
"""
Start a dependency-free test.

Expand All @@ -118,7 +117,7 @@ async def submit_test(self, tr: TestRun):
def on_job_submit(self, tr: TestRun) -> None:
return

async def delayed_submit_test(self, tr: TestRun, delay: int = 5):
def delayed_submit_test(self, tr: TestRun, delay: int = 5):
"""
Delay the start of a test based on start_post_comp dependency.

Expand All @@ -127,8 +126,8 @@ async def delayed_submit_test(self, tr: TestRun, delay: int = 5):
delay (int): Delay in seconds before starting the test.
"""
logging.debug(f"Delayed start for test {tr.name} by {delay} seconds.")
await asyncio.sleep(delay)
await self.submit_test(tr)
time.sleep(delay)
self.submit_test(tr)

@abstractmethod
def _submit_test(self, tr: TestRun) -> BaseJob:
Expand All @@ -143,7 +142,7 @@ def _submit_test(self, tr: TestRun) -> BaseJob:
"""
pass

async def check_start_post_init_dependencies(self):
def check_start_post_init_dependencies(self):
"""
Check and handle start_post_init dependencies.

Expand All @@ -164,9 +163,9 @@ async def check_start_post_init_dependencies(self):

logging.debug(f"start_post_init for test {tr.name} ({is_running=}, {is_completed=}, {self.mode=})")
if is_running or is_completed:
await self.check_and_schedule_start_post_init_dependent_tests(tr)
self.check_and_schedule_start_post_init_dependent_tests(tr)

async def check_and_schedule_start_post_init_dependent_tests(self, started_test_run: TestRun):
def check_and_schedule_start_post_init_dependent_tests(self, started_test_run: TestRun):
"""
Schedule tests with a start_post_init dependency on the provided started_test.

Expand All @@ -177,7 +176,7 @@ async def check_and_schedule_start_post_init_dependent_tests(self, started_test_
if tr not in self.testrun_to_job_map:
for dep_type, dep in tr.dependencies.items():
if (dep_type == "start_post_init") and (dep.test_run == started_test_run):
await self.delayed_submit_test(tr)
self.delayed_submit_test(tr)

def find_dependency_free_tests(self) -> List[TestRun]:
"""
Expand Down Expand Up @@ -229,7 +228,7 @@ def get_job_output_path(self, tr: TestRun) -> Path:

return job_output_path

async def monitor_jobs(self) -> int:
def monitor_jobs(self) -> int:
"""
Monitor the status of jobs, handle end_post_comp dependencies, and schedule start_post_comp dependent jobs.

Expand All @@ -248,20 +247,20 @@ async def monitor_jobs(self) -> int:

if self.mode == "dry-run":
successful_jobs_count += 1
await self.handle_job_completion(job)
self.handle_job_completion(job)
else:
if self.test_scenario.job_status_check:
job_status_result = self.get_job_status(job)
if job_status_result.is_successful:
successful_jobs_count += 1
await self.handle_job_completion(job)
self.handle_job_completion(job)
else:
error_message = (
f"Job {job.id} for test {job.test_run.name} failed: {job_status_result.error_message}"
)
logging.error(error_message)
await self.handle_job_completion(job)
await self.shutdown()
self.handle_job_completion(job)
self.shutdown()
raise JobFailureError(job.test_run.name, error_message, job_status_result.error_message)
else:
job_status_result = self.get_job_status(job)
Expand All @@ -271,7 +270,7 @@ async def monitor_jobs(self) -> int:
)
logging.error(error_message)
successful_jobs_count += 1
await self.handle_job_completion(job)
self.handle_job_completion(job)

return successful_jobs_count

Expand All @@ -296,7 +295,7 @@ def get_job_status(self, job: BaseJob) -> JobStatusResult:
return workload_run_results
return JobStatusResult(is_successful=True)

async def handle_job_completion(self, completed_job: BaseJob):
def handle_job_completion(self, completed_job: BaseJob):
"""
Handle the completion of a job, including dependency management and iteration control.

Expand All @@ -316,9 +315,9 @@ async def handle_job_completion(self, completed_job: BaseJob):
completed_job.test_run.current_iteration += 1
msg = f"Re-running job for iteration {completed_job.test_run.current_iteration}"
logging.info(msg)
await self.submit_test(completed_job.test_run)
self.submit_test(completed_job.test_run)
else:
await self.handle_dependencies(completed_job)
self.handle_dependencies(completed_job)

def on_job_completion(self, job: BaseJob) -> None:
"""
Expand All @@ -332,37 +331,27 @@ def on_job_completion(self, job: BaseJob) -> None:
"""
return

async def handle_dependencies(self, completed_job: BaseJob) -> List[Task]:
def handle_dependencies(self, completed_job: BaseJob):
"""
Handle the start_post_comp and end_post_comp dependencies for a completed job.

Args:
completed_job (BaseJob): The job that has just been completed.

Returns:
List[asyncio.Task]: A list of asyncio.Task objects created for handling the dependencies.
"""
tasks = []

# Handling start_post_comp dependencies
for tr in self.test_scenario.test_runs:
if tr not in self.testrun_to_job_map:
for dep_type, dep in tr.dependencies.items():
if dep_type == "start_post_comp" and dep.test_run == completed_job.test_run:
task = await self.delayed_submit_test(tr)
if task:
tasks.append(task)
self.delayed_submit_test(tr)

# Handling end_post_comp dependencies
for test, dependent_job in self.testrun_to_job_map.items():
for dep_type, dep in test.dependencies.items():
if dep_type == "end_post_comp" and dep.test_run == completed_job.test_run:
task = await self.delayed_kill_job(dependent_job)
tasks.append(task)

return tasks
self.delayed_kill_job(dependent_job)

async def delayed_kill_job(self, job: BaseJob, delay: int = 0):
def delayed_kill_job(self, job: BaseJob, delay: int = 0):
"""
Schedule termination of a Standalone job after a specified delay.

Expand All @@ -371,7 +360,7 @@ async def delayed_kill_job(self, job: BaseJob, delay: int = 0):
delay (int): Delay in seconds after which the job should be terminated.
"""
logging.info(f"Scheduling termination of job {job.id} after {delay} seconds.")
await asyncio.sleep(delay)
time.sleep(delay)
job.terminated_by_dependency = True
self.system.kill(job)

Expand Down
30 changes: 4 additions & 26 deletions src/cloudai/_core/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import datetime
import logging
from types import FrameType
Expand Down Expand Up @@ -80,39 +79,18 @@ def create_runner(self, mode: str, system: System, test_scenario: TestScenario)

return runner_class(mode, system, test_scenario, results_root)

async def run(self):
def run(self):
"""Run the test scenario using the instantiated runner."""
try:
await self.runner.run()
self.runner.run()
logging.debug("All jobs finished successfully.")
except asyncio.CancelledError:
logging.info("Runner cancelled, performing cleanup...")
await self.runner.shutdown()
return
except JobFailureError as exc:
logging.debug(f"Runner failed JobFailure exception: {exc}", exc_info=True)

def _cancel_all(self):
# the below code might look excessive, this is to address https://docs.astral.sh/ruff/rules/asyncio-dangling-task/
shutdown_task = asyncio.create_task(self.runner.shutdown())
tasks = {shutdown_task}
shutdown_task.add_done_callback(tasks.discard)

for task in asyncio.all_tasks():
if task == shutdown_task:
continue

logging.debug(f"Cancelling task: {task}")
try:
task.cancel()
except asyncio.CancelledError as exc:
logging.debug(f"Error cancelling task: {task}, {exc}", exc_info=True)
pass

def cancel_on_signal(
self,
signum: int,
frame: Optional[FrameType], # noqa: Vulture
):
logging.info(f"Signal {signum} received, shutting down...")
asyncio.get_running_loop().call_soon_threadsafe(self._cancel_all)
self.runner.shutdown()
5 changes: 2 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,7 +15,6 @@
# limitations under the License.

import argparse
import asyncio
import copy
import logging
import signal
Expand Down Expand Up @@ -192,7 +191,7 @@ def generate_reports(system: System, test_scenario: TestScenario, result_dir: Pa


def handle_non_dse_job(runner: Runner, args: argparse.Namespace) -> None:
asyncio.run(runner.run())
runner.run()
generate_reports(runner.runner.system, runner.runner.test_scenario, runner.runner.scenario_root)
logging.info("All jobs are complete.")

Expand Down
5 changes: 2 additions & 3 deletions src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import copy
import csv
import logging
Expand Down Expand Up @@ -113,7 +112,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
self.runner.testrun_to_job_map.clear()

try:
asyncio.run(self.runner.run())
self.runner.run()
except Exception as e:
logging.error(f"Error running step {self.test_run.step}: {e}")

Expand Down
20 changes: 11 additions & 9 deletions src/cloudai/systems/runai/runai_rest_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -20,7 +20,7 @@
from typing import Any, Dict, Optional

import requests
import websockets
from websockets.sync.client import connect as ws_connect


class RunAIRestClient:
Expand Down Expand Up @@ -496,7 +496,7 @@ def is_cluster_api_available(self, cluster_domain: str) -> bool:
response = requests.get(url, headers=headers)
return "OK" in response.text

async def fetch_training_logs(
def fetch_training_logs(
self, cluster_domain: str, project_name: str, training_task_name: str, output_file_path: Path
):
if not self.is_cluster_api_available(cluster_domain):
Expand All @@ -512,9 +512,11 @@ async def fetch_training_logs(
}

ssl_context = ssl._create_unverified_context()
async with websockets.connect(url, extra_headers=headers, ssl=ssl_context) as websocket:
with output_file_path.open("w") as log_file:
async for message in websocket:
if isinstance(message, bytes):
message = message.decode("utf-8")
log_file.write(str(message))
with (
ws_connect(url, additional_headers=headers, ssl=ssl_context) as websocket,
output_file_path.open("w") as log_file,
):
for message in websocket:
if isinstance(message, bytes):
message = message.decode("utf-8")
log_file.write(str(message))
Loading