Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 108 additions & 16 deletions exporters/spline_segmentation_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,76 @@
import json
from scipy.ndimage.morphology import binary_fill_holes
import PIL

import io
import codecs
from PIL import Image
import base64

class SplineSegmentationExporterForm(forms.Form):
path = forms.CharField(label='Storage path', max_length=1000)
delete_existing_data = forms.BooleanField(label='Delete any existing data at storage path', initial=False, required=False)
json_annotations = forms.BooleanField(label='Export annotations as geoJSON files for instance segmentation', initial=False,
required=False)

def __init__(self, task, data=None):
super().__init__(data)
self.fields['subjects'] = forms.ModelMultipleChoiceField(
queryset=Subject.objects.filter(dataset__task=task))


def img_b64_to_arr(img_b64):
"""convert image data to image array"""

f = io.BytesIO()
f.write(base64.b64decode(img_b64))
img_arr = np.array(PIL.Image.open(f))
return img_arr

def img_arr_to_b64(img_arr):
"""convert image array to image data (base 64) according to labelme format"""

img_pil = Image.fromarray(img_arr)
f = io.BytesIO()
img_pil.save(f, format='PNG')
data = f.getvalue()
encData = codecs.encode(data, 'base64').decode()
encData = encData.replace('\n', '')
return encData


def create_json(coord, image_size, image_data, image_path):
"""This function creates a JSON dictionary according to labelme format"""

data = []
coordinates = coord

for i in range(len(coordinates)):
if i % 2 == 0:
data.append(
{
"label": str(coordinates[i + 1]),
"points": coordinates[i],
"group_id": 'null',
"shape_type": "polygon",
"flags": {},
}
)

json_dict = {
"version": "4.5.6",
"flags": {},
"shapes": data,
"imagePath": image_path,
"imageData": image_data,
"imageHeight": image_size[0],
"imageWidth": image_size[1]

}

return json_dict



class SplineSegmentationExporter(Exporter):
"""
asdads
Expand All @@ -37,6 +95,7 @@ def get_form(self, data=None):

def export(self, form):
delete_existing_data = form.cleaned_data['delete_existing_data']
json_annotations = form.cleaned_data['json_annotations']
# Create dir, delete old if it exists
path = form.cleaned_data['path']
if delete_existing_data:
Expand All @@ -51,11 +110,11 @@ def export(self, form):
# Create folder if it doesn't exist
create_folder(path)

self.add_subjects_to_path(path, form.cleaned_data['subjects'])
self.add_subjects_to_path(path, form.cleaned_data['subjects'], json_annotations)

return True, path

def add_subjects_to_path(self, path, data):
def add_subjects_to_path(self, path, data, json_annotations):

# For each subject
for subject in data:
Expand Down Expand Up @@ -89,21 +148,24 @@ def add_subjects_to_path(self, path, data):
image_pil = PIL.Image.open(new_filename)
image_size = image_pil.size
spacing = [1, 1]
self.save_segmentation(frame, image_size, join(subject_subfolder, target_gt_name), spacing)
self.save_segmentation(frame, image_size, join(subject_subfolder, target_gt_name), spacing, json_annotations)

return True, path

def get_object_segmentation(self, image_size, frame):
segmentation = np.zeros(image_size, dtype=np.uint8)
tension = 0.5

coordinates = []
labels = Label.objects.filter(task=frame.image_annotation.task).order_by('id')
counter = 1
xy_new_temp =0

for label in labels:
objects = ControlPoint.objects.filter(label=label, image=frame).only('object').distinct()
for object in objects:
previous_x = None
previous_y = None
xy = []
control_points = ControlPoint.objects.filter(label=label, image=frame, object=object.object).order_by('index')
max_index = len(control_points)
for i in range(max_index):
Expand Down Expand Up @@ -152,26 +214,56 @@ def get_object_segmentation(self, image_size, frame):
previous_x = x
previous_y = y

xy.append(previous_x)
xy.append(previous_y)

segmentation[y, x] = counter

xy_new = [xy[j:j + 2] for j in range(0, len(xy), 2)]

if i == max_index-1 and xy_new_temp != xy_new:
coordinates.append(xy_new)
coordinates.append(a.label)
xy_new_temp = xy_new

# Fill the hole
segmentation[binary_fill_holes(segmentation == counter)] = counter

counter += 1

return segmentation
return segmentation, coordinates


def save_segmentation(self, frame, image_size, filename, spacing):
def save_segmentation(self, frame, image_size, filename, spacing, json_annotations):
image_size = [image_size[1], image_size[0]]

# Create compounded segmentation object
segmentation = self.get_object_segmentation(image_size, frame)

segmentation_mhd = MetaImage(data=segmentation)
segmentation_mhd.set_attribute('ImageQuality', frame.image_annotation.image_quality)
segmentation_mhd.set_spacing(spacing)
metadata = ImageMetadata.objects.filter(image=frame.image_annotation.image)
for item in metadata:
segmentation_mhd.set_attribute(item.name, item.value)
segmentation_mhd.write(filename)
segmentation, coords = self.get_object_segmentation(image_size, frame)

if json_annotations:
image_filename = frame.image_annotation.image.format.replace('#', str(frame.frame_nr))
image_path = os.path.basename(os.path.normpath(image_filename))
if image_filename.endswith('.mhd'):
image_mhd = MetaImage(filename=image_filename)
image_array = image_mhd.get_pixel_data()

else:
image_pil = PIL.Image.open(image_filename)
image_array = np.asarray(image_pil)

image_data = img_arr_to_b64(image_array)
json_dict = create_json(coords, image_size, image_data, image_path)
with open(filename[:-7] + '.json', "w") as f:
print("The json file is created")
jason_str = json.dumps(json_dict)
f.write(jason_str)

else:
segmentation_mhd = MetaImage(data=segmentation)
segmentation_mhd.set_attribute('ImageQuality', frame.image_annotation.image_quality)
segmentation_mhd.set_spacing(spacing)
metadata = ImageMetadata.objects.filter(image=frame.image_annotation.image)
for item in metadata:
segmentation_mhd.set_attribute(item.name, item.value)
segmentation_mhd.write(filename)