-
-
Notifications
You must be signed in to change notification settings - Fork 223
Expand file tree
/
Copy path__init__.py
More file actions
404 lines (351 loc) · 13.9 KB
/
__init__.py
File metadata and controls
404 lines (351 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import asyncio
import heapq
import time
from collections.abc import Iterable, Iterator
from typing import Callable, Coroutine, Dict, List, Optional, Tuple
from uuid import UUID
from cachetools import TTLCache
from sqlalchemy import delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.base.configurator import (
Configurator,
StoredBackendRecord,
)
from dstack._internal.core.backends.configurators import (
get_configurator,
list_available_backend_types,
)
from dstack._internal.core.backends.local.backend import LocalBackend
from dstack._internal.core.backends.models import (
AnyBackendConfigWithCreds,
AnyBackendConfigWithoutCreds,
)
from dstack._internal.core.errors import (
BackendAuthError,
BackendError,
BackendInvalidCredentialsError,
BackendNotAvailable,
ResourceExistsError,
ResourceNotExistsError,
ServerClientError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceOfferWithAvailability,
)
from dstack._internal.core.models.runs import Requirements
from dstack._internal.server import settings
from dstack._internal.server.models import BackendModel, DecryptedString, ProjectModel
from dstack._internal.settings import LOCAL_BACKEND_ENABLED
from dstack._internal.utils.common import run_async
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
async def create_backend(
session: AsyncSession,
project: ProjectModel,
config: AnyBackendConfigWithCreds,
) -> AnyBackendConfigWithCreds:
configurator = get_configurator(config.type)
if configurator is None:
raise BackendNotAvailable()
backend = await get_project_backend_by_type(project=project, backend_type=configurator.TYPE)
if backend is not None:
raise ResourceExistsError()
backend = await validate_and_create_backend_model(
project=project, configurator=configurator, config=config
)
session.add(backend)
await session.commit()
return config
async def update_backend(
session: AsyncSession,
project: ProjectModel,
config: AnyBackendConfigWithCreds,
) -> AnyBackendConfigWithCreds:
configurator = get_configurator(config.type)
if configurator is None:
raise BackendNotAvailable()
backend_exists = any(configurator.TYPE == b.type for b in project.backends)
if not backend_exists:
raise ResourceNotExistsError()
backend = await validate_and_create_backend_model(
project=project, configurator=configurator, config=config
)
# FIXME: potentially long write transaction
await session.execute(
update(BackendModel)
.where(
BackendModel.project_id == backend.project_id,
BackendModel.type == backend.type,
)
.values(
config=backend.config,
auth=backend.auth,
)
)
return config
async def validate_and_create_backend_model(
project: ProjectModel,
configurator: Configurator,
config: AnyBackendConfigWithCreds,
) -> BackendModel:
await run_async(
configurator.validate_config, config, default_creds_enabled=settings.DEFAULT_CREDS_ENABLED
)
backend_record = await run_async(
configurator.create_backend,
project_name=project.name,
config=config,
)
return BackendModel(
project_id=project.id,
type=configurator.TYPE,
config=backend_record.config,
auth=DecryptedString(plaintext=backend_record.auth),
)
async def get_backend_config(
project: ProjectModel,
backend_type: BackendType,
) -> Optional[AnyBackendConfigWithCreds]:
configurator = get_configurator(backend_type)
if configurator is None:
raise BackendNotAvailable()
for backend_model in project.backends:
if not backend_model.auth.decrypted:
logger.warning(
"Failed to decrypt creds for %s backend. Backend will be ignored.",
backend_model.type.value,
)
continue
if backend_model.type == backend_type:
return get_backend_config_with_creds_from_backend_model(configurator, backend_model)
return None
def get_backend_config_with_creds_from_backend_model(
configurator: Configurator,
backend_model: BackendModel,
) -> AnyBackendConfigWithCreds:
backend_record = get_stored_backend_record(backend_model)
backend_config = configurator.get_backend_config_with_creds(backend_record)
return backend_config
def get_backend_config_without_creds_from_backend_model(
configurator: Configurator,
backend_model: BackendModel,
) -> AnyBackendConfigWithoutCreds:
backend_record = get_stored_backend_record(backend_model)
backend_config = configurator.get_backend_config_without_creds(backend_record)
return backend_config
def get_stored_backend_record(backend_model: BackendModel) -> StoredBackendRecord:
return StoredBackendRecord(
config=backend_model.config,
auth=backend_model.auth.get_plaintext_or_error(),
project_id=backend_model.project_id,
backend_id=backend_model.id,
)
async def delete_backends(
session: AsyncSession,
project: ProjectModel,
backends_types: List[BackendType],
):
if BackendType.DSTACK in backends_types:
raise ServerClientError("Cannot delete dstack backend")
current_backends_types = set(b.type for b in project.backends)
deleted_backends_types = current_backends_types.intersection(backends_types)
if len(deleted_backends_types) == 0:
return
# FIXME: potentially long write transaction
# Not urgent since backend deletion is a rare operation
await session.execute(
delete(BackendModel).where(
BackendModel.type.in_(deleted_backends_types),
BackendModel.project_id == project.id,
)
)
logger.info(
"Deleted backends %s in project %s",
[b.value for b in deleted_backends_types],
project.name,
)
BackendTuple = Tuple[BackendModel, Backend]
_BACKENDS_CACHE_LOCKS: Dict[UUID, asyncio.Lock] = {}
_BACKENDS_CACHE = TTLCache[UUID, Dict[BackendType, BackendTuple]](maxsize=1000, ttl=300)
def _get_project_cache_lock(project_id: UUID) -> asyncio.Lock:
return _BACKENDS_CACHE_LOCKS.setdefault(project_id, asyncio.Lock())
async def get_project_backends_with_models(project: ProjectModel) -> List[BackendTuple]:
async with _get_project_cache_lock(project.id):
key = project.id
project_backends = _BACKENDS_CACHE.get(key, {})
for backend_model in project.backends:
cached_backend = project_backends.get(backend_model.type)
if (
cached_backend is not None
and cached_backend[0].config == backend_model.config
and cached_backend[0].auth == backend_model.auth
):
continue
configurator = get_configurator(backend_model.type)
if configurator is None:
logger.warning(
"Missing dependencies for %s backend. Backend will be ignored.",
backend_model.type.value,
)
continue
if not backend_model.auth.decrypted:
logger.warning(
"Failed to decrypt creds for %s backend. Backend will be ignored.",
backend_model.type.value,
)
continue
try:
backend_record = get_stored_backend_record(backend_model)
backend = await run_async(configurator.get_backend, backend_record)
except (BackendInvalidCredentialsError, BackendAuthError):
logger.warning(
"Credentials for %s backend are invalid. Backend will be ignored.",
backend_model.type.value,
)
continue
project_backends[backend_model.type] = (backend_model, backend)
# `__setitem__()` will also expire the cache.
# Note that there is no global cache lock so a race condition is possible:
# one coroutine updates/re-assigns backends expired by another coroutine.
# This is ok since the only effect is that project's cache gets restored.
_BACKENDS_CACHE[key] = project_backends
return list(project_backends.values())
_get_project_backend_with_model_by_type = None
def set_get_project_backend_with_model_by_type(
func: Callable[[ProjectModel, BackendType], Coroutine[None, None, Optional[BackendTuple]]],
):
"""
Overrides `get_project_effective_backend_with_model_by_type` with `func`.
Then get_project_backend_* functions can pass overrides=True to call `func`
This can be used if a backend needs to be replaced with another backend.
For example, DstackBackend in dstack Sky can be used in place of other backends.
"""
global _get_project_backend_with_model_by_type
_get_project_backend_with_model_by_type = func
async def get_project_backend_with_model_by_type(
project: ProjectModel,
backend_type: BackendType,
overrides: bool = False,
) -> Optional[BackendTuple]:
if overrides and _get_project_backend_with_model_by_type is not None:
return await _get_project_backend_with_model_by_type(project, backend_type)
backends_with_models = await get_project_backends_with_models(project=project)
for backend_model, backend in backends_with_models:
if backend.TYPE == backend_type:
return backend_model, backend
return None
async def get_project_backend_with_model_by_type_or_error(
project: ProjectModel,
backend_type: BackendType,
overrides: bool = False,
) -> BackendTuple:
backend_with_model = await get_project_backend_with_model_by_type(
project=project,
backend_type=backend_type,
overrides=overrides,
)
if backend_with_model is None:
raise BackendNotAvailable()
return backend_with_model
async def get_project_backends(project: ProjectModel) -> List[Backend]:
backends_with_models = await get_project_backends_with_models(project)
backends = [b for _, b in backends_with_models]
if LOCAL_BACKEND_ENABLED:
backends.append(LocalBackend())
return backends
async def get_project_backend_by_type(
project: ProjectModel,
backend_type: BackendType,
overrides: bool = False,
) -> Optional[Backend]:
backend_with_model = await get_project_backend_with_model_by_type(
project=project,
backend_type=backend_type,
overrides=overrides,
)
if backend_with_model is None:
return None
return backend_with_model[1]
async def get_project_backend_by_type_or_error(
project: ProjectModel,
backend_type: BackendType,
overrides: bool = False,
) -> Backend:
backend = await get_project_backend_by_type(
project=project,
backend_type=backend_type,
overrides=overrides,
)
if backend is None:
raise BackendNotAvailable()
return backend
async def get_project_backend_model_by_type(
project: ProjectModel, backend_type: BackendType
) -> Optional[BackendModel]:
for backend in project.backends:
if backend.type == backend_type:
return backend
return None
async def get_project_backend_model_by_type_or_error(
project: ProjectModel, backend_type: BackendType
) -> BackendModel:
backend_model = await get_project_backend_model_by_type(
project=project, backend_type=backend_type
)
if backend_model is None:
raise BackendNotAvailable()
return backend_model
async def get_backend_offers(
backends: List[Backend],
requirements: Requirements,
exclude_not_available: bool = False,
) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]:
"""
Yields backend offers satisfying `requirements` sorted by price.
"""
def get_filtered_offers_with_backends(
backend: Backend,
offers: Iterable[InstanceOfferWithAvailability],
) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]:
for offer in offers:
if not exclude_not_available or offer.availability.is_available():
yield (backend, offer)
logger.debug("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends]
offers_by_backend = []
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):
if isinstance(result, BackendError):
logger.warning(
"Failed to get offers from backend %s: %s",
backend.TYPE,
repr(result),
)
continue
elif isinstance(result, BaseException):
logger.error(
"Got exception when requesting offers from backend %s",
backend.TYPE,
exc_info=result,
)
continue
offers_by_backend.append(get_filtered_offers_with_backends(backend, result))
# Merge preserving order for every backend.
offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price)
return offers
def check_backend_type_available(backend_type: BackendType):
if backend_type not in list_available_backend_types():
raise BackendNotAvailable(
f"Backend {backend_type.value} not available."
" Ensure that backend dependencies are installed."
f" Available backends: {[b.value for b in list_available_backend_types()]}."
)
def get_offers_tracked(
backend: Backend, requirements: Requirements
) -> Iterator[InstanceOfferWithAvailability]:
start = time.time()
res = backend.compute().get_offers(requirements)
duration = time.time() - start
logger.debug("Got offers from %s in %.6fs", backend.TYPE.value, duration)
return res