Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
group_jobs_by_replica_latest,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.prometheus.client_metrics import run_metrics
from dstack._internal.server.services.runs import (
fmt,
process_terminating_run,
Expand Down Expand Up @@ -329,6 +330,24 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
run_model.status.name,
new_status.name,
)
if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING:
current_time = common.get_current_datetime()
submit_to_provision_duration = (
current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc)
).total_seconds()
logger.info(
"%s: run took %.2f seconds from submision to provisioning.",
fmt(run_model),
submit_to_provision_duration,
)
project_name = run_model.project.name
run_metrics.log_submit_to_provision_duration(
submit_to_provision_duration, project_name, run_spec.configuration.type
)

if new_status == RunStatus.PENDING:
run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type)

run_model.status = new_status
run_model.termination_reason = termination_reason
# While a run goes to pending without provisioning, resubmission_attempt increases.
Expand Down
10 changes: 5 additions & 5 deletions src/dstack/_internal/server/routers/prometheus.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
from typing import Annotated

import prometheus_client
from fastapi import APIRouter, Depends
from fastapi.responses import PlainTextResponse
from prometheus_client import generate_latest
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.server import settings
from dstack._internal.server.db import get_session
from dstack._internal.server.security.permissions import OptionalServiceAccount
from dstack._internal.server.services import prometheus
from dstack._internal.server.services.prometheus import custom_metrics
from dstack._internal.server.utils.routers import error_not_found

_auth = OptionalServiceAccount(os.getenv("DSTACK_PROMETHEUS_AUTH_TOKEN"))
Expand All @@ -27,6 +27,6 @@ async def get_prometheus_metrics(
) -> str:
if not settings.ENABLE_PROMETHEUS_METRICS:
raise error_not_found()
custom_metrics = await prometheus.get_metrics(session=session)
prometheus_metrics = generate_latest()
return custom_metrics + prometheus_metrics.decode()
custom_metrics_ = await custom_metrics.get_metrics(session=session)
client_metrics = prometheus_client.generate_latest().decode()
return custom_metrics_ + client_metrics
Empty file.
52 changes: 52 additions & 0 deletions src/dstack/_internal/server/services/prometheus/client_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from prometheus_client import Counter, Histogram


class RunMetrics:
"""Wrapper class for run-related Prometheus metrics."""

def __init__(self):
self._submit_to_provision_duration = Histogram(
"dstack_submit_to_provision_duration_seconds",
"Time from when a run has been submitted and first job provisioning",
# Buckets optimized for percentile calculation
buckets=[
15,
30,
45,
60,
90,
120,
180,
240,
300,
360,
420,
480,
540,
600,
900,
1200,
1800,
float("inf"),
],
labelnames=["project_name", "run_type"],
)

self._pending_runs_total = Counter(
"dstack_pending_runs_total",
"Number of pending runs",
labelnames=["project_name", "run_type"],
)

def log_submit_to_provision_duration(
self, duration_seconds: float, project_name: str, run_type: str
):
self._submit_to_provision_duration.labels(
project_name=project_name, run_type=run_type
).observe(duration_seconds)

def increment_pending_runs(self, project_name: str, run_type: str):
self._pending_runs_total.labels(project_name=project_name, run_type=run_type).inc()


run_metrics = RunMetrics()
74 changes: 69 additions & 5 deletions src/tests/_internal/server/background/tasks/test_process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch

import pytest
from freezegun import freeze_time
from pydantic import parse_obj_as
from sqlalchemy.ext.asyncio import AsyncSession

Expand Down Expand Up @@ -30,6 +31,7 @@
get_job_provisioning_data,
get_run_spec,
)
from dstack._internal.utils import common

pytestmark = pytest.mark.usefixtures("image_config_mock")

Expand Down Expand Up @@ -80,10 +82,28 @@ async def make_run(
class TestProcessRuns:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc))
async def test_submitted_to_provisioning(self, test_db, session: AsyncSession):
run = await make_run(session, status=RunStatus.SUBMITTED)
await create_job(session=session, run=run, status=JobStatus.PROVISIONING)
await process_runs.process_runs()
current_time = common.get_current_datetime()

expected_duration = (
current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc)
).total_seconds()

with patch(
"dstack._internal.server.background.tasks.process_runs.run_metrics"
) as mock_run_metrics:
await process_runs.process_runs()

mock_run_metrics.log_submit_to_provision_duration.assert_called_once()
args = mock_run_metrics.log_submit_to_provision_duration.call_args[0]
assert args[1] == run.project.name
assert args[2] == "service"
# Assert the duration is close to our expected duration (within 0.05 second tolerance)
assert args[0] == expected_duration

await session.refresh(run)
assert run.status == RunStatus.PROVISIONING

Expand All @@ -103,7 +123,14 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession):
run = await make_run(session, status=RunStatus.PROVISIONING)
await create_job(session=session, run=run, status=JobStatus.PULLING)

await process_runs.process_runs()
with patch(
"dstack._internal.server.background.tasks.process_runs.run_metrics"
) as mock_run_metrics:
await process_runs.process_runs()

mock_run_metrics.log_submit_to_provision_duration.assert_not_called()
mock_run_metrics.increment_pending_runs.assert_not_called()

await session.refresh(run)
assert run.status == RunStatus.PROVISIONING

Expand Down Expand Up @@ -161,9 +188,19 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession):
instance=instance,
job_provisioning_data=get_job_provisioning_data(),
)
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
with (
patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock,
patch(
"dstack._internal.server.background.tasks.process_runs.run_metrics"
) as mock_run_metrics,
):
datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3)
await process_runs.process_runs()

mock_run_metrics.increment_pending_runs.assert_called_once_with(
run.project.name, "service"
)

await session.refresh(run)
assert run.status == RunStatus.PENDING

Expand Down Expand Up @@ -205,12 +242,29 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession):
class TestProcessRunsReplicas:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc))
async def test_submitted_to_provisioning_if_any(self, test_db, session: AsyncSession):
run = await make_run(session, status=RunStatus.SUBMITTED, replicas=2)
await create_job(session=session, run=run, status=JobStatus.SUBMITTED, replica_num=0)
await create_job(session=session, run=run, status=JobStatus.PROVISIONING, replica_num=1)
current_time = common.get_current_datetime()

expected_duration = (
current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc)
).total_seconds()

with patch(
"dstack._internal.server.background.tasks.process_runs.run_metrics"
) as mock_run_metrics:
await process_runs.process_runs()

mock_run_metrics.log_submit_to_provision_duration.assert_called_once()
args = mock_run_metrics.log_submit_to_provision_duration.call_args[0]
assert args[1] == run.project.name
assert args[2] == "service"
assert isinstance(args[0], float)
assert args[0] == expected_duration

await process_runs.process_runs()
await session.refresh(run)
assert run.status == RunStatus.PROVISIONING

Expand Down Expand Up @@ -251,9 +305,19 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession):
instance=await create_instance(session, project=run.project, spot=True),
job_provisioning_data=get_job_provisioning_data(),
)
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
with (
patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock,
patch(
"dstack._internal.server.background.tasks.process_runs.run_metrics"
) as mock_run_metrics,
):
datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3)
await process_runs.process_runs()

mock_run_metrics.increment_pending_runs.assert_called_once_with(
run.project.name, "service"
)

await session.refresh(run)
assert run.status == RunStatus.PENDING

Expand Down
4 changes: 2 additions & 2 deletions src/tests/_internal/server/routers/test_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def enable_metrics(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("image_config_mock", "test_db", "enable_metrics")
class TestGetPrometheusMetrics:
@patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS)
@patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS)
async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient):
user = await create_user(session=session, name="test-user", global_role=GlobalRole.USER)
offer = get_instance_offer_with_availability(
Expand Down Expand Up @@ -335,7 +335,7 @@ async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient)
)
assert response.text.strip() == expected

@patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS)
@patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS)
async def test_returns_empty_response_if_no_runs(self, client: AsyncClient):
response = await client.get("/metrics")
assert response.status_code == 200
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import MagicMock

from dstack._internal.server.services.prometheus.client_metrics import run_metrics


class TestRunMetrics:
def test_log_submit_to_provision_duration(self, monkeypatch):
mock_histogram = MagicMock()
mock_labels = MagicMock()
mock_histogram.labels.return_value = mock_labels
monkeypatch.setattr(run_metrics, "_submit_to_provision_duration", mock_histogram)

duration = 120.5
project_name = "test-project"
run_type = "dev"

run_metrics.log_submit_to_provision_duration(duration, project_name, run_type)

mock_histogram.labels.assert_called_once_with(project_name=project_name, run_type=run_type)
mock_labels.observe.assert_called_once_with(duration)

def test_increment_pending_runs(self, monkeypatch):
mock_counter = MagicMock()
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels

monkeypatch.setattr(run_metrics, "_pending_runs_total", mock_counter)

project_name = "test-project"
run_type = "train"

run_metrics.increment_pending_runs(project_name, run_type)
mock_counter.labels.assert_called_once_with(project_name=project_name, run_type=run_type)
mock_labels.inc.assert_called_once()

def test_multiple_calls_to_log_submit_to_provision_duration(self):
run_metrics.log_submit_to_provision_duration(60.0, "project1", "dev")
run_metrics.log_submit_to_provision_duration(120.0, "project1", "prod")
run_metrics.log_submit_to_provision_duration(30.0, "project2", "dev")

def test_multiple_calls_to_increment_pending_runs(self):
run_metrics.increment_pending_runs("project1", "dev")
run_metrics.increment_pending_runs("project1", "prod")
run_metrics.increment_pending_runs("project2", "dev")
run_metrics.increment_pending_runs("project1", "dev")