diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index ed7b4bdb0f..105f9e8c9d 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -27,6 +27,7 @@ group_jobs_by_replica_latest, ) from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.prometheus.client_metrics import run_metrics from dstack._internal.server.services.runs import ( fmt, process_terminating_run, @@ -329,6 +330,24 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run_model.status.name, new_status.name, ) + if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING: + current_time = common.get_current_datetime() + submit_to_provision_duration = ( + current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc) + ).total_seconds() + logger.info( + "%s: run took %.2f seconds from submision to provisioning.", + fmt(run_model), + submit_to_provision_duration, + ) + project_name = run_model.project.name + run_metrics.log_submit_to_provision_duration( + submit_to_provision_duration, project_name, run_spec.configuration.type + ) + + if new_status == RunStatus.PENDING: + run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type) + run_model.status = new_status run_model.termination_reason = termination_reason # While a run goes to pending without provisioning, resubmission_attempt increases. diff --git a/src/dstack/_internal/server/routers/prometheus.py b/src/dstack/_internal/server/routers/prometheus.py index 28546e1eee..a5538edfec 100644 --- a/src/dstack/_internal/server/routers/prometheus.py +++ b/src/dstack/_internal/server/routers/prometheus.py @@ -1,15 +1,15 @@ import os from typing import Annotated +import prometheus_client from fastapi import APIRouter, Depends from fastapi.responses import PlainTextResponse -from prometheus_client import generate_latest from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.server import settings from dstack._internal.server.db import get_session from dstack._internal.server.security.permissions import OptionalServiceAccount -from dstack._internal.server.services import prometheus +from dstack._internal.server.services.prometheus import custom_metrics from dstack._internal.server.utils.routers import error_not_found _auth = OptionalServiceAccount(os.getenv("DSTACK_PROMETHEUS_AUTH_TOKEN")) @@ -27,6 +27,6 @@ async def get_prometheus_metrics( ) -> str: if not settings.ENABLE_PROMETHEUS_METRICS: raise error_not_found() - custom_metrics = await prometheus.get_metrics(session=session) - prometheus_metrics = generate_latest() - return custom_metrics + prometheus_metrics.decode() + custom_metrics_ = await custom_metrics.get_metrics(session=session) + client_metrics = prometheus_client.generate_latest().decode() + return custom_metrics_ + client_metrics diff --git a/src/dstack/_internal/server/services/prometheus/__init__.py b/src/dstack/_internal/server/services/prometheus/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/server/services/prometheus/client_metrics.py b/src/dstack/_internal/server/services/prometheus/client_metrics.py new file mode 100644 index 0000000000..d25971ef78 --- /dev/null +++ b/src/dstack/_internal/server/services/prometheus/client_metrics.py @@ -0,0 +1,52 @@ +from prometheus_client import Counter, Histogram + + +class RunMetrics: + """Wrapper class for run-related Prometheus metrics.""" + + def __init__(self): + self._submit_to_provision_duration = Histogram( + "dstack_submit_to_provision_duration_seconds", + "Time from when a run has been submitted and first job provisioning", + # Buckets optimized for percentile calculation + buckets=[ + 15, + 30, + 45, + 60, + 90, + 120, + 180, + 240, + 300, + 360, + 420, + 480, + 540, + 600, + 900, + 1200, + 1800, + float("inf"), + ], + labelnames=["project_name", "run_type"], + ) + + self._pending_runs_total = Counter( + "dstack_pending_runs_total", + "Number of pending runs", + labelnames=["project_name", "run_type"], + ) + + def log_submit_to_provision_duration( + self, duration_seconds: float, project_name: str, run_type: str + ): + self._submit_to_provision_duration.labels( + project_name=project_name, run_type=run_type + ).observe(duration_seconds) + + def increment_pending_runs(self, project_name: str, run_type: str): + self._pending_runs_total.labels(project_name=project_name, run_type=run_type).inc() + + +run_metrics = RunMetrics() diff --git a/src/dstack/_internal/server/services/prometheus.py b/src/dstack/_internal/server/services/prometheus/custom_metrics.py similarity index 100% rename from src/dstack/_internal/server/services/prometheus.py rename to src/dstack/_internal/server/services/prometheus/custom_metrics.py diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index d616494b64..9f138ec2a3 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest +from freezegun import freeze_time from pydantic import parse_obj_as from sqlalchemy.ext.asyncio import AsyncSession @@ -30,6 +31,7 @@ get_job_provisioning_data, get_run_spec, ) +from dstack._internal.utils import common pytestmark = pytest.mark.usefixtures("image_config_mock") @@ -80,10 +82,28 @@ async def make_run( class TestProcessRuns: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc)) async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): run = await make_run(session, status=RunStatus.SUBMITTED) await create_job(session=session, run=run, status=JobStatus.PROVISIONING) - await process_runs.process_runs() + current_time = common.get_current_datetime() + + expected_duration = ( + current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc) + ).total_seconds() + + with patch( + "dstack._internal.server.background.tasks.process_runs.run_metrics" + ) as mock_run_metrics: + await process_runs.process_runs() + + mock_run_metrics.log_submit_to_provision_duration.assert_called_once() + args = mock_run_metrics.log_submit_to_provision_duration.call_args[0] + assert args[1] == run.project.name + assert args[2] == "service" + # Assert the duration is close to our expected duration (within 0.05 second tolerance) + assert args[0] == expected_duration + await session.refresh(run) assert run.status == RunStatus.PROVISIONING @@ -103,7 +123,14 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession): run = await make_run(session, status=RunStatus.PROVISIONING) await create_job(session=session, run=run, status=JobStatus.PULLING) - await process_runs.process_runs() + with patch( + "dstack._internal.server.background.tasks.process_runs.run_metrics" + ) as mock_run_metrics: + await process_runs.process_runs() + + mock_run_metrics.log_submit_to_provision_duration.assert_not_called() + mock_run_metrics.increment_pending_runs.assert_not_called() + await session.refresh(run) assert run.status == RunStatus.PROVISIONING @@ -161,9 +188,19 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession): instance=instance, job_provisioning_data=get_job_provisioning_data(), ) - with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + with ( + patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, + patch( + "dstack._internal.server.background.tasks.process_runs.run_metrics" + ) as mock_run_metrics, + ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) await process_runs.process_runs() + + mock_run_metrics.increment_pending_runs.assert_called_once_with( + run.project.name, "service" + ) + await session.refresh(run) assert run.status == RunStatus.PENDING @@ -205,12 +242,29 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession): class TestProcessRunsReplicas: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc)) async def test_submitted_to_provisioning_if_any(self, test_db, session: AsyncSession): run = await make_run(session, status=RunStatus.SUBMITTED, replicas=2) await create_job(session=session, run=run, status=JobStatus.SUBMITTED, replica_num=0) await create_job(session=session, run=run, status=JobStatus.PROVISIONING, replica_num=1) + current_time = common.get_current_datetime() + + expected_duration = ( + current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc) + ).total_seconds() + + with patch( + "dstack._internal.server.background.tasks.process_runs.run_metrics" + ) as mock_run_metrics: + await process_runs.process_runs() + + mock_run_metrics.log_submit_to_provision_duration.assert_called_once() + args = mock_run_metrics.log_submit_to_provision_duration.call_args[0] + assert args[1] == run.project.name + assert args[2] == "service" + assert isinstance(args[0], float) + assert args[0] == expected_duration - await process_runs.process_runs() await session.refresh(run) assert run.status == RunStatus.PROVISIONING @@ -251,9 +305,19 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession): instance=await create_instance(session, project=run.project, spot=True), job_provisioning_data=get_job_provisioning_data(), ) - with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + with ( + patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, + patch( + "dstack._internal.server.background.tasks.process_runs.run_metrics" + ) as mock_run_metrics, + ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) await process_runs.process_runs() + + mock_run_metrics.increment_pending_runs.assert_called_once_with( + run.project.name, "service" + ) + await session.refresh(run) assert run.status == RunStatus.PENDING diff --git a/src/tests/_internal/server/routers/test_prometheus.py b/src/tests/_internal/server/routers/test_prometheus.py index 4cb7c01c2f..1f7e1e274b 100644 --- a/src/tests/_internal/server/routers/test_prometheus.py +++ b/src/tests/_internal/server/routers/test_prometheus.py @@ -100,7 +100,7 @@ def enable_metrics(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.usefixtures("image_config_mock", "test_db", "enable_metrics") class TestGetPrometheusMetrics: - @patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS) + @patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS) async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient): user = await create_user(session=session, name="test-user", global_role=GlobalRole.USER) offer = get_instance_offer_with_availability( @@ -335,7 +335,7 @@ async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient) ) assert response.text.strip() == expected - @patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS) + @patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS) async def test_returns_empty_response_if_no_runs(self, client: AsyncClient): response = await client.get("/metrics") assert response.status_code == 200 diff --git a/src/tests/_internal/server/services/prometheus/__init__.py b/src/tests/_internal/server/services/prometheus/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/services/prometheus/test_client_metrics.py b/src/tests/_internal/server/services/prometheus/test_client_metrics.py new file mode 100644 index 0000000000..9d21ff5360 --- /dev/null +++ b/src/tests/_internal/server/services/prometheus/test_client_metrics.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +from dstack._internal.server.services.prometheus.client_metrics import run_metrics + + +class TestRunMetrics: + def test_log_submit_to_provision_duration(self, monkeypatch): + mock_histogram = MagicMock() + mock_labels = MagicMock() + mock_histogram.labels.return_value = mock_labels + monkeypatch.setattr(run_metrics, "_submit_to_provision_duration", mock_histogram) + + duration = 120.5 + project_name = "test-project" + run_type = "dev" + + run_metrics.log_submit_to_provision_duration(duration, project_name, run_type) + + mock_histogram.labels.assert_called_once_with(project_name=project_name, run_type=run_type) + mock_labels.observe.assert_called_once_with(duration) + + def test_increment_pending_runs(self, monkeypatch): + mock_counter = MagicMock() + mock_labels = MagicMock() + mock_counter.labels.return_value = mock_labels + + monkeypatch.setattr(run_metrics, "_pending_runs_total", mock_counter) + + project_name = "test-project" + run_type = "train" + + run_metrics.increment_pending_runs(project_name, run_type) + mock_counter.labels.assert_called_once_with(project_name=project_name, run_type=run_type) + mock_labels.inc.assert_called_once() + + def test_multiple_calls_to_log_submit_to_provision_duration(self): + run_metrics.log_submit_to_provision_duration(60.0, "project1", "dev") + run_metrics.log_submit_to_provision_duration(120.0, "project1", "prod") + run_metrics.log_submit_to_provision_duration(30.0, "project2", "dev") + + def test_multiple_calls_to_increment_pending_runs(self): + run_metrics.increment_pending_runs("project1", "dev") + run_metrics.increment_pending_runs("project1", "prod") + run_metrics.increment_pending_runs("project2", "dev") + run_metrics.increment_pending_runs("project1", "dev")