Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ class Project(CoreModel):
created_at: Optional[datetime] = None
backends: List[BackendInfo]
members: List[Member]
is_public: bool = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Add ProjectModel.is_public

Revision ID: 35f732ee4cf5
Revises: bca2fdf130bf
Create Date: 2025-06-06 13:04:02.912032

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "35f732ee4cf5"
down_revision = "bca2fdf130bf"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add is_public column as nullable first
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.add_column(sa.Column("is_public", sa.Boolean(), nullable=True))

# Set is_public to False for existing projects
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid WHAT comments that duplicate the code. If LLM generated, please remove.

op.execute(sa.sql.text("UPDATE projects SET is_public = FALSE"))

# Make is_public non-nullable with default value
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.alter_column("is_public", nullable=False, server_default=sa.false())
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Remove is_public column
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.drop_column("is_public")
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class ProjectModel(BaseModel):
name: Mapped[str] = mapped_column(String(50), unique=True)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)

owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
owner: Mapped[UserModel] = relationship(lazy="joined")
Expand Down
7 changes: 4 additions & 3 deletions src/dstack/_internal/server/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dstack._internal.server.security.permissions import (
Authenticated,
ProjectManager,
ProjectMember,
ProjectMemberOrPublicAccess,
)
from dstack._internal.server.services import projects
from dstack._internal.server.utils.routers import get_base_api_additional_responses
Expand All @@ -36,7 +36,7 @@ async def list_projects(

`members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them.
"""
return await projects.list_user_projects(session=session, user=user)
return await projects.list_user_accessible_projects(session=session, user=user)


@router.post("/create")
Expand All @@ -49,6 +49,7 @@ async def create_project(
session=session,
user=user,
project_name=body.project_name,
is_public=body.is_public,
)


Expand All @@ -68,7 +69,7 @@ async def delete_projects(
@router.post("/{project_name}/get")
async def get_project(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
) -> Project:
_, project = user_project
return projects.project_model_to_project(project)
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/schemas/projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List
from typing import Annotated, List, Optional

from pydantic import Field

Expand All @@ -8,6 +8,7 @@

class CreateProjectRequest(CoreModel):
project_name: str
is_public: Optional[bool] = False


class DeleteProjectsRequest(CoreModel):
Expand Down
36 changes: 36 additions & 0 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,42 @@ async def __call__(
return await get_project_member(session, project_name, token.credentials)


class ProjectMemberOrPublicAccess:
"""
Allows access to project for:
- Global admins
- Project members
- Any authenticated user if the project is public
"""

async def __call__(
self,
*,
session: AsyncSession = Depends(get_session),
project_name: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()

project = await get_project_model_by_name(session=session, project_name=project_name)
if project is None:
raise error_not_found()

if user.global_role == GlobalRole.ADMIN:
return user, project

project_role = get_user_project_role(user=user, project=project)
if project_role is not None:
return user, project

if project.is_public:
return user, project

raise error_forbidden()


class OptionalServiceAccount:
def __init__(self, token: Optional[str]) -> None:
self._token = token
Expand Down
55 changes: 54 additions & 1 deletion src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,37 @@ async def list_user_projects(
session: AsyncSession,
user: UserModel,
) -> List[Project]:
"""
Returns projects where the user is a member.
"""
if user.global_role == GlobalRole.ADMIN:
projects = await list_project_models(session=session)
else:
projects = await list_user_project_models(session=session, user=user)

projects = sorted(projects, key=lambda p: p.created_at)
return [
project_model_to_project(p, include_backends=False, include_members=False)
for p in projects
]


async def list_user_accessible_projects(
session: AsyncSession,
user: UserModel,
) -> List[Project]:
"""
Returns all projects accessible to the user:
- For global admins: ALL projects in the system
- For regular users: Projects where user is a member + public projects where user is NOT a member
"""
if user.global_role == GlobalRole.ADMIN:
projects = await list_project_models(session=session)
else:
member_projects = await list_user_project_models(session=session, user=user)
public_projects = await list_public_non_member_project_models(session=session, user=user)
projects = member_projects + public_projects

projects = sorted(projects, key=lambda p: p.created_at)
return [
project_model_to_project(p, include_backends=False, include_members=False)
Expand Down Expand Up @@ -86,6 +113,7 @@ async def create_project(
session: AsyncSession,
user: UserModel,
project_name: str,
is_public: bool = False,
) -> Project:
user_permissions = users.get_user_permissions(user)
if not user_permissions.can_create_projects:
Expand All @@ -100,6 +128,7 @@ async def create_project(
session=session,
owner=user,
project_name=project_name,
is_public=is_public,
)
await add_project_member(
session=session,
Expand Down Expand Up @@ -233,6 +262,9 @@ async def list_user_project_models(
user: UserModel,
include_members: bool = False,
) -> List[ProjectModel]:
"""
List project models for a user where they are a member.
"""
options = []
if include_members:
options.append(joinedload(ProjectModel.members))
Expand All @@ -248,6 +280,25 @@ async def list_user_project_models(
return list(res.scalars().unique().all())


async def list_public_non_member_project_models(
session: AsyncSession,
user: UserModel,
) -> List[ProjectModel]:
"""
List public project models where user is NOT a member.
"""
res = await session.execute(
select(ProjectModel).where(
ProjectModel.deleted == False,
ProjectModel.is_public == True,
ProjectModel.id.notin_(
select(MemberModel.project_id).where(MemberModel.user_id == user.id)
),
)
)
return list(res.scalars().all())


async def list_user_owned_project_models(
session: AsyncSession, user: UserModel, include_deleted: bool = False
) -> List[ProjectModel]:
Expand Down Expand Up @@ -323,7 +374,7 @@ async def get_project_model_by_id_or_error(


async def create_project_model(
session: AsyncSession, owner: UserModel, project_name: str
session: AsyncSession, owner: UserModel, project_name: str, is_public: bool = False
) -> ProjectModel:
private_bytes, public_bytes = await run_async(
generate_rsa_key_pair_bytes, f"{project_name}@dstack"
Expand All @@ -334,6 +385,7 @@ async def create_project_model(
name=project_name,
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
is_public=is_public,
)
session.add(project)
await session.commit()
Expand Down Expand Up @@ -407,6 +459,7 @@ def project_model_to_project(
created_at=project_model.created_at.replace(tzinfo=timezone.utc),
backends=backends,
members=members,
is_public=project_model.is_public,
)


Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_project(
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
ssh_private_key: str = "",
ssh_public_key: str = "",
is_public: bool = False,
) -> ProjectModel:
if owner is None:
owner = await create_user(session=session, name="test_owner")
Expand All @@ -149,6 +150,7 @@ async def create_project(
created_at=created_at,
ssh_private_key=ssh_private_key,
ssh_public_key=ssh_public_key,
is_public=is_public,
)
session.add(project)
await session.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/dstack/api/server/_projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

from pydantic import parse_obj_as

Expand All @@ -17,8 +17,8 @@ def list(self) -> List[Project]:
resp = self._request("/api/projects/list")
return parse_obj_as(List[Project.__response__], resp.json())

def create(self, project_name: str) -> Project:
body = CreateProjectRequest(project_name=project_name)
def create(self, project_name: str, is_public: Optional[bool] = False) -> Project:
body = CreateProjectRequest(project_name=project_name, is_public=is_public)
resp = self._request("/api/projects/create", body=body.json())
return parse_obj_as(Project.__response__, resp.json())

Expand Down
Loading