diff --git a/app/config.py b/app/config.py index 84c12cd3..0f936800 100644 --- a/app/config.py +++ b/app/config.py @@ -7,6 +7,12 @@ from app.db.types import StorageType +# size aliases +KB = 1024 +MB = KB * 1024 +GB = MB * 1024 +TB = GB * 1024 + class Settings(BaseSettings): model_config = SettingsConfigDict( @@ -46,18 +52,26 @@ class Settings(BaseSettings): AUTH_CACHE_MAX_TTL: int = 300 # seconds AUTH_CACHE_INFO: bool = False - S3_PRESIGNED_URL_NETLOC: str | None = None # to override the presigned url hostname and port - S3_MULTIPART_THRESHOLD: int = 5 * 1024**2 # bytes # TODO: decide an appropriate value + # to override the presigned url hostname and port when running locally + S3_PRESIGNED_URL_NETLOC: str | None = None S3_PRESIGNED_URL_EXPIRATION: int = 6 * 3600 # 6 hours - - S3_MULTIPART_UPLOAD_MAX_SIZE: int = 1024**4 # 1TB - S3_MULTIPART_UPLOAD_MIN_PART_SIZE: int = 5 * 1024**2 - S3_MULTIPART_UPLOAD_MAX_PART_SIZE: int = 5 * 1024**3 + # upload_fileobj: data flows through the service + S3_MULTIPART_UPLOAD_THRESHOLD: int = 100 * MB + S3_MULTIPART_UPLOAD_CHUNKSIZE: int = 10 * MB + S3_MULTIPART_UPLOAD_MAX_CONCURRENCY: int = 10 + # copy: server-side, data stays in S3 + S3_MULTIPART_COPY_THRESHOLD: int = 5 * GB + S3_MULTIPART_COPY_CHUNKSIZE: int = 1 * GB + S3_MULTIPART_COPY_MAX_CONCURRENCY: int = 10 + + S3_MULTIPART_UPLOAD_MAX_SIZE: int = 1 * TB + S3_MULTIPART_UPLOAD_MIN_PART_SIZE: int = 5 * MB + S3_MULTIPART_UPLOAD_MAX_PART_SIZE: int = 5 * GB S3_MULTIPART_UPLOAD_MIN_PARTS: int = 1 S3_MULTIPART_UPLOAD_MAX_PARTS: int = 10_000 S3_MULTIPART_UPLOAD_DEFAULT_PARTS: int = 100 - API_ASSET_POST_MAX_SIZE: int = 150 * 1024**2 # bytes # TODO: decide an appropriate value + API_ASSET_POST_MAX_SIZE: int = 150 * MB PAGINATION_DEFAULT_PAGE_SIZE: int = 30 PAGINATION_MAX_PAGE_SIZE: int = 200 diff --git a/app/db/utils.py b/app/db/utils.py index 1d09bf32..5bfa40ea 100644 --- a/app/db/utils.py +++ b/app/db/utils.py @@ -6,17 +6,29 @@ from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, RelationshipProperty from app.db.model import ( + Activity, Base, CellMorphologyProtocol, Entity, + ETypeClassification, Identifiable, LocationMixin, MeasurableEntityMixin, + MTypeClassification, ) from app.db.types import CellMorphologyGenerationType, EntityType, ResourceType from app.logger import L from app.schemas.utils import NOT_SET +PublishableBaseModel = Activity | Entity | ETypeClassification | MTypeClassification + +PUBLISHABLE_BASE_CLASSES: list[type[PublishableBaseModel]] = [ + Activity, + Entity, + ETypeClassification, + MTypeClassification, +] + MEASURABLE_ENTITIES: dict[str, type[Entity]] = { mapper.class_.__tablename__: mapper.class_ for mapper in Base.registry.mappers diff --git a/app/routers/admin.py b/app/routers/admin.py index f99c3f74..7f717aba 100644 --- a/app/routers/admin.py +++ b/app/routers/admin.py @@ -1,19 +1,33 @@ import uuid +from typing import Annotated -from fastapi import APIRouter +from fastapi import APIRouter, Query +from app.config import storages +from app.db.types import StorageType from app.db.utils import RESOURCE_TYPE_TO_CLASS from app.dependencies.common import PaginationQuery from app.dependencies.db import RepoGroupDep, SessionDep +from app.dependencies.s3 import StorageClientFactoryDep from app.filters.asset import AssetFilterDep from app.queries.common import router_admin_delete_one from app.schemas.asset import ( AssetRead, ) +from app.schemas.publish import ChangeProjectVisibilityResponse from app.schemas.routers import DeleteResponse from app.schemas.types import ListResponse -from app.service import admin as admin_service, asset as asset_service -from app.utils.routers import EntityRoute, ResourceRoute, entity_route_to_type, route_to_type +from app.service import ( + admin as admin_service, + asset as asset_service, + publish as publish_service, +) +from app.utils.routers import ( + EntityRoute, + ResourceRoute, + entity_route_to_type, + route_to_type, +) router = APIRouter( prefix="/admin", @@ -89,3 +103,71 @@ def delete_entity_asset( ) # Note: Asset storage object is deleted via app.db.events return asset + + +@router.post("/publish-project/{project_id}") +def publish_project( + db: SessionDep, + storage_client_factory: StorageClientFactoryDep, + *, + project_id: uuid.UUID, + max_assets: Annotated[ + int | None, Query(description="Limit the number of assets to be made public.") + ] = None, + dry_run: Annotated[ + bool, Query(description="Simulate the operation without making any change.") + ], +) -> ChangeProjectVisibilityResponse: + """Publish the content of a project. + + This endpoint is used to make public the resources in a project. + + It's recommended to call the endpoint with dry_run=true before running it with dry_run=false. + + If max_assets is specified, the endpoint should be called multiple times until the response + says that the operation is completed. + """ + storage = storages[StorageType.aws_s3_internal] + s3_client = storage_client_factory(storage) + return publish_service.set_project_visibility( + db=db, + s3_client=s3_client, + project_id=project_id, + storage=storage, + max_assets=max_assets, + dry_run=dry_run, + public=True, + ) + + +@router.post("/unpublish-project/{project_id}") +def unpublish_project( + db: SessionDep, + storage_client_factory: StorageClientFactoryDep, + project_id: uuid.UUID, + *, + max_assets: Annotated[ + int | None, Query(description="Limit the number of assets to be made private.") + ] = None, + dry_run: bool, +) -> ChangeProjectVisibilityResponse: + """Unpublish the content of a project. + + This endpoint is used to make private the resources in a project. + + It's recommended to call the endpoint with dry_run=true before running it with dry_run=false. + + If max_assets is specified, the endpoint should be called multiple times until the response + says that the operation is completed. + """ + storage = storages[StorageType.aws_s3_internal] + s3_client = storage_client_factory(storage) + return publish_service.set_project_visibility( + db=db, + s3_client=s3_client, + project_id=project_id, + storage=storage, + max_assets=max_assets, + dry_run=dry_run, + public=False, + ) diff --git a/app/schemas/publish.py b/app/schemas/publish.py new file mode 100644 index 00000000..33cc6d03 --- /dev/null +++ b/app/schemas/publish.py @@ -0,0 +1,66 @@ +import uuid +from typing import Annotated + +from pydantic import BaseModel, Field + + +class MoveFileResult(BaseModel): + size: Annotated[int, Field(description="Size of the file")] + error: str | None = None + + +class MoveDirectoryResult(BaseModel): + size: Annotated[int, Field(description="Size of moved files in the directory")] = 0 + file_count: Annotated[int, Field(description="Number of moved files in the directory")] = 0 + errors: list[str] = [] + + def update_from_file_result(self, file_result: MoveFileResult) -> None: + self.size += file_result.size + self.file_count += 1 + if file_result.error: + self.errors.append(file_result.error) + + +class MoveAssetsResult(BaseModel): + total_size: Annotated[int, Field(description="Total size of moved files")] = 0 + file_count: Annotated[int, Field(description="Number of moved files")] = 0 + asset_count: Annotated[int, Field(description="Number of updated assets")] = 0 + errors: list[str] = [] + + def update_from_file_result(self, file_result: MoveFileResult) -> None: + self.total_size += file_result.size + self.file_count += 1 + self.asset_count += 1 + if file_result.error: + self.errors.append(file_result.error) + + def update_from_directory_result(self, directory_result: MoveDirectoryResult) -> None: + self.total_size += directory_result.size + self.file_count += directory_result.file_count + self.asset_count += 1 + self.errors.extend(directory_result.errors) + + +class ChangeProjectVisibilityResponse(BaseModel): + """Successful response to the publish or unpublish operation.""" + + message: Annotated[str, Field(description="A human-readable message describing the result")] + project_id: Annotated[uuid.UUID, Field(description="ID of the project")] + public: Annotated[bool, Field(description="Whether the content is now public or private")] + resource_count: Annotated[ + int, + Field(description="Number of updated resources (activities, entities, classifications)"), + ] + move_assets_result: Annotated[ + MoveAssetsResult, Field(description="Result of the assets movement") + ] + dry_run: Annotated[bool, Field(description="True if the operation has been simulated only")] + completed: Annotated[ + bool, + Field( + description=( + "Whether the assets have been fully updated. It may be False if `max_assets` " + "have been specified, and there are still assets to be moved." + ) + ), + ] diff --git a/app/service/publish.py b/app/service/publish.py new file mode 100644 index 00000000..2e50c143 --- /dev/null +++ b/app/service/publish.py @@ -0,0 +1,187 @@ +import uuid +from itertools import batched + +import sqlalchemy as sa +from sqlalchemy.orm import Session +from types_boto3_s3 import S3Client + +from app.config import StorageUnion +from app.db.model import Asset, Entity +from app.db.types import AssetStatus, StorageType +from app.db.utils import PUBLISHABLE_BASE_CLASSES, PublishableBaseModel +from app.logger import L +from app.schemas.publish import ChangeProjectVisibilityResponse, MoveAssetsResult +from app.utils.s3 import ( + convert_s3_path_visibility, + get_s3_path_prefix, + move_directory, + move_file, +) + +BATCH_SIZE = 500 + + +def _set_base_class_visibility( + db: Session, + project_id: uuid.UUID, + db_model_class: type[PublishableBaseModel], + *, + public: bool, +) -> int: + """Update authorized_public in all the resources of a single base class. + + Rows are updated directly without loading the ORM models, so it's not possible to fire any + SQLAlchemy event, but it is more efficient as it avoids loading all the resources in memory. + + Returns the number of updated resources. + """ + result = db.execute( + sa.update(db_model_class) + .where( + db_model_class.authorized_project_id == project_id, + db_model_class.authorized_public.is_(not public), + ) + .values( + authorized_public=public, + update_date=db_model_class.update_date, # preserve update_date + ) + ) + return result.rowcount # type: ignore[attr-defined] + + +def _set_assets_visibility( + db: Session, + *, + s3_client: S3Client, + project_id: uuid.UUID, + bucket_name: str, + storage_type: StorageType, + max_assets: int | None, + dry_run: bool, + public: bool, +) -> MoveAssetsResult: + """Move assets from private to public in S3 or vice versa, and update their path in the db. + + Rows are updated in batches directly after loading the ORM models for better efficiency. + + This function must be called after the entities have been converted to public (private). + It ignores any private (public) entity added concurrently, because the query applies + a filter on `Entity.authorized_public`. + + Returns the total number of assets and files moved, and their total size. + """ + old_prefix = get_s3_path_prefix(public=not public) + private_assets = db.execute( + sa.select(Asset) + .join(Entity, Entity.id == Asset.entity_id) + .where( + Entity.authorized_project_id == project_id, + Entity.authorized_public.is_(public), + Asset.storage_type == storage_type, + Asset.status == AssetStatus.CREATED, + Asset.full_path.like(f"{old_prefix}%"), + ) + .with_for_update(of=Asset) + .limit(max_assets) + ).scalars() + move_result = MoveAssetsResult() + for batch in batched(private_assets, BATCH_SIZE): + path_mapping: dict[uuid.UUID, str] = {} + L.info("Processing batch of {} assets [dry_run={}]", len(batch), dry_run) + for asset in batch: + src_key = asset.full_path + dst_key = convert_s3_path_visibility(asset.full_path, public=public) + if asset.is_directory: + move_result.update_from_directory_result( + move_directory( + s3_client, + src_bucket_name=bucket_name, + dst_bucket_name=bucket_name, + src_key=src_key, + dst_key=dst_key, + dry_run=dry_run, + ) + ) + else: + move_result.update_from_file_result( + move_file( + s3_client, + src_bucket_name=bucket_name, + dst_bucket_name=bucket_name, + src_key=src_key, + dst_key=dst_key, + size=asset.size, + dry_run=dry_run, + ) + ) + path_mapping[asset.id] = dst_key + db.expunge(asset) # free memory from session's identity map + db.execute( + sa.update(Asset) + .where(Asset.id.in_(path_mapping)) + .values( + full_path=sa.case(path_mapping, value=Asset.id), + update_date=Asset.update_date, # preserve update_date + ) + ) + return move_result + + +def set_project_visibility( + db: Session, + *, + s3_client: S3Client, + project_id: uuid.UUID, + storage: StorageUnion, + max_assets: int | None, + dry_run: bool, + public: bool = True, +) -> ChangeProjectVisibilityResponse: + """Change the visibility of entities, activities, classifications, and assets in a project. + + If public is True, all the resources are made public. + If public is False, all the resources are made private if possible, or it should fail + if any resource has been used in other projects. + + The function can be called multiple times sequentially, to update max_assets per request. + """ + savepoint = db.begin_nested() + description = "public" if public else "private" + resource_count = 0 + for db_model_class in PUBLISHABLE_BASE_CLASSES: + L.info( + "Updating table {} to {} for project {} [dry_run={}]", + db_model_class.__tablename__, + description, + project_id, + dry_run, + ) + resource_count += _set_base_class_visibility( + db=db, + project_id=project_id, + db_model_class=db_model_class, + public=public, + ) + L.info("Updating assets to {} for project {} [dry_run={}]", description, project_id, dry_run) + move_result = _set_assets_visibility( + db=db, + s3_client=s3_client, + project_id=project_id, + bucket_name=storage.bucket, + storage_type=storage.type, + max_assets=max_assets, + dry_run=dry_run, + public=public, + ) + if dry_run: + savepoint.rollback() + completed = max_assets is None or move_result.asset_count < max_assets + return ChangeProjectVisibilityResponse( + message=f"Project resources successfully made {description}", + project_id=project_id, + resource_count=resource_count, + move_assets_result=move_result, + dry_run=dry_run, + public=public, + completed=completed, + ) diff --git a/app/utils/s3.py b/app/utils/s3.py index 4f29ef6a..2c0e2196 100644 --- a/app/utils/s3.py +++ b/app/utils/s3.py @@ -12,19 +12,39 @@ from botocore.exceptions import ClientError from fastapi import HTTPException from types_boto3_s3 import S3Client -from types_boto3_s3.type_defs import PaginatorConfigTypeDef +from types_boto3_s3.type_defs import ( + CopySourceTypeDef, + DeleteObjectRequestTypeDef, + PaginatorConfigTypeDef, +) from app.config import StorageUnion, settings, storages from app.db.types import EntityType, StorageType from app.logger import L from app.schemas.asset import validate_path +from app.schemas.publish import MoveDirectoryResult, MoveFileResult from app.utils.common import clip +PUBLIC_ASSET_PREFIX = "public/" +PRIVATE_ASSET_PREFIX = "private/" + class StorageClientFactory(Protocol): def __call__(self, storage: StorageUnion) -> S3Client: ... +def ensure_directory_prefix(prefix: str) -> str: + """Return the prefix with a trailing '/' if it's missing.""" + if not prefix.endswith("/"): + prefix += "/" + return prefix + + +def get_s3_path_prefix(*, public: bool) -> str: + """Return the S3 path prefix for public or private assets.""" + return PUBLIC_ASSET_PREFIX if public else PRIVATE_ASSET_PREFIX + + def build_s3_path( *, vlab_id: UUID, @@ -35,8 +55,23 @@ def build_s3_path( is_public: bool, ) -> str: """Return the key used to store the file on S3.""" - prefix = "public" if is_public else "private" - return f"{prefix}/{vlab_id}/{proj_id}/assets/{entity_type.name}/{entity_id}/{filename}" + prefix = get_s3_path_prefix(public=is_public) + return f"{prefix}{vlab_id}/{proj_id}/assets/{entity_type.name}/{entity_id}/{filename}" + + +def convert_s3_path_visibility(s3_path: str, *, public: bool) -> str: + """Convert a private S3 path to a public one, or vice versa. + + Args: + s3_path: the original S3 path. + public: whether the returned path should be public or private. + """ + old_prefix = get_s3_path_prefix(public=not public) + new_prefix = get_s3_path_prefix(public=public) + if not s3_path.startswith(old_prefix): + msg = f"S3 path must start with {old_prefix!r}." + raise ValueError(msg) + return new_prefix + s3_path.removeprefix(old_prefix) def validate_filename(filename: str) -> bool: @@ -97,13 +132,16 @@ def upload_to_s3( bucket_name: name of the S3 bucket. s3_key: S3 object key (destination path in the bucket). """ - config = TransferConfig(multipart_threshold=settings.S3_MULTIPART_THRESHOLD) try: s3_client.upload_fileobj( file_obj, Bucket=bucket_name, Key=s3_key, - Config=config, + Config=TransferConfig( + multipart_threshold=settings.S3_MULTIPART_UPLOAD_THRESHOLD, + multipart_chunksize=settings.S3_MULTIPART_UPLOAD_CHUNKSIZE, + max_concurrency=settings.S3_MULTIPART_UPLOAD_MAX_CONCURRENCY, + ), ) except Exception: # noqa: BLE001 L.exception("Error while uploading file to s3://{}/{}", bucket_name, s3_key) @@ -273,9 +311,8 @@ def list_directory_with_details( pagination_config: PaginatorConfigTypeDef | None = None, ) -> dict: # with `prefix="foo/asdf" argument will match all `foo/asdf/` and `foo/asdf_asdf/, - # insure we have a ending / to prevent being promiscuous - if not prefix.endswith("/"): - prefix += "/" + # ensure we have a ending / to prevent being promiscuous + prefix = ensure_directory_prefix(prefix) paginator = s3_client.get_paginator("list_objects_v2") files = {} @@ -323,3 +360,122 @@ def check_object( return {"exists": False} raise return {"exists": True, "type": object_type} + + +def copy_file( + s3_client: S3Client, + *, + src_bucket_name: str, + dst_bucket_name: str, + src_key: str, + dst_key: str, +) -> bool: + """Copy a file in S3, using multipart copy for large objects. + + See https://docs.aws.amazon.com/boto3/latest/reference/services/s3/client/copy.html + """ + copy_source: CopySourceTypeDef = { + "Bucket": src_bucket_name, + "Key": src_key, + } + try: + s3_client.copy( + CopySource=copy_source, + Bucket=dst_bucket_name, + Key=dst_key, + Config=TransferConfig( + multipart_threshold=settings.S3_MULTIPART_COPY_THRESHOLD, + multipart_chunksize=settings.S3_MULTIPART_COPY_CHUNKSIZE, + max_concurrency=settings.S3_MULTIPART_COPY_MAX_CONCURRENCY, + ), + ) + except Exception: # noqa: BLE001 + L.exception( + "Error while copying file from s3://{}/{} to s3://{}/{}", + src_bucket_name, + src_key, + dst_bucket_name, + dst_key, + ) + return False + return True + + +def move_file( + s3_client: S3Client, + *, + src_bucket_name: str, + dst_bucket_name: str, + src_key: str, + dst_key: str, + size: int, + dry_run: bool, +) -> MoveFileResult: + """Move a file in S3 by copying it to the new location and deleting the original.""" + if (src_bucket_name, src_key) == (dst_bucket_name, dst_key): + msg = "Source and destination cannot be the same." + raise ValueError(msg) + if dry_run: + return MoveFileResult(size=size, error=None) + try: + # check if the source object exists and get its metadata + src_head = s3_client.head_object(Bucket=src_bucket_name, Key=src_key) + except ClientError as e: + if e.response.get("Error", {}).get("Code") != "404": + raise + try: + # check if the destination object already exists + s3_client.head_object(Bucket=dst_bucket_name, Key=dst_key) + except ClientError: + msg = f"Failed to get object s3://{src_bucket_name}/{src_key}" + L.warning(msg) + return MoveFileResult(size=size, error=msg) + L.warning("Source already moved: s3://{}/{}", src_bucket_name, src_key) + return MoveFileResult(size=size, error=None) + if not copy_file( + s3_client, + src_bucket_name=src_bucket_name, + dst_bucket_name=dst_bucket_name, + src_key=src_key, + dst_key=dst_key, + ): + msg = ( + f"Failed to copy object from s3://{src_bucket_name}/{src_key} " + f"to s3://{dst_bucket_name}/{dst_key}" + ) + L.warning(msg) + return MoveFileResult(size=size, error=msg) + # delete the original object without leaving a delete marker when versioning is enabled + delete_kwargs: DeleteObjectRequestTypeDef = {"Bucket": src_bucket_name, "Key": src_key} + if src_version_id := src_head.get("VersionId"): + delete_kwargs["VersionId"] = src_version_id + s3_client.delete_object(**delete_kwargs) + return MoveFileResult(size=size, error=None) + + +def move_directory( + s3_client: S3Client, + *, + src_bucket_name: str, + dst_bucket_name: str, + src_key: str, + dst_key: str, + dry_run: bool, +) -> MoveDirectoryResult: + """Move a directory in S3 by copying it to the new location and deleting the original.""" + src_key = ensure_directory_prefix(src_key) + dst_key = ensure_directory_prefix(dst_key) + objects = list_directory_with_details(s3_client, bucket_name=src_bucket_name, prefix=src_key) + move_directory_result = MoveDirectoryResult(size=0, file_count=0) + for obj in objects.values(): + move_file_result = move_file( + s3_client, + src_bucket_name=src_bucket_name, + dst_bucket_name=dst_bucket_name, + src_key=f"{src_key}{obj['name']}", + dst_key=f"{dst_key}{obj['name']}", + size=obj["size"], + dry_run=dry_run, + ) + move_directory_result.update_from_file_result(move_file_result) + return move_directory_result diff --git a/tests/conftest.py b/tests/conftest.py index 55fe9de9..78045eb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from sqlalchemy import text from sqlalchemy.orm import Session +from types_boto3_s3 import S3Client from app.application import app from app.config import storages @@ -123,8 +124,12 @@ def s3_open_bucket(): @pytest.fixture(scope="session") -def _create_buckets(s3, s3_internal_bucket, s3_open_bucket): +def _create_buckets(s3: S3Client, s3_internal_bucket, s3_open_bucket): s3.create_bucket(Bucket=s3_internal_bucket) + s3.put_bucket_versioning( + Bucket=s3_internal_bucket, + VersioningConfiguration={"Status": "Enabled"}, + ) s3.create_bucket(Bucket=s3_open_bucket, ACL="public-read") diff --git a/tests/schemas/test_publish.py b/tests/schemas/test_publish.py new file mode 100644 index 00000000..65321297 --- /dev/null +++ b/tests/schemas/test_publish.py @@ -0,0 +1,61 @@ +from app.schemas.publish import MoveAssetsResult, MoveDirectoryResult, MoveFileResult + + +def test_move_directory_result_update_from_file_result_with_error(): + result = MoveDirectoryResult() + result.update_from_file_result(MoveFileResult(size=10, error="some error")) + + assert result.size == 10 + assert result.file_count == 1 + assert result.errors == ["some error"] + + +def test_move_directory_result_update_from_file_result_without_error(): + result = MoveDirectoryResult() + result.update_from_file_result(MoveFileResult(size=10, error=None)) + + assert result.size == 10 + assert result.file_count == 1 + assert result.errors == [] + + +def test_move_assets_result_update_from_file_result_with_error(): + result = MoveAssetsResult() + result.update_from_file_result(MoveFileResult(size=5, error="file error")) + + assert result.total_size == 5 + assert result.file_count == 1 + assert result.asset_count == 1 + assert result.errors == ["file error"] + + +def test_move_assets_result_update_from_file_result_without_error(): + result = MoveAssetsResult() + result.update_from_file_result(MoveFileResult(size=5, error=None)) + + assert result.total_size == 5 + assert result.file_count == 1 + assert result.asset_count == 1 + assert result.errors == [] + + +def test_move_assets_result_update_from_directory_result_with_errors(): + result = MoveAssetsResult() + dir_result = MoveDirectoryResult(size=20, file_count=3, errors=["err1", "err2"]) + result.update_from_directory_result(dir_result) + + assert result.total_size == 20 + assert result.file_count == 3 + assert result.asset_count == 1 + assert result.errors == ["err1", "err2"] + + +def test_move_assets_result_update_from_directory_result_without_errors(): + result = MoveAssetsResult() + dir_result = MoveDirectoryResult(size=20, file_count=3) + result.update_from_directory_result(dir_result) + + assert result.total_size == 20 + assert result.file_count == 3 + assert result.asset_count == 1 + assert result.errors == [] diff --git a/tests/test_publish.py b/tests/test_publish.py new file mode 100644 index 00000000..44ca86e4 --- /dev/null +++ b/tests/test_publish.py @@ -0,0 +1,303 @@ +import io +import uuid + +import pytest +import sqlalchemy as sa + +from app.config import storages +from app.db.model import Asset, Entity +from app.db.types import EntityType, StorageType +from app.utils.s3 import PRIVATE_ASSET_PREFIX, PUBLIC_ASSET_PREFIX, build_s3_path + +from tests.utils import ( + PROJECT_ID, + VIRTUAL_LAB_ID, + add_db, + assert_request, + s3_key_exists, + upload_entity_asset, +) + +PUBLISH_URL = "/admin/publish-project" +UNPUBLISH_URL = "/admin/unpublish-project" + + +@pytest.fixture +def private_morphology_with_asset(client, subject_id, brain_region_id, tmp_path): + entity_id = assert_request( + client.post, + url="/cell-morphology", + json={ + "name": "Private Morphology", + "description": "desc", + "brain_region_id": str(brain_region_id), + "subject_id": str(subject_id), + "location": {"x": 0, "y": 0, "z": 0}, + "authorized_public": False, + }, + ).json()["id"] + + filepath = tmp_path / "morph.asc" + filepath.write_bytes(b"morphology data") + upload_entity_asset( + client=client, + entity_type=EntityType.cell_morphology, + entity_id=entity_id, + files={"file": ("morph.asc", filepath.read_bytes(), "application/asc")}, + label="morphology", + ) + return uuid.UUID(entity_id) + + +@pytest.fixture +def private_circuit_with_directory_asset(db, s3, circuit, person_id): + s3_path = build_s3_path( + vlab_id=VIRTUAL_LAB_ID, + proj_id=PROJECT_ID, + entity_type=EntityType.circuit, + entity_id=circuit.id, + filename="my-directory", + is_public=False, + ) + bucket = storages[StorageType.aws_s3_internal].bucket + directory_files = ["circuit_config.json", "nodes.h5", "edges.h5"] + for fname in directory_files: + s3.upload_fileobj(io.BytesIO(b"data"), Bucket=bucket, Key=f"{s3_path}/{fname}") + + asset = add_db( + db, + Asset( + path="my-directory", + full_path=s3_path, + status="created", + is_directory=True, + content_type="application/vnd.directory", + size=0, + sha256_digest=None, + meta={}, + entity_id=circuit.id, + created_by_id=person_id, + updated_by_id=person_id, + label="sonata_circuit", + storage_type=StorageType.aws_s3_internal, + ), + ) + return circuit.id, asset.id, directory_files + + +def _get_entity(db, entity_id): + return db.execute(sa.select(Entity).where(Entity.id == entity_id)).scalar_one() + + +def _get_asset(db, entity_id): + return db.execute(sa.select(Asset).where(Asset.entity_id == entity_id)).scalar_one() + + +def _publish(client, project_id, *, dry_run=False, max_assets=None): + params = {"dry_run": dry_run} + if max_assets is not None: + params["max_assets"] = max_assets + return client.post(f"{PUBLISH_URL}/{project_id}", params=params) + + +def _unpublish(client, project_id, *, dry_run=False, max_assets=None): + params = {"dry_run": dry_run} + if max_assets is not None: + params["max_assets"] = max_assets + return client.post(f"{UNPUBLISH_URL}/{project_id}", params=params) + + +def test_publish(db, client_admin, s3, private_morphology_with_asset): + entity_id = private_morphology_with_asset + + asset_before = _get_asset(db, entity_id) + old_path = asset_before.full_path + assert old_path.startswith(PRIVATE_ASSET_PREFIX) + assert s3_key_exists(s3, key=old_path) + + response = _publish(client_admin, PROJECT_ID, dry_run=False) + assert response.status_code == 200 + data = response.json() + + assert data["public"] is True + assert data["dry_run"] is False + # resource_count can be > 1 in case there are other resources besides the morphology + assert data["resource_count"] >= 1 + assert data["move_assets_result"]["asset_count"] == 1 + assert data["move_assets_result"]["file_count"] == 1 + assert data["move_assets_result"]["total_size"] > 0 + assert data["completed"] is True + + db.expire_all() + entity = _get_entity(db, entity_id) + assert entity.authorized_public is True + + asset = _get_asset(db, entity_id) + assert asset.full_path.startswith(PUBLIC_ASSET_PREFIX) + assert s3_key_exists(s3, key=asset.full_path) + assert not s3_key_exists(s3, key=old_path) + + +def test_publish_dry_run(db, client_admin, private_morphology_with_asset): + entity_id = private_morphology_with_asset + + response = _publish(client_admin, PROJECT_ID, dry_run=True) + assert response.status_code == 200 + data = response.json() + + assert data["project_id"] == PROJECT_ID + assert data["public"] is True + assert data["dry_run"] is True + # resource_count can be > 1 in case there are other resources besides the morphology + assert data["resource_count"] >= 1 + assert data["completed"] is True + + db.expire_all() + entity = _get_entity(db, entity_id) + assert entity.authorized_public is False + + asset = _get_asset(db, entity_id) + assert asset.full_path.startswith(PRIVATE_ASSET_PREFIX) + + +@pytest.mark.usefixtures("private_morphology_with_asset") +def test_publish_with_max_assets(client_admin): + response = _publish(client_admin, PROJECT_ID, dry_run=False, max_assets=1) + assert response.status_code == 200 + data = response.json() + assert data["move_assets_result"]["asset_count"] == 1 + assert data["completed"] is False + + response = _publish(client_admin, PROJECT_ID, dry_run=False, max_assets=1) + assert response.status_code == 200 + data = response.json() + assert data["move_assets_result"]["asset_count"] == 0 + assert data["completed"] is True + + +def test_publish_no_resources(client_admin): + empty_project_id = str(uuid.uuid4()) + response = _publish(client_admin, empty_project_id, dry_run=False) + assert response.status_code == 200 + data = response.json() + assert data["resource_count"] == 0 + assert data["move_assets_result"]["asset_count"] == 0 + assert data["completed"] is True + + +def test_unpublish(db, client_admin, s3, private_morphology_with_asset): + entity_id = private_morphology_with_asset + + _publish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + + asset_before = _get_asset(db, entity_id) + old_path = asset_before.full_path + assert old_path.startswith(PUBLIC_ASSET_PREFIX) + assert s3_key_exists(s3, key=old_path) + + response = _unpublish(client_admin, PROJECT_ID, dry_run=False) + assert response.status_code == 200 + data = response.json() + + assert data["public"] is False + assert data["dry_run"] is False + # resource_count can be > 1 in case there are other resources besides the morphology + assert data["resource_count"] >= 1 + assert data["move_assets_result"]["asset_count"] == 1 + assert data["completed"] is True + + db.expire_all() + entity = _get_entity(db, entity_id) + assert entity.authorized_public is False + + asset = _get_asset(db, entity_id) + assert asset.full_path.startswith(PRIVATE_ASSET_PREFIX) + assert s3_key_exists(s3, key=asset.full_path) + assert not s3_key_exists(s3, key=old_path) + + +def test_unpublish_dry_run(db, client_admin, private_morphology_with_asset): + entity_id = private_morphology_with_asset + + _publish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + + response = _unpublish(client_admin, PROJECT_ID, dry_run=True) + assert response.status_code == 200 + data = response.json() + + assert data["public"] is False + assert data["dry_run"] is True + + db.expire_all() + entity = _get_entity(db, entity_id) + assert entity.authorized_public is True + + asset = _get_asset(db, entity_id) + assert asset.full_path.startswith(PUBLIC_ASSET_PREFIX) + + +def test_publish_then_unpublish(db, client_admin, s3, private_morphology_with_asset): + entity_id = private_morphology_with_asset + + asset_before = _get_asset(db, entity_id) + original_path = asset_before.full_path + + _publish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + assert _get_entity(db, entity_id).authorized_public is True + assert _get_asset(db, entity_id).full_path.startswith(PUBLIC_ASSET_PREFIX) + + _unpublish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + assert _get_entity(db, entity_id).authorized_public is False + + asset_after = _get_asset(db, entity_id) + assert asset_after.full_path.startswith(PRIVATE_ASSET_PREFIX) + assert asset_after.full_path == original_path + assert s3_key_exists(s3, key=asset_after.full_path) + + +def test_publish_directory_asset(db, client_admin, s3, private_circuit_with_directory_asset): + _entity_id, asset_id, directory_files = private_circuit_with_directory_asset + + asset_before = db.get(Asset, asset_id) + old_path = asset_before.full_path + assert old_path.startswith(PRIVATE_ASSET_PREFIX) + for fname in directory_files: + assert s3_key_exists(s3, key=f"{old_path}/{fname}") + + response = _publish(client_admin, PROJECT_ID, dry_run=False) + assert response.status_code == 200 + data = response.json() + assert data["move_assets_result"]["file_count"] == len(directory_files) + + db.expire_all() + asset_after = db.get(Asset, asset_id) + assert asset_after.full_path.startswith(PUBLIC_ASSET_PREFIX) + for fname in directory_files: + assert s3_key_exists(s3, key=f"{asset_after.full_path}/{fname}") + assert not s3_key_exists(s3, key=f"{old_path}/{fname}") + + +def test_unpublish_directory_asset(db, client_admin, s3, private_circuit_with_directory_asset): + _entity_id, asset_id, directory_files = private_circuit_with_directory_asset + + asset_before = db.get(Asset, asset_id) + original_path = asset_before.full_path + + _publish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + + public_path = db.get(Asset, asset_id).full_path + assert public_path.startswith(PUBLIC_ASSET_PREFIX) + + _unpublish(client_admin, PROJECT_ID, dry_run=False) + db.expire_all() + + asset_after = db.get(Asset, asset_id) + assert asset_after.full_path == original_path + for fname in directory_files: + assert s3_key_exists(s3, key=f"{asset_after.full_path}/{fname}") + assert not s3_key_exists(s3, key=f"{public_path}/{fname}") diff --git a/tests/test_utils/test_s3.py b/tests/test_utils/test_s3.py index 80dcf4a6..e6aabbea 100644 --- a/tests/test_utils/test_s3.py +++ b/tests/test_utils/test_s3.py @@ -1,6 +1,9 @@ +import io import math from pathlib import Path +from unittest.mock import Mock +import botocore.exceptions import pytest from app.config import settings @@ -9,6 +12,24 @@ from tests.utils import PROJECT_ID, VIRTUAL_LAB_ID +pytestmark = pytest.mark.usefixtures("_create_buckets") + + +def _upload(s3, bucket, key, data=b"content"): + s3.upload_fileobj(io.BytesIO(data), Bucket=bucket, Key=key) + + +def _exists(s3, bucket, key): + try: + s3.head_object(Bucket=bucket, Key=key) + except botocore.exceptions.ClientError: + return False + return True + + +def _read(s3, bucket, key): + return s3.get_object(Bucket=bucket, Key=key)["Body"].read() + def test_build_s3_path_private(): entity_id = 123456 @@ -175,3 +196,215 @@ def test_validate_multipart_filesize(): max_size = settings.S3_MULTIPART_UPLOAD_MAX_SIZE assert test_module.validate_multipart_filesize(max_size) is True assert test_module.validate_multipart_filesize(max_size + 1) is False + + +def test_ensure_directory_prefix(): + assert test_module.ensure_directory_prefix("foo") == "foo/" + assert test_module.ensure_directory_prefix("foo/") == "foo/" + assert test_module.ensure_directory_prefix("a/b/c") == "a/b/c/" + + +def test_convert_s3_path_visibility_private_to_public(): + path = "private/vlab/proj/assets/morph/1/file.asc" + result = test_module.convert_s3_path_visibility(path, public=True) + assert result == "public/vlab/proj/assets/morph/1/file.asc" + + +def test_convert_s3_path_visibility_public_to_private(): + path = "public/vlab/proj/assets/morph/1/file.asc" + result = test_module.convert_s3_path_visibility(path, public=False) + assert result == "private/vlab/proj/assets/morph/1/file.asc" + + +def test_convert_s3_path_visibility_wrong_prefix(): + with pytest.raises(ValueError, match="must start with"): + test_module.convert_s3_path_visibility("public/file.txt", public=True) + with pytest.raises(ValueError, match="must start with"): + test_module.convert_s3_path_visibility("private/file.txt", public=False) + + +def test_copy_file(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "src/file.txt", b"hello") + + result = test_module.copy_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/file.txt", + dst_key="dst/file.txt", + ) + + assert result is True + assert _read(s3, bucket, "dst/file.txt") == b"hello" + assert _exists(s3, bucket, "src/file.txt") + + +def test_copy_file_source_missing(s3, s3_internal_bucket): + bucket = s3_internal_bucket + + result = test_module.copy_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/nonexistent.txt", + dst_key="dst/nonexistent.txt", + ) + + assert result is False + + +def test_move_file(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "src/move.txt", b"move me") + + result = test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/move.txt", + dst_key="dst/move.txt", + size=7, + dry_run=False, + ) + + assert result.size == 7 + assert result.error is None + assert _read(s3, bucket, "dst/move.txt") == b"move me" + assert not _exists(s3, bucket, "src/move.txt") + + +def test_move_file_dry_run(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "src/dry.txt", b"stay") + + result = test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/dry.txt", + dst_key="dst/dry.txt", + size=4, + dry_run=True, + ) + + assert result.size == 4 + assert result.error is None + assert _exists(s3, bucket, "src/dry.txt") + assert not _exists(s3, bucket, "dst/dry.txt") + + +def test_move_file_same_src_dst(s3, s3_internal_bucket): + bucket = s3_internal_bucket + with pytest.raises(ValueError, match="Source and destination cannot be the same"): + test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="same.txt", + dst_key="same.txt", + size=1, + dry_run=False, + ) + + +def test_move_file_already_moved(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "dst/already.txt", b"already there") + + result = test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/already.txt", + dst_key="dst/already.txt", + size=13, + dry_run=False, + ) + + assert result.size == 13 + assert result.error is None + assert _read(s3, bucket, "dst/already.txt") == b"already there" + + +def test_move_file_source_missing_dest_missing(s3, s3_internal_bucket): + bucket = s3_internal_bucket + result = test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/gone.txt", + dst_key="dst/gone.txt", + size=1, + dry_run=False, + ) + + assert result.size == 1 + assert result.error is not None + assert "Failed to get object" in result.error + + +def test_move_file_copy_failure(s3, s3_internal_bucket, monkeypatch): + """Test move_file when copy_file fails after head_object succeeds.""" + bucket = s3_internal_bucket + _upload(s3, bucket, "src/copy_fail.txt", b"data") + + copy_file_mock = Mock(return_value=False) + monkeypatch.setattr(test_module, "copy_file", copy_file_mock) + result = test_module.move_file( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="src/copy_fail.txt", + dst_key="dst/copy_fail.txt", + size=4, + dry_run=False, + ) + + assert copy_file_mock.call_count == 1 + assert result.size == 4 + assert result.error is not None + assert "Failed to copy object" in result.error + # source file should still exist since copy failed + assert _exists(s3, bucket, "src/copy_fail.txt") + + +def test_move_directory(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "srcdir/a.txt", b"aaa") + _upload(s3, bucket, "srcdir/b.txt", b"bb") + + result = test_module.move_directory( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="srcdir", + dst_key="dstdir", + dry_run=False, + ) + + assert result.file_count == 2 + assert result.size == 5 + assert _read(s3, bucket, "dstdir/a.txt") == b"aaa" + assert _read(s3, bucket, "dstdir/b.txt") == b"bb" + assert not _exists(s3, bucket, "srcdir/a.txt") + assert not _exists(s3, bucket, "srcdir/b.txt") + + +def test_move_directory_dry_run(s3, s3_internal_bucket): + bucket = s3_internal_bucket + _upload(s3, bucket, "srcdir2/c.txt", b"ccc") + + result = test_module.move_directory( + s3, + src_bucket_name=bucket, + dst_bucket_name=bucket, + src_key="srcdir2", + dst_key="dstdir2", + dry_run=True, + ) + + assert result.file_count == 1 + assert result.size == 3 + assert _exists(s3, bucket, "srcdir2/c.txt") + assert not _exists(s3, bucket, "dstdir2/c.txt")