Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_fleets import process_fleets
from dstack._internal.server.background.tasks.process_gateways import (
process_gateways,
process_gateways_connections,
process_submitted_gateways,
)
from dstack._internal.server.background.tasks.process_instances import (
process_instances,
Expand Down Expand Up @@ -70,9 +70,7 @@ def start_background_tasks() -> AsyncIOScheduler:
)
_scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1)
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
_scheduler.add_job(
process_submitted_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5
)
_scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5)
_scheduler.add_job(
process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
)
Expand Down
74 changes: 46 additions & 28 deletions src/dstack/_internal/server/background/tasks/process_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
gateway_connections_pool,
)
from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker
from dstack._internal.server.services.logging import fmt
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger

Expand All @@ -27,14 +28,14 @@ async def process_gateways_connections():
await _process_active_connections()


async def process_submitted_gateways():
async def process_gateways():
lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__)
async with get_session_ctx() as session:
async with lock:
res = await session.execute(
select(GatewayModel)
.where(
GatewayModel.status == GatewayStatus.SUBMITTED,
GatewayModel.status.in_([GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]),
GatewayModel.id.not_in(lockset),
)
.options(lazyload(GatewayModel.gateway_compute))
Expand All @@ -48,7 +49,25 @@ async def process_submitted_gateways():
lockset.add(gateway_model.id)
try:
gateway_model_id = gateway_model.id
await _process_submitted_gateway(session=session, gateway_model=gateway_model)
initial_status = gateway_model.status
if initial_status == GatewayStatus.SUBMITTED:
await _process_submitted_gateway(session=session, gateway_model=gateway_model)
elif initial_status == GatewayStatus.PROVISIONING:
await _process_provisioning_gateway(session=session, gateway_model=gateway_model)
else:
logger.error(
"%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper()
)
if gateway_model.status != initial_status:
logger.info(
"%s: gateway status has changed %s -> %s%s",
fmt(gateway_model),
initial_status.upper(),
gateway_model.status.upper(),
f": {gateway_model.status_message}" if gateway_model.status_message else "",
)
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
finally:
lockset.difference_update([gateway_model_id])

Expand Down Expand Up @@ -89,7 +108,7 @@ async def _process_connection(conn: GatewayConnection):


async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel):
logger.info("Started gateway %s provisioning", gateway_model.name)
logger.info("%s: started gateway provisioning", fmt(gateway_model))
# Refetch to load related attributes.
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
res = await session.execute(
Expand All @@ -110,8 +129,6 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
except BackendNotAvailable:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Backend not available"
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return

try:
Expand All @@ -123,53 +140,54 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
)
session.add(gateway_model)
gateway_model.status = GatewayStatus.PROVISIONING
await session.commit()
await session.refresh(gateway_model)
except BackendError as e:
logger.info(
"Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e)
)
logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e)
gateway_model.status = GatewayStatus.FAILED
status_message = f"Backend error: {repr(e)}"
if len(e.args) > 0:
status_message = str(e.args[0])
gateway_model.status_message = status_message
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return
except Exception as e:
logger.exception(
"Got exception when creating gateway compute for gateway %s", gateway_model.name
)
logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = f"Unexpected error: {repr(e)}"
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return


async def _process_provisioning_gateway(
session: AsyncSession, gateway_model: GatewayModel
) -> None:
# Refetch to load related attributes.
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
res = await session.execute(
select(GatewayModel)
.where(GatewayModel.id == gateway_model.id)
.execution_options(populate_existing=True)
)
gateway_model = res.unique().scalar_one()

# FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway:
# - cannot delete the gateway before it is provisioned because the DB model is locked
# - connection retry counter is reset on server restart
# - only one server replica is processing the gateway
# Easy to fix by doing only one connection/configuration attempt per processing iteration. The
# main challenge is applying the same provisioning model to the dstack Sky gateway to avoid
# maintaining a different model for Sky.
connection = await gateways_services.connect_to_gateway_with_retry(
gateway_model.gateway_compute
)
if connection is None:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to connect to gateway"
gateway_model.last_processed_at = get_current_datetime()
gateway_model.gateway_compute.deleted = True
await session.commit()
return

try:
await gateways_services.configure_gateway(connection)
except Exception:
logger.exception("Failed to configure gateway %s", gateway_model.name)
logger.exception("%s: failed to configure gateway", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to configure gateway"
gateway_model.last_processed_at = get_current_datetime()
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
gateway_model.gateway_compute.active = False
await session.commit()
return

gateway_model.status = GatewayStatus.RUNNING
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
17 changes: 13 additions & 4 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import uuid
from datetime import timedelta, timezone
from functools import partial
from typing import List, Optional, Sequence

import httpx
Expand Down Expand Up @@ -186,6 +187,7 @@ async def create_gateway(
return gateway_model_to_gateway(gateway)


# NOTE: dstack Sky imports and uses this function
async def connect_to_gateway_with_retry(
gateway_compute: GatewayComputeModel,
) -> Optional[GatewayConnection]:
Expand Down Expand Up @@ -380,6 +382,8 @@ async def get_or_add_gateway_connection(
async def init_gateways(session: AsyncSession):
res = await session.execute(
select(GatewayComputeModel).where(
# FIXME: should not include computes related to gateways in the `provisioning` status.
# Causes warnings and delays when restarting the server during gateway provisioning.
GatewayComputeModel.active == True,
GatewayComputeModel.deleted == False,
)
Expand Down Expand Up @@ -421,7 +425,8 @@ async def init_gateways(session: AsyncSession):

for gateway_compute, error in await gather_map_async(
await gateway_connections_pool.all(),
configure_gateway,
# Need several attempts to handle short gateway downtime after update
partial(configure_gateway, attempts=7),
return_exceptions=True,
):
if isinstance(error, Exception):
Expand Down Expand Up @@ -461,15 +466,19 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool:
) > get_current_datetime() - timedelta(seconds=60)


async def configure_gateway(connection: GatewayConnection) -> None:
# NOTE: dstack Sky imports and uses this function
async def configure_gateway(
connection: GatewayConnection,
attempts: int = GATEWAY_CONFIGURE_ATTEMPTS,
) -> None:
"""
Try submitting gateway config several times in case gateway's HTTP server is not
running yet
"""

logger.debug("Configuring gateway %s", connection.ip_address)

for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1):
for attempt in range(attempts - 1):
try:
async with connection.client() as client:
await client.submit_gateway_config()
Expand All @@ -478,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
logger.debug(
"Failed attempt %s/%s at configuring gateway %s: %r",
attempt + 1,
GATEWAY_CONFIGURE_ATTEMPTS,
attempts,
connection.ip_address,
e,
)
Expand Down
6 changes: 4 additions & 2 deletions src/dstack/_internal/server/services/logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Union

from dstack._internal.server.models import JobModel, RunModel
from dstack._internal.server.models import GatewayModel, JobModel, RunModel


def fmt(model: Union[RunModel, JobModel]) -> str:
def fmt(model: Union[RunModel, JobModel, GatewayModel]) -> str:
"""Consistent string representation of a model for logging."""
if isinstance(model, RunModel):
return f"run({model.id.hex[:6]}){model.run_name}"
if isinstance(model, JobModel):
return f"job({model.id.hex[:6]}){model.job_name}"
if isinstance(model, GatewayModel):
return f"gateway({model.id.hex[:6]}){model.name}"
return str(model)
Loading
Loading