Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum

ALLOWED_FILE_TYPES = ['.png', '.jpg']
S3_DELIMITER = "/"

class StorageOrigin(Enum):
LOCAL = 'local'
S3 = 's3'
209 changes: 180 additions & 29 deletions core/management/commands/importrois.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,218 @@
import os
import boto3
from pathlib import Path

from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth.models import User

from core.models import ROI, Annotation, ImageCollection, Label
from services.s3_service import S3Service
from constants import StorageOrigin, ALLOWED_FILE_TYPES, S3_DELIMITER



class Command(BaseCommand):
help = 'import rois'

def add_arguments(self, parser):
parser.add_argument('directory', type=str, help='directory containing images')
parser.add_argument('directory', type=str, help='directory (or prefix if using S3) containing images')
parser.add_argument('-c','--collection', type=str, help='image collection to create or add images to')
parser.add_argument('-b', '--bucket', type=str, help='the bucket when importing from S3')
parser.add_argument('-u','--user', type=str, help='username for any created annotations (user must exist)')
parser.add_argument('--include-rois', type=str, help='only import rois matching a comma separated list')
parser.add_argument('--include-rois-file', type=str, help='only import rois found in a file (new line separated)')
parser.add_argument('--prefix', type=str, help='a prefix to prepend to all ROIs used with --include-rois or --include-rois-file')

origin_choices = [StorageOrigin.LOCAL.value, StorageOrigin.S3.value,]
parser.add_argument('-o','--origin', type=str, choices=origin_choices, default='local', help='storage type to use (local or s3)')

def is_roi_included(self, filename, included_rois):
if not included_rois:
return True

is_included = any(roi in filename for roi in included_rois)

return is_included

def scan_local(self, directory, included_rois=None):
unlabeled = []
labeled = {}
folders = []

# First, loop through the directory and sort entries into unlabeled files and top level folders
for entry in os.listdir(directory):
name, ext = os.path.splitext(entry)
if ext in ALLOWED_FILE_TYPES and self.is_roi_included(name, included_rois):
unlabeled.append(entry)
continue

path = os.path.join(directory, entry)
if os.path.isdir(path):
folders.append(entry)
continue

# For each top level folder, use that as the label and get all the files inside
for folder in folders:
labeled[folder] = []

path = os.path.join(directory, folder)
for entry in os.listdir(path):
name, ext = os.path.splitext(entry)
if ext not in ALLOWED_FILE_TYPES or not self.is_roi_included(name, included_rois):
continue

labeled[folder].append(entry)

return unlabeled, labeled

def scan_s3(self, s3_client, bucket, directory, included_rois=None):
unlabeled = []
labeled = {}
folders = []

# Intentionally using "list_objects" here instead of "list_objects_v2" to work around potential permission or
# feature restrictions when using VAST as the backend storage resource
paginator = s3_client.get_paginator('list_objects')

for page in paginator.paginate(Bucket=bucket, Delimiter=S3_DELIMITER, Prefix=directory):
for cp in page.get("CommonPrefixes", []):
folder = cp.get("Prefix")

if directory != "":
folder = folder.removeprefix(directory)

folders.append(folder)

for obj in page.get("Contents", []):
filename = obj['Key']

# Ignore folders and files within folders
if filename.endswith('/'):
continue

# Remove the directory/path if there is one
if directory != "":
filename = filename.removeprefix(directory)

name, ext = os.path.splitext(filename)
if ext not in ALLOWED_FILE_TYPES or not self.is_roi_included(name, included_rois):
continue

unlabeled.append(filename)

for folder in folders:
key = folder.rstrip("/")
labeled[key] = []
prefix = os.path.join(directory, folder)

for page in paginator.paginate(Bucket=bucket, Delimiter=S3_DELIMITER, Prefix=prefix):

for obj in page.get("Contents", []):
filename = obj['Key'].removeprefix(prefix)

# Ignore subfolders and files within subfolders
if "/" in filename:
continue

# Remove the directory/path if there is one
if directory != "":
filename = filename.removeprefix(directory)

name, ext = os.path.splitext(filename)
if ext not in ALLOWED_FILE_TYPES or not self.is_roi_included(name, included_rois):
continue

labeled[key].append(filename)

return unlabeled, labeled

def handle(self, *args, **options):
# handle arguments
directory = options['directory']
collection_name = options.get('collection')
username = options.get('user')
origin = options.get('origin')
bucket = options.get('bucket')
rois_list = options.get('include_rois')
rois_file = options.get('include_rois_file')
prefix = options.get('prefix') or ''
s3_client = S3Service.get_client() if origin == StorageOrigin.S3.value else None

# validate arguments
if not os.path.exists(directory):

if rois_list and rois_file:
raise CommandError('the include-rois and include-rois-file arguments cannot be used at the same time')

if rois_file:
path = Path(rois_file)
if not path.exists() or not path.is_file():
raise CommandError(f"file not found: {path}")

# Only verify the path physically exists when using local storage
if origin == StorageOrigin.LOCAL.value and not os.path.exists(directory):
raise CommandError('specified directory does not exist')

# When using S3 for storage, a bucket is required
if origin == StorageOrigin.S3.value and (bucket or "") == "":
raise CommandError('bucket must be specified')

# For S3, if the user wants to look in root, the options are a bit unclear so we should allow them to an empty
# string (with ""), or a single slash. However, as far as AWS is concerned, directory in this case should be
# set to an empty string (slash will not work properly)
if origin == StorageOrigin.S3.value and directory == "/":
directory = ""

# For S3, if the user entered a directory, it must end in a trailing slash. Rather than require it, we can just
# add one if it's not there
if origin == StorageOrigin.S3.value and directory != "" and not directory.endswith("/"):
directory += "/"

# Load any filters
included_rois = []

if rois_list:
included_rois = [prefix + item.strip() for item in rois_list.split(',') if item.strip()]

if rois_file:
path = Path(rois_file)
try:
with path.open('r') as f:
included_rois = [prefix + line.strip() for line in f if line.strip()]
except Exception as e:
raise CommandError(f"could not read file {rois_file}: {e}")

user = None
if username:
try:
user = User.objects.get(username=username)
except:
Copy link

Copilot AI Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare except clause catches all exceptions. Use specific exception types like User.DoesNotExist to handle the expected case properly.

Suggested change
except:
except User.DoesNotExist:

Copilot uses AI. Check for mistakes.
raise CommandError(f'unable to retrieve user {username}')

collection = None
if collection_name is not None:
collection, created = ImageCollection.objects.get_or_create(
name=collection_name)
# scan directory and one level of subdirectories
def scan(dir):
result = []
for fn in os.listdir(dir):
name, ext = os.path.splitext(fn)
if ext not in ['.png', '.jpg']:
continue
result.append(fn)
return result
unlabeled = scan(directory)
labeled = {}
for n in os.listdir(directory):
if os.path.isdir(os.path.join(directory, n)):
label = n
label_dir_path = os.path.join(directory, n)
labeled[label] = scan(label_dir_path)
collection, _ = ImageCollection.objects.get_or_create(name=collection_name)

if origin == StorageOrigin.S3.value:
unlabeled, labeled = self.scan_s3(s3_client, bucket, directory, included_rois=included_rois)
else:
unlabeled, labeled = self.scan_local(directory, included_rois=included_rois)

if len(labeled) > 0 and not user:
raise CommandError('labeled ROIs found but no username specified')

print(f'found {len(unlabeled)} unlabeled images and {len(labeled)} label directories')

# now create ROI records in the database
print(f'importing {len(unlabeled)} unlabeled ROIs...')
for roi_filename in unlabeled:
path = os.path.join(directory, roi_filename)
roi = ROI.objects.create_or_update_roi(path, collection=collection)
if len(unlabeled) > 0:
print(f'importing {len(unlabeled)} unlabeled ROIs...')
for roi_filename in unlabeled:
path = os.path.join(directory, roi_filename)
_ = ROI.objects.create_or_update_roi(path, collection=collection, origin=origin, bucket=bucket, s3_client=s3_client)

for label_name, rois in labeled.items():
print(f'importing {len(rois)} ROIs labeled "{label_name}"...')
label, created = Label.objects.get_or_create(name=label_name)
label, _ = Label.objects.get_or_create(name=label_name)
for roi_filename in rois:
roi_path = os.path.join(directory, label_name, roi_filename)
roi = ROI.objects.create_or_update_roi(roi_path, collection=collection)
roi = ROI.objects.create_or_update_roi(roi_path, collection=collection, origin=origin, bucket=bucket, s3_client=s3_client)
Annotation.objects.create_or_verify(roi, label, user)



23 changes: 23 additions & 0 deletions core/migrations/0009_roi_bucket_roi_origin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 5.2.3 on 2025-08-12 03:30

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('core', '0008_roi_winning_annotation'),
]

operations = [
migrations.AddField(
model_name='roi',
name='bucket',
field=models.CharField(blank=True, max_length=100, null=True),
),
migrations.AddField(
model_name='roi',
name='origin',
field=models.CharField(choices=[('LOCAL', 'local'), ('S3', 's3')], default='local', max_length=50),
),
]
44 changes: 37 additions & 7 deletions core/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import io
from datetime import datetime
import json

Expand All @@ -11,6 +12,8 @@
from django.utils import timezone

from PIL import Image
import boto3
from constants import StorageOrigin


class ROIQuerySet(models.QuerySet):
Expand Down Expand Up @@ -61,24 +64,31 @@ class ROIManager(models.Manager):
def get_queryset(self):
return ROIQuerySet(self.model, using=self._db)

def create_or_update_roi(self, path, collection=None):
def create_or_update_roi(self, path, collection=None, origin='', bucket=None, s3_client=None):
if not path.endswith('.png') and not path.endswith('.jpg'):
raise NameError(f'{path} is not the path to a ROI image')
roi_id = os.path.basename(path)[:-4] # we know it ends with a 3-character image extension

with transaction.atomic():
try:
roi = self.get(roi_id=roi_id)
if roi.path != path:
if roi.path != path or roi.origin != origin or roi.bucket != bucket:
roi.path = path
roi.bucket = bucket
roi.origin = origin
roi.save()
if collection is not None:
if not roi.collections.filter(id=collection.id).exists():
roi.collections.add(collection)
except ROI.DoesNotExist:
with Image.open(path) as image:
width, height = image.size
roi = self.create(roi_id=roi_id, width=width, height=height, path=path)
width, height = self.calculate_dimensions(path, origin, bucket, s3_client)
roi = self.create(
roi_id=roi_id,
width=width,
height=height,
path=path,
origin=origin,
bucket=bucket)
if collection is not None:
roi.collections.add(collection)
return roi
Expand All @@ -89,15 +99,35 @@ def with_label(self, label):
def unlabeled(self):
return self.get_queryset().unlabeled()

def calculate_dimensions(self, path, origin, bucket=None, s3_client=None):
try:
if origin == StorageOrigin.S3.value:
response = s3_client.get_object(Bucket=bucket, Key=path)

data = response["Body"].read()

with Image.open(io.BytesIO(data)) as image:
return image.size
else:
with Image.open(path) as image:
return image.size
except Exception as e:
print(f"Failed to download or read image from S3: {e}")
return 0, 0
Comment on lines +114 to +116
Copy link

Copilot AI Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() for error logging in production code. Consider using the logging module for consistent error handling and proper log levels.

Copilot uses AI. Check for mistakes.


class ROI(models.Model):
roi_id = models.CharField(max_length=255, unique=True)
width = models.IntegerField()
height = models.IntegerField()
path = models.CharField(max_length=512)
winning_annotation = models.ForeignKey('Annotation', on_delete=models.CASCADE, null=True,\
winning_annotation = models.ForeignKey('Annotation', on_delete=models.CASCADE, null=True, \
related_name='associated_roi')

bucket = models.CharField(max_length=100, null=True, blank=True)
origin = models.CharField(max_length=50, null=False, blank=False, default=StorageOrigin.LOCAL.value, choices=[
(StorageOrigin.LOCAL.value, StorageOrigin.LOCAL.value),
(StorageOrigin.S3.value, StorageOrigin.S3.value),
])
objects = ROIManager()

@property
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ version: '3.9'
services:
web:
image: harbor-registry.whoi.edu/photic/photic_web:1.0
env_file:
- .env
command: python manage.py runserver 0.0.0.0:8000
volumes:
- ${ROI_PATH}:/rois
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ services:
context: .

command: python manage.py runserver 0.0.0.0:8000
env_file:
- .env
volumes:
- .:/app
- ${ROI_PATH}:/rois
Expand Down
5 changes: 5 additions & 0 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ POSTGRES_DATA_PATH=/srv/postgresql/data
# Location of SSL certificate files
SSL_CERT=/etc/ssl/example.crt
SSL_KEY=/etc/ssl/example.key

# Credentials for S3/Vast (Vast requires an endpoint url, AWS does not)
S3_ACCESS_KEY=
S3_SECRET_KEY=
S3_ENDPOINT_URL=
Loading