diff --git a/docs/docs/concepts/secrets.md b/docs/docs/concepts/secrets.md new file mode 100644 index 0000000000..d541f281a6 --- /dev/null +++ b/docs/docs/concepts/secrets.md @@ -0,0 +1,125 @@ +# Secrets + +Secrets allow centralized management of sensitive values such as API keys and credentials. They are project-scoped, managed by project admins, and can be referenced in run configurations to pass sensitive values to runs in a secure manner. + +!!! info "Secrets encryption" + By default, secrets are stored in plaintext in the DB. + Configure [server encryption](../guides/server-deployment.md#encryption) to store secrets encrypted. + +## Manage secrets + +### Set + +Use the `dstack secret set` command to create a new secret: + +
+ +```shell +$ dstack secret set my_secret some_secret_value +OK +``` + +
+ +The same command can be used to update an existing secret: + +
+ +```shell +$ dstack secret set my_secret another_secret_value +OK +``` + +
+ +### List + +Use the `dstack secret list` command to list all secrets set in a project: + +
+ +```shell +$ dstack secret + NAME VALUE + hf_token ****** + my_secret ****** + +``` + +
+ +### Get + +The `dstack secret list` does not show secret values. To see a secret value, use the `dstack secret get` command: + +
+ +```shell +$ dstack secret get my_secret + NAME VALUE + my_secret some_secret_value + +``` + +
+ +### Delete + +Secrets can be deleted using the `dstack secret delete` command: + +
+ +```shell +$ dstack secret delete my_secret +Delete the secret my_secret? [y/n]: y +OK +``` + +
+ +## Use secrets + +You can use the `${{ secrets. }}` syntax to reference secrets in run configurations. Currently, secrets interpolation is supported in `env` and `registry_auth` properties. + +### `env` + +Suppose you need to pass a sensitive environment variable to a run such as `HF_TOKEN`. You'd first create a secret holding the environment variable value: + +
+ +```shell +$ dstack secret set hf_token {hf_token_value} +OK +``` + +
+ +and then reference the secret in `env`: + +
+ +```yaml +type: service +env: + - HF_TOKEN=${{ secrets.hf_token }} +commands: + ... +``` + +
+ +### `registry_auth` + +If you need to pull a private Docker image, you can store registry credentials as secrets and reference them in `registry_auth`: + +
+ +```yaml +type: service +image: nvcr.io/nim/deepseek-ai/deepseek-r1-distill-llama-8b +registry_auth: + username: $oauthtoken + password: ${{ secrets.ngc_api_key }} +``` + +
diff --git a/docs/docs/reference/cli/dstack/secret.md b/docs/docs/reference/cli/dstack/secret.md new file mode 100644 index 0000000000..9044cc37f3 --- /dev/null +++ b/docs/docs/reference/cli/dstack/secret.md @@ -0,0 +1,61 @@ +# dstack secret + +The `dstack secret` commands allow managing [Secrets](../../../concepts/secrets.md). + +## dstack secret set + +The `dstack secret set` command creates a new secret or updates an existing one. + +##### Usage + +
+ +```shell +$ dstack secret set --help +#GENERATE# +``` + +
+ +## dstack secret list + +The `dstack secret list` command lists all secrets set in a project. +##### Usage + +
+ +```shell +$ dstack secret list --help +#GENERATE# +``` + +
+ +## dstack secret get + +The `dstack secret get` command show the value of a specified secret. +##### Usage + +
+ +```shell +$ dstack secret get --help +#GENERATE# +``` + +
+ +## dstack secret delete + +The `dstack secret delete` command deletes the specified secret. + +##### Usage + +
+ +```shell +$ dstack secret delete --help +#GENERATE# +``` + +
diff --git a/mkdocs.yml b/mkdocs.yml index 2f2b7b4d06..ab7666f8b9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -221,6 +221,7 @@ nav: - Fleets: docs/concepts/fleets.md - Volumes: docs/concepts/volumes.md - Repos: docs/concepts/repos.md + - Secrets: docs/concepts/secrets.md - Projects: docs/concepts/projects.md - Gateways: docs/concepts/gateways.md - Guides: @@ -254,6 +255,7 @@ nav: - dstack offer: docs/reference/cli/dstack/offer.md - dstack volume: docs/reference/cli/dstack/volume.md - dstack gateway: docs/reference/cli/dstack/gateway.md + - dstack secret: docs/reference/cli/dstack/secret.md - API: - Python API: docs/reference/api/python/index.md - REST API: docs/reference/api/rest/index.md diff --git a/src/dstack/_internal/cli/commands/secrets.py b/src/dstack/_internal/cli/commands/secrets.py new file mode 100644 index 0000000000..9d42201ed1 --- /dev/null +++ b/src/dstack/_internal/cli/commands/secrets.py @@ -0,0 +1,92 @@ +import argparse + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import SecretNameCompleter +from dstack._internal.cli.utils.common import ( + confirm_ask, + console, +) +from dstack._internal.cli.utils.secrets import print_secrets_table + + +class SecretCommand(APIBaseCommand): + NAME = "secret" + DESCRIPTION = "Manage secrets" + + def _register(self): + super()._register() + self._parser.set_defaults(subfunc=self._list) + subparsers = self._parser.add_subparsers(dest="action") + + list_parser = subparsers.add_parser( + "list", help="List secrets", formatter_class=self._parser.formatter_class + ) + list_parser.set_defaults(subfunc=self._list) + + get_parser = subparsers.add_parser( + "get", help="Get secret value", formatter_class=self._parser.formatter_class + ) + get_parser.add_argument( + "name", + help="The name of the secret", + ).completer = SecretNameCompleter() + get_parser.set_defaults(subfunc=self._get) + + set_parser = subparsers.add_parser( + "set", help="Set secret", formatter_class=self._parser.formatter_class + ) + set_parser.add_argument( + "name", + help="The name of the secret", + ) + set_parser.add_argument( + "value", + help="The value of the secret", + ) + set_parser.set_defaults(subfunc=self._set) + + delete_parser = subparsers.add_parser( + "delete", + help="Delete secrets", + formatter_class=self._parser.formatter_class, + ) + delete_parser.add_argument( + "name", + help="The name of the secret", + ).completer = SecretNameCompleter() + delete_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) + delete_parser.set_defaults(subfunc=self._delete) + + def _command(self, args: argparse.Namespace): + super()._command(args) + args.subfunc(args) + + def _list(self, args: argparse.Namespace): + secrets = self.api.client.secrets.list(self.api.project) + print_secrets_table(secrets) + + def _get(self, args: argparse.Namespace): + secret = self.api.client.secrets.get(self.api.project, name=args.name) + print_secrets_table([secret]) + + def _set(self, args: argparse.Namespace): + self.api.client.secrets.create_or_update( + self.api.project, + name=args.name, + value=args.value, + ) + console.print("[grey58]OK[/]") + + def _delete(self, args: argparse.Namespace): + if not args.yes and not confirm_ask(f"Delete the secret [code]{args.name}[/]?"): + console.print("\nExiting...") + return + + with console.status("Deleting secret..."): + self.api.client.secrets.delete( + project_name=self.api.project, + names=[args.name], + ) + console.print("[grey58]OK[/]") diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 345a9a16e6..c91d0f2feb 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -17,6 +17,7 @@ from dstack._internal.cli.commands.offer import OfferCommand from dstack._internal.cli.commands.project import ProjectCommand from dstack._internal.cli.commands.ps import PsCommand +from dstack._internal.cli.commands.secrets import SecretCommand from dstack._internal.cli.commands.server import ServerCommand from dstack._internal.cli.commands.stats import StatsCommand from dstack._internal.cli.commands.stop import StopCommand @@ -72,6 +73,7 @@ def main(): MetricsCommand.register(subparsers) ProjectCommand.register(subparsers) PsCommand.register(subparsers) + SecretCommand.register(subparsers) ServerCommand.register(subparsers) StatsCommand.register(subparsers) StopCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/completion.py b/src/dstack/_internal/cli/services/completion.py index ed8ce26ad0..ed63d82783 100644 --- a/src/dstack/_internal/cli/services/completion.py +++ b/src/dstack/_internal/cli/services/completion.py @@ -75,6 +75,11 @@ def fetch_resource_names(self, api: Client) -> Iterable[str]: return [r.name for r in api.client.gateways.list(api.project)] +class SecretNameCompleter(BaseAPINameCompleter): + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.client.secrets.list(api.project)] + + class ProjectNameCompleter(BaseCompleter): """ Completer for local project names. diff --git a/src/dstack/_internal/cli/utils/secrets.py b/src/dstack/_internal/cli/utils/secrets.py new file mode 100644 index 0000000000..5fcbb5a99a --- /dev/null +++ b/src/dstack/_internal/cli/utils/secrets.py @@ -0,0 +1,25 @@ +from typing import List + +from rich.table import Table + +from dstack._internal.cli.utils.common import add_row_from_dict, console +from dstack._internal.core.models.secrets import Secret + + +def print_secrets_table(secrets: List[Secret]) -> None: + console.print(get_secrets_table(secrets)) + console.print() + + +def get_secrets_table(secrets: List[Secret]) -> Table: + table = Table(box=None) + table.add_column("NAME", no_wrap=True) + table.add_column("VALUE") + + for secret in secrets: + row = { + "NAME": secret.name, + "VALUE": secret.value or "*" * 6, + } + add_row_from_dict(table, row) + return table diff --git a/src/dstack/_internal/core/models/secrets.py b/src/dstack/_internal/core/models/secrets.py index 86c9f93781..ab3f411290 100644 --- a/src/dstack/_internal/core/models/secrets.py +++ b/src/dstack/_internal/core/models/secrets.py @@ -1,9 +1,16 @@ +from typing import Optional +from uuid import UUID + from dstack._internal.core.models.common import CoreModel class Secret(CoreModel): + id: UUID name: str - value: str + value: Optional[str] = None def __str__(self) -> str: - return f'Secret(name="{self.name}", value={"*" * len(self.value)})' + displayed_value = "*" + if self.value is not None: + displayed_value = "*" * len(self.value) + return f'Secret(name="{self.name}", value={displayed_value})' diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index ef07becdfd..5249efe5d2 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -70,9 +70,10 @@ from dstack._internal.server.services.runs import ( run_model_to_run, ) +from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.storage import get_default_storage from dstack._internal.utils import common as common_utils -from dstack._internal.utils.interpolator import VariablesInterpolator +from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -181,7 +182,17 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): common_utils.get_or_error(job_model.instance) ) - secrets = {} # TODO secrets + secrets = await get_project_secrets_mapping(session=session, project=project) + + try: + _interpolate_secrets(secrets, job.job_spec) + except InterpolatorError as e: + logger.info("%s: terminating due to secrets interpolation error", fmt(job_model)) + job_model.status = JobStatus.TERMINATING + job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER + job_model.termination_reason_message = e.args[0] + job_model.last_processed_at = common_utils.get_current_datetime() + return repo_creds_model = await get_repo_creds(session=session, repo=repo_model, user=run_model.user) repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds @@ -218,7 +229,6 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job_model, job_provisioning_data, volumes, - secrets, job.job_spec.registry_auth, public_keys, ssh_user, @@ -327,8 +337,9 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): else: if job_model.termination_reason: logger.warning( - "%s: failed because shim/runner returned an error, age=%s", + "%s: failed due to %s, age=%s", fmt(job_model), + job_model.termination_reason.value, job_submission.age, ) job_model.status = JobStatus.TERMINATING @@ -471,7 +482,6 @@ def _process_provisioning_with_shim( job_model: JobModel, job_provisioning_data: JobProvisioningData, volumes: List[Volume], - secrets: Dict[str, str], registry_auth: Optional[RegistryAuth], public_keys: List[str], ssh_user: str, @@ -497,10 +507,8 @@ def _process_provisioning_with_shim( registry_username = "" registry_password = "" if registry_auth is not None: - logger.debug("%s: authenticating to the registry...", fmt(job_model)) - interpolate = VariablesInterpolator({"secrets": secrets}).interpolate - registry_username = interpolate(registry_auth.username) - registry_password = interpolate(registry_auth.password) + registry_username = registry_auth.username + registry_password = registry_auth.password volume_mounts: List[VolumeMountPoint] = [] instance_mounts: List[InstanceMountPoint] = [] @@ -957,7 +965,9 @@ def _submit_job_to_runner( run=run, job=job, cluster_info=cluster_info, - secrets=secrets, + # Do not send all the secrets since interpolation is already done by the server. + # TODO: Passing secrets may be necessary for filtering out secret values from logs. + secrets={}, repo_credentials=repo_credentials, instance_env=instance_env, ) @@ -975,6 +985,16 @@ def _submit_job_to_runner( return True +def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec): + interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error + job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} + if job_spec.registry_auth is not None: + job_spec.registry_auth = RegistryAuth( + username=interpolate(job_spec.registry_auth.username), + password=interpolate(job_spec.registry_auth.password), + ) + + def _get_instance_specific_mounts( backend_type: BackendType, instance_type_name: str ) -> List[InstanceMountPoint]: diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 105f9e8c9d..5a2d596664 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -35,6 +35,7 @@ run_model_to_run, scale_run_replicas, ) +from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count from dstack._internal.utils import common from dstack._internal.utils.logging import get_logger @@ -404,7 +405,11 @@ async def _handle_run_replicas( ) return - await _update_jobs_to_new_deployment_in_place(run_model, run_spec) + await _update_jobs_to_new_deployment_in_place( + session=session, + run_model=run_model, + run_spec=run_spec, + ) if _has_out_of_date_replicas(run_model): non_terminated_replica_count = len( {j.replica_num for j in run_model.jobs if not j.status.is_finished()} @@ -444,18 +449,25 @@ async def _handle_run_replicas( ) -async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None: +async def _update_jobs_to_new_deployment_in_place( + session: AsyncSession, run_model: RunModel, run_spec: RunSpec +) -> None: """ Bump deployment_num for jobs that do not require redeployment. """ - + secrets = await get_project_secrets_mapping( + session=session, + project=run_model.project, + ) for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( run_spec=run_spec, + secrets=secrets, replica_num=replica_num, ) assert len(new_job_specs) == len(job_models), ( diff --git a/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py b/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py new file mode 100644 index 0000000000..6563e92dc9 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py @@ -0,0 +1,49 @@ +"""Add SecretModel + +Revision ID: 644b8a114187 +Revises: 5f1707c525d2 +Create Date: 2025-06-30 11:00:04.326290 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "644b8a114187" +down_revision = "5f1707c525d2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "secrets", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column("updated_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column("name", sa.String(length=200), nullable=False), + sa.Column("value", dstack._internal.server.models.EncryptedString(), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_secrets_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_secrets")), + sa.UniqueConstraint("project_id", "name", name="uq_secrets_project_id_name"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("secrets") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index dd8fd5f125..d39d07be10 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -726,3 +726,21 @@ class JobPrometheusMetrics(BaseModel): collected_at: Mapped[datetime] = mapped_column(NaiveDateTime) # Raw Prometheus text response text: Mapped[str] = mapped_column(Text) + + +class SecretModel(BaseModel): + __tablename__ = "secrets" + __table_args__ = (UniqueConstraint("project_id", "name", name="uq_secrets_project_id_name"),) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship() + + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + updated_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + + name: Mapped[str] = mapped_column(String(200)) + value: Mapped[DecryptedString] = mapped_column(EncryptedString()) diff --git a/src/dstack/_internal/server/routers/secrets.py b/src/dstack/_internal/server/routers/secrets.py index 3732a0b1ca..bbfa26be93 100644 --- a/src/dstack/_internal/server/routers/secrets.py +++ b/src/dstack/_internal/server/routers/secrets.py @@ -1,15 +1,19 @@ -from typing import List +from typing import List, Tuple -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import Run +from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.secrets import Secret +from dstack._internal.server.db import get_session +from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.secrets import ( - AddSecretRequest, + CreateOrUpdateSecretRequest, DeleteSecretsRequest, - GetSecretsRequest, - ListSecretsRequest, + GetSecretRequest, ) +from dstack._internal.server.security.permissions import ProjectAdmin +from dstack._internal.server.services import secrets as secrets_services router = APIRouter( prefix="/api/project/{project_name}/secrets", @@ -18,20 +22,58 @@ @router.post("/list") -async def list_secrets(project_name: str, body: ListSecretsRequest) -> List[Run]: - pass +async def list_secrets( + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> List[Secret]: + _, project = user_project + return await secrets_services.list_secrets( + session=session, + project=project, + ) @router.post("/get") -async def get_secret(project_name: str, body: GetSecretsRequest) -> Secret: - pass +async def get_secret( + body: GetSecretRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> Secret: + _, project = user_project + secret = await secrets_services.get_secret( + session=session, + project=project, + name=body.name, + ) + if secret is None: + raise ResourceNotExistsError() + return secret -@router.post("/add") -async def add_or_update_secret(project_name: str, body: AddSecretRequest) -> Secret: - pass +@router.post("/create_or_update") +async def create_or_update_secret( + body: CreateOrUpdateSecretRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> Secret: + _, project = user_project + return await secrets_services.create_or_update_secret( + session=session, + project=project, + name=body.name, + value=body.value, + ) @router.post("/delete") -async def delete_secrets(project_name: str, body: DeleteSecretsRequest): - pass +async def delete_secrets( + body: DeleteSecretsRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +): + _, project = user_project + await secrets_services.delete_secrets( + session=session, + project=project, + names=body.secrets_names, + ) diff --git a/src/dstack/_internal/server/schemas/secrets.py b/src/dstack/_internal/server/schemas/secrets.py index 769c87052c..a8d78ea071 100644 --- a/src/dstack/_internal/server/schemas/secrets.py +++ b/src/dstack/_internal/server/schemas/secrets.py @@ -1,20 +1,16 @@ from typing import List -from dstack._internal.core.models.secrets import Secret -from dstack._internal.server.schemas.common import RepoRequest +from dstack._internal.core.models.common import CoreModel -class ListSecretsRequest(RepoRequest): - pass +class GetSecretRequest(CoreModel): + name: str -class GetSecretsRequest(RepoRequest): - pass +class CreateOrUpdateSecretRequest(CoreModel): + name: str + value: str -class AddSecretRequest(RepoRequest): - secret: Secret - - -class DeleteSecretsRequest(RepoRequest): +class DeleteSecretsRequest(CoreModel): secrets_names: List[str] diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 979daa44c2..41aa496be7 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -65,15 +65,23 @@ logger = get_logger(__name__) -async def get_jobs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[Job]: +async def get_jobs_from_run_spec( + run_spec: RunSpec, secrets: Dict[str, str], replica_num: int +) -> List[Job]: return [ Job(job_spec=s, job_submissions=[]) - for s in await get_job_specs_from_run_spec(run_spec, replica_num) + for s in await get_job_specs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + ) ] -async def get_job_specs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[JobSpec]: - job_configurator = _get_job_configurator(run_spec) +async def get_job_specs_from_run_spec( + run_spec: RunSpec, secrets: Dict[str, str], replica_num: int +) -> List[JobSpec]: + job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets) job_specs = await job_configurator.get_job_specs(replica_num=replica_num) return job_specs @@ -159,10 +167,10 @@ def delay_job_instance_termination(job_model: JobModel): job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15) -def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator: +def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator: configuration_type = RunConfigurationType(run_spec.configuration.type) configurator_class = _configuration_type_to_configurator_class_map[configuration_type] - return configurator_class(run_spec) + return configurator_class(run_spec=run_spec, secrets=secrets) _job_configurator_classes = [ diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 465a24fb0d..349f2bcf02 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -68,8 +68,13 @@ class JobConfigurator(ABC): # JobSSHKey should be shared for all jobs in a replica for inter-node communication. _job_ssh_key: Optional[JobSSHKey] = None - def __init__(self, run_spec: RunSpec): + def __init__( + self, + run_spec: RunSpec, + secrets: Optional[Dict[str, str]] = None, + ): self.run_spec = run_spec + self.secrets = secrets or {} async def get_job_specs(self, replica_num: int) -> List[JobSpec]: job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1) @@ -98,10 +103,20 @@ def _ports(self) -> List[PortMapping]: async def _get_image_config(self) -> ImageConfig: if self._image_config is not None: return self._image_config + interpolate = VariablesInterpolator({"secrets": self.secrets}).interpolate_or_error + registry_auth = self.run_spec.configuration.registry_auth + if registry_auth is not None: + try: + registry_auth = RegistryAuth( + username=interpolate(registry_auth.username), + password=interpolate(registry_auth.password), + ) + except InterpolatorError as e: + raise ServerClientError(e.args[0]) image_config = await run_async( _get_image_config, self._image_name(), - self.run_spec.configuration.registry_auth, + registry_auth, ) self._image_config = image_config return image_config diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index a2d96cf1a4..a10922ef79 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType @@ -17,7 +17,7 @@ class DevEnvironmentJobConfigurator(JobConfigurator): TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT - def __init__(self, run_spec: RunSpec): + def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]): if run_spec.configuration.ide == "vscode": __class = VSCodeDesktop elif run_spec.configuration.ide == "cursor": @@ -29,7 +29,7 @@ def __init__(self, run_spec: RunSpec): version=run_spec.configuration.version, extensions=["ms-python.python", "ms-toolsai.jupyter"], ) - super().__init__(run_spec) + super().__init__(run_spec=run_spec, secrets=secrets) def _shell_commands(self) -> List[str]: commands = self.ide.get_install_commands() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 385660f944..bb775bdf29 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -82,6 +82,7 @@ from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.services.resources import set_resources_defaults +from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.users import get_user_model_by_name from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -311,7 +312,12 @@ async def get_plan( ): action = ApplyAction.UPDATE - jobs = await get_jobs_from_run_spec(effective_run_spec, replica_num=0) + secrets = await get_project_secrets_mapping(session=session, project=project) + jobs = await get_jobs_from_run_spec( + run_spec=effective_run_spec, + secrets=secrets, + replica_num=0, + ) volumes = await get_job_configured_volumes( session=session, @@ -462,6 +468,10 @@ async def submit_run( project=project, run_spec=run_spec, ) + secrets = await get_project_secrets_mapping( + session=session, + project=project, + ) lock_namespace = f"run_names_{project.name}" if get_db().dialect_name == "sqlite": @@ -513,7 +523,11 @@ async def submit_run( await services.register_service(session, run_model, run_spec) for replica_num in range(replicas): - jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num) + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + ) for job in jobs: job_model = create_job_model_for_new_submission( run_model=run_model, @@ -1068,10 +1082,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) scheduled_replicas += 1 + secrets = await get_project_secrets_mapping( + session=session, + project=run_model.project, + ) + for replica_num in range( len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff ): - jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num) + # FIXME: Handle getting image configuration errors or skip it. + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + ) for job in jobs: job_model = create_job_model_for_new_submission( run_model=run_model, @@ -1084,8 +1108,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica async def retry_run_replica_jobs( session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool ): + # FIXME: Handle getting image configuration errors or skip it. + secrets = await get_project_secrets_mapping( + session=session, + project=run_model.project, + ) new_jobs = await get_jobs_from_run_spec( - RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + secrets=secrets, replica_num=latest_jobs[0].replica_num, ) assert len(new_jobs) == len(latest_jobs), ( diff --git a/src/dstack/_internal/server/services/secrets.py b/src/dstack/_internal/server/services/secrets.py new file mode 100644 index 0000000000..0439f9e688 --- /dev/null +++ b/src/dstack/_internal/server/services/secrets.py @@ -0,0 +1,204 @@ +import re +from typing import Dict, List, Optional + +import sqlalchemy.exc +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ( + ResourceExistsError, + ResourceNotExistsError, + ServerClientError, +) +from dstack._internal.core.models.secrets import Secret +from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +_SECRET_NAME_REGEX = "^[A-Za-z0-9-_]{1,200}$" +_SECRET_VALUE_MAX_LENGTH = 2000 + + +async def list_secrets( + session: AsyncSession, + project: ProjectModel, +) -> List[Secret]: + secret_models = await list_project_secret_models(session=session, project=project) + return [secret_model_to_secret(s, include_value=False) for s in secret_models] + + +async def get_project_secrets_mapping( + session: AsyncSession, + project: ProjectModel, +) -> Dict[str, str]: + secret_models = await list_project_secret_models(session=session, project=project) + return {s.name: s.value.get_plaintext_or_error() for s in secret_models} + + +async def get_secret( + session: AsyncSession, + project: ProjectModel, + name: str, +) -> Optional[Secret]: + secret_model = await get_project_secret_model_by_name( + session=session, + project=project, + name=name, + ) + if secret_model is None: + return None + return secret_model_to_secret(secret_model, include_value=True) + + +async def create_or_update_secret( + session: AsyncSession, + project: ProjectModel, + name: str, + value: str, +) -> Secret: + _validate_secret(name=name, value=value) + try: + secret_model = await create_secret( + session=session, + project=project, + name=name, + value=value, + ) + except ResourceExistsError: + secret_model = await update_secret( + session=session, + project=project, + name=name, + value=value, + ) + return secret_model_to_secret(secret_model, include_value=True) + + +async def delete_secrets( + session: AsyncSession, + project: ProjectModel, + names: List[str], +): + existing_secrets_query = await session.execute( + select(SecretModel).where( + SecretModel.project_id == project.id, + SecretModel.name.in_(names), + ) + ) + existing_names = [s.name for s in existing_secrets_query.scalars().all()] + missing_names = set(names) - set(existing_names) + if missing_names: + raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}") + + await session.execute( + delete(SecretModel).where( + SecretModel.project_id == project.id, + SecretModel.name.in_(names), + ) + ) + await session.commit() + logger.info("Deleted secrets %s in project %s", names, project.name) + + +def secret_model_to_secret(secret_model: SecretModel, include_value: bool = False) -> Secret: + value = None + if include_value: + value = secret_model.value.get_plaintext_or_error() + return Secret( + id=secret_model.id, + name=secret_model.name, + value=value, + ) + + +async def list_project_secret_models( + session: AsyncSession, + project: ProjectModel, +) -> List[SecretModel]: + res = await session.execute( + select(SecretModel) + .where( + SecretModel.project_id == project.id, + ) + .order_by(SecretModel.created_at.desc()) + ) + secret_models = list(res.scalars().all()) + return secret_models + + +async def get_project_secret_model_by_name( + session: AsyncSession, + project: ProjectModel, + name: str, +) -> Optional[SecretModel]: + res = await session.execute( + select(SecretModel).where( + SecretModel.project_id == project.id, + SecretModel.name == name, + ) + ) + return res.scalar_one_or_none() + + +async def create_secret( + session: AsyncSession, + project: ProjectModel, + name: str, + value: str, +) -> SecretModel: + secret_model = SecretModel( + project_id=project.id, + name=name, + value=DecryptedString(plaintext=value), + ) + try: + async with session.begin_nested(): + session.add(secret_model) + except sqlalchemy.exc.IntegrityError: + raise ResourceExistsError() + await session.commit() + return secret_model + + +async def update_secret( + session: AsyncSession, + project: ProjectModel, + name: str, + value: str, +) -> SecretModel: + await session.execute( + update(SecretModel) + .where( + SecretModel.project_id == project.id, + SecretModel.name == name, + ) + .values( + value=DecryptedString(plaintext=value), + ) + ) + await session.commit() + secret_model = await get_project_secret_model_by_name( + session=session, + project=project, + name=name, + ) + if secret_model is None: + raise ResourceNotExistsError() + return secret_model + + +def _validate_secret(name: str, value: str): + _validate_secret_name(name) + _validate_secret_value(value) + + +def _validate_secret_name(name: str): + if re.match(_SECRET_NAME_REGEX, name) is None: + raise ServerClientError(f"Secret name should match regex '{_SECRET_NAME_REGEX}") + + +def _validate_secret_value(value: str): + if len(value) > _SECRET_VALUE_MAX_LENGTH: + raise ServerClientError(f"Secret value length must not exceed {_SECRET_VALUE_MAX_LENGTH}") diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 7aadb48979..4bcb95404e 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -90,6 +90,7 @@ RepoCredsModel, RepoModel, RunModel, + SecretModel, UserModel, VolumeAttachmentModel, VolumeModel, @@ -332,7 +333,9 @@ async def create_job( if deployment_num is None: deployment_num = run.deployment_num run_spec = RunSpec.parse_raw(run.run_spec) - job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0] + job_spec = ( + await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num) + )[0] job_spec.job_num = job_num job = JobModel( project_id=run.project_id, @@ -934,6 +937,22 @@ async def create_job_prometheus_metrics( return metrics +async def create_secret( + session: AsyncSession, + project: ProjectModel, + name: str, + value: str, +): + secret_model = SecretModel( + project=project, + name=name, + value=DecryptedString(plaintext=value), + ) + session.add(secret_model) + await session.commit() + return secret_model + + def get_private_key_string() -> str: return """ -----BEGIN RSA PRIVATE KEY----- diff --git a/src/dstack/api/server/_secrets.py b/src/dstack/api/server/_secrets.py index adba4081a9..9a2a2763f1 100644 --- a/src/dstack/api/server/_secrets.py +++ b/src/dstack/api/server/_secrets.py @@ -4,33 +4,33 @@ from dstack._internal.core.models.secrets import Secret from dstack._internal.server.schemas.secrets import ( - AddSecretRequest, + CreateOrUpdateSecretRequest, DeleteSecretsRequest, - GetSecretsRequest, - ListSecretsRequest, + GetSecretRequest, ) from dstack.api.server._group import APIClientGroup class SecretsAPIClient(APIClientGroup): - def list(self, project_name: str, repo_id: str) -> List[Secret]: - body = ListSecretsRequest(repo_id=repo_id) - resp = self._request(f"/api/project/{project_name}/secrets/list", body=body.json()) + def list(self, project_name: str) -> List[Secret]: + resp = self._request(f"/api/project/{project_name}/secrets/list") return parse_obj_as(List[Secret.__response__], resp.json()) - def get(self, project_name: str, repo_id: str, secret_name: str) -> Secret: - raise NotImplementedError() - body = GetSecretsRequest(repo_id=repo_id) + def get(self, project_name: str, name: str) -> Secret: + body = GetSecretRequest(name=name) resp = self._request(f"/api/project/{project_name}/secrets/get", body=body.json()) return parse_obj_as(Secret, resp.json()) - def add(self, project_name: str, repo_id: str, secret_name: str, secret_value: str) -> Secret: - body = AddSecretRequest( - repo_id=repo_id, secret=Secret(name=secret_name, value=secret_value) + def create_or_update(self, project_name: str, name: str, value: str) -> Secret: + body = CreateOrUpdateSecretRequest( + name=name, + value=value, + ) + resp = self._request( + f"/api/project/{project_name}/secrets/create_or_update", body=body.json() ) - resp = self._request(f"/api/project/{project_name}/secrets/add", body=body.json()) return parse_obj_as(Secret.__response__, resp.json()) - def delete(self, project_name: str, repo_id: str, secrets_names: List[str]): - body = DeleteSecretsRequest(repo_id=repo_id, secrets_names=secrets_names) + def delete(self, project_name: str, names: List[str]): + body = DeleteSecretsRequest(secrets_names=names) self._request(f"/api/project/{project_name}/secrets/delete", body=body.json()) diff --git a/src/tests/_internal/server/routers/test_secrets.py b/src/tests/_internal/server/routers/test_secrets.py new file mode 100644 index 0000000000..d029dc4ee0 --- /dev/null +++ b/src/tests/_internal/server/routers/test_secrets.py @@ -0,0 +1,271 @@ +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.models import SecretModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_project, + create_secret, + create_user, + get_auth_headers, +) + + +class TestListSecrets: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_403_if_not_admin( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/secrets/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_lists_secrets(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + secret1 = await create_secret( + session=session, project=project, name="secret1", value="123456" + ) + secret2 = await create_secret( + session=session, project=project, name="secret2", value="123456" + ) + response = await client.post( + f"/api/project/{project.name}/secrets/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(secret2.id), + "name": "secret2", + "value": None, + }, + { + "id": str(secret1.id), + "name": "secret1", + "value": None, + }, + ] + + +class TestGetSecret: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_403_if_not_admin( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/secrets/get", + headers=get_auth_headers(user.token), + json={"name": "my_secret"}, + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_secret_with_value( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + secret = await create_secret( + session=session, project=project, name="secret1", value="123456" + ) + response = await client.post( + f"/api/project/{project.name}/secrets/get", + headers=get_auth_headers(user.token), + json={"name": "secret1"}, + ) + assert response.status_code == 200 + assert response.json() == { + "id": str(secret.id), + "name": "secret1", + "value": "123456", + } + + +class TestCreateOrUpdateSecret: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_403_if_not_admin( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/secrets/create_or_update", + headers=get_auth_headers(user.token), + json={"name": "my_secret"}, + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_creates_secret(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = await client.post( + f"/api/project/{project.name}/secrets/create_or_update", + headers=get_auth_headers(user.token), + json={"name": "secret1", "value": "123456"}, + ) + assert response.status_code == 200 + res = await session.execute(select(SecretModel)) + secret_model = res.scalar() + assert secret_model is not None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_updates_secret(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + secret = await create_secret( + session=session, project=project, name="secret1", value="old_value" + ) + response = await client.post( + f"/api/project/{project.name}/secrets/create_or_update", + headers=get_auth_headers(user.token), + json={"name": "secret1", "value": "new_value"}, + ) + assert response.status_code == 200 + await session.refresh(secret) + assert secret.value.get_plaintext_or_error() == "new_value" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "name, value", + [ + ("too_long_secret_value", "a" * 2001), + ("", "empty_name"), + ("@7&.", "wierd_name_chars"), + ], + ) + async def test_rejects_bad_names_values( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + name: str, + value, + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = await client.post( + f"/api/project/{project.name}/secrets/create_or_update", + headers=get_auth_headers(user.token), + json={"name": name, "value": value}, + ) + assert response.status_code == 400 + res = await session.execute(select(SecretModel)) + secret_model = res.scalar() + assert secret_model is None + + +class TestDeleteSecrets: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_403_if_not_admin( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/secrets/delete", + headers=get_auth_headers(user.token), + json={"secrets_names": ["my_secret"]}, + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_secrets(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + # Create two secrets + await create_secret(session=session, project=project, name="secret1", value="123456") + await create_secret(session=session, project=project, name="secret2", value="abcdef") + + # Verify both secrets exist + res = await session.execute( + select(SecretModel).where(SecretModel.project_id == project.id) + ) + secrets = res.scalars().all() + assert len(secrets) == 2 + + # Delete one secret + response = await client.post( + f"/api/project/{project.name}/secrets/delete", + headers=get_auth_headers(user.token), + json={"secrets_names": ["secret1"]}, + ) + assert response.status_code == 200 + + # Verify only one secret remains + res = await session.execute( + select(SecretModel).where(SecretModel.project_id == project.id) + ) + secrets = res.scalars().all() + assert len(secrets) == 1 + assert secrets[0].name == "secret2" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_delete_nonexistent_secret_raises_error( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = await client.post( + f"/api/project/{project.name}/secrets/delete", + headers=get_auth_headers(user.token), + json={"secrets_names": ["nonexistent_secret"]}, + ) + assert response.status_code == 400 # ResourceNotExistsError should return 404 diff --git a/src/tests/_internal/server/services/test_runs.py b/src/tests/_internal/server/services/test_runs.py index 7d3f2f4595..d463f0e4fb 100644 --- a/src/tests/_internal/server/services/test_runs.py +++ b/src/tests/_internal/server/services/test_runs.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Union import pytest from pydantic import parse_obj_as @@ -30,7 +30,7 @@ async def make_run( session: AsyncSession, replicas_statuses: List[JobStatus], status: RunStatus = RunStatus.RUNNING, - replicas: str = 1, + replicas: Union[str, int] = 1, ) -> RunModel: project = await create_project(session=session) user = await create_user(session=session) @@ -70,7 +70,7 @@ async def make_run( status=job_status, replica_num=replica_num, ) - await session.refresh(run) + await session.refresh(run, attribute_names=["project", "jobs"]) return run