diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 2dd410cd28..8147249e34 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -4,8 +4,8 @@ from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_fleets import process_fleets from dstack._internal.server.background.tasks.process_gateways import ( + process_gateways, process_gateways_connections, - process_submitted_gateways, ) from dstack._internal.server.background.tasks.process_instances import ( process_instances, @@ -70,9 +70,7 @@ def start_background_tasks() -> AsyncIOScheduler: ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job( - process_submitted_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 - ) + _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) _scheduler.add_job( process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index e2a17aa151..cd6025a1ae 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -16,6 +16,7 @@ gateway_connections_pool, ) from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker +from dstack._internal.server.services.logging import fmt from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -27,14 +28,14 @@ async def process_gateways_connections(): await _process_active_connections() -async def process_submitted_gateways(): +async def process_gateways(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) async with get_session_ctx() as session: async with lock: res = await session.execute( select(GatewayModel) .where( - GatewayModel.status == GatewayStatus.SUBMITTED, + GatewayModel.status.in_([GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]), GatewayModel.id.not_in(lockset), ) .options(lazyload(GatewayModel.gateway_compute)) @@ -48,7 +49,25 @@ async def process_submitted_gateways(): lockset.add(gateway_model.id) try: gateway_model_id = gateway_model.id - await _process_submitted_gateway(session=session, gateway_model=gateway_model) + initial_status = gateway_model.status + if initial_status == GatewayStatus.SUBMITTED: + await _process_submitted_gateway(session=session, gateway_model=gateway_model) + elif initial_status == GatewayStatus.PROVISIONING: + await _process_provisioning_gateway(session=session, gateway_model=gateway_model) + else: + logger.error( + "%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper() + ) + if gateway_model.status != initial_status: + logger.info( + "%s: gateway status has changed %s -> %s%s", + fmt(gateway_model), + initial_status.upper(), + gateway_model.status.upper(), + f": {gateway_model.status_message}" if gateway_model.status_message else "", + ) + gateway_model.last_processed_at = get_current_datetime() + await session.commit() finally: lockset.difference_update([gateway_model_id]) @@ -89,7 +108,7 @@ async def _process_connection(conn: GatewayConnection): async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel): - logger.info("Started gateway %s provisioning", gateway_model.name) + logger.info("%s: started gateway provisioning", fmt(gateway_model)) # Refetch to load related attributes. # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE. res = await session.execute( @@ -110,8 +129,6 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew except BackendNotAvailable: gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Backend not available" - gateway_model.last_processed_at = get_current_datetime() - await session.commit() return try: @@ -123,53 +140,54 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew ) session.add(gateway_model) gateway_model.status = GatewayStatus.PROVISIONING - await session.commit() - await session.refresh(gateway_model) except BackendError as e: - logger.info( - "Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e) - ) + logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e) gateway_model.status = GatewayStatus.FAILED status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) gateway_model.status_message = status_message - gateway_model.last_processed_at = get_current_datetime() - await session.commit() - return except Exception as e: - logger.exception( - "Got exception when creating gateway compute for gateway %s", gateway_model.name - ) + logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = f"Unexpected error: {repr(e)}" - gateway_model.last_processed_at = get_current_datetime() - await session.commit() - return + +async def _process_provisioning_gateway( + session: AsyncSession, gateway_model: GatewayModel +) -> None: + # Refetch to load related attributes. + # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE. + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway_model.id) + .execution_options(populate_existing=True) + ) + gateway_model = res.unique().scalar_one() + + # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: + # - cannot delete the gateway before it is provisioned because the DB model is locked + # - connection retry counter is reset on server restart + # - only one server replica is processing the gateway + # Easy to fix by doing only one connection/configuration attempt per processing iteration. The + # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid + # maintaining a different model for Sky. connection = await gateways_services.connect_to_gateway_with_retry( gateway_model.gateway_compute ) if connection is None: gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to connect to gateway" - gateway_model.last_processed_at = get_current_datetime() gateway_model.gateway_compute.deleted = True - await session.commit() return - try: await gateways_services.configure_gateway(connection) except Exception: - logger.exception("Failed to configure gateway %s", gateway_model.name) + logger.exception("%s: failed to configure gateway", fmt(gateway_model)) gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to configure gateway" - gateway_model.last_processed_at = get_current_datetime() await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) gateway_model.gateway_compute.active = False - await session.commit() return gateway_model.status = GatewayStatus.RUNNING - gateway_model.last_processed_at = get_current_datetime() - await session.commit() diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 1564592576..fb38393168 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -2,6 +2,7 @@ import datetime import uuid from datetime import timedelta, timezone +from functools import partial from typing import List, Optional, Sequence import httpx @@ -186,6 +187,7 @@ async def create_gateway( return gateway_model_to_gateway(gateway) +# NOTE: dstack Sky imports and uses this function async def connect_to_gateway_with_retry( gateway_compute: GatewayComputeModel, ) -> Optional[GatewayConnection]: @@ -380,6 +382,8 @@ async def get_or_add_gateway_connection( async def init_gateways(session: AsyncSession): res = await session.execute( select(GatewayComputeModel).where( + # FIXME: should not include computes related to gateways in the `provisioning` status. + # Causes warnings and delays when restarting the server during gateway provisioning. GatewayComputeModel.active == True, GatewayComputeModel.deleted == False, ) @@ -421,7 +425,8 @@ async def init_gateways(session: AsyncSession): for gateway_compute, error in await gather_map_async( await gateway_connections_pool.all(), - configure_gateway, + # Need several attempts to handle short gateway downtime after update + partial(configure_gateway, attempts=7), return_exceptions=True, ): if isinstance(error, Exception): @@ -461,7 +466,11 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool: ) > get_current_datetime() - timedelta(seconds=60) -async def configure_gateway(connection: GatewayConnection) -> None: +# NOTE: dstack Sky imports and uses this function +async def configure_gateway( + connection: GatewayConnection, + attempts: int = GATEWAY_CONFIGURE_ATTEMPTS, +) -> None: """ Try submitting gateway config several times in case gateway's HTTP server is not running yet @@ -469,7 +478,7 @@ async def configure_gateway(connection: GatewayConnection) -> None: logger.debug("Configuring gateway %s", connection.ip_address) - for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1): + for attempt in range(attempts - 1): try: async with connection.client() as client: await client.submit_gateway_config() @@ -478,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None: logger.debug( "Failed attempt %s/%s at configuring gateway %s: %r", attempt + 1, - GATEWAY_CONFIGURE_ATTEMPTS, + attempts, connection.ip_address, e, ) diff --git a/src/dstack/_internal/server/services/logging.py b/src/dstack/_internal/server/services/logging.py index 1f2d106a54..545067d6ab 100644 --- a/src/dstack/_internal/server/services/logging.py +++ b/src/dstack/_internal/server/services/logging.py @@ -1,12 +1,14 @@ from typing import Union -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import GatewayModel, JobModel, RunModel -def fmt(model: Union[RunModel, JobModel]) -> str: +def fmt(model: Union[RunModel, JobModel, GatewayModel]) -> str: """Consistent string representation of a model for logging.""" if isinstance(model, RunModel): return f"run({model.id.hex[:6]}){model.run_name}" if isinstance(model, JobModel): return f"job({model.id.hex[:6]}){model.job_name}" + if isinstance(model, GatewayModel): + return f"gateway({model.id.hex[:6]}){model.name}" return str(model) diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py index 159547af4b..3460f18cb9 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py @@ -5,56 +5,48 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus -from dstack._internal.server.background.tasks.process_gateways import process_submitted_gateways +from dstack._internal.server.background.tasks.process_gateways import process_gateways from dstack._internal.server.testing.common import ( AsyncContextManager, ComputeMockSpec, create_backend, create_gateway, + create_gateway_compute, create_project, ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestProcessSubmittedGateways: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_provisions_gateway(self, test_db, session: AsyncSession): + async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, + status=GatewayStatus.SUBMITTED, ) - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m, - patch( - "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" - ) as pool_add, - ): + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: aws = Mock() m.return_value = (backend, aws) - pool_add.return_value = MagicMock() - pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( instance_id="i-1234567890", ip_address="2.2.2.2", region="us", ) - await process_submitted_gateways() + await process_gateways() m.assert_called_once() aws.compute.return_value.create_gateway.assert_called_once() - pool_add.assert_called_once() await session.refresh(gateway) - assert gateway.status == GatewayStatus.RUNNING + assert gateway.status == GatewayStatus.PROVISIONING assert gateway.gateway_compute is not None assert gateway.gateway_compute.ip_address == "2.2.2.2" - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_marks_gateway_as_failed_if_gateway_creation_errors( self, test_db, session: AsyncSession ): @@ -64,6 +56,7 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( session=session, project_id=project.id, backend_id=backend.id, + status=GatewayStatus.SUBMITTED, ) with patch( "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" @@ -72,47 +65,57 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( m.return_value = (backend, aws) aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") - await process_submitted_gateways() + await process_gateways() m.assert_called_once() aws.compute.return_value.create_gateway.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Some error" - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessProvisioningGateways: + async def test_provisioning_to_running(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, + ) + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await process_gateways() + pool_add.assert_called_once() + await session.refresh(gateway) + assert gateway.status == GatewayStatus.RUNNING + async def test_marks_gateway_as_failed_if_fails_to_connect( self, test_db, session: AsyncSession ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, ) - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m, - patch( - "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" - ) as connect_to_gateway_with_retry_mock, - ): - aws = Mock() - m.return_value = (backend, aws) + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_to_gateway_with_retry_mock: connect_to_gateway_with_retry_mock.return_value = None - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( - instance_id="i-1234567890", - ip_address="2.2.2.2", - region="us", - ) - await process_submitted_gateways() - m.assert_called_once() - aws.compute.return_value.create_gateway.assert_called_once() + await process_gateways() connect_to_gateway_with_retry_mock.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED - assert gateway.gateway_compute is not None - assert gateway.gateway_compute is not None + assert gateway.status_message == "Failed to connect to gateway"