diff --git a/docs/docs/concepts/secrets.md b/docs/docs/concepts/secrets.md
new file mode 100644
index 0000000000..d541f281a6
--- /dev/null
+++ b/docs/docs/concepts/secrets.md
@@ -0,0 +1,125 @@
+# Secrets
+
+Secrets allow centralized management of sensitive values such as API keys and credentials. They are project-scoped, managed by project admins, and can be referenced in run configurations to pass sensitive values to runs in a secure manner.
+
+!!! info "Secrets encryption"
+ By default, secrets are stored in plaintext in the DB.
+ Configure [server encryption](../guides/server-deployment.md#encryption) to store secrets encrypted.
+
+## Manage secrets
+
+### Set
+
+Use the `dstack secret set` command to create a new secret:
+
+
+
+```shell
+$ dstack secret set my_secret some_secret_value
+OK
+```
+
+
+
+The same command can be used to update an existing secret:
+
+
+
+```shell
+$ dstack secret set my_secret another_secret_value
+OK
+```
+
+
+
+### List
+
+Use the `dstack secret list` command to list all secrets set in a project:
+
+
+
+```shell
+$ dstack secret
+ NAME VALUE
+ hf_token ******
+ my_secret ******
+
+```
+
+
+
+### Get
+
+The `dstack secret list` does not show secret values. To see a secret value, use the `dstack secret get` command:
+
+
+
+```shell
+$ dstack secret get my_secret
+ NAME VALUE
+ my_secret some_secret_value
+
+```
+
+
+
+### Delete
+
+Secrets can be deleted using the `dstack secret delete` command:
+
+
+
+```shell
+$ dstack secret delete my_secret
+Delete the secret my_secret? [y/n]: y
+OK
+```
+
+
+
+## Use secrets
+
+You can use the `${{ secrets. }}` syntax to reference secrets in run configurations. Currently, secrets interpolation is supported in `env` and `registry_auth` properties.
+
+### `env`
+
+Suppose you need to pass a sensitive environment variable to a run such as `HF_TOKEN`. You'd first create a secret holding the environment variable value:
+
+
+
+```shell
+$ dstack secret set hf_token {hf_token_value}
+OK
+```
+
+
+
+and then reference the secret in `env`:
+
+
+
+```yaml
+type: service
+env:
+ - HF_TOKEN=${{ secrets.hf_token }}
+commands:
+ ...
+```
+
+
+
+### `registry_auth`
+
+If you need to pull a private Docker image, you can store registry credentials as secrets and reference them in `registry_auth`:
+
+
+
+```yaml
+type: service
+image: nvcr.io/nim/deepseek-ai/deepseek-r1-distill-llama-8b
+registry_auth:
+ username: $oauthtoken
+ password: ${{ secrets.ngc_api_key }}
+```
+
+
diff --git a/docs/docs/reference/cli/dstack/secret.md b/docs/docs/reference/cli/dstack/secret.md
new file mode 100644
index 0000000000..9044cc37f3
--- /dev/null
+++ b/docs/docs/reference/cli/dstack/secret.md
@@ -0,0 +1,61 @@
+# dstack secret
+
+The `dstack secret` commands allow managing [Secrets](../../../concepts/secrets.md).
+
+## dstack secret set
+
+The `dstack secret set` command creates a new secret or updates an existing one.
+
+##### Usage
+
+
+
+```shell
+$ dstack secret set --help
+#GENERATE#
+```
+
+
+
+## dstack secret list
+
+The `dstack secret list` command lists all secrets set in a project.
+##### Usage
+
+
+
+```shell
+$ dstack secret list --help
+#GENERATE#
+```
+
+
+
+## dstack secret get
+
+The `dstack secret get` command show the value of a specified secret.
+##### Usage
+
+
+
+```shell
+$ dstack secret get --help
+#GENERATE#
+```
+
+
+
+## dstack secret delete
+
+The `dstack secret delete` command deletes the specified secret.
+
+##### Usage
+
+
+
+```shell
+$ dstack secret delete --help
+#GENERATE#
+```
+
+
diff --git a/mkdocs.yml b/mkdocs.yml
index 2f2b7b4d06..ab7666f8b9 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -221,6 +221,7 @@ nav:
- Fleets: docs/concepts/fleets.md
- Volumes: docs/concepts/volumes.md
- Repos: docs/concepts/repos.md
+ - Secrets: docs/concepts/secrets.md
- Projects: docs/concepts/projects.md
- Gateways: docs/concepts/gateways.md
- Guides:
@@ -254,6 +255,7 @@ nav:
- dstack offer: docs/reference/cli/dstack/offer.md
- dstack volume: docs/reference/cli/dstack/volume.md
- dstack gateway: docs/reference/cli/dstack/gateway.md
+ - dstack secret: docs/reference/cli/dstack/secret.md
- API:
- Python API: docs/reference/api/python/index.md
- REST API: docs/reference/api/rest/index.md
diff --git a/src/dstack/_internal/cli/commands/secrets.py b/src/dstack/_internal/cli/commands/secrets.py
new file mode 100644
index 0000000000..9d42201ed1
--- /dev/null
+++ b/src/dstack/_internal/cli/commands/secrets.py
@@ -0,0 +1,92 @@
+import argparse
+
+from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import SecretNameCompleter
+from dstack._internal.cli.utils.common import (
+ confirm_ask,
+ console,
+)
+from dstack._internal.cli.utils.secrets import print_secrets_table
+
+
+class SecretCommand(APIBaseCommand):
+ NAME = "secret"
+ DESCRIPTION = "Manage secrets"
+
+ def _register(self):
+ super()._register()
+ self._parser.set_defaults(subfunc=self._list)
+ subparsers = self._parser.add_subparsers(dest="action")
+
+ list_parser = subparsers.add_parser(
+ "list", help="List secrets", formatter_class=self._parser.formatter_class
+ )
+ list_parser.set_defaults(subfunc=self._list)
+
+ get_parser = subparsers.add_parser(
+ "get", help="Get secret value", formatter_class=self._parser.formatter_class
+ )
+ get_parser.add_argument(
+ "name",
+ help="The name of the secret",
+ ).completer = SecretNameCompleter()
+ get_parser.set_defaults(subfunc=self._get)
+
+ set_parser = subparsers.add_parser(
+ "set", help="Set secret", formatter_class=self._parser.formatter_class
+ )
+ set_parser.add_argument(
+ "name",
+ help="The name of the secret",
+ )
+ set_parser.add_argument(
+ "value",
+ help="The value of the secret",
+ )
+ set_parser.set_defaults(subfunc=self._set)
+
+ delete_parser = subparsers.add_parser(
+ "delete",
+ help="Delete secrets",
+ formatter_class=self._parser.formatter_class,
+ )
+ delete_parser.add_argument(
+ "name",
+ help="The name of the secret",
+ ).completer = SecretNameCompleter()
+ delete_parser.add_argument(
+ "-y", "--yes", help="Don't ask for confirmation", action="store_true"
+ )
+ delete_parser.set_defaults(subfunc=self._delete)
+
+ def _command(self, args: argparse.Namespace):
+ super()._command(args)
+ args.subfunc(args)
+
+ def _list(self, args: argparse.Namespace):
+ secrets = self.api.client.secrets.list(self.api.project)
+ print_secrets_table(secrets)
+
+ def _get(self, args: argparse.Namespace):
+ secret = self.api.client.secrets.get(self.api.project, name=args.name)
+ print_secrets_table([secret])
+
+ def _set(self, args: argparse.Namespace):
+ self.api.client.secrets.create_or_update(
+ self.api.project,
+ name=args.name,
+ value=args.value,
+ )
+ console.print("[grey58]OK[/]")
+
+ def _delete(self, args: argparse.Namespace):
+ if not args.yes and not confirm_ask(f"Delete the secret [code]{args.name}[/]?"):
+ console.print("\nExiting...")
+ return
+
+ with console.status("Deleting secret..."):
+ self.api.client.secrets.delete(
+ project_name=self.api.project,
+ names=[args.name],
+ )
+ console.print("[grey58]OK[/]")
diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py
index 345a9a16e6..c91d0f2feb 100644
--- a/src/dstack/_internal/cli/main.py
+++ b/src/dstack/_internal/cli/main.py
@@ -17,6 +17,7 @@
from dstack._internal.cli.commands.offer import OfferCommand
from dstack._internal.cli.commands.project import ProjectCommand
from dstack._internal.cli.commands.ps import PsCommand
+from dstack._internal.cli.commands.secrets import SecretCommand
from dstack._internal.cli.commands.server import ServerCommand
from dstack._internal.cli.commands.stats import StatsCommand
from dstack._internal.cli.commands.stop import StopCommand
@@ -72,6 +73,7 @@ def main():
MetricsCommand.register(subparsers)
ProjectCommand.register(subparsers)
PsCommand.register(subparsers)
+ SecretCommand.register(subparsers)
ServerCommand.register(subparsers)
StatsCommand.register(subparsers)
StopCommand.register(subparsers)
diff --git a/src/dstack/_internal/cli/services/completion.py b/src/dstack/_internal/cli/services/completion.py
index ed8ce26ad0..ed63d82783 100644
--- a/src/dstack/_internal/cli/services/completion.py
+++ b/src/dstack/_internal/cli/services/completion.py
@@ -75,6 +75,11 @@ def fetch_resource_names(self, api: Client) -> Iterable[str]:
return [r.name for r in api.client.gateways.list(api.project)]
+class SecretNameCompleter(BaseAPINameCompleter):
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ return [r.name for r in api.client.secrets.list(api.project)]
+
+
class ProjectNameCompleter(BaseCompleter):
"""
Completer for local project names.
diff --git a/src/dstack/_internal/cli/utils/secrets.py b/src/dstack/_internal/cli/utils/secrets.py
new file mode 100644
index 0000000000..5fcbb5a99a
--- /dev/null
+++ b/src/dstack/_internal/cli/utils/secrets.py
@@ -0,0 +1,25 @@
+from typing import List
+
+from rich.table import Table
+
+from dstack._internal.cli.utils.common import add_row_from_dict, console
+from dstack._internal.core.models.secrets import Secret
+
+
+def print_secrets_table(secrets: List[Secret]) -> None:
+ console.print(get_secrets_table(secrets))
+ console.print()
+
+
+def get_secrets_table(secrets: List[Secret]) -> Table:
+ table = Table(box=None)
+ table.add_column("NAME", no_wrap=True)
+ table.add_column("VALUE")
+
+ for secret in secrets:
+ row = {
+ "NAME": secret.name,
+ "VALUE": secret.value or "*" * 6,
+ }
+ add_row_from_dict(table, row)
+ return table
diff --git a/src/dstack/_internal/core/models/secrets.py b/src/dstack/_internal/core/models/secrets.py
index 86c9f93781..ab3f411290 100644
--- a/src/dstack/_internal/core/models/secrets.py
+++ b/src/dstack/_internal/core/models/secrets.py
@@ -1,9 +1,16 @@
+from typing import Optional
+from uuid import UUID
+
from dstack._internal.core.models.common import CoreModel
class Secret(CoreModel):
+ id: UUID
name: str
- value: str
+ value: Optional[str] = None
def __str__(self) -> str:
- return f'Secret(name="{self.name}", value={"*" * len(self.value)})'
+ displayed_value = "*"
+ if self.value is not None:
+ displayed_value = "*" * len(self.value)
+ return f'Secret(name="{self.name}", value={displayed_value})'
diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
index ef07becdfd..5249efe5d2 100644
--- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
@@ -70,9 +70,10 @@
from dstack._internal.server.services.runs import (
run_model_to_run,
)
+from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.storage import get_default_storage
from dstack._internal.utils import common as common_utils
-from dstack._internal.utils.interpolator import VariablesInterpolator
+from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
@@ -181,7 +182,17 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
common_utils.get_or_error(job_model.instance)
)
- secrets = {} # TODO secrets
+ secrets = await get_project_secrets_mapping(session=session, project=project)
+
+ try:
+ _interpolate_secrets(secrets, job.job_spec)
+ except InterpolatorError as e:
+ logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
+ job_model.status = JobStatus.TERMINATING
+ job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
+ job_model.termination_reason_message = e.args[0]
+ job_model.last_processed_at = common_utils.get_current_datetime()
+ return
repo_creds_model = await get_repo_creds(session=session, repo=repo_model, user=run_model.user)
repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
@@ -218,7 +229,6 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
job_model,
job_provisioning_data,
volumes,
- secrets,
job.job_spec.registry_auth,
public_keys,
ssh_user,
@@ -327,8 +337,9 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
else:
if job_model.termination_reason:
logger.warning(
- "%s: failed because shim/runner returned an error, age=%s",
+ "%s: failed due to %s, age=%s",
fmt(job_model),
+ job_model.termination_reason.value,
job_submission.age,
)
job_model.status = JobStatus.TERMINATING
@@ -471,7 +482,6 @@ def _process_provisioning_with_shim(
job_model: JobModel,
job_provisioning_data: JobProvisioningData,
volumes: List[Volume],
- secrets: Dict[str, str],
registry_auth: Optional[RegistryAuth],
public_keys: List[str],
ssh_user: str,
@@ -497,10 +507,8 @@ def _process_provisioning_with_shim(
registry_username = ""
registry_password = ""
if registry_auth is not None:
- logger.debug("%s: authenticating to the registry...", fmt(job_model))
- interpolate = VariablesInterpolator({"secrets": secrets}).interpolate
- registry_username = interpolate(registry_auth.username)
- registry_password = interpolate(registry_auth.password)
+ registry_username = registry_auth.username
+ registry_password = registry_auth.password
volume_mounts: List[VolumeMountPoint] = []
instance_mounts: List[InstanceMountPoint] = []
@@ -957,7 +965,9 @@ def _submit_job_to_runner(
run=run,
job=job,
cluster_info=cluster_info,
- secrets=secrets,
+ # Do not send all the secrets since interpolation is already done by the server.
+ # TODO: Passing secrets may be necessary for filtering out secret values from logs.
+ secrets={},
repo_credentials=repo_credentials,
instance_env=instance_env,
)
@@ -975,6 +985,16 @@ def _submit_job_to_runner(
return True
+def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec):
+ interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error
+ job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()}
+ if job_spec.registry_auth is not None:
+ job_spec.registry_auth = RegistryAuth(
+ username=interpolate(job_spec.registry_auth.username),
+ password=interpolate(job_spec.registry_auth.password),
+ )
+
+
def _get_instance_specific_mounts(
backend_type: BackendType, instance_type_name: str
) -> List[InstanceMountPoint]:
diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py
index 105f9e8c9d..5a2d596664 100644
--- a/src/dstack/_internal/server/background/tasks/process_runs.py
+++ b/src/dstack/_internal/server/background/tasks/process_runs.py
@@ -35,6 +35,7 @@
run_model_to_run,
scale_run_replicas,
)
+from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.services import update_service_desired_replica_count
from dstack._internal.utils import common
from dstack._internal.utils.logging import get_logger
@@ -404,7 +405,11 @@ async def _handle_run_replicas(
)
return
- await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
+ await _update_jobs_to_new_deployment_in_place(
+ session=session,
+ run_model=run_model,
+ run_spec=run_spec,
+ )
if _has_out_of_date_replicas(run_model):
non_terminated_replica_count = len(
{j.replica_num for j in run_model.jobs if not j.status.is_finished()}
@@ -444,18 +449,25 @@ async def _handle_run_replicas(
)
-async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
+async def _update_jobs_to_new_deployment_in_place(
+ session: AsyncSession, run_model: RunModel, run_spec: RunSpec
+) -> None:
"""
Bump deployment_num for jobs that do not require redeployment.
"""
-
+ secrets = await get_project_secrets_mapping(
+ session=session,
+ project=run_model.project,
+ )
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
if all(j.status.is_finished() for j in job_models):
continue
if all(j.deployment_num == run_model.deployment_num for j in job_models):
continue
+ # FIXME: Handle getting image configuration errors or skip it.
new_job_specs = await get_job_specs_from_run_spec(
run_spec=run_spec,
+ secrets=secrets,
replica_num=replica_num,
)
assert len(new_job_specs) == len(job_models), (
diff --git a/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py b/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py
new file mode 100644
index 0000000000..6563e92dc9
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py
@@ -0,0 +1,49 @@
+"""Add SecretModel
+
+Revision ID: 644b8a114187
+Revises: 5f1707c525d2
+Create Date: 2025-06-30 11:00:04.326290
+
+"""
+
+import sqlalchemy as sa
+import sqlalchemy_utils
+from alembic import op
+
+import dstack._internal.server.models
+
+# revision identifiers, used by Alembic.
+revision = "644b8a114187"
+down_revision = "5f1707c525d2"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table(
+ "secrets",
+ sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
+ sa.Column(
+ "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False
+ ),
+ sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False),
+ sa.Column("updated_at", dstack._internal.server.models.NaiveDateTime(), nullable=False),
+ sa.Column("name", sa.String(length=200), nullable=False),
+ sa.Column("value", dstack._internal.server.models.EncryptedString(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["project_id"],
+ ["projects.id"],
+ name=op.f("fk_secrets_project_id_projects"),
+ ondelete="CASCADE",
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_secrets")),
+ sa.UniqueConstraint("project_id", "name", name="uq_secrets_project_id_name"),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table("secrets")
+ # ### end Alembic commands ###
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index dd8fd5f125..d39d07be10 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -726,3 +726,21 @@ class JobPrometheusMetrics(BaseModel):
collected_at: Mapped[datetime] = mapped_column(NaiveDateTime)
# Raw Prometheus text response
text: Mapped[str] = mapped_column(Text)
+
+
+class SecretModel(BaseModel):
+ __tablename__ = "secrets"
+ __table_args__ = (UniqueConstraint("project_id", "name", name="uq_secrets_project_id_name"),)
+
+ id: Mapped[uuid.UUID] = mapped_column(
+ UUIDType(binary=False), primary_key=True, default=uuid.uuid4
+ )
+
+ project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
+ project: Mapped["ProjectModel"] = relationship()
+
+ created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
+ updated_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
+
+ name: Mapped[str] = mapped_column(String(200))
+ value: Mapped[DecryptedString] = mapped_column(EncryptedString())
diff --git a/src/dstack/_internal/server/routers/secrets.py b/src/dstack/_internal/server/routers/secrets.py
index 3732a0b1ca..bbfa26be93 100644
--- a/src/dstack/_internal/server/routers/secrets.py
+++ b/src/dstack/_internal/server/routers/secrets.py
@@ -1,15 +1,19 @@
-from typing import List
+from typing import List, Tuple
-from fastapi import APIRouter
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
-from dstack._internal.core.models.runs import Run
+from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.secrets import Secret
+from dstack._internal.server.db import get_session
+from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.secrets import (
- AddSecretRequest,
+ CreateOrUpdateSecretRequest,
DeleteSecretsRequest,
- GetSecretsRequest,
- ListSecretsRequest,
+ GetSecretRequest,
)
+from dstack._internal.server.security.permissions import ProjectAdmin
+from dstack._internal.server.services import secrets as secrets_services
router = APIRouter(
prefix="/api/project/{project_name}/secrets",
@@ -18,20 +22,58 @@
@router.post("/list")
-async def list_secrets(project_name: str, body: ListSecretsRequest) -> List[Run]:
- pass
+async def list_secrets(
+ session: AsyncSession = Depends(get_session),
+ user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
+) -> List[Secret]:
+ _, project = user_project
+ return await secrets_services.list_secrets(
+ session=session,
+ project=project,
+ )
@router.post("/get")
-async def get_secret(project_name: str, body: GetSecretsRequest) -> Secret:
- pass
+async def get_secret(
+ body: GetSecretRequest,
+ session: AsyncSession = Depends(get_session),
+ user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
+) -> Secret:
+ _, project = user_project
+ secret = await secrets_services.get_secret(
+ session=session,
+ project=project,
+ name=body.name,
+ )
+ if secret is None:
+ raise ResourceNotExistsError()
+ return secret
-@router.post("/add")
-async def add_or_update_secret(project_name: str, body: AddSecretRequest) -> Secret:
- pass
+@router.post("/create_or_update")
+async def create_or_update_secret(
+ body: CreateOrUpdateSecretRequest,
+ session: AsyncSession = Depends(get_session),
+ user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
+) -> Secret:
+ _, project = user_project
+ return await secrets_services.create_or_update_secret(
+ session=session,
+ project=project,
+ name=body.name,
+ value=body.value,
+ )
@router.post("/delete")
-async def delete_secrets(project_name: str, body: DeleteSecretsRequest):
- pass
+async def delete_secrets(
+ body: DeleteSecretsRequest,
+ session: AsyncSession = Depends(get_session),
+ user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
+):
+ _, project = user_project
+ await secrets_services.delete_secrets(
+ session=session,
+ project=project,
+ names=body.secrets_names,
+ )
diff --git a/src/dstack/_internal/server/schemas/secrets.py b/src/dstack/_internal/server/schemas/secrets.py
index 769c87052c..a8d78ea071 100644
--- a/src/dstack/_internal/server/schemas/secrets.py
+++ b/src/dstack/_internal/server/schemas/secrets.py
@@ -1,20 +1,16 @@
from typing import List
-from dstack._internal.core.models.secrets import Secret
-from dstack._internal.server.schemas.common import RepoRequest
+from dstack._internal.core.models.common import CoreModel
-class ListSecretsRequest(RepoRequest):
- pass
+class GetSecretRequest(CoreModel):
+ name: str
-class GetSecretsRequest(RepoRequest):
- pass
+class CreateOrUpdateSecretRequest(CoreModel):
+ name: str
+ value: str
-class AddSecretRequest(RepoRequest):
- secret: Secret
-
-
-class DeleteSecretsRequest(RepoRequest):
+class DeleteSecretsRequest(CoreModel):
secrets_names: List[str]
diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py
index 979daa44c2..41aa496be7 100644
--- a/src/dstack/_internal/server/services/jobs/__init__.py
+++ b/src/dstack/_internal/server/services/jobs/__init__.py
@@ -65,15 +65,23 @@
logger = get_logger(__name__)
-async def get_jobs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[Job]:
+async def get_jobs_from_run_spec(
+ run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
+) -> List[Job]:
return [
Job(job_spec=s, job_submissions=[])
- for s in await get_job_specs_from_run_spec(run_spec, replica_num)
+ for s in await get_job_specs_from_run_spec(
+ run_spec=run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
+ )
]
-async def get_job_specs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[JobSpec]:
- job_configurator = _get_job_configurator(run_spec)
+async def get_job_specs_from_run_spec(
+ run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
+) -> List[JobSpec]:
+ job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets)
job_specs = await job_configurator.get_job_specs(replica_num=replica_num)
return job_specs
@@ -159,10 +167,10 @@ def delay_job_instance_termination(job_model: JobModel):
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)
-def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator:
+def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator:
configuration_type = RunConfigurationType(run_spec.configuration.type)
configurator_class = _configuration_type_to_configurator_class_map[configuration_type]
- return configurator_class(run_spec)
+ return configurator_class(run_spec=run_spec, secrets=secrets)
_job_configurator_classes = [
diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py
index 465a24fb0d..349f2bcf02 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/base.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/base.py
@@ -68,8 +68,13 @@ class JobConfigurator(ABC):
# JobSSHKey should be shared for all jobs in a replica for inter-node communication.
_job_ssh_key: Optional[JobSSHKey] = None
- def __init__(self, run_spec: RunSpec):
+ def __init__(
+ self,
+ run_spec: RunSpec,
+ secrets: Optional[Dict[str, str]] = None,
+ ):
self.run_spec = run_spec
+ self.secrets = secrets or {}
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1)
@@ -98,10 +103,20 @@ def _ports(self) -> List[PortMapping]:
async def _get_image_config(self) -> ImageConfig:
if self._image_config is not None:
return self._image_config
+ interpolate = VariablesInterpolator({"secrets": self.secrets}).interpolate_or_error
+ registry_auth = self.run_spec.configuration.registry_auth
+ if registry_auth is not None:
+ try:
+ registry_auth = RegistryAuth(
+ username=interpolate(registry_auth.username),
+ password=interpolate(registry_auth.password),
+ )
+ except InterpolatorError as e:
+ raise ServerClientError(e.args[0])
image_config = await run_async(
_get_image_config,
self._image_name(),
- self.run_spec.configuration.registry_auth,
+ registry_auth,
)
self._image_config = image_config
return image_config
diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py
index a2d96cf1a4..a10922ef79 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/dev.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Dict, List, Optional
from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
@@ -17,7 +17,7 @@
class DevEnvironmentJobConfigurator(JobConfigurator):
TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT
- def __init__(self, run_spec: RunSpec):
+ def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]):
if run_spec.configuration.ide == "vscode":
__class = VSCodeDesktop
elif run_spec.configuration.ide == "cursor":
@@ -29,7 +29,7 @@ def __init__(self, run_spec: RunSpec):
version=run_spec.configuration.version,
extensions=["ms-python.python", "ms-toolsai.jupyter"],
)
- super().__init__(run_spec)
+ super().__init__(run_spec=run_spec, secrets=secrets)
def _shell_commands(self) -> List[str]:
commands = self.ide.get_install_commands()
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index 385660f944..bb775bdf29 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -82,6 +82,7 @@
from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
from dstack._internal.server.services.resources import set_resources_defaults
+from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.users import get_user_model_by_name
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.random_names import generate_name
@@ -311,7 +312,12 @@ async def get_plan(
):
action = ApplyAction.UPDATE
- jobs = await get_jobs_from_run_spec(effective_run_spec, replica_num=0)
+ secrets = await get_project_secrets_mapping(session=session, project=project)
+ jobs = await get_jobs_from_run_spec(
+ run_spec=effective_run_spec,
+ secrets=secrets,
+ replica_num=0,
+ )
volumes = await get_job_configured_volumes(
session=session,
@@ -462,6 +468,10 @@ async def submit_run(
project=project,
run_spec=run_spec,
)
+ secrets = await get_project_secrets_mapping(
+ session=session,
+ project=project,
+ )
lock_namespace = f"run_names_{project.name}"
if get_db().dialect_name == "sqlite":
@@ -513,7 +523,11 @@ async def submit_run(
await services.register_service(session, run_model, run_spec)
for replica_num in range(replicas):
- jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
+ jobs = await get_jobs_from_run_spec(
+ run_spec=run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
+ )
for job in jobs:
job_model = create_job_model_for_new_submission(
run_model=run_model,
@@ -1068,10 +1082,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
scheduled_replicas += 1
+ secrets = await get_project_secrets_mapping(
+ session=session,
+ project=run_model.project,
+ )
+
for replica_num in range(
len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
):
- jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
+ # FIXME: Handle getting image configuration errors or skip it.
+ jobs = await get_jobs_from_run_spec(
+ run_spec=run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
+ )
for job in jobs:
job_model = create_job_model_for_new_submission(
run_model=run_model,
@@ -1084,8 +1108,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
async def retry_run_replica_jobs(
session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
):
+ # FIXME: Handle getting image configuration errors or skip it.
+ secrets = await get_project_secrets_mapping(
+ session=session,
+ project=run_model.project,
+ )
new_jobs = await get_jobs_from_run_spec(
- RunSpec.__response__.parse_raw(run_model.run_spec),
+ run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
+ secrets=secrets,
replica_num=latest_jobs[0].replica_num,
)
assert len(new_jobs) == len(latest_jobs), (
diff --git a/src/dstack/_internal/server/services/secrets.py b/src/dstack/_internal/server/services/secrets.py
new file mode 100644
index 0000000000..0439f9e688
--- /dev/null
+++ b/src/dstack/_internal/server/services/secrets.py
@@ -0,0 +1,204 @@
+import re
+from typing import Dict, List, Optional
+
+import sqlalchemy.exc
+from sqlalchemy import delete, select, update
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from dstack._internal.core.errors import (
+ ResourceExistsError,
+ ResourceNotExistsError,
+ ServerClientError,
+)
+from dstack._internal.core.models.secrets import Secret
+from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel
+from dstack._internal.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+_SECRET_NAME_REGEX = "^[A-Za-z0-9-_]{1,200}$"
+_SECRET_VALUE_MAX_LENGTH = 2000
+
+
+async def list_secrets(
+ session: AsyncSession,
+ project: ProjectModel,
+) -> List[Secret]:
+ secret_models = await list_project_secret_models(session=session, project=project)
+ return [secret_model_to_secret(s, include_value=False) for s in secret_models]
+
+
+async def get_project_secrets_mapping(
+ session: AsyncSession,
+ project: ProjectModel,
+) -> Dict[str, str]:
+ secret_models = await list_project_secret_models(session=session, project=project)
+ return {s.name: s.value.get_plaintext_or_error() for s in secret_models}
+
+
+async def get_secret(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+) -> Optional[Secret]:
+ secret_model = await get_project_secret_model_by_name(
+ session=session,
+ project=project,
+ name=name,
+ )
+ if secret_model is None:
+ return None
+ return secret_model_to_secret(secret_model, include_value=True)
+
+
+async def create_or_update_secret(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+ value: str,
+) -> Secret:
+ _validate_secret(name=name, value=value)
+ try:
+ secret_model = await create_secret(
+ session=session,
+ project=project,
+ name=name,
+ value=value,
+ )
+ except ResourceExistsError:
+ secret_model = await update_secret(
+ session=session,
+ project=project,
+ name=name,
+ value=value,
+ )
+ return secret_model_to_secret(secret_model, include_value=True)
+
+
+async def delete_secrets(
+ session: AsyncSession,
+ project: ProjectModel,
+ names: List[str],
+):
+ existing_secrets_query = await session.execute(
+ select(SecretModel).where(
+ SecretModel.project_id == project.id,
+ SecretModel.name.in_(names),
+ )
+ )
+ existing_names = [s.name for s in existing_secrets_query.scalars().all()]
+ missing_names = set(names) - set(existing_names)
+ if missing_names:
+ raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}")
+
+ await session.execute(
+ delete(SecretModel).where(
+ SecretModel.project_id == project.id,
+ SecretModel.name.in_(names),
+ )
+ )
+ await session.commit()
+ logger.info("Deleted secrets %s in project %s", names, project.name)
+
+
+def secret_model_to_secret(secret_model: SecretModel, include_value: bool = False) -> Secret:
+ value = None
+ if include_value:
+ value = secret_model.value.get_plaintext_or_error()
+ return Secret(
+ id=secret_model.id,
+ name=secret_model.name,
+ value=value,
+ )
+
+
+async def list_project_secret_models(
+ session: AsyncSession,
+ project: ProjectModel,
+) -> List[SecretModel]:
+ res = await session.execute(
+ select(SecretModel)
+ .where(
+ SecretModel.project_id == project.id,
+ )
+ .order_by(SecretModel.created_at.desc())
+ )
+ secret_models = list(res.scalars().all())
+ return secret_models
+
+
+async def get_project_secret_model_by_name(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+) -> Optional[SecretModel]:
+ res = await session.execute(
+ select(SecretModel).where(
+ SecretModel.project_id == project.id,
+ SecretModel.name == name,
+ )
+ )
+ return res.scalar_one_or_none()
+
+
+async def create_secret(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+ value: str,
+) -> SecretModel:
+ secret_model = SecretModel(
+ project_id=project.id,
+ name=name,
+ value=DecryptedString(plaintext=value),
+ )
+ try:
+ async with session.begin_nested():
+ session.add(secret_model)
+ except sqlalchemy.exc.IntegrityError:
+ raise ResourceExistsError()
+ await session.commit()
+ return secret_model
+
+
+async def update_secret(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+ value: str,
+) -> SecretModel:
+ await session.execute(
+ update(SecretModel)
+ .where(
+ SecretModel.project_id == project.id,
+ SecretModel.name == name,
+ )
+ .values(
+ value=DecryptedString(plaintext=value),
+ )
+ )
+ await session.commit()
+ secret_model = await get_project_secret_model_by_name(
+ session=session,
+ project=project,
+ name=name,
+ )
+ if secret_model is None:
+ raise ResourceNotExistsError()
+ return secret_model
+
+
+def _validate_secret(name: str, value: str):
+ _validate_secret_name(name)
+ _validate_secret_value(value)
+
+
+def _validate_secret_name(name: str):
+ if re.match(_SECRET_NAME_REGEX, name) is None:
+ raise ServerClientError(f"Secret name should match regex '{_SECRET_NAME_REGEX}")
+
+
+def _validate_secret_value(value: str):
+ if len(value) > _SECRET_VALUE_MAX_LENGTH:
+ raise ServerClientError(f"Secret value length must not exceed {_SECRET_VALUE_MAX_LENGTH}")
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index 7aadb48979..4bcb95404e 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -90,6 +90,7 @@
RepoCredsModel,
RepoModel,
RunModel,
+ SecretModel,
UserModel,
VolumeAttachmentModel,
VolumeModel,
@@ -332,7 +333,9 @@ async def create_job(
if deployment_num is None:
deployment_num = run.deployment_num
run_spec = RunSpec.parse_raw(run.run_spec)
- job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
+ job_spec = (
+ await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num)
+ )[0]
job_spec.job_num = job_num
job = JobModel(
project_id=run.project_id,
@@ -934,6 +937,22 @@ async def create_job_prometheus_metrics(
return metrics
+async def create_secret(
+ session: AsyncSession,
+ project: ProjectModel,
+ name: str,
+ value: str,
+):
+ secret_model = SecretModel(
+ project=project,
+ name=name,
+ value=DecryptedString(plaintext=value),
+ )
+ session.add(secret_model)
+ await session.commit()
+ return secret_model
+
+
def get_private_key_string() -> str:
return """
-----BEGIN RSA PRIVATE KEY-----
diff --git a/src/dstack/api/server/_secrets.py b/src/dstack/api/server/_secrets.py
index adba4081a9..9a2a2763f1 100644
--- a/src/dstack/api/server/_secrets.py
+++ b/src/dstack/api/server/_secrets.py
@@ -4,33 +4,33 @@
from dstack._internal.core.models.secrets import Secret
from dstack._internal.server.schemas.secrets import (
- AddSecretRequest,
+ CreateOrUpdateSecretRequest,
DeleteSecretsRequest,
- GetSecretsRequest,
- ListSecretsRequest,
+ GetSecretRequest,
)
from dstack.api.server._group import APIClientGroup
class SecretsAPIClient(APIClientGroup):
- def list(self, project_name: str, repo_id: str) -> List[Secret]:
- body = ListSecretsRequest(repo_id=repo_id)
- resp = self._request(f"/api/project/{project_name}/secrets/list", body=body.json())
+ def list(self, project_name: str) -> List[Secret]:
+ resp = self._request(f"/api/project/{project_name}/secrets/list")
return parse_obj_as(List[Secret.__response__], resp.json())
- def get(self, project_name: str, repo_id: str, secret_name: str) -> Secret:
- raise NotImplementedError()
- body = GetSecretsRequest(repo_id=repo_id)
+ def get(self, project_name: str, name: str) -> Secret:
+ body = GetSecretRequest(name=name)
resp = self._request(f"/api/project/{project_name}/secrets/get", body=body.json())
return parse_obj_as(Secret, resp.json())
- def add(self, project_name: str, repo_id: str, secret_name: str, secret_value: str) -> Secret:
- body = AddSecretRequest(
- repo_id=repo_id, secret=Secret(name=secret_name, value=secret_value)
+ def create_or_update(self, project_name: str, name: str, value: str) -> Secret:
+ body = CreateOrUpdateSecretRequest(
+ name=name,
+ value=value,
+ )
+ resp = self._request(
+ f"/api/project/{project_name}/secrets/create_or_update", body=body.json()
)
- resp = self._request(f"/api/project/{project_name}/secrets/add", body=body.json())
return parse_obj_as(Secret.__response__, resp.json())
- def delete(self, project_name: str, repo_id: str, secrets_names: List[str]):
- body = DeleteSecretsRequest(repo_id=repo_id, secrets_names=secrets_names)
+ def delete(self, project_name: str, names: List[str]):
+ body = DeleteSecretsRequest(secrets_names=names)
self._request(f"/api/project/{project_name}/secrets/delete", body=body.json())
diff --git a/src/tests/_internal/server/routers/test_secrets.py b/src/tests/_internal/server/routers/test_secrets.py
new file mode 100644
index 0000000000..d029dc4ee0
--- /dev/null
+++ b/src/tests/_internal/server/routers/test_secrets.py
@@ -0,0 +1,271 @@
+import pytest
+from httpx import AsyncClient
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from dstack._internal.core.models.users import GlobalRole, ProjectRole
+from dstack._internal.server.models import SecretModel
+from dstack._internal.server.services.projects import add_project_member
+from dstack._internal.server.testing.common import (
+ create_project,
+ create_secret,
+ create_user,
+ get_auth_headers,
+)
+
+
+class TestListSecrets:
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_403_if_not_admin(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/list",
+ headers=get_auth_headers(user.token),
+ json={},
+ )
+ assert response.status_code == 403
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_lists_secrets(self, test_db, session: AsyncSession, client: AsyncClient):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ secret1 = await create_secret(
+ session=session, project=project, name="secret1", value="123456"
+ )
+ secret2 = await create_secret(
+ session=session, project=project, name="secret2", value="123456"
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/list",
+ headers=get_auth_headers(user.token),
+ json={},
+ )
+ assert response.status_code == 200
+ assert response.json() == [
+ {
+ "id": str(secret2.id),
+ "name": "secret2",
+ "value": None,
+ },
+ {
+ "id": str(secret1.id),
+ "name": "secret1",
+ "value": None,
+ },
+ ]
+
+
+class TestGetSecret:
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_403_if_not_admin(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/get",
+ headers=get_auth_headers(user.token),
+ json={"name": "my_secret"},
+ )
+ assert response.status_code == 403
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_secret_with_value(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ secret = await create_secret(
+ session=session, project=project, name="secret1", value="123456"
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/get",
+ headers=get_auth_headers(user.token),
+ json={"name": "secret1"},
+ )
+ assert response.status_code == 200
+ assert response.json() == {
+ "id": str(secret.id),
+ "name": "secret1",
+ "value": "123456",
+ }
+
+
+class TestCreateOrUpdateSecret:
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_403_if_not_admin(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/create_or_update",
+ headers=get_auth_headers(user.token),
+ json={"name": "my_secret"},
+ )
+ assert response.status_code == 403
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_creates_secret(self, test_db, session: AsyncSession, client: AsyncClient):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/create_or_update",
+ headers=get_auth_headers(user.token),
+ json={"name": "secret1", "value": "123456"},
+ )
+ assert response.status_code == 200
+ res = await session.execute(select(SecretModel))
+ secret_model = res.scalar()
+ assert secret_model is not None
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_updates_secret(self, test_db, session: AsyncSession, client: AsyncClient):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ secret = await create_secret(
+ session=session, project=project, name="secret1", value="old_value"
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/create_or_update",
+ headers=get_auth_headers(user.token),
+ json={"name": "secret1", "value": "new_value"},
+ )
+ assert response.status_code == 200
+ await session.refresh(secret)
+ assert secret.value.get_plaintext_or_error() == "new_value"
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ @pytest.mark.parametrize(
+ "name, value",
+ [
+ ("too_long_secret_value", "a" * 2001),
+ ("", "empty_name"),
+ ("@7&.", "wierd_name_chars"),
+ ],
+ )
+ async def test_rejects_bad_names_values(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ name: str,
+ value,
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/create_or_update",
+ headers=get_auth_headers(user.token),
+ json={"name": name, "value": value},
+ )
+ assert response.status_code == 400
+ res = await session.execute(select(SecretModel))
+ secret_model = res.scalar()
+ assert secret_model is None
+
+
+class TestDeleteSecrets:
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_403_if_not_admin(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/delete",
+ headers=get_auth_headers(user.token),
+ json={"secrets_names": ["my_secret"]},
+ )
+ assert response.status_code == 403
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_deletes_secrets(self, test_db, session: AsyncSession, client: AsyncClient):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ # Create two secrets
+ await create_secret(session=session, project=project, name="secret1", value="123456")
+ await create_secret(session=session, project=project, name="secret2", value="abcdef")
+
+ # Verify both secrets exist
+ res = await session.execute(
+ select(SecretModel).where(SecretModel.project_id == project.id)
+ )
+ secrets = res.scalars().all()
+ assert len(secrets) == 2
+
+ # Delete one secret
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/delete",
+ headers=get_auth_headers(user.token),
+ json={"secrets_names": ["secret1"]},
+ )
+ assert response.status_code == 200
+
+ # Verify only one secret remains
+ res = await session.execute(
+ select(SecretModel).where(SecretModel.project_id == project.id)
+ )
+ secrets = res.scalars().all()
+ assert len(secrets) == 1
+ assert secrets[0].name == "secret2"
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_delete_nonexistent_secret_raises_error(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ response = await client.post(
+ f"/api/project/{project.name}/secrets/delete",
+ headers=get_auth_headers(user.token),
+ json={"secrets_names": ["nonexistent_secret"]},
+ )
+ assert response.status_code == 400 # ResourceNotExistsError should return 404
diff --git a/src/tests/_internal/server/services/test_runs.py b/src/tests/_internal/server/services/test_runs.py
index 7d3f2f4595..d463f0e4fb 100644
--- a/src/tests/_internal/server/services/test_runs.py
+++ b/src/tests/_internal/server/services/test_runs.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Union
import pytest
from pydantic import parse_obj_as
@@ -30,7 +30,7 @@ async def make_run(
session: AsyncSession,
replicas_statuses: List[JobStatus],
status: RunStatus = RunStatus.RUNNING,
- replicas: str = 1,
+ replicas: Union[str, int] = 1,
) -> RunModel:
project = await create_project(session=session)
user = await create_user(session=session)
@@ -70,7 +70,7 @@ async def make_run(
status=job_status,
replica_num=replica_num,
)
- await session.refresh(run)
+ await session.refresh(run, attribute_names=["project", "jobs"])
return run